## Config

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from experiment import data_path, device

m = 512 # Manifold dimension
model_name = f'cifar10-boat-manifold-{m}-cef-sequential'
checkpoint_path = data_path / 'cef_models' / model_name
gen_path = data_path / 'generated' / model_name

## Load data

In [None]:
import torchvision
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
])

image_folder = data_path / f'cifar10-manifold-{m}-boat'
train_data = torchvision.datasets.ImageFolder(root=str(image_folder), transform=transform)
held_out = len(train_data) // 10
train_data, val_data = random_split(train_data, [len(train_data) - held_out, held_out])

## Define model

In [None]:
from nflows import cef_models

flow = cef_models.Cifar10CEFlow(m).to(device)

## Train

In [None]:
import torch.optim as opt
from experiment import train_injective_flow

optim = opt.Adam(flow.parameters(), lr=0.001)
scheduler = opt.lr_scheduler.CosineAnnealingLR(optim, 1000)

def weight_schedule():
    '''Yield epoch weights for likelihood and recon loss, respectively'''
    for _ in range(50):
        yield 0, 100000
        
train_loader = DataLoader(train_data, batch_size=512, shuffle=True, num_workers=30)
val_loader = DataLoader(val_data, batch_size=512, shuffle=True, num_workers=30)

train_injective_flow(flow, optim, scheduler, weight_schedule, train_loader, val_loader,
                     model_name, checkpoint_path=checkpoint_path, checkpoint_frequency=100)

In [None]:
optim = opt.Adam(flow.distribution.parameters(), lr=0.001)
scheduler = opt.lr_scheduler.CosineAnnealingLR(optim, 1000)

def weight_schedule():
    for _ in range(1000):
        yield 0.01, 0
        scheduler.step()

train_injective_flow(flow, optim, scheduler, weight_schedule, train_loader, val_loader,
                     model_name, checkpoint_path=checkpoint_path, checkpoint_frequency=25)

## Generate some samples

In [None]:
from experiment import save_samples

save_samples(flow, num_samples=10000, gen_path=gen_path, checkpoint_epoch=-1, batch_size=512)