# IFM Generation: Celeba 64x64 (128x128)

In [None]:
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, TensorDataset, Subset

import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

#import wandb
import sys
sys.path.append("/trinity/home/a.kolesov/EFM/")
from src.models import DDPM, ExponentialMovingAverage
from src.ifm_field import IFM
from src.utils import Config, optimization_manager, random_color

## 1. Base Config

In [None]:
config = Config()

 
config.L = 20. # important parameter
config.SCALE = 1.
config.K =  math.pi/config.L
config.D = math.pi/(2*config.K)
config.device = 'cuda'
config.experiment = 'generation'
 

config.data = Config()
config.data.name = 'Celeba'
config.data.num_channels = 3
config.data.img_resize = 128#64
config.data.image_size= 128#64
config.data.centered = True
config.DIM = config.data.num_channels*config.data.image_size*config.data.image_size + 1

config.p = Config()
config.p.x_loc = 0.

config.q = Config()
config.q.x_loc = config.L

config.training = Config()
config.training.small_batch_size =64 # important parameter
config.training.batch_size =64      # important parameter
config.training.field_type = "Shifted"
config.training.plan_type = "Independent"
config.training.field_form = "exponential"
config.training.n_iters = 1_000_000
config.training.sde = 'poisson'
config.training.eval_freq = 1_000
config.training.snapshot_freq = 5_000
config.training.sigma_end = 0.1
config.training.M = 191 # important parameter
config.training.tau = 0.03  
config.training.epsilon = 1e-3 # important parameter
config.training.interpolation = 'Uniform'
config.training.noised_interpolation = False

config.training.restrict_M = False
config.training.gamma = 5. # important parameter

config.model  = Config()

config.model = config.model
config.model.name = 'ncsnpp'
config.model.scale_by_sigma = False
config.model.ema_rate = 0.9999
config.model.normalization = 'GroupNorm'
config.model.nonlinearity = 'swish'
config.model.nf = 128
config.model.ch_mult = (1, 1, 2, 2, 4, 4)
config.model.num_res_blocks = 2
config.model.attn_resolutions = (16,)
config.model.resamp_with_conv = True
config.model.conditional = True
config.model.fir = False
config.model.fir_kernel = [1, 3, 3, 1]
config.model.skip_rescale = True
config.model.resblock_type = 'biggan'
config.model.progressive = 'none'
config.model.progressive_input = 'none'
config.model.progressive_combine = 'sum'
config.model.attention_type = 'ddpm'
config.model.init_scale = 0.
config.model.fourier_scale = 16
config.model.embedding_type = 'positional'
config.model.conv_size = 3
config.model.sigma_end = 0.01
config.model.dropout = 0.1
config.model.class_cond = False

config.optim  = Config()
config.optim.weight_decay = 0
config.optim.optimizer = 'Adam'
config.optim.lr = 2e-4
config.optim.beta1 = 0.9
config.optim.eps = 1e-8 
config.optim.warmup = 5000  
config.optim.grad_clip = 1.

config.ode = Config()
config.ode.gamma = 1e-7
config.ode.step = 0.25

config.sampling = Config()
config.sampling.method = 'ode'
config.sampling.ode_solver = 'rk45'
config.sampling.N = 100
config.sampling.z_max = config.L# - config.training.epsilon
config.sampling.z_min = config.training.epsilon
config.sampling.upper_norm = 3000
config.sampling.z_exp=1
config.sampling.vs = False
config.sampling.visual_iterations=10

## 2. Data

In [None]:
TRANSFORM = torchvision.transforms.Compose([
                torchvision.transforms.CenterCrop(140),
                torchvision.transforms.Resize((config.data.image_size, config.data.image_size)),
                torchvision.transforms.ToTensor()
            ])

path = '/trinity/home/a.kolesov/data/celeba-dataset/img_align_celeba/'
dataset =  ImageFolder(path, transform=TRANSFORM)

train_ratio, test_ratio = 0.45, 0.1
train_size, test_size = int(len(dataset) * train_ratio), int(len(dataset) * test_ratio)
idx = np.arange(len(dataset))


train_idx = idx[:-test_size]  
test_idx = idx[-10_000:]

train_set = Subset(dataset, train_idx)
eval_set = Subset(dataset, test_idx)
        
train_loader = torch.utils.data.DataLoader(train_set, batch_size=config.training.batch_size,
                                           shuffle=True)
eval_loader =  torch.utils.data.DataLoader(eval_set, batch_size=config.training.batch_size, 
                                           shuffle=True)
                                           

train_iter = iter(train_loader)
eval_iter = iter(eval_loader)

## 3. Models

In [None]:
net = DDPM(config).to(config.device)
params = net.parameters()
optimizer = torch.optim.Adam(params,
                       lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
                       weight_decay=config.optim.weight_decay)

ema = ExponentialMovingAverage(net.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=net, ema=ema, step=0)
optimize_fn = optimization_manager(config)

## 4. Training

In [None]:
config.name_exp = f"CelebaGeneration{config.data.image_size}x{config.data.image_size}_L={config.L}_sc={config.SCALE}_BS_{config.training.batch_size}_\
SBS={config.training.small_batch_size}_plan_{config.training.plan_type}_\
Int={config.training.interpolation}"
#wandb.init(project="IFMGenerationCelebA", name=config.name_exp)

In [None]:
ifm = IFM(config)
net, state = ifm.train(train_loader, eval_loader,
                       net, optimizer, optimize_fn,
                       state)