## Config

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from experiment import data_path, device

model_name = 'celeba-mf-sequential'
checkpoint_path = data_path / 'cef_models' / model_name
gen_path = data_path / 'generated' / model_name

## Load data

In [None]:
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import data

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

train_data = data.CelebA(root=data_path, split='train', transform=transform)
val_data = data.CelebA(root=data_path, split='valid', transform=transform)
test_data = data.CelebA(root=data_path, split='test', transform=transform)

## Define model

In [None]:
from nflows import cef_models

flow =  cef_models.CelebAMFlow().to(device)

## Train

In [None]:
import torch.optim as opt
from experiment import train_injective_flow

optim = opt.Adam(flow.parameters(), lr=0.0001)
scheduler = opt.lr_scheduler.CosineAnnealingLR(optim, 300)

def weight_schedule():
    for _ in range(30):
        yield 0, 10000
        
train_loader = DataLoader(train_data, batch_size=256, shuffle=True, num_workers=30)
val_loader = DataLoader(val_data, batch_size=128, shuffle=True, num_workers=30)

train_injective_flow(flow, optim, scheduler, weight_schedule, train_loader, val_loader,
                     model_name, checkpoint_path=checkpoint_path, checkpoint_frequency=100)

In [None]:
optim = opt.Adam(flow.distribution.parameters(), lr=0.0001)
scheduler = opt.lr_scheduler.CosineAnnealingLR(optim, 300)

def weight_schedule():
    for _ in range(300):
        yield 0.001, 0
        scheduler.step()

train_injective_flow(flow, optim, scheduler, weight_schedule, train_loader, val_loader,
                     model_name, checkpoint_path=checkpoint_path, checkpoint_frequency=25)

## Generate some samples

In [None]:
from experiment import save_samples

save_samples(flow, num_samples=len(test_data), gen_path=gen_path, checkpoint_epoch=-1, batch_size=512)