# Adversarial VAE (Adv.VAE)
- Dec 31, 2020


## Load libraries

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os,sys
import re
import math
from datetime import datetime
import time
sys.dont_write_bytecode = True

In [None]:
import pandas as pd

import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from typing import List, Set, Dict, Tuple, Optional, Iterable, Mapping, Union, Callable, TypeVar

from pprint import pprint
from ipdb import set_trace as brpt

In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from  torch.linalg import norm as tnorm
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision import datasets, transforms

import pytorch_lightning as pl
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.tuner.tuning import Tuner


# Select Visible GPU
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 
os.environ["CUDA_VISIBLE_DEVICES"]="1"

## Set Path 
1. Add project root and src folders to `sys.path`
2. Set DATA_ROOT to `maptile_v2` folder

In [None]:
this_nb_path = Path(os.getcwd())
ROOT = this_nb_path.parent
SRC = ROOT/'src'
DATA_ROOT = Path("/data/hayley-old/maptiles_v2/")
paths2add = [this_nb_path, ROOT]

print("Project root: ", str(ROOT))
print('Src folder: ', str(SRC))
print("This nb path: ", str(this_nb_path))


for p in paths2add:
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))
        print(f"\n{str(p)} added to the path.")
        
# print(sys.path)

In [None]:
from src.data.transforms.transforms import Identity, Unnormalizer, LinearRescaler
from src.data.transforms.functional import unnormalize

from src.visualize.utils import show_timgs, show_batch
from src.utils.misc import info
from collections import OrderedDict


## Start experiment 
Given a maptile, predict its style as one of OSM, CartoVoyager

In [None]:
from src.models.plmodules.vanilla_vae import VanillaVAE
from src.models.plmodules.bilatent_vae import BiVAE
from src.models.plmodules.three_fcs import ThreeFCs


In [None]:
# # For reproducibility, set seed like following:
# seed = 100
# pl.seed_everything(seed)
# # sets seeds for numpy, torch, python.random and PYTHONHASHSEED.
# model = Model()
# trainer = pl.Trainer(deterministic=True)

## Adversarial model

TODO:

---
### For batch in dataloader:
- x: (BS, C, h, w): a mini-batch of (c,h,w) tensor

### mu, log_var = model.encoder(x) 
- mu: (BS, latent_dim)
- log_var: (BS, latent_dim)

### z = model.rsample(mu, log_var, self.n_samples) 
- z: (BS, n_samples, latent_dim)
-`z[n]` constains `n_samples` number of latent codes, sampled from the same distribution `N(mu[n], logvar[n])`
 
### recon = model.decoder(z) 
- recon: (BS, n_samples, c, h, w)
- `recon[n]` contains `n_samples` number of (c,h,w)-sized $mu_{x}$, corresponding to the center of the factorized Gaussian for the latent code $z^{(n,l)}$ ($l$th z_sample from $N(\mu[n], logvar[n])$, ie. $\mu_{x}^{(n,l)}$

### out = model.forward(x)
- out (dict): keys are "mu", "logvar", "recon"

### loss_dict = loss_function(out, x, self.n_samples)
- loss_dict (dict): keys are "loss", "kl", "recon_loss"
- kl is computed the same way as in the Vanillia_VAE model's `loss_function`
- recon_loss is a generalized version with `self.n_samples` (>=1) number of samples to estimated each datapoint's MSE_loss as the average over the loss's from the `n_samples` number of $z_{n,l}$ samples.


In [None]:
from src.data.datamodules import BaseDataModule, USPSDataModule, MNISTMDataModule, MNISTDataModule
from src.models.plmodules.bilatent_vae import BiVAE

# Init DataModule
data_root = ROOT/'data'
in_shape = (1,32,32)
batch_size = 32

dm = MNISTDataModule(
    data_root=data_root, 
    in_shape=in_shape,
    batch_size=batch_size)
dm.setup('fit')
# show_batch(dm, cmap='gray')

# Initi plModule
latent_dim=10
hidden_dims = [32,64,128,256]#,512]
adversary_dims = [30,20,15]
lr = 1e-3
act_fn = nn.ReLU()
is_contrasive = True # If true, use adv. loss from both content and style codes. Else just style codes
model = BiVAE(
    in_shape=dm.size(), 
    n_classes=dm.n_classes,
    latent_dim=latent_dim,
    hidden_dims=hidden_dims,
    adversary_dims=adversary_dims,
    learning_rate=lr, 
    act_fn=act_fn,
    size_average=False
)

# model
    

In [None]:
# Add Callbacks
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from src.callbacks.hist_logger import HistogramLogger
from src.callbacks.recon_logger import ReconLogger

# Model wrapper from graph viz
from src.models.model_wrapper import ModelWrapper

callbacks = [
#         HistogramLogger(hist_epoch_interval=1),
#         ReconLogger(recon_epoch_interval=1),
#         EarlyStopping('val_loss', patience=10),
]

# Start the experiment
exp_name = f'{model.name}_{dm.name}'
tb_logger = pl_loggers.TensorBoardLogger(save_dir=f'{ROOT}/temp-logs', 
                                         name=exp_name,
                                         log_graph=False,
                                        default_hp_metric=False)
print(tb_logger.log_dir)

# Log computational graph
# model_wrapper = ModelWrapper(model)
# tb_logger.experiment.add_graph(model_wrapper, model.example_input_array.to(model.device))
# tb_logger.log_graph(model)

trainer_config = {
    'gpus':1,
    'max_epochs': 300,
    'progress_bar_refresh_rate':20,
#     'auto_lr_find': True,
    'terminate_on_nan':True,
#     'num_sanity_val_steps':0.25,
    'check_val_every_n_epoch':10,
    'logger':tb_logger,
#     'callbacks':callbacks,
}

# 
# trainer = pl.Trainer(fast_dev_run=3)
trainer = pl.Trainer(**trainer_config)
# trainer.tune(model=model, datamodule=dm)

# Start exp
# Fit model
trainer.fit(model, dm)
print(f"Finished at ep {trainer.current_epoch, trainer.batch_idx}")

 TODO:
 OPTIMIZER
 def configure_optimizers(self):
        #TODO: ADD optimizer for discriminator
        return torch.optim.Adam(self.parameters(), lr=self.hparams.get("learning_rate"))

TODO:
- [ ] Check output sizes of BiVAE's 
    - [x] encode
    - [x] rsample
    - [x] combine_content_and_style
    - [x] decode
    - [x] forward
- [ ] Check losses 

In [None]:
x,y = next(iter(dm.train_dataloader()))
info(x), info(y)

- check `encode` and `rsample`

In [None]:
dict_qparams = model.encode(x)
for k,v in dict_qparams.items():
    print(f"\n{k}:  {v.shape}")
    if 'mu' in k:
        print(v[0])
    else:
        print(v[0].exp())

In [None]:
dict_z = model.rsample(dict_qparams)
for k,v in dict_z.items():
    print(f"\n{k}:  {v.shape}")
    print(v[0])

- check `combine_content_style` and `decode`

In [None]:
z = model.combine_content_style(dict_z)
assert z.shape == (batch_size, latent_dim)
print("z shape: ", z.shape) #(BS, latent_dim)

In [None]:
mu_x_pred = model.decode(z)
assert mu_x_pred.shape == (batch_size, *in_shape)
print("mu_x_pred shape: ", mu_x_pred.shape)

- Check the entire forward pass

In [None]:
out_dict  = model(x)
for k,v in out_dict.items():
    print(f"\n{k}:  {v.shape}")


- Check the component's of the optimization objective (ie. loss)
    - [x] partition_z: z -> dict_z (keys are "c" and "s")
    - [ ] predict_y: z_partition -> scores
    - [ ]

In [None]:
dict_z = model.partition_z(z)
for k,v in dict_z.items():
    print(f"{k}: {v.shape}")
    assert v.shape == (batch_size, model.content_dim)

In [None]:
c,s = dict_z["c"], dict_z["s"]
c.shape, s.shape

TODO: 
- [ ] Showing the changes in the scores based on c and scores based on s will be super intersting to see as the model learns!!!

In [None]:
scores_c = model.predict_y(c)
scores_s = model.predict_y(s)
assert scores_c.shape == (batch_size, model.n_classes)
assert scores_s.shape == (batch_size, model.n_classes)

print(scores_c[0]) # TODO: Showing the changes in the scores based on c and scores based on s will be super intersting to see as the model learns!!!
print(scores_s[0])

In [None]:
y[0]

- check `compute_loss_c` and `compute_loss_s`


In [None]:
loss_c = model.compute_loss_c(c)
print("loss_c: ", loss_c)

In [None]:
loss_s = model.compute_loss_s(s, y)
print("loss_s: ", loss_s)

- Full loss workflow

In [None]:
out_dict = model(x)
loss_dict = model.loss_function(out_dict, [x,y], 'train')
pprint(loss_dict)

In [None]:
a = torch.ones((5,2))
b = torch.zeros((5,3))

In [None]:
torch.cat([a,b], dim=1)

In [None]:
m = nn.LogSoftmax()
m(a).exp()

# TODO: 
Showing the changes in the scores based on c and scores based on s will be super intersting to see as the model learns!!!

---
## Play with MNISTM and USPS datasets

- MNISTM
    - original size of an image: (1, 16,16)
    - labels: {0, ..., 9}
- USPS
    - original size of an image: (3, 28, 28)
    - labels" {0, ..., 9}
    

In [None]:
from src.data.datasets.mnistm import MNISTM
from torchvision.datasets import USPS


In [None]:
# MNISTM Dataset
bs = 16
num_workers = 16
pin_memory = True
xforms = transforms.Compose([
    transforms.ToTensor(),
    ])
# target_xforms = 
ds = MNISTM(ROOT/'data', 
          transform=xforms,
          download=True)

dl = DataLoader(ds, batch_size=bs, shuffle=True, 
               num_workers=num_workers, pin_memory=pin_memory)


x,y = next(iter(dl))
info(x)
info(y)

In [None]:
show_timgs(x)

In [None]:
# USPS Dataset
bs = 16
num_workers = 16
pin_memory = True
xforms = transforms.Compose([
    transforms.ToTensor(),
    ])
# target_xforms = 
ds = USPS(ROOT/'data', 
          transform=xforms,
          download=True)

dl = DataLoader(ds, batch_size=bs, shuffle=True, 
               num_workers=num_workers, pin_memory=pin_memory)


x,y = next(iter(dl))
info(x)
info(y)
show_timgs(x, cmap='gray')