# EFM Translation 

In [1]:
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, TensorDataset, Subset

import math
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import wandb
import sys
sys.path.append("/trinity/home/a.kolesov/EFM/")
from src.models import DDPM, ExponentialMovingAverage
from src.efm_field import EFM
from src.utils import Config, optimization_manager, random_color

## 1. Base Config

In [29]:
config = Config()

 
config.device = 'cuda'
config.experiment = 'translation'

config.data = Config()
config.data.name = 'Celeba'
config.data.name_sets = ['Female','Male']
config.data.num_channels = 3
config.data.img_resize = 64
config.data.image_size=64
config.data.centered = True
config.DIM = config.data.num_channels*config.data.image_size*config.data.image_size + 1

config.p = Config()
config.p.x_loc = 0.

config.q = Config()
config.q.x_loc = config.L

config.training = Config()
config.training.small_batch_size =64 # important parameter
config.training.batch_size =64      # important parameter
config.training.field_type = "Shifted"
config.training.plan_type = "Independent"
config.training.field_form = "exponential"
config.training.n_iters = 1_000_000
config.training.sde = 'poisson'
config.training.eval_freq = 1_000
config.training.snapshot_freq = 5_000
config.training.sigma_end = 0.1
config.training.M = 191 # important parameter
config.training.tau = 0.03  
config.training.epsilon = 1e-3 # important parameter
config.training.interpolation = 'Uniform_mixing'
config.training.noised_interpolation = False
config.training.restrict_M = False
config.training.gamma = 5. # important parameter

config.model  = Config()

config.model = config.model
config.model.name = 'ncsnpp'
config.model.scale_by_sigma = False
config.model.ema_rate = 0.9999
config.model.normalization = 'GroupNorm'
config.model.nonlinearity = 'swish'
config.model.nf = 128
config.model.ch_mult = (1, 2, 2, 2)
config.model.num_res_blocks = 4
config.model.attn_resolutions = (16,)
config.model.resamp_with_conv = True
config.model.conditional = True
config.model.fir = False
config.model.fir_kernel = [1, 3, 3, 1]
config.model.skip_rescale = True
config.model.resblock_type = 'biggan'
config.model.progressive = 'none'
config.model.progressive_input = 'none'
config.model.progressive_combine = 'sum'
config.model.attention_type = 'ddpm'
config.model.init_scale = 0.
config.model.fourier_scale = 16
config.model.embedding_type = 'positional'
config.model.conv_size = 3
config.model.sigma_end = 0.01
config.model.dropout = 0.1

config.optim  = Config()
config.optim.weight_decay = 0
config.optim.optimizer = 'Adam'
config.optim.lr = 2e-4
config.optim.beta1 = 0.9
config.optim.eps = 1e-8 
config.optim.warmup = 5000  
config.optim.grad_clip = 1.

config.ode = Config()
config.ode.gamma = 1e-7
config.ode.step = 0.25

config.sampling = Config()
config.sampling.method = 'ode'
config.sampling.ode_solver = 'rk45'
config.sampling.N = 100
config.sampling.z_max = config.L# - config.training.epsilon
config.sampling.z_min = config.training.epsilon
config.sampling.upper_norm = 3000
config.sampling.z_exp=1
config.sampling.vs = False
config.sampling.visual_iterations=10

## 2. Data

In [30]:
TRANSFORM = torchvision.transforms.Compose([
                torchvision.transforms.CenterCrop(140),
                torchvision.transforms.Resize((config.data.image_size, config.data.image_size)),
                torchvision.transforms.ToTensor()
            ])


train_set, eval_set = {}, {}
train_loader, eval_loader = {},{}
train_iter,eval_iter = {},{}
for name_set in config.data.name_sets:
    
    path = f'/trinity/home/a.kolesov/data/Celeba_gender/Dataset/Train/{name_set}/'
    train_set[name_set]  =  ImageFolder(path, transform=TRANSFORM)
    path = f'/trinity/home/a.kolesov/data/Celeba_gender/Dataset/Test/{name_set}'
    eval_set[name_set]  =  ImageFolder(path, transform=TRANSFORM)
    
    train_loader[name_set] = torch.utils.data.DataLoader(train_set[name_set],
                                                         batch_size=config.training.batch_size,
                                                         shuffle=True)
    
    eval_loader[name_set] =  torch.utils.data.DataLoader(eval_set[name_set], 
                                                         batch_size=config.training.batch_size, 
                                                         shuffle=True)
    
    train_iter[name_set] = iter(train_loader[name_set])
    eval_iter[name_set] = iter(eval_loader[name_set])
                                           


## 3. Models

In [6]:
net = DDPM(config).to(config.device)
params = net.parameters()
optimizer = torch.optim.Adam(params,
                       lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
                       weight_decay=config.optim.weight_decay)

ema = ExponentialMovingAverage(net.parameters(), decay=config.model.ema_rate)
state = dict(optimizer=optimizer, model=net, ema=ema, step=0)
optimize_fn = optimization_manager(config)

## 4. Training

In [6]:
config.name_exp = f"Int={config.training.interpolation}_Noise={config.training.noised_interpolation}\
                  L={config.L}_sc={config.SCALE}_BS_{config.training.batch_size}_\
                  SBS={config.training.small_batch_size}"
wandb.init(project="EFMTranslationCelebA", name=config.name_exp)

[34m[1mwandb[0m: Currently logged in as: [33memfalafeli[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.21.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [None]:
efm = EFM(config)
net, state = efm.train(train_loader, eval_loader,
                       net, optimizer, optimize_fn,
                       state)