In [1]:
"""
CONNECT WITH DRIVE
"""

%pip install gdown
import gdown
import os
import shutil
from google.colab import drive

# Drive
drive.mount('/content/drive')
# url = 'https://drive.google.com/uc?id=1AvXhv8p9ZCUZ2dzbYoA5CP8VT7a8b6Gf' # full
url = 'https://drive.google.com/uc?id=1c4n7GXPn0OTgUsPJwuWWnIignXXJDm6W' # train
url_test = 'https://drive.google.com/uc?id=1PNLYjJ9lGwn2bie2EkVaF1Qw5IcY7rov' # test


output = 'dataset_nonrandom_responses.pth'

gdown.download(url, output, quiet=False)
gdown.download(url_test, f'{output}_test', quiet=False)

# Github
!git clone https://github.com/karolrogozinski/SCAE.git

source = '/content/cern_alice_fast_sim_corrvae'
destination = '/content'

for file in os.listdir(source):
    source_path = os.path.join(source, file)
    dest_path = os.path.join(destination, file)

    try:
        shutil.copy(source_path, dest_path)
    except IsADirectoryError:
        shutil.copytree(source_path, dest_path)


[33mDEPRECATION: nb-black 1.0.7 has a non-standard dependency specifier black>='19.3'; python_version >= "3.6". pip 24.0 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of nb-black or contact the author to suggest that they release a version with a conforming dependency specifiers. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0m^C
Note: you may need to restart the kernel to use updated packages.


ModuleNotFoundError: No module named 'google.colab'

In [2]:
"""
CorrVAE training based on [TODO]
"""

import time

import numpy as np
from tqdm import tqdm

from geomloss import SamplesLoss

import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch import optim

from src.model import ControlVAE
from src.encoders import EncoderControlVAE
from src.decoders import DecoderControlVAE
from src.noise_generator import NoiseGenerator

from utils.helpers import plot_epoch, save_model
from utils.loss import get_losses


In [3]:
"""
PROPERTIES
All of them are normalized

0. - x coordinate of max pixel
1. - y coordinate of max pixel
2. - x coordinate of mass center
3. - y coordinate of mass center
4. - number of non zero pixels
5. - categorized number of non zero pixels
6. - sum of pixels over the treshold !DONT USE IT
7. - sum of pixels
8. - max pixel value
"""

properties = [2, 3]


In [None]:
"""
ARGUMENTS
"""

device = 'cpu'

results_dir = '/content/drive/MyDrive/models/'
data_source = './data/dataset_nonrandom_responses_train.pth'
data_source_test = './data/dataset_nonrandom_responses_test.pth'

img_size = (1, 44, 44)
latent_dim = 8
latent_dim_prop = 8
latent_dim_cond = 8
cond_dim = 9

hid_channels = 64

num_prop = len(properties)

lr = 1e-4

batch_size = 128
epochs = 250

beta = 1
taus = 0.2
idx_kl = 0
w_kl = 100

samples = 10000

lambdas = [
    1000000,    # reconstruction_loss
    1,          # pairwise_tc_loss
    1000000,    # reconstruction_prop_loss
    1,          # groupwise_wz_loss
    1,          # kl_loss
    1,          # sinkhorn_z
    1,          # sinkhorn_w
]


In [None]:
"""
WANDB LOGGER
"""
try:
    import wandb
except:
    %pip install wandb
    import wandb


In [None]:
token = ...
wandb.login(key=token, relogin=True)


In [None]:
wandb.init(
    project="SCAE",

    config={
        "data size": samples,
        "batch_size": batch_size,
        # "beta": beta,
        # "taus": taus,
        "loss func": 'sigmoid',
        "lr": lr,
        "num prop": num_prop,
        "latent_dim": latent_dim,
        "latent_dim_prop": latent_dim_prop
    }
)


In [None]:
"""
DATA
"""
data = torch.load(data_source)
dataset = TensorDataset(data['features'][:samples], data['labels'][:samples])

train_loader = DataLoader(dataset, shuffle=True, batch_size=batch_size)

data_test = torch.load(data_source_test)
dataset_test = TensorDataset(data_test['features'], data_test['labels'])

test_loader = DataLoader(dataset_test, shuffle=True, batch_size=batch_size)


In [None]:
"""
MODEL
"""

encoder = eval("EncoderControlVAE")
decoder = eval("DecoderControlVAE")
noise_generator_z = eval("NoiseGenerator")
noise_generator_w = eval("NoiseGenerator")

model = ControlVAE(img_size, encoder, decoder, noise_generator_z, noise_generator_w,
                   latent_dim, latent_dim_prop, num_prop, device=device)
model = model.to(device)
model.train()

optimizer = optim.Adam(model.parameters(), lr=lr)

lr_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[50, 100, 150, 200],
        gamma=0.1
)

mse_loss = torch.nn.MSELoss(reduction="sum")
sinkhorn_loss = SamplesLoss("sinkhorn", blur=0.05,scaling = 0.95,diameter=0.01,debias=True)


In [None]:
start_epoch = time.time()

for epoch in range(1, epochs):
    if (epoch + 1) % 10 == 0:
        taus = taus * 0.1

    epoch_loss = []
    epoch_rec_loss = []
    epoch_kl_loss = []
    epoch_rec_prop_loss = []
    epoch_sinkhorn_z = []
    epoch_sinkhorn_w = []
    epoch_sinkhorn_wz = []
    epoch_pairwise_loss = []
    epoch_groupwise_loss = []
    epoch_l1_norm = []

    for (data, label) in tqdm(train_loader):
        idx_kl += 1
        data = data.to(device)
        cond = label[:,9:].to(device)

        reconstruct, y_reconstruct, z, w , z_, w_, w_mask = model(data, taus)

        latent_dist_z = (model.z_mean_avg.repeat([data.shape[0], 1]), model.z_std_avg.repeat([data.shape[0], 1]))
        latent_dist_w = (model.w_mean_avg.repeat([data.shape[0], 1]), model.w_std_avg.repeat([data.shape[0], 1]))

        """
        Losses
        """
        ###### Reconstruction loss ######
        rec_loss = mse_loss(reconstruct, data) / 64
        rec_loss = rec_loss / batch_size

        ###### Reconstruction prop loss ######
        rec_loss_prop = []
        for i, prop in enumerate(properties):
            rec_loss_prop.append(mse_loss(y_reconstruct[:,i], label[:, prop].float().to(device)))

        rec_loss_prop_all = sum(rec_loss_prop)
 
        ###### Sinkhorn losses ######
        sinkhorn_z = sinkhorn_loss(z, z_)
        sinkhorn_w = sinkhorn_loss(w, w_)
       
        kl_loss, pairwise_tc_loss, groupwise_tc_loss, l1norm, loss, w_kl =\
            get_losses(w, latent_dist_w, beta,
               z, latent_dist_z, w_mask, device, idx_kl,
               rec_loss, rec_loss_prop_all, w_kl, len(train_loader.dataset),
               lambdas, sinkhorn_z, sinkhorn_w)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss.append(float(loss))
        epoch_kl_loss.append(float(kl_loss))
        epoch_rec_loss.append(float(rec_loss))
        epoch_rec_prop_loss.append(float(rec_loss_prop_all))
        epoch_sinkhorn_z.append(float(sinkhorn_z))
        epoch_sinkhorn_w.append(float(sinkhorn_w))
        epoch_pairwise_loss.append(float(pairwise_tc_loss))
        epoch_groupwise_loss.append(float(groupwise_tc_loss))
        epoch_l1_norm.append(float(l1norm))

    wandb.log(
        {"total_loss ": np.mean(epoch_loss),
            "KL_loss": np.mean(epoch_kl_loss),
            "rec_loss": np.mean(epoch_rec_loss),
            "rec_prop_loss": np.mean(epoch_rec_prop_loss),
            "sinkhorn_z": np.mean(epoch_sinkhorn_z),
            "sinkhorn_w": np.mean(epoch_sinkhorn_w),
            "wwi_loss": np.mean(epoch_pairwise_loss),
            "wz_loss": np.mean(epoch_groupwise_loss),
            "l1_norm": np.mean(epoch_l1_norm),
    })

    # save_model(model, optimizer, results_dir, epoch)
    plot_epoch(test_loader, model, device, taus, epoch, time.time() - start_epoch)

    lr_scheduler.step()
    start_epoch = time.time()
