In [None]:
import os, sys
sys.path.append("..")

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

import numpy as np
import torch
import torch.nn as nn
import torchvision
import gc

from src import distributions
import torch.nn.functional as F


from src.resnet2 import ResNet_G, ResNet_D
from src.unet import UNet

from src.losses import InjectiveVGGPerceptualLoss, L2Loss

from src.tools import unfreeze, freeze
from src.tools import weights_init_D, weights_init_G
from src.tools import h5py_to_dataset

from src.plotters import plot_bar_random_images, plot_bar_images

from src.tools import get_generated_stats, SumSequential
from src.fid_score import calculate_frechet_distance

from copy import deepcopy
import json

from tqdm import tqdm
from IPython.display import clear_output

import wandb
from src.tools import fig2data, fig2img # for wandb

## Main Config

In [None]:
DEVICE_IDS = [0,1,2,3]

# DATASET, DATASET_PATH = 'ave_celeba', '../../data/ave_celeba/'
# DATASET, DATASET_PATH = 'celeba', '../../data/celeba_aligned/' # 202k Celeba Images resized to 64x64
# DATASET, DATASET_PATH = 'mnist01', '../../data/'
# DATASET, DATASET_PATH = 'fashionmnist_all', '../../data/'
# DATASET, DATASET_PATH = 'fashionmnist017', '../../data/'
DATASET, DATASET_PATH = 'handbag_shoes_fruit360', '../../data/' # shoes, handbags, fruit360

G_ITERS, D_ITERS, T_ITERS = 50, 50, 10
G_LR, D_LR, T_LR = 3e-4, 3e-4, 3e-4

D_SCHEDULER_GAMMA = 0.5
D_SCHEDULER_STEP = 10000
G_SCHEDULER_GAMMA = 0.5
G_SCHEDULER_STEP = 10000

BATCH_SIZE = 64
PLOT_INTERVAL = 400
FD_INTERVAL = 4000
MAX_STEPS = 500001
SEED = 0x000000

REGRESSION = 'VGG' #'L2'
SHIFT = True if DATASET == 'ave_celeba' else False # Preprocess the data with shifting means?
EXP_NAME = f'{DATASET}_G{G_ITERS}_T{T_ITERS}_D{D_ITERS}_{REGRESSION}_{SHIFT}'
OUTPUT_PATH = '../checkpoints/{}/'.format(DATASET)

## Preparation

In [None]:
config = dict(
    DATASET=DATASET, 
    G_ITERS=G_ITERS, D_ITERS=D_ITERS, T_ITERS=T_ITERS,
    G_LR=G_LR, D_LR=D_LR, T_LR=T_LR,
    BATCH_SIZE=BATCH_SIZE,
    REGRESSION=REGRESSION,
    D_SCHEDULER_GAMMA=D_SCHEDULER_GAMMA,
    D_SCHEDULER_STEP=D_SCHEDULER_STEP,
    G_SCHEDULER_GAMMA=G_SCHEDULER_GAMMA,
    G_SCHEDULER_STEP=G_SCHEDULER_STEP,
    SHIFT=SHIFT,
)
    
assert torch.cuda.is_available()
torch.cuda.set_device(f'cuda:{DEVICE_IDS[0]}')
torch.manual_seed(SEED); np.random.seed(SEED)

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

if REGRESSION == 'VGG':
    vgg_loss = InjectiveVGGPerceptualLoss().cuda()
    if len(DEVICE_IDS) > 1:
        vgg_loss = nn.DataParallel(vgg_loss, device_ids=DEVICE_IDS)
    
G_criterion = L2Loss() if REGRESSION == 'L2' else vgg_loss

if DATASET == 'ave_celeba':
    IMG_SIZE, SIZE, NC, Z_DIM = 64, 200000, 3, 128
    ALPHAS = [0.25, 0.5, 0.25]
    CLASSES = [0, 1, 2]
elif DATASET == 'celeba':
    IMG_SIZE, SIZE, NC, Z_DIM = 64, 200000, 3, 128
    ALPHAS = [1.]
    CLASSES = [0]
elif DATASET == 'fashionmnist_all':
    IMG_SIZE, SIZE, NC, Z_DIM = 32, 6000, 1, 16
    ALPHAS = [0.1 for _ in range(10)]
    CLASSES = list(range(10))
elif DATASET == 'fashionmnist017':
    IMG_SIZE, SIZE, NC, Z_DIM = 32, 6000, 1, 16
    ALPHAS = [1/3. for _ in range(3)]
    CLASSES = [0,1,7]
elif DATASET == 'mnist01':
    IMG_SIZE, SIZE, NC, Z_DIM = 32, 6000, 1, 16
    ALPHAS = [0.5, .5]
    CLASSES = [0,1]
elif DATASET == 'mnist01':
    IMG_SIZE, SIZE, NC, Z_DIM = 32, 6000, 1, 16
    ALPHAS = [0.5, .5]
    CLASSES = [0,1]
elif DATASET == 'handbag_shoes_fruit360':
    IMG_SIZE, SIZE, NC, Z_DIM = 64, 6000, 3, 128
    ALPHAS = [1./3, 1./3, 1./3]
    CLASSES = None
else:
    raise Exception('Unknown dataset')
    
K = len(ALPHAS)
INCEPTION = True if 'celeba' in DATASET else False

## Prepare Samplers (Z, Y)

In [None]:
Z_sampler = distributions.StandardNormalSampler(dim=Z_DIM)
Y_samplers = []

transform = torchvision.transforms.Compose([
    torchvision.transforms.Pad(14, fill=(255,255,255)) if DATASET == 'handbag_shoes_fruit360' else torchvision.transforms.Identity(),
    torchvision.transforms.Resize(IMG_SIZE),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Lambda(lambda x: 2 * x - 1)
])

if DATASET != 'handbag_shoes_fruit360': 
    for k in range(K):
        if DATASET in ['ave_celeba', 'celeba']:
            dataset = torchvision.datasets.ImageFolder(DATASET_PATH, transform=transform)
        elif 'fashionmnist' in DATASET:
            dataset = torchvision.datasets.FashionMNIST(root=DATASET_PATH, download=True, transform=transform, train=True)
        elif 'mnist' in DATASET:
            dataset = torchvision.datasets.MNIST(root=DATASET_PATH, download=True, transform=transform)
        else:
            raise Exception('Unknown dataset')

        try:
            dataset.samples = [s for s in dataset.samples if s[1] == CLASSES[k]]
        except:
            idx = [t == CLASSES[k] for t in dataset.targets]
            if 'mnist' in DATASET:
                dataset.targets, dataset.data = np.array(dataset.targets)[idx], torch.tensor(dataset.data)[idx]
            else:
                dataset.targets, dataset.data = np.array(dataset.targets)[idx], np.array(dataset.data)[idx]

        Y_samplers.append(distributions.DatasetSampler(dataset))
        
elif DATASET == 'handbag_shoes_fruit360':
    dataset = h5py_to_dataset(os.path.join(DATASET_PATH, 'handbag_64.hdf5'))
    Y_samplers.append(distributions.DatasetSampler(dataset))
    dataset = h5py_to_dataset(os.path.join(DATASET_PATH, 'shoes_64.hdf5'))
    Y_samplers.append(distributions.DatasetSampler(dataset))
    dataset = torchvision.datasets.ImageFolder(os.path.join(DATASET_PATH, 'fruit360'), transform=transform)
    Y_samplers.append(distributions.DatasetSampler(dataset))
    
if SHIFT:
    with torch.no_grad():
        Y_bar_mean = 0.
        Y_means = []
        for k in range(K):
            Y_means.append(Y_samplers[k].dataset.mean(dim=0))
            Y_bar_mean += ALPHAS[k] * Y_means[-1]
        for k in range(K):
            Y_samplers[k].dataset += Y_bar_mean - Y_means[k]
    
torch.cuda.empty_cache(); gc.collect()
clear_output()

# Load dataset stats

In [None]:
stats_name = None
if DATASET in ['ave_celeba', 'celeba']:
    stats_name = 'celeba'
elif DATASET == 'fashionmnist_all':
    stats_name = 'fashionmnist_bar'
elif DATASET == 'fashionmnist017':
    stats_name = 'fashionmnist017'
elif DATASET == 'mnist01':
    stats_name = 'mnist01'
elif DATASET == 'handbag_shoes_fruit360':
    pass
else:
    raise Exception('Unknown dataset')

if stats_name is not None:
    with open('../stats/{}.json'.format(stats_name), 'r') as f:
        stats = json.load(f)
        mu_data, sigma_data = stats['mu'], stats['sigma']
        del stats
        gc.collect()
else:
    mu_data, sigma_data = None, None

# Initialize Networks

In [None]:
G = ResNet_G(Z_DIM, IMG_SIZE, nc=NC).cuda()
G.apply(weights_init_D)

Ds = []
for k in range(K):
    Ds.append(ResNet_D(IMG_SIZE, nc=NC).cuda())
    Ds[-1].apply(weights_init_D)

Ts = []
for k in range(K):
    Ts.append(UNet(NC, NC, base_factor=48 if NC == 3 else 16).cuda())

if len(DEVICE_IDS) > 1:
    G = nn.DataParallel(G, device_ids=DEVICE_IDS)
    for k in range(K):
        Ts[k] = nn.DataParallel(Ts[k], device_ids=DEVICE_IDS)
        Ds[k] = nn.DataParallel(Ds[k], device_ids=DEVICE_IDS)
    
print('G params:', np.sum([np.prod(p.shape) for p in G.parameters()]))
print('T params:', np.sum([np.prod(p.shape) for p in Ts[0].parameters()]))
print('D params:', np.sum([np.prod(p.shape) for p in Ds[0].parameters()]))

In [None]:
# Fix random images
Z_fixed = Z_sampler.sample(10)
Ys_fixed = [Y_samplers[k].sample(10) for k in range(K)]
Y_pack_fixed = [Y_samplers[k].sample(1)[0] for k in range(K)]

### Plots Test

In [None]:
fig, axes = plot_bar_images(Z_fixed, Ys_fixed, G, Ts, ALPHAS)
fig, axes = plot_bar_random_images(Z_sampler, Y_samplers, G, Ts, ALPHAS)
if hasattr(Ts[0], 'inverse'):
    fig, axes = plot_bar_maps(Y_pack_fixed, Ts)
    fig, axes = plot_random_bar_maps(Y_samplers, Ts)
torch.cuda.empty_cache()
gc.collect()

# Run Training

In [None]:
wandb.init(name=EXP_NAME, project='wasserstein2iterativenetworks', entity='gunsandroses', config=config)
pass

In [None]:
Ts_opt, Ds_opt = [], []

G_opt = torch.optim.Adam(G.parameters(), lr=G_LR, weight_decay=1e-7)
for k in range(K):
    Ts_opt.append(torch.optim.Adam(Ts[k].parameters(), lr=T_LR, weight_decay=1e-10))
    Ds_opt.append(torch.optim.Adam(Ds[k].parameters(), lr=D_LR, weight_decay=1e-10))

G_sch = torch.optim.lr_scheduler.StepLR(G_opt, step_size=G_SCHEDULER_STEP, gamma=G_SCHEDULER_GAMMA)
Ds_sch = []   
for k in range(K):
    Ds_sch.append(torch.optim.lr_scheduler.StepLR(Ds_opt[k], step_size=D_SCHEDULER_STEP, gamma=D_SCHEDULER_GAMMA))

last_plot_step, last_fd_step = -np.inf, 0

In [None]:
step = 0

In [None]:
while step < MAX_STEPS:
    freeze(G)
    
    for k in range(K):
        # D and T optimization cycle
        for d_iter in tqdm(range(D_ITERS)):
            step += 1

            # T optimization
            unfreeze(Ts[k]); freeze(Ds[k])
            for t_iter in range(T_ITERS): 
                with torch.no_grad():
                    X = G(Z_sampler.sample(BATCH_SIZE))
                Ts_opt[k].zero_grad()
                T_X = Ts[k](X)
                T_loss = .5 * F.mse_loss(X, T_X) - Ds[k](T_X).mean()
                T_loss.backward(); Ts_opt[k].step()
            del T_loss, T_X, X; gc.collect(); torch.cuda.empty_cache()

            # D optimization
            with torch.no_grad():
                X = G(Z_sampler.sample(BATCH_SIZE))
            Y = Y_samplers[k].sample(BATCH_SIZE)
            unfreeze(Ds[k]); freeze(Ts[k])
            with torch.no_grad():
                T_X = Ts[k](X)  
            Ds_opt[k].zero_grad()
            D_loss = Ds[k](T_X).mean() - Ds[k](Y).mean()
            D_loss.backward(); Ds_opt[k].step(); Ds_sch[k].step()
            wandb.log({f'D_loss_{k}' : D_loss.item()}, step=step) 
            del D_loss, Y, X; gc.collect(); torch.cuda.empty_cache()
        
    if step >= last_plot_step + PLOT_INTERVAL:
        print('Plotting')
        last_plot_step = step; clear_output(wait=True)
        
        fig, axes = plot_bar_images(Z_fixed, Ys_fixed, G, Ts, ALPHAS)
        wandb.log({'Fixed Images' : [wandb.Image(fig2img(fig))]}, step=step) 
        plt.show(fig); plt.close(fig) 

        fig, axes = plot_bar_random_images(Z_sampler, Y_samplers, G, Ts, ALPHAS)
        wandb.log({'Random Images' : [wandb.Image(fig2img(fig))]}, step=step) 
        plt.show(fig); plt.close(fig) 
    
    if step >= last_fd_step + FD_INTERVAL:
        last_fd_step = step
        
        if mu_data is not None:
            print('Computing FD score')
            m, s = get_generated_stats(G, Z_sampler, size=SIZE, batch_size=8, inception=INCEPTION, verbose=True)
            FD_G = calculate_frechet_distance(m, s, mu_data, sigma_data)
            del m, s;  gc.collect(); torch.cuda.empty_cache()

            T_G = SumSequential(G, Ts, ALPHAS)
            m, s = get_generated_stats(T_G, Z_sampler, size=SIZE, inception=INCEPTION, batch_size=8, verbose=True)
            FD_T_G = calculate_frechet_distance(m, s, mu_data, sigma_data)
            del m, s;  gc.collect(); torch.cuda.empty_cache()

            wandb.log({'FD_G' : FD_G, 'FD_T_G' : FD_T_G}, step=step)
        
        print('Creating a checkpoint')
        freeze(G); torch.save(G.state_dict(), os.path.join(OUTPUT_PATH, f'G_{step}.pt'))
        for k in range(K):
            freeze(Ts[k]); 
            torch.save(Ts[k].state_dict(), os.path.join(OUTPUT_PATH, f'T_{k}_{step}.pt'))
    
    # G optimization
    if G_ITERS > 0:
        for k in range(K):
            freeze(Ts[k])
        G_old = deepcopy(G); freeze(G_old)
        unfreeze(G)
        for g_iter in range(G_ITERS):
            step += 1
            Z = Z_sampler.sample(BATCH_SIZE)
            with torch.no_grad():
                G_old_Z = G_old(Z)
                T_G_old_Z = torch.zeros_like(G_old(Z))
                for k in range(K):
                    T_G_old_Z += ALPHAS[k] * Ts[k](G_old(Z))

            G_opt.zero_grad()
            G_loss = .5 * G_criterion(G(Z), T_G_old_Z).mean()
            G_loss.backward(); G_opt.step(); G_sch.step()

            wandb.log({"G_loss" : G_loss.item()}, step=step)

        del G_old, G_loss, T_G_old_Z, Z
        gc.collect(); torch.cuda.empty_cache()