# Ave, Celeba! Experiment for latent space with entropy regularization

In [None]:
import torch
import torchvision

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm import tqdm
import itertools
import wandb
import os

import sys
sys.path.append("../../")
from src.utils import Config, weights_init_D, freeze, unfreeze, normalize_out_to_0_1
from src.data import DatasetSampler
from src.models import ResNet_D, UNet, linear_model
from src.train_entropic import train_entropic
from src.cost import strong_cost

# for generator
import sys 
sys.path.append("..")
import dnnlib
import legacy

import warnings
warnings.filterwarnings('ignore')

%load_ext autoreload
%autoreload 2

## 1. Config

In [None]:
NUM_EPOCHS = 1 
INNER_ITERATIONS = 10
BATCH_SIZE =32
EPSILON = 1.
LR_LATENT_MLP=1e-4
LR_ENCODER=1e-4
LR_POTENTIAL = 1e-4
GPU_DEVICE = 0
ALAE = False
RESNET_D_ENCODER = True
CHI = False
SPHERE = False

In [None]:
CONFIG = Config()

CONFIG.K = 3
CONFIG.EPSILON = EPSILON

CONFIG.LAMBDAS = [0.25,0.5,0.25]
CONFIG.CLASSES = [0,1,2]
CONFIG.DATASET_PATH ='../../../data/ave_celeba_green_v2/' 
CONFIG.DATASET = 'ave_celeba'

CONFIG.SPHERE_PROJECTION = SPHERE
CONFIG.CHI_PROJECTION = CHI
CONFIG.FLAG_LATENT = True
CONFIG.GENERATOR_PATH =  "../../../stylegan2_ada_pytorch_before/training-runs\
/00011-aligned_celeba-stylegan2/network-snapshot-008800.pkl"
CONFIG.LATENT_ENCODER_SIZE =256
CONFIG.LATENT_SIZE = 512
CONFIG.ALAE = ALAE
CONFIG.RESNET_D_ENCODER = RESNET_D_ENCODER

CONFIG.BATCH_SIZE =BATCH_SIZE
CONFIG.IMG_SIZE = 64
CONFIG.NC = 3
CONFIG.NUM_EPOCHS = NUM_EPOCHS
CONFIG.INNER_ITERATIONS = INNER_ITERATIONS


CONFIG.HIDDEN_SIZE = [max(2*CONFIG.LATENT_SIZE,128),
                      max(2*CONFIG.LATENT_SIZE,128)] 

CONFIG.LR_LATENT_MLP = LR_LATENT_MLP

CONFIG.LR_ENCODER  = LR_ENCODER

CONFIG.LR_POTENTIAL = LR_POTENTIAL
CONFIG.BETAS = (0.2, 0.99)

CONFIG.GPU_DEVICE = GPU_DEVICE
assert torch.cuda.is_available()
CONFIG.DEVICE = f'cuda:{CONFIG.GPU_DEVICE}'

## 2. Generator

In [None]:
if CONFIG.FLAG_LATENT:
    with dnnlib.util.open_url(CONFIG.GENERATOR_PATH) as f:
        G =  legacy.load_network_pkl(f)['G_ema'].to(CONFIG.DEVICE)

## 3. Potential and Encoder

### 3.1 potential

In [None]:
nets_for_pot = [ResNet_D(size=CONFIG.IMG_SIZE,
                  nc=CONFIG.NC,
                  nfilter=64, 
                  nfilter_max=512, 
                  res_ratio=0.1,
                  n_output=1,bn_flag=False,pn_flag=False).to(CONFIG.DEVICE)
                for i in range(CONFIG.K)]


for f in nets_for_pot: 
    weights_init_D(f)
    
param_nets = [net.parameters() for net in nets_for_pot]
nets_for_pot_opt = torch.optim.Adam(itertools.chain(*param_nets),
                               CONFIG.LR_POTENTIAL, betas=CONFIG.BETAS)
 

### 3.2 Encoder

In [None]:

encoder = ResNet_D(size=CONFIG.IMG_SIZE,
              nc=CONFIG.NC,
              nfilter=64, 
              nfilter_max=512, 
              res_ratio=0.1,
              n_output=2*CONFIG.LATENT_SIZE,bn_flag=True,pn_flag=True).to(CONFIG.DEVICE)

weights_init_D(encoder)

encoder_opt = torch.optim.Adam(encoder.parameters(),
                              CONFIG.LR_ENCODER, betas=CONFIG.BETAS)

latent_mlp, latent_mlp_opt = None, None

## 4. Data Samplers

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(CONFIG.IMG_SIZE),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: torch.clip(x,0,1))
])

data_samplers=[]
for k in tqdm(range(CONFIG.K)):
    dataset = torchvision.datasets.ImageFolder(os.path.join(CONFIG.DATASET_PATH,f"ave_celeba_{k}/"), transform=transform)
    data_samplers.append(DatasetSampler(dataset, flag_label=True, batch_size=256 ,num_workers=40))

## 5. Training

In [None]:
name_exp = f"KL_EPS_{EPSILON}_ALAE_{ALAE}_ENC_{RESNET_D_ENCODER}_EPS_{EPSILON}"
CONFIG.NAME_EXP = name_exp
wandb.init(project="BNOT" ,
           name=name_exp, 
           config=CONFIG)

In [None]:
train_entropic(nets_for_pot, 
      nets_for_pot_opt,
           encoder,encoder_opt,
           latent_mlp=latent_mlp, latent_mlp_opt=latent_mlp_opt,
          data_samplers=data_samplers,
          generator=G,
          config=CONFIG)