In [5]:
import torch
import numpy as np
import argparse
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from train import create_parser
from network import ToyNet
from diffusion import Follmer
from data import get_target_fn
from misc import dict2obj
from data import get_target_fn, compute_target_mean_and_sd
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
defaults = dict(device=0, seed=42, data_name='m1', nsample=5000, output='logs/m1/ckpt',
                data_dim=1, cond_dim=5, sigma_data=1, M=1,bsz=1000, train_steps=10000, lr=1e-3, 
                dump_freq=1000, print_freq=1000, sde_solver='alpha-maruyama', 
                eps0=1e-3, eps1=1e-3, num_steps=1000, heun_steps=13, ntest=5000, nMC=200)

args = dict2obj(defaults)
os.makedirs(args.output, exist_ok=True)
device = torch.device(f"cuda:{args.device}") if torch.cuda.is_available() else torch.device("cpu")
np.random.seed(args.seed)

target_fn = get_target_fn(args.data_name)
cond = np.random.uniform(-3,3,[args.nsample, args.cond_dim])
data = target_fn(cond)
data = torch.from_numpy(data).float()
cond = torch.from_numpy(cond).float()
dataset = TensorDataset(data, cond)
loader = DataLoader(dataset, batch_size=args.bsz, shuffle=True, drop_last=True)
def create_infinite_dataloader(loader):
    while True:
        yield from loader
loader = create_infinite_dataloader(loader)
model = ToyNet(args.data_dim, args.cond_dim, hidden_dims=[32, 16]).to(device)
optim = torch.optim.Adam(model.parameters(), lr=args.lr)
sde = Follmer(args)

for step in range(1, args.train_steps+1):
    batch, cond = next(loader)
    batch  = batch.to(device)
    cond = cond.to(device)
    optim.zero_grad()
    loss = sde.compute_schr_loss(model, batch, cond)
    loss.backward()
    optim.step()
    if step % args.print_freq == 0:
        print(f"Step[{step}/{args.train_steps}], Loss {loss.item():.4f}")
    if step % args.dump_freq == 0 or step == args.train_steps:
        torch.save(dict(model=model.state_dict(), optim=optim.state_dict, step=step), 
                f"{args.output}/{step}.pth")

Step[1000/10000], Loss 0.9279
Step[2000/10000], Loss 0.2692
Step[3000/10000], Loss 0.1039
Step[4000/10000], Loss 0.0750
Step[5000/10000], Loss 0.0582
Step[6000/10000], Loss 0.0342
Step[7000/10000], Loss 0.0291
Step[8000/10000], Loss 0.0270
Step[9000/10000], Loss 0.0236
Step[10000/10000], Loss 0.0197


In [6]:
grid = torch.linspace(args.eps0, 1-args.eps1, args.num_steps, device=device)
x1 = sde.sampling_prior(shape=torch.Tensor(args.nMC*args.ntest, args.data_dim), device=device)
cond = np.random.randn(args.ntest, args.cond_dim)
cond = np.random.uniform(-3,3,[args.ntest, args.cond_dim])
mean, sd = compute_target_mean_and_sd(args.data_name, cond)
mean_ = np.empty_like(mean)
sd_ = np.empty_like(sd)
cond = np.repeat(cond, args.nMC, axis=0)
cond = torch.from_numpy(cond).float().to(device)
with torch.no_grad():
    x0 = sde.solve_sde(model, x1, grid, args.sde_solver, cond).cpu().numpy()
pred = x0
for j in range(args.ntest):
    mean_[j] = np.mean(pred[j*args.nMC:(j+1)*args.nMC, :])
    sd_[j] = np.std(pred[j*args.nMC:(j+1)*args.nMC, :])
mse_mean = np.mean(np.square(mean - mean_))
mse_sd = np.mean(np.square(sd - sd_))
print(f'MSE(mean):{mse_mean:.3f}, MSE(SD)={mse_sd.mean():.3f}')

MSE(mean):0.022, MSE(SD)=0.013
