## Config

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from experiment import data_path, device

model_name = 'mnist-cef-joint'
checkpoint_path = data_path / 'cef_models' / model_name
gen_path = data_path / 'generated' / model_name

## Data

In [None]:
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets

# Pad images from 28x28 to 32x32 to make it a power of 2
transform = transforms.Compose([
    transforms.Pad(2),
    transforms.ToTensor(),
])

train_data = datasets.MNIST(
    root=data_path, train=True, download=True, transform=transform)
train_data, val_data = random_split(train_data, [50000, 10000])

test_data = datasets.MNIST(
    root=data_path, train=False, download=True, transform=transform)

## Model

In [None]:
from nflows import cef_models

flow = cef_models.MNISTCEFlow(128).to(device)

## Train

In [None]:
import torch.optim as opt
from experiment import train_injective_flow

optim = opt.Adam(flow.parameters(), lr=0.001)
scheduler = opt.lr_scheduler.CosineAnnealingLR(optim, 1000)

def weight_schedule():
    '''Yield epoch weights for likelihood and recon loss, respectively'''
    for _ in range(1000):
        yield 0.01, 100000
        scheduler.step()
        
train_loader = DataLoader(train_data, batch_size=512, shuffle=True, num_workers=30)
val_loader = DataLoader(val_data, batch_size=512, shuffle=True, num_workers=30)

train_injective_flow(flow, optim, scheduler, weight_schedule, train_loader, val_loader,
                     model_name, checkpoint_path=checkpoint_path, checkpoint_frequency=10)

## Generate Images

In [None]:
from experiment import save_samples

save_samples(flow, num_samples=10000, gen_path=gen_path, checkpoint_epoch=-1, batch_size=512)