### This notebook prepares EDC pipelines vs SRVP pipelines for different classifiers

In [None]:
import torch
import random
import os
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 300
import numpy as np
from torchinfo import summary

import sys
sys.path.insert(0,'..')
torch.cuda.empty_cache()

# set random seed for reproducibility.
seed = 1123
random.seed(seed)
torch.manual_seed(seed)
print("Random Seed: ", seed)

# use GPU if available.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device, "will be used.")

In [None]:
input_shape = [3,64,64]
batch_size = 6
batch_input_shape = [batch_size, *input_shape]

classes = ["Male"]
num_classes = len(classes) + 1

latent_dims = 32

## loading vs running params
# set False and set params in the EDCs preparation section below if want to train the EDCs.
edc_load = True  
# set False if want to train the classifier in the Classifier preparation section below.
cla_load = True  
# set False and set params in the SRVPs preparation section below if want to train the SRVPs.
srvp_load = True

In [None]:
# dataset preparation
from torch.utils.data import random_split, DataLoader
from notebook_utils import CelebADataset, flog
from utils import get_transforms

dset = CelebADataset("../data/CelebA", "list_attr_celeba.txt", get_transforms(input_shape[-1], [], True), classes)
data_len, test_ratio = len(dset), 0.1
train_dataset, test_dataset = random_split(
    dset, [int(np.ceil((1-test_ratio)*data_len)), int(np.floor(test_ratio*data_len))]
)
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### Encoder->Decoder->Classifier (EDC) pipelines preparation

In [None]:
# EDC pipelines summary and reconstruction statistics
from notebook_utils import get_encoder, Encoder, GenModel, get_decoder#, test_gen_model
from utils import denormalize

exp_disc = "baseline"
edc_dir = f"./models/gen_models_{input_shape[-1]}/EDC_{latent_dims}/{exp_disc}"
os.makedirs(edc_dir, exist_ok=True)

dec_types = ["tiny", "small", "deeper", "resnet"]
edc_gens = {}
for edc_dec_type in dec_types:
    # test_gen_model("big", edc_dec_type, batch_input_shape, latent_dims)
    dec = get_decoder(edc_dec_type, input_shape, latent_dims)
    edc_gens[edc_dec_type] = GenModel(Encoder(get_encoder("big", input_shape, latent_dims), device), 
                                      dec,
                                      device,
                                      f"{edc_dir}/enc_big_edc_{edc_dec_type}.tar" if edc_load else None)
    if not edc_load:
        flog(f"{edc_dir}/notes.txt", [edc_dec_type, dec, summary(dec, input_size=(1, latent_dims), device=device)])

In [None]:
# EDC trainings code
if not edc_load:
    from notebook_utils import train_gen_model

    dec_types = ["tiny", "small", "deeper", "resnet"]
    nepochs = {"tiny": 6, "small": 8, "deeper": 12, "resnet": 18}
    for dec_type in dec_types:
        print(f"Training EDC with Encoder: big, Decoder: {dec_type}.")
        edc_gen = edc_gens[dec_type]
        opt_gen = torch.optim.Adam(edc_gen.parameters(), lr=2.5e-3)
        gen_path = f"{edc_dir}/enc_big_edc_{dec_type}"
        edc_gens[dec_type] = train_gen_model(edc_gen, opt_gen, gen_path, train_dl, num_epochs=nepochs[dec_type], device=device)

In [None]:
from notebook_utils import denormalize, save_batch_images # for FID score computation using 3rd party lib
from torchvision.utils import make_grid, save_image

# reconstruction checks measures on decoders
edc_fig = f"{edc_dir}/figures"
os.makedirs(edc_fig, exist_ok=True)

recons_errs = {k: [] for k in dec_types}
for bi, (x, y) in enumerate(train_dl):
    save_image(make_grid(denormalize("CelebA", x[:32])), f"{edc_fig}/original_{bi}.png")
    save_batch_images(denormalize("CelebA", x), f"{edc_fig}/original", bi)
    for dec_type in dec_types:
        x_hat = torch.clamp(edc_gens[dec_type](x.to(device))[-1], min=-1, max=1).cpu()
        recons_errs[dec_type].append(torch.nn.MSELoss()(x_hat, x).item()) #(torch.sqrt(torch.sum((x-x_hat)**2))/N).item())
        save_image(make_grid(denormalize("CelebA", x_hat[:32])), f"{edc_fig}/enc_big_{dec_type}_{bi}.png")
        save_batch_images(denormalize("CelebA", x_hat), f"{edc_fig}/dec_{dec_type}", bi)
    if bi > 20:
        break

print("Reconstruction errors")
for k, v in recons_errs.items():
    v = np.array(v)
    print(k, np.min(v), np.mean(v), np.median(v), np.max(v))

### Classifiers (Networks) preparation

In [None]:
from notebook_utils import get_classifier, train_cla, test_cla

cla_type = "deeper"
cla_path = f"./models/classifiers_{input_shape[-1]}/{cla_type}_{classes[0]}"
cla = get_classifier(cla_type, latent_dims, num_classes, device, f"{cla_path}.tar" if cla_load else None).to(device)

In [None]:
# CLA training code
if not cla_load:
    os.makedirs(os.path.dirname(cla_path), exist_ok=True)
     cla = train_cla(cla, cla_path, train_dl, num_epochs=6, device=device)
flog(f"{cla_path}_notes.txt", [summary(cla, input_size=(1, *input_shape), device=device), f"Test accuracy: {test_cla(cla, train_dl, device)}"])

### Semantic Robustness Verification Problem (SRVP) pipelines preparation for the above Classifier

In [None]:
from notebook_utils import build_srvp_pipeline

srvp_disc = "baseline"
srvp_dir = f"./models/SRVP_{input_shape[-1]}/cla_{cla_type}"
srvps = {32: None, 64: None, 192: None, 392: None}
scla_loads = {32: False, 64: False, 192: False, 392: True, 512: False}
lds = [32, 64, 192, 392]
for ld in lds: 
    cla = get_classifier(cla_type, latent_dims, num_classes, device, f"{cla_path}.tar" if scla_loads[ld] else None)
    srvp_path = f"{srvp_dir}/ld{ld}"
    srvps[ld], _ = build_srvp_pipeline(device, cla, input_shape, ld, "deeper", f"{srvp_path}.tar" if srvp_load else None)

In [None]:
# training SRVP pipeline
if not srvp_load:
    from notebook_utils import train_gen_model, train_gen_model_all

    for ld in lds:
        if scla_loads[ld]:
            gen_params = list(srvp.encoding_head.parameters()) + list(srvp.decoder.parameters())
        else:
            gen_params = list(srvp.parameters())
        opt_gen = torch.optim.Adam(gen_params, lr=5e-4)
        srvps[ld] = train_gen_model_all(srvp, opt_gen, srvp_path, train_dl, num_epochs=12, device=device, with_cla=not scla_loads[ld])

In [None]:
from notebook_utils import denormalize, save_batch_images # for FID score computation using 3rd party lib
from torchvision.utils import make_grid, save_image

# save SRVPs recons in a folder for FID
srvp_fig = f"{srvp_dir}/figures"
os.makedirs(srvp_fig, exist_ok=True)

recons_errs = {k: [] for k in lds}
for bi, (x, y) in enumerate(train_dl):
    save_image(make_grid(denormalize("CelebA", x[:32])), f"{srvp_fig}/original_{bi}.png")
    save_batch_images(denormalize("CelebA", x), f"{srvp_fig}/original", bi)
    for ld in lds:
        x_out = srvps[ld](x.to(device))
        x_hat = torch.clamp(x_out[-1], min=-1, max=1).cpu()
        recons_errs[ld].append(torch.nn.MSELoss()(x_hat, x).item())
        save_image(make_grid(denormalize("CelebA", x_hat[:32])), f"{srvp_fig}/ld{ld}_{bi}.png")
        save_batch_images(denormalize("CelebA", x_hat), f"{srvp_fig}/ld{ld}", bi)
    if bi > 20:
        break

print("Reconstruction errors")
for k, v in recons_errs.items():
    v = np.array(v)
    print(k, np.min(v), np.mean(v), np.median(v), np.max(v))    

### Training a Real-vs-Fake discriminator for Generative model evaluation and Latent space traversal

In [None]:
from notebook_utils import real_fake_discriminator, get_classifier

disc_path = f"./models/discriminators_{input_shape[-1]}/general_mid"
os.makedirs(disc_path, exist_ok=True)
disc = get_classifier("mid", 32, 2, device)
real_fake_discriminator(train_dl, disc, srvps, device, disc_path)

### Sample code for finding a traversal between 2 endpoints guided by the above discriminator

In [None]:
from notebook_utils import disc_guided_interpolation

cla_type = "deep"
cla_path = f"./models/classifiers_{input_shape[-1]}/{cla_type}_{classes[0]}"
cla = get_classifier(cla_type, latent_dims, num_classes, device, None).to(device)

srvp_disc = "baseline"
srvp_path = f"./models/SRVP_{input_shape[-1]}/cla_{cla_type}/ld{latent_dims}.tar"
srvp, _ = build_srvp_pipeline(device, cla, input_shape, latent_dims, "deeper", srvp_path)    

endpts = torch.randn((2, latent_dims))
disc_guided_interpolation(endpts, disc, srvp.decoder, device, f"./models/discriminators_{input_shape[-1]}/general_mid.tar")

#### For the simple EDC and SRVP pipelines trained with this notebook, use the verification comparison.py script (with right constants updated at the top of the script & consistent with the 2nd cell in this notebook) to run the verification runs for these pipelines. The separation between these notebook and python script is to ease running the verification experiments on a server.