In [None]:
"""
CONNECT WITH DRIVE
"""

%pip install gdown
import gdown

from google.colab import drive


drive.mount('/content/drive')

url = 'https://drive.google.com/uc?id=1IrhQFVlV8aIQ9EV4J-pHhCR1tNk3noTN'
output = '/data/dataset_nonrandom_responses.pth'

gdown.download(url, output, quiet=False)

In [None]:
"""
CorrVAE training based on [TODO]
"""

import time

import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch import optim

from model import ControlVAE
from encoders import EncoderControlVAE
from decoders import DecoderControlVAE

from utils.helpers import prepare_dataloader, plot_epoch, save_model
from utils.loss import get_losses

In [None]:
"""
ARGUMENTS
"""

device = 'cpu'

results_dir = '/content/drive/MyDrive/models/'
data_source = 'dataset_nonrandom_responses.pth'

img_size = (1, 44, 44)
latent_dim = 8
num_prop = 3
lr = 1e-4

batch_size = 64
epochs = 100

beta = 1
taus = 0.2
idx_kl = 0
w_kl = 100

In [None]:
"""
WANDB LOGGER
"""
try:
    import wandb
except:
    %pip install wandb
    import wandb


wandb.init(
    project="CorrVAE_64x64",

    config={
        # "data size": train_size,
        "batch_size": batch_size,
        "beta": beta,
        "taus": taus,
        "loss func": 'sigmoid',
        "lr": lr,
        "num prop": num_prop,
        "latent_dim": latent_dim,
    }
)

In [None]:
"""
DATA
"""

data = torch.load(data_source)
dataset = TensorDataset(data['features'], data['labels'])

train_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size)


In [None]:
"""
MODEL
"""

encoder = eval("EncoderControlVAE")
decoder = eval("DecoderControlVAE")

model = ControlVAE(img_size, encoder, decoder, latent_dim, num_prop, device=device)
model = model.to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=lr)

lr_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[10, 20, 30, 50, 70, 100, 130, 160, 200],
        gamma=0.1
)

mse_loss = torch.nn.MSELoss(reduction="sum")

recon_loss_prop_rec = []
recon_loss_rec = []
kl_loss_rec = []
pwwi_loss_rec = []
pwz_loss_rec = []
l1_loss_rec = []
mask_rec = []

In [None]:
start_epoch = time.time()

for epoch in range(epochs):
    if (epoch + 1) % 10 == 0:
        taus = taus * 0.1

    epoch_loss = []
    epoch_kl_loss = []
    epoch_rec_loss = []
    epoch_prop_loss = []
    epoch_pairwise_loss = []
    epoch_groupwise_loss = []
    epoch_l1_norm = []
    
    for (data, label) in tqdm(train_loader):
        idx_kl += 1
        data = data.to(device)
        
        (reconstruct,y_reconstruct), latent_dist_z, latent_dist_w,\
            latent_sample_z, latent_sample_w, w_mask, mask_ori = model(data, taus)

        latent_sample = torch.cat([latent_sample_w, latent_sample_z], dim=-1)
        latent_dist = (torch.cat([latent_dist_w[0], latent_dist_z[0]], dim=-1), 
                       torch.cat([latent_dist_w[1], latent_dist_z[1]], dim=-1))
        
        ###### Reconstruction loss ######
        rec_loss = F.mse_loss(reconstruct, data, reduction="sum") / 64
        rec_loss = rec_loss / batch_size
        
        rec_loss_prop = []
        rec_loss_prop.append(mse_loss(y_reconstruct[:,0], label[:,0].float().to(device)))
        rec_loss_prop.append(mse_loss(y_reconstruct[:,1], label[:,1].float().to(device)))
        # rec_loss_prop.append(mse_loss(y_reconstruct[:,2], label[:,2].float().to(device)))
        rec_loss_prop_all = sum(rec_loss_prop)
        
        ###### Other losses ######
        kl_loss, pairwise_tc_loss, groupwise_tc_loss, l1norm, loss, w_kl = get_losses(
            latent_dist, latent_sample_w, latent_dist_w, beta,
            latent_sample_z, latent_dist_z, w_mask, device, idx_kl,
            rec_loss, rec_loss_prop_all, w_kl, len(train_loader.dataset)
        )
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss.append(float(loss))
        epoch_kl_loss.append(float(kl_loss))
        epoch_rec_loss.append(float(rec_loss))
        epoch_prop_loss.append(float(rec_loss_prop_all))
        epoch_pairwise_loss.append(float(pairwise_tc_loss))
        epoch_groupwise_loss.append(float(groupwise_tc_loss))
        epoch_l1_norm.append(float(l1norm))

    wandb.log(
        {"total_loss ": np.mean(epoch_loss),
            "KL_loss": np.mean(epoch_kl_loss),
            "rec_loss": np.mean(epoch_rec_loss),
            "rec_prop_loss": np.mean(epoch_prop_loss),
            "wwi_loss": np.mean(epoch_pairwise_loss),
            "wz_loss": np.mean(epoch_groupwise_loss),
            "l1_norm": np.mean(epoch_l1_norm),
    })

    save_model(model, optimizer, results_dir, epoch)
    plot_epoch(train_loader, model, device, taus, epoch, time.time() - start_epoch)

    lr_scheduler.step()
    start_epoch = time.time()