In [None]:
import os, sys
sys.path.append("..")

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

import numpy as np
import torch
import torch.nn as nn
import torchvision
import gc

from src import distributions
import torch.nn.functional as F

# NOT networks
from src.resnet2 import ResNet_D
from src.unet import UNet

from src.tools import unfreeze, freeze
from src.tools import weights_init_D
from src.tools_paired import load_paired_dataset, get_pushed_loader_stats, get_pushed_loader_metrics
from src.fid_score import calculate_frechet_distance
from src.plotters import plot_random_paired_images, plot_images
from src.u2net import U2NET
from src.losses import VGGPerceptualLoss as VGGLoss

from copy import deepcopy
import json

from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

import wandb
from src.tools import fig2data, fig2img # for wandb

# This needed to use dataloaders for some datasets
from PIL import PngImagePlugin
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

## Main Config

In [None]:
DEVICE_IDS = [0]

# DATASET, DATASET_PATH, REVERSE = 'comic_faces_v1', '../datasets/face2comics_v1.0.0_by_Sxela', False
DATASET, DATASET_PATH, REVERSE = 'celeba_mask', '../datasets/CelebAMask-HQ', False
#DATASET, DATASET_PATH, REVERSE = 'edges2shoes', '../datasets/Edges2Shoes', False

T_TYPE = 'U2Net'  # 'Unet_pix2pix' # 'UNet' # or  ('ResNet_pix2pix' - not implemented)
D_TYPE = 'ResNet'  # or 'ResNet_pix2pix' - DOES NOT WORK WELL (it is actually not a resnet:)

# These three work only for pix2pix networks
T_DROPOUT = False
T_NORM = 'batch' # 'instance' or 'none'

# Works only for ResNet_D
D_DROPOUT = False

# For ResNet_pix2pix it uses the given layer. For our ResNet_D uses the batchnorm/none.
D_NORM = 'none' # 'instance' or 'none'
GP = 10
LAMBDA = 0.1
T_ITERS = 10
D_LR, T_LR = 0.0001, 0.0001
IMG_SIZE = 256

BATCH_SIZE = 32
CONDITIONAL = False # Test conditional NOT (not needed anymore)
NOT = True # Train Neural optimal transport or pure regression

PLOT_INTERVAL = 1000
COST ='vgg' #'mse' # 'mae' # 'vgg'
CPKT_INTERVAL = 5000
MAX_STEPS = 50001
SEED = 0x000000

# EMAS = [0.99, 0.999, 0.9999]
# EMA_START = 70000

CONTINUE = -1

EXP_NAME = f'NOT_ours_{DATASET}_T{T_ITERS}_{COST}_{IMG_SIZE}_{T_TYPE}_{D_TYPE}_{BATCH_SIZE}_PixelNorm'
OUTPUT_PATH = '../checkpoints/{}/gnot/{}_{}_{}_{}_{}_{}_{}_{}/'.format(COST, DATASET, IMG_SIZE, NOT, CONDITIONAL, T_TYPE, D_TYPE, BATCH_SIZE, 'PixelNorm')

## Preparation

In [None]:
config = dict(
    DATASET=DATASET,
    T_TYPE=T_TYPE, D_TYPE=D_TYPE,
    T_ITERS=T_ITERS,
    T_DROPOUT=T_DROPOUT, D_DROPOUT=D_DROPOUT,
    D_LR=D_LR, T_LR=T_LR,
    BATCH_SIZE=BATCH_SIZE,
    CONDITIONAL=CONDITIONAL,
    NOT=NOT, COST=COST
)

assert not ((not NOT) and CONDITIONAL)
FID_EPOCHS = 50
    
assert torch.cuda.is_available()
torch.cuda.set_device(f'cuda:{DEVICE_IDS[0]}')
torch.manual_seed(SEED); np.random.seed(SEED)

if COST == 'vgg':
    vgg_loss = VGGLoss().cuda()

if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

## Loading data stats for testing

In [None]:
# filename = '../stats/{}_{}_{}_test.json'.format(DATASET, IMG_SIZE, REVERSE)
# with open(filename, 'r') as fp:
#     data_stats = json.load(fp)
#     mu_data, sigma_data = data_stats['mu'], data_stats['sigma']
# del data_stats

## Prepare Samplers (X, Y)

In [None]:
XY_sampler, XY_test_sampler = load_paired_dataset(DATASET, DATASET_PATH, img_size=IMG_SIZE, reverse=REVERSE)
XY_sampler_plt, XY_test_sampler_plt = XY_sampler, XY_test_sampler
    
torch.cuda.empty_cache(); gc.collect()
clear_output()

# Initializing Networks

In [None]:
T_norm_layer = get_norm_layer(T_NORM)
D_norm_layer = get_norm_layer(D_NORM)

if D_TYPE == 'ResNet':
    D = ResNet_D(IMG_SIZE, nc=3 if not CONDITIONAL else 6, bn=D_NORM != 'none', use_dropout=D_DROPOUT).cuda()
    D.apply(weights_init_D)
elif D_TYPE == 'ResNet_pix2pix':
    D = NLayerDiscriminator(
        3 if not CONDITIONAL else 6, n_layers=3,
        norm_layer=D_norm_layer).cuda()
    init_weights(D)
else:
    raise NotImplementedError('Unknown D_TYPE: {}'.format(D_TYPE))

if T_TYPE == 'UNet':
    T = UNet(3, 3, base_factor=48).cuda()
elif T_TYPE == 'Unet_pix2pix':
    T = UnetGenerator(
        3, 3, num_downs=np.log2(IMG_SIZE).astype(int),
        use_dropout=T_DROPOUT, norm_layer=T_norm_layer
    ).cuda()
    init_weights(T)
elif T_TYPE == 'U2Net':
    T = U2NET(out_ch=3).cuda()
else:
    raise NotImplementedError('Unknown T_TYPE: {}'.format(T_TYPE))
    
if len(DEVICE_IDS) > 1:
    T = nn.DataParallel(T, device_ids=DEVICE_IDS)
    D = nn.DataParallel(D, device_ids=DEVICE_IDS)
    
print('T params:', np.sum([np.prod(p.shape) for p in T.parameters()]))
print('D params:', np.sum([np.prod(p.shape) for p in D.parameters()]))

In [None]:
torch.manual_seed(0xBADBEEF); np.random.seed(0xBADBEEF)
X_fixed, Y_fixed = XY_sampler_plt.sample(10)
X_test_fixed, Y_test_fixed = XY_test_sampler_plt.sample(10)

### Plots Test

In [None]:
fig, axes = plot_images(X_fixed, Y_fixed, T)
fig, axes = plot_random_paired_images(XY_sampler, T)
fig, axes = plot_images(X_test_fixed, Y_test_fixed, T)
fig, axes = plot_random_paired_images(XY_test_sampler, T)

# Run Training

In [None]:
#wandb.init(name=EXP_NAME, project='pairedot', entity='gunsandroses', config=config)

wandb.init(project='ICLR GNOT (testing)',
    name=EXP_NAME,
    entity='rock-and-roll',
    reinit=True,
    config = config,
)
pass

In [None]:
T_opt = torch.optim.Adam(T.parameters(), lr=T_LR, weight_decay=1e-10)
D_opt = torch.optim.Adam(D.parameters(), lr=D_LR, weight_decay=1e-10)

T_scheduler = torch.optim.lr_scheduler.MultiStepLR(T_opt, milestones=[15000, 30000, 45000, 70000], gamma=0.5)
D_scheduler = torch.optim.lr_scheduler.MultiStepLR(D_opt, milestones=[15000, 30000, 45000, 70000], gamma=0.5)

if CONTINUE > -1:
    T_opt.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'T_opt_{SEED}_{CONTINUE}.pt')))
    T_scheduler.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'T_scheduler_{SEED}_{CONTINUE}.pt')))
    T.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'T_{SEED}_{CONTINUE}.pt')))
    D_opt.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D_opt_{SEED}_{CONTINUE}.pt')))
    D.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D_{SEED}_{CONTINUE}.pt')))
    D_scheduler.load_state_dict(torch.load(os.path.join(OUTPUT_PATH, f'D_scheduler_{SEED}_{CONTINUE}.pt')))


In [None]:
for step in tqdm(range(CONTINUE+1, MAX_STEPS)):
    # T optimization
    unfreeze(T); freeze(D)
    for t_iter in range(T_ITERS): 
        T_opt.zero_grad()
        X, Y = XY_sampler.sample(BATCH_SIZE)
        T_X = T(X)
        
        if CONDITIONAL:
            T_X = torch.cat([T_X, X], dim=1)  
            
        if COST == 'rmse':
            T_loss = (Y-T_X[:, :3]).flatten(start_dim=1).norm(dim=1).mean()
        elif COST == 'mse':
            T_loss = (Y-T_X[:, :3]).flatten(start_dim=1).square().sum(dim=1).mean()
        elif COST == 'mae':
            T_loss = (Y-T_X[:, :3]).flatten(start_dim=1).abs().sum(dim=1).mean()
        elif COST == 'vgg':
            T_loss = vgg_loss(Y, T_X[:, :3]).mean()
        else:
            raise Exception('Unknown COST')  
            
        if NOT:
            T_loss -= D(T_X).mean()
            
        T_loss.backward(); T_opt.step()
    T_scheduler.step()
    del T_loss, T_X, X, Y; gc.collect(); torch.cuda.empty_cache()

    if NOT:
        # D optimization
        freeze(T); unfreeze(D)
        X, _ = XY_sampler.sample(BATCH_SIZE)
        with torch.no_grad():
            T_X = T(X)
        _, Y = XY_sampler.sample(BATCH_SIZE) # We may use the previous batch here
        if CONDITIONAL:
            with torch.no_grad():
                T_X = torch.cat([T_X, X], dim=1)
                Y = torch.cat([Y, X], dim=1)
        D_opt.zero_grad()
        D_loss = D(T_X).mean() - D(Y).mean()
        D_loss.backward(); D_opt.step(); D_scheduler.step()
        wandb.log({f'D_loss' : D_loss.item()}, step=step) 
        del D_loss, Y, X, T_X, _; gc.collect(); torch.cuda.empty_cache()

        
    if step % PLOT_INTERVAL == 0:
        print('Plotting')
        clear_output(wait=True)

        fig, axes = plot_images(X_fixed, Y_fixed, T)
        wandb.log({'Fixed Images' : [wandb.Image(fig2img(fig))]}, step=step) 
        plt.show(fig); plt.close(fig) 
        
        fig, axes = plot_random_paired_images(XY_sampler_plt, T)
        wandb.log({'Random Images' : [wandb.Image(fig2img(fig))]}, step=step) 
        plt.show(fig); plt.close(fig) 
        
        fig, axes = plot_images(X_test_fixed, Y_test_fixed, T)
        wandb.log({'Fixed Test Images' : [wandb.Image(fig2img(fig))]}, step=step) 
        plt.show(fig); plt.close(fig) 
        
        fig, axes = plot_random_paired_images(XY_test_sampler_plt, T)
        wandb.log({'Random Test Images' : [wandb.Image(fig2img(fig))]}, step=step) 
        plt.show(fig); plt.close(fig) 
    
    
    if step % CPKT_INTERVAL == 0:
        freeze(T); 
        torch.save(T.state_dict(), os.path.join(OUTPUT_PATH, f'T_{SEED}_{step}.pt'))
        torch.save(D.state_dict(), os.path.join(OUTPUT_PATH, f'D_{SEED}_{step}.pt'))
        torch.save(D_opt.state_dict(), os.path.join(OUTPUT_PATH, f'D_opt_{SEED}_{step}.pt'))
        torch.save(T_opt.state_dict(), os.path.join(OUTPUT_PATH, f'T_opt_{SEED}_{step}.pt'))
        torch.save(D_scheduler.state_dict(), os.path.join(OUTPUT_PATH, f'D_scheduler_{SEED}_{step}.pt'))
        torch.save(T_scheduler.state_dict(), os.path.join(OUTPUT_PATH, f'T_scheduler_{SEED}_{step}.pt'))
        
        
        #print('Computing FID')
        #mu, sigma = get_pushed_loader_stats(T, XY_test_sampler.loader,  n_epochs=FID_EPOCHS)
        #fid = calculate_frechet_distance(mu_data, sigma_data, mu, sigma)
        #wandb.log({f'FID (Test)' : fid}, step=step)
        #del mu, sigma
        

    
    gc.collect(); torch.cuda.empty_cache()