# Ave, CELEBA! Experiments with L2 cost

In [1]:
import torch
import torchvision

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import tqdm
import itertools
import wandb
import os
import sys

sys.path.append("../../")
from src.utils import Config, weights_init_D, freeze, unfreeze
from src.data import DatasetSampler
from src.models import ResNet_D, UNet
from src.train import train
from src.cost import strong_cost

%load_ext autoreload
%autoreload 2

## 1. Config

In [None]:
BATCH_SIZE = 128
NUM_EPOCHS = 100
INNER_ITERATIONS = 10
LR = 1e-4
GPU_DEVICE = 0

In [None]:
CONFIG = Config()

CONFIG.FLAG_LATENT = False
CONFIG.FLAG_LATENT_CRITIC = False
CONFIG.DIM = 3*64*64
CONFIG.CLASSES = [0,1,2]
CONFIG.K = 3# amount of distributions
CONFIG.LAMBDAS = [0.25,0.5,0.25]
CONFIG.IMG_SIZE = 64
CONFIG.NC=3
CONFIG.DATASET_PATH ='../../data/ave_celeba_green_v2/' 
CONFIG.DATASET = 'ave_celeba'

CONFIG.LR = LR
CONFIG.CLIP_GRADS_NORM = False
CONFIG.BETAS = (0.2, 0.99)
CONFIG.BATCH_SIZE = BATCH_SIZE
CONFIG.NUM_EPOCHS = NUM_EPOCHS
CONFIG.INNER_ITERATIONS = INNER_ITERATIONS

CONFIG.GPU_DEVICE = GPU_DEVICE
assert torch.cuda.is_available()
CONFIG.DEVICE = f'cuda:{CONFIG.GPU_DEVICE}'

## 2. Data samplers

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(CONFIG.IMG_SIZE),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: torch.clip(x,0,1))
])

data_samplers=[]
for k in tqdm(range(CONFIG.K)):
    dataset = torchvision.datasets.ImageFolder(os.path.join(CONFIG.DATASET_PATH,f"ave_celeba_{k}/"), transform=transform)
    data_samplers.append(DatasetSampler(dataset, flag_label=True, batch_size=256 ,num_workers=40))

## 3. Potential and maps

In [None]:
nets_for_pot = [ResNet_D(size=CONFIG.IMG_SIZE,
          nc=CONFIG.NC,
          nfilter=64, 
          nfilter_max=512, 
          res_ratio=0.1,n_output=1,
        bn_flag=False, pn_flag=False).to(CONFIG.DEVICE)
        for i in range(CONFIG.K)]

for f in nets_for_pot: 
    weights_init_D(f)

param_nets = [net.parameters() for net in nets_for_pot]
nets_for_pot_opt = torch.optim.Adam(itertools.chain(*param_nets),
                           CONFIG.LR, betas=CONFIG.BETAS)


In [None]:
maps = [UNet(n_channels=CONFIG.NC, n_classes=CONFIG.NC,
             base_factor=48 , bilinear=True).to(CONFIG.DEVICE) for i in range(CONFIG.K)]

param_maps = [mp.parameters() for mp in maps]
maps_opt = torch.optim.Adam(itertools.chain(*param_maps),
                               CONFIG.LR, betas=CONFIG.BETAS)

## 4. Train

In [None]:
generator=None

In [None]:
name_exp = f"L2_SPACE_INN_{CONFIG.INNER_ITERATIONS}_LR_{CONFIG.LR}_BS_{CONFIG.BATCH_SIZE}"
CONFIG.NAME_EXP = name_exp

In [None]:
wandb.init(project="BNOT" ,
           name=name_exp ,
           config=CONFIG)

In [None]:
train(nets_for_pot, 
          maps,
          nets_for_pot_opt,
          maps_opt,
          data_samplers,
          generator,
          CONFIG)