In [1]:
import numpy as np
import torch

In [2]:
def gen_data(bias=0, N=500):
    X = np.linspace(bias, bias + 5, N).reshape(-1, 1)
    mu = np.exp(-X + bias)
    eps = np.random.normal(0, 1, X.shape)
    sigma = 0.05 * (X - bias + 0.5)
    return X, mu + eps * sigma

In [3]:
X, y = gen_data(0, 500)
X = torch.from_numpy(X).to(torch.float32)
y = torch.from_numpy(y).to(torch.float32)

In [4]:
import sys
import os

cwd = os.getcwd()
h, _ = os.path.split(cwd)

sys.path.append(h)

In [5]:
from probaforms.models import ResidualUnconditional
from probaforms.models import ResidualConditional
from probaforms.models import ResidualFlowModel

## Conditional

In [6]:
flow_args_dict = {
    'var_dim': 1,
    'cond_dim': 1,
    'hid_dim': 16,
    'n_block_layers': 3,
    'n_layers': 3,
    'spnorm_coeff': 0.95,
    'n_backward_iters': 100,
}

flow = ResidualFlowModel(**flow_args_dict)

In [7]:
start_lr = 1e-1
final_lr = 1e-2
n_epochs = 100
sched_lambda = (final_lr / start_lr) ** (1 / n_epochs)

optim = torch.optim.Adam(flow.parameters(), lr=start_lr, weight_decay=1e-5)
sched = torch.optim.lr_scheduler.ExponentialLR(optim, sched_lambda)
wrapper = ResidualConditional(flow, optim, n_epochs=n_epochs, batch_size=100, scheduler=sched)

In [8]:
_ = wrapper.fit(y, X)
_ = wrapper.sample(X, batched=None).cpu()

Epoch: 100%|█████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.43it/s]
batch: 100%|█████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  9.73it/s]


# Unconditional

In [9]:
flow_args_dict = {
    'var_dim': 2,
    'cond_dim': None,
    'hid_dim': 16,
    'n_block_layers': 3,
    'n_layers': 3,
    'spnorm_coeff': 0.95,
    'n_backward_iters': 100,
}

flow = ResidualFlowModel(**flow_args_dict)

In [10]:
start_lr = 1e-1
final_lr = 1e-2
n_epochs = 100
sched_lambda = (final_lr / start_lr) ** (1 / n_epochs)

optim = torch.optim.Adam(flow.parameters(), lr=start_lr, weight_decay=1e-5)
sched = torch.optim.lr_scheduler.ExponentialLR(optim, sched_lambda)
wrapper = ResidualUnconditional(flow, optim, n_epochs=n_epochs, batch_size=100, scheduler=sched)

In [11]:
data = torch.cat([X, y], dim=1)
_ = wrapper.fit(data)
_ = wrapper.sample(500).cpu()

Epoch: 100%|█████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.72it/s]
batch: 100%|█████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.11it/s]
