# Ave, CELEBA! Experiment with MMD regularization

In [None]:
import torch
import torch.nn as nn
import torchvision

from tqdm import tqdm
import wandb
import itertools
import os
import sys
sys.path.append("../..")
from src.models import UNet, ResNet_D
from src.utils import Config, weights_init_D, freeze
from src.data import DatasetSampler
from src.train_gauss_kernel import train_gauss_kernel

sys.path.append("..")
import dnnlib
import legacy

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import warnings
warnings.filterwarnings('ignore') 

In [None]:
GAMMA = 0.1
BATCH_SIZE = 28
RESNET_ENCODER_LATENT = 2
INNER_ITERATIONS = 10
NUM_EPOCHS = 1
GPU_DEVICE = 0
WEIGHT_NORM = 0.1
Z_STD = 0.1

In [None]:
CONFIG = Config()

CONFIG.K = 3
CONFIG.LAMBDAS = [0.25,0.5,0.25]

CONFIG.DATASET_PATH ='../../../data/ave_celeba_green_v2/' 
CONFIG.DATASET = 'ave_celeba'
CONFIG.BATCH_SIZE = BATCH_SIZE
CONFIG.IMG_SIZE =64
CONFIG.NC =3

CONFIG.LR_POTENTIAL = 2e-4
CONFIG.LR_ENCODER = 2e-4
CONFIG.BETAS = (0.2, 0.99)

CONFIG.FLAG_LATENT = True
CONFIG.ZC = 1
CONFIG.LATENT_SIZE = 512
CONFIG.GENERATOR_PATH = "../../../stylegan2_ada_pytorch_before/training-runs/00011-aligned_celeba-stylegan2/network-snapshot-008800.pkl"


CONFIG.NUM_EPOCHS = NUM_EPOCHS
CONFIG.INNER_ITERATIONS = INNER_ITERATIONS
CONFIG.GAMMA = GAMMA
 
CONFIG.WEIGHT_NORM = WEIGHT_NORM   
CONFIG.Z_STD = Z_STD
CONFIG.RESNET_ENCODER_LATENT = RESNET_ENCODER_LATENT

 

CONFIG.GPU_DEVICE = GPU_DEVICE
assert torch.cuda.is_available()
CONFIG.DEVICE = f'cuda:{CONFIG.GPU_DEVICE}'

## 2. Generator

In [None]:
if CONFIG.FLAG_LATENT:
    with dnnlib.util.open_url(CONFIG.GENERATOR_PATH) as f:
        G =  legacy.load_network_pkl(f)['G_ema'].to(CONFIG.DEVICE)

## 3. Potential and Conditional Encoder

### 3.1 potentials

In [None]:
nets_for_pot = [ResNet_D(size=CONFIG.IMG_SIZE,
                  nc=CONFIG.NC,
                  nfilter=64, 
                  nfilter_max=512, 
                  res_ratio=0.1,
                  n_output=1,bn_flag=False,pn_flag=False).to(CONFIG.DEVICE)
                  for i in range(CONFIG.K)]


# initialization
for f in nets_for_pot: 
    weights_init_D(f)

# optimization
param_nets = [net.parameters() for net in nets_for_pot]
nets_for_pot_opt = torch.optim.Adam(itertools.chain(*param_nets),
                                lr=CONFIG.LR_POTENTIAL,
                                betas=CONFIG.BETAS)

### 3.2 Conditional Encoder

In [None]:
encoder = [ResNet_D(size=CONFIG.IMG_SIZE,
                  nc=CONFIG.NC + CONFIG.ZC,
                  nfilter=64, 
                  nfilter_max=512, 
                  res_ratio=0.1,
                  n_output=2*CONFIG.LATENT_SIZE,
                   bn_flag=True, pn_flag=True).to(CONFIG.DEVICE)
           for k in range(CONFIG.K)]

# initialization
for k in range(CONFIG.K): 
    weights_init_D(encoder[k])

# optimization
encoder_params = [enc.parameters() for enc in encoder]
encoder_opt = torch.optim.Adam( itertools.chain(*encoder_params), 
                              lr=CONFIG.LR_ENCODER,
                              betas=CONFIG.BETAS)

## 4. Data samplers

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(CONFIG.IMG_SIZE),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: torch.clip(x,0,1))
])

data_samplers=[]
for k in tqdm(range(CONFIG.K)):
    dataset = torchvision.datasets.ImageFolder(os.path.join(CONFIG.DATASET_PATH,f"ave_celeba_{k}/"),
                                               transform=transform)
    data_samplers.append(DatasetSampler(dataset, flag_label=True, batch_size=256 ,num_workers=40))

## 5. Wandb 

In [None]:
name_exp = f"gauss_Dist_GAMMA_{GAMMA}_Z_STD_{Z_STD}_BS_{BATCH_SIZE}_NZ_{RESNET_ENCODER_LATENT}"
CONFIG.NAME_EXP = name_exp

In [None]:
wandb.init(project="BNOT" ,
           name=name_exp ,
           config=CONFIG)

## 6. Train

In [None]:
train_gauss_kernel(nets_for_pot,
                nets_for_pot_opt,
                encoder,
                encoder_opt,
                data_samplers,
                G,
                CONFIG)