In [1]:
import datetime
import os.path as osp
import platform

import matplotlib
import mplhep
import numpy as np
import torch
import torch_geometric
from pyg_ssl.args import parse_args
from pyg_ssl.mlpf import MLPF
from pyg_ssl.training_mlpf import training_loop_mlpf
from pyg_ssl.training_VICReg import training_loop_VICReg
from pyg_ssl.utils import CLUSTERS_X, TRACKS_X, data_split, load_VICReg, save_MLPF, save_VICReg
from pyg_ssl.VICReg import DECODER, ENCODER

matplotlib.use("Agg")
mplhep.style.use(mplhep.styles.CMS)

# Ignore divide by 0 errors
np.seterr(divide="ignore", invalid="ignore")

# define the global base device
if torch.cuda.device_count():
    device = torch.device("cuda:0")
    print(f"Will use {torch.cuda.get_device_name(device)}")
else:
    device = "cpu"
    print("Will use cpu")

Will use NVIDIA GeForce GTX 1080 Ti


In [14]:
import random
def data_split(dataset, data_split_mode):
    """
    Depending on the data split mode chosen, the function returns different data splits.

    Choices for data_split_mode
        1. `quick`: uses only 1 datafile of each sample for quick debugging. Nothing interesting there.
        2. `domain_adaptation`: uses QCD samples to train/validate VICReg and TTbar samples to train/validate MLPF.
        3. `mix`: uses a mix of both QCD and TTbar samples to train/validate VICReg and MLPF.

    Returns (each as a list)
        data_VICReg_train, data_VICReg_valid, data_mlpf_train, data_mlpf_valid, data_test_qcd, data_test_ttbar

    """
    print(f"Will use data split mode `{data_split_mode}`")

    if data_split_mode == "quick":
        data_qcd = torch.load(f"{dataset}/p8_ee_qcd_ecm365/processed/data_0.pt")
        data_ttbar = torch.load(f"{dataset}/p8_ee_tt_ecm365/processed/data_0.pt")

        data_test_qcd = data_qcd[: round(0.1 * len(data_qcd))]
        data_test_ttbar = data_ttbar[: round(0.1 * len(data_ttbar))]

        # label remaining data as `rem`
        rem_qcd = data_qcd[round(0.1 * len(data_qcd)) :]
        rem_ttbar = data_ttbar[round(0.1 * len(data_qcd)) :]
        
        data_VICReg = rem_qcd[: round(0.8 * len(rem_qcd))] + rem_ttbar[: round(0.8 * len(rem_ttbar))]
        data_mlpf = rem_qcd[round(0.8 * len(rem_qcd)) :] + rem_ttbar[round(0.8 * len(rem_ttbar)) :]

        # shuffle the samples after mixing (not super necessary since the DataLoaders will shuffle anyway)
        random.shuffle(data_VICReg)
        random.shuffle(data_mlpf)

        data_VICReg_train = data_VICReg[: round(0.9 * len(data_VICReg))]
        data_VICReg_valid = data_VICReg[round(0.9 * len(data_VICReg)) :]

        data_mlpf_train = data_mlpf[: round(0.9 * len(data_mlpf))]
        data_mlpf_valid = data_mlpf[round(0.9 * len(data_mlpf)) :]

    else:  # actual meaningful data splits
        # load the qcd and ttbar samples seperately
        qcd_files = glob.glob(f"{dataset}/p8_ee_qcd_ecm365/processed/*")
        ttbar_files = glob.glob(f"{dataset}/p8_ee_tt_ecm365/processed/*")

        data_qcd = []
        for file in list(qcd_files):
            data_qcd += torch.load(f"{file}")

        data_ttbar = []
        for file in list(ttbar_files):
            data_ttbar += torch.load(f"{file}")

        # use 10% of each sample for testing
        data_test_qcd = data_qcd[: round(0.1 * len(data_qcd))]
        data_test_ttbar = data_ttbar[: round(0.1 * len(data_ttbar))]

        # label remaining data as `rem`
        rem_qcd = data_qcd[round(0.1 * len(data_qcd)) :]
        rem_ttbar = data_ttbar[round(0.1 * len(data_qcd)) :]

        if data_split_mode == "domain_adaptation":
            """
            use QCD samples for VICReg with an 80-20 split.
            use TTbar samples for MLPF with an 80-20 split.
            """
            data_VICReg_train = rem_qcd[: round(0.8 * len(rem_qcd))]
            data_VICReg_valid = rem_qcd[round(0.8 * len(rem_qcd)) :]

            data_mlpf_train = rem_ttbar[: round(0.8 * len(rem_ttbar))]
            data_mlpf_valid = rem_ttbar[round(0.8 * len(rem_ttbar)) :]

        elif data_split_mode == "mix":
            """
            use (80% of QCD + 80% of TTbar) samples for VICReg with a 90-10 split.
            use (20% of QCD + 20% of TTbar) samples for MLPF with a 90-10 split.
            """
            data_VICReg = rem_qcd[: round(0.8 * len(rem_qcd))] + rem_ttbar[: round(0.8 * len(rem_ttbar))]
            data_mlpf = rem_qcd[round(0.8 * len(rem_qcd)) :] + rem_ttbar[round(0.8 * len(rem_ttbar)) :]

            # shuffle the samples after mixing (not super necessary since the DataLoaders will shuffle anyway)
            random.shuffle(data_VICReg)
            random.shuffle(data_mlpf)

            data_VICReg_train = data_VICReg[: round(0.9 * len(data_VICReg))]
            data_VICReg_valid = data_VICReg[round(0.9 * len(data_VICReg)) :]

            data_mlpf_train = data_mlpf[: round(0.9 * len(data_mlpf))]
            data_mlpf_valid = data_mlpf[round(0.9 * len(data_mlpf)) :]

    print(f"Will use {len(data_VICReg_train)} events to train VICReg")
    print(f"Will use {len(data_VICReg_valid)} events to validate VICReg")
    print(f"Will use {len(data_mlpf_train)} events to train MLPF")
    print(f"Will use {len(data_mlpf_valid)} events to validate MLPF")

    return data_VICReg_train, data_VICReg_valid, data_mlpf_train, data_mlpf_valid, data_test_qcd, data_test_ttbar

In [15]:
# load the clic dataset
data_VICReg_train, data_VICReg_valid, data_mlpf_train, data_mlpf_valid, data_test_qcd, data_test_ttbar = data_split(
    "/pfclicvol/data/clic_edm4hep", "quick"
)

Will use data split mode `quick`
Will use 25843 events to train VICReg
Will use 2871 events to validate VICReg
Will use 6461 events to train MLPF
Will use 718 events to validate MLPF


In [16]:
embedding_dim_VICReg = 256
width_encoder = 256 
num_convs = 3
expand_dim = 256
width_decoder = 256

encoder_model_kwargs = {
    "embedding_dim": embedding_dim_VICReg,
    "width": width_encoder,
    "num_convs": num_convs,
    "space_dim": 4,
    "propagate_dim": 22,
    "k": 32,
}

decoder_model_kwargs = {
    "input_dim": embedding_dim_VICReg,
    "output_dim": expand_dim,
    "width": width_decoder,
}

encoder = ENCODER(**encoder_model_kwargs).to(device)
decoder = DECODER(**decoder_model_kwargs).to(device)

print("Encoder", encoder)
print("Decoder", decoder)

Encoder ENCODER(
  (nn1): Sequential(
    (0): Linear(in_features=14, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=256, out_features=256, bias=True)
  )
  (nn2): Sequential(
    (0): Linear(in_features=15, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=256, out_features=256, bias=True)
  )
  (conv): ModuleList(
    (0): GravNetConv(256, 256, k=32)
    (1): GravNetConv(256, 256, k=32)
    (2): GravNetConv(256, 256, k=32)
  )
)
Decoder DECODER(
  (expander): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features

In [17]:
import json
import pickle as pkl
import time

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool

from pyg_ssl.utils import distinguish_PFelements

# VICReg loss function
def criterion(x, y, device="cuda", lmbd=25, epsilon=1e-3):
    bs = x.size(0)
    emb = x.size(1)

    std_x = torch.sqrt(x.var(dim=0) + epsilon)
    std_y = torch.sqrt(y.var(dim=0) + epsilon)
    var_loss = torch.mean(F.relu(1 - std_x)) + torch.mean(F.relu(1 - std_y))

    invar_loss = F.mse_loss(x, y)

    xNorm = (x - x.mean(0)) / x.std(0)
    yNorm = (y - y.mean(0)) / y.std(0)
    crossCorMat = (xNorm.T @ yNorm) / bs
    cross_loss = (crossCorMat * lmbd - torch.eye(emb, device=torch.device(device)) * lmbd).pow(2).sum()

    return var_loss, invar_loss, cross_loss


@torch.no_grad()
def validation_run(
    device,
    encoder,
    decoder,
    train_loader,
    valid_loader,
    lmbd,
    u,
    v,
):
    with torch.no_grad():
        optimizer = None
        ret = train(
            device,
            encoder,
            decoder,
            train_loader,
            valid_loader,
            optimizer,
            lmbd,
            u,
            v,
        )
    return ret


def train(
    device,
    encoder,
    decoder,
    train_loader,
    valid_loader,
    optimizer,
    lmbd,
    u,
    v,
):

    is_train = not (optimizer is None)

    if is_train:
        print("---->Initiating a training run")
        encoder.train()
        decoder.train()
        loader = train_loader
    else:
        print("---->Initiating a validation run")
        encoder.eval()
        decoder.eval()
        loader = valid_loader

    # initialize loss counters
    losses = 0.0
    var_losses, invar_losses, cross_losses = 0.0, 0.0, 0.0

    for i, batch in tqdm.tqdm(enumerate(loader), total=len(loader)):

        # make transformation
        tracks, clusters = distinguish_PFelements(batch.to(device))

        # ENCODE
        embedding_tracks, embedding_clusters = encoder(tracks, clusters)
        # POOLING
        pooled_tracks = global_mean_pool(embedding_tracks, tracks.batch)
        pooled_clusters = global_mean_pool(embedding_clusters, clusters.batch)
        # DECODE
        out_tracks, out_clusters = decoder(pooled_tracks, pooled_clusters)

        # compute loss
        var_loss, invar_loss, cross_loss = criterion(out_tracks, out_clusters, device, lmbd)
        loss = u * var_loss + v * invar_loss + cross_loss

        # update parameters
        if is_train:
            for param in encoder.parameters():
                param.grad = None
            for param in decoder.parameters():
                param.grad = None
            loss.backward()
            optimizer.step()

        losses += loss.detach()
        var_losses += var_loss.detach()
        invar_losses += invar_loss.detach()
        cross_losses += cross_loss.detach()

    losses = losses.cpu().item() / (len(loader))
    var_losses = var_losses.cpu().item() / (len(loader))
    invar_losses = invar_losses.cpu().item() / (len(loader))
    cross_losses = cross_losses.cpu().item() / (len(loader))

    return losses, var_losses, invar_losses, cross_losses

In [49]:
batch_size_VICReg = 1000
train_loader = torch_geometric.loader.DataLoader(data_VICReg_train, batch_size_VICReg)
valid_loader = torch_geometric.loader.DataLoader(data_VICReg_valid, batch_size_VICReg)

In [None]:
import tqdm

lr = 1e-4
n_epochs = 30
lmbd = 0.1
u = 0.01
v = 0.01
patience = 50

optimizer = torch.optim.SGD(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)

torch.cuda.empty_cache()

t0_initial = time.time()

losses_train_tot, losses_train_var, losses_train_invar, losses_train_cross = [], [], [], []
losses_valid_tot, losses_valid_var, losses_valid_invar, losses_valid_cross = [], [], [], []

best_val_loss_tot, best_val_loss_var, best_val_loss_invar, best_val_cross = 99999.9, 99999.9, 99999.9, 99999.9
stale_epochs = 0

for epoch in range(n_epochs):
    t0 = time.time()

    if stale_epochs > patience:
        print("breaking due to stale epochs")
        break

    # training step
    losses_t_tot, losses_t_var, losses_t_invar, losses_t_cross = train(
        device,
        encoder,
        decoder,
        train_loader,
        valid_loader,
        optimizer,
        lmbd,
        u,
        v,
    )

    losses_train_tot.append(losses_t_tot)
    losses_train_var.append(losses_t_var)
    losses_train_invar.append(losses_t_invar)
    losses_train_cross.append(losses_t_cross)

    # validation step
    losses_v_tot, losses_v_var, losses_v_invar, losses_v_cross = validation_run(
        device,
        encoder,
        decoder,
        train_loader,
        valid_loader,
        lmbd,
        u,
        v,
    )

    losses_valid_tot.append(losses_v_tot)
    losses_valid_var.append(losses_v_var)
    losses_valid_invar.append(losses_v_invar)
    losses_valid_cross.append(losses_v_cross)

    # save the lowest value of each component of the loss to print it on the legend of the loss plots
    if losses_v_var < best_val_loss_var:
        best_val_loss_var = losses_v_var
        best_train_loss_var = losses_t_var

    if losses_v_invar < best_val_loss_invar:
        best_val_loss_invar = losses_v_invar
        best_train_loss_invar = losses_t_invar

    if losses_v_cross < best_val_cross:
        best_val_loss_cross = losses_v_cross
        best_train_loss_cross = losses_t_cross

    # early-stopping
    if losses_v_tot < best_val_loss_tot:
        best_val_loss_tot = losses_v_tot
        best_train_loss_tot = losses_t_tot

        stale_epochs = 0

    else:
        stale_epochs += 1

    t1 = time.time()

    epochs_remaining = n_epochs - (epoch + 1)
    time_per_epoch = (t1 - t0_initial) / (epoch + 1)
    eta = epochs_remaining * time_per_epoch / 60

    print(
        f"epoch={epoch + 1} / {n_epochs} "
        + f"train_loss={round(losses_train_tot[epoch], 4)} "
        + f"valid_loss={round(losses_valid_tot[epoch], 4)} "
        + f"stale={stale_epochs} "
        + f"time={round((t1-t0)/60, 2)}m "
        + f"eta={round(eta, 1)}m"
    )

print("----------------------------------------------------------")
print(f"Done with training. Total training time is {round((time.time() - t0_initial)/60,3)}min")

---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.13it/s]


epoch=1 / 30 train_loss=3.2089 valid_loss=3.2996 stale=0 time=0.29m eta=8.3m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.67it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.15it/s]


epoch=2 / 30 train_loss=3.1654 valid_loss=3.2897 stale=0 time=0.28m eta=8.0m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.67it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.16it/s]


epoch=3 / 30 train_loss=3.1164 valid_loss=3.2899 stale=1 time=0.28m eta=7.7m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.67it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.18it/s]


epoch=4 / 30 train_loss=3.1001 valid_loss=3.2833 stale=0 time=0.28m eta=7.4m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.67it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.15it/s]


epoch=5 / 30 train_loss=3.0678 valid_loss=3.2753 stale=0 time=0.28m eta=7.1m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.67it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.17it/s]


epoch=6 / 30 train_loss=3.2201 valid_loss=3.2975 stale=1 time=0.28m eta=6.8m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.16it/s]


epoch=7 / 30 train_loss=3.2181 valid_loss=3.2711 stale=0 time=0.29m eta=6.6m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.67it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.14it/s]


epoch=8 / 30 train_loss=3.0841 valid_loss=3.2647 stale=0 time=0.29m eta=6.3m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.18it/s]


epoch=9 / 30 train_loss=3.0369 valid_loss=3.2759 stale=1 time=0.29m eta=6.0m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.18it/s]


epoch=10 / 30 train_loss=3.0331 valid_loss=3.26 stale=0 time=0.29m eta=5.7m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.16it/s]


epoch=11 / 30 train_loss=2.9995 valid_loss=3.273 stale=1 time=0.29m eta=5.4m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.22it/s]


epoch=12 / 30 train_loss=3.0125 valid_loss=3.2483 stale=0 time=0.28m eta=5.1m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.67it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.17it/s]


epoch=13 / 30 train_loss=3.0231 valid_loss=3.2429 stale=0 time=0.29m eta=4.8m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.67it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.19it/s]


epoch=14 / 30 train_loss=2.9908 valid_loss=3.2481 stale=1 time=0.28m eta=4.6m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.17it/s]


epoch=15 / 30 train_loss=2.9585 valid_loss=3.2348 stale=0 time=0.29m eta=4.3m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.21it/s]


epoch=16 / 30 train_loss=2.9431 valid_loss=3.2435 stale=1 time=0.29m eta=4.0m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.18it/s]


epoch=17 / 30 train_loss=2.9832 valid_loss=3.2428 stale=2 time=0.29m eta=3.7m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.19it/s]


epoch=18 / 30 train_loss=2.97 valid_loss=3.2299 stale=0 time=0.29m eta=3.4m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.22it/s]


epoch=19 / 30 train_loss=2.9445 valid_loss=3.2382 stale=1 time=0.28m eta=3.1m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.17it/s]


epoch=20 / 30 train_loss=2.9457 valid_loss=3.2386 stale=2 time=0.29m eta=2.9m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.19it/s]


epoch=21 / 30 train_loss=2.9281 valid_loss=3.228 stale=0 time=0.29m eta=2.6m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.17it/s]


epoch=22 / 30 train_loss=2.9114 valid_loss=3.245 stale=1 time=0.29m eta=2.3m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.20it/s]


epoch=23 / 30 train_loss=2.9046 valid_loss=3.2469 stale=2 time=0.29m eta=2.0m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.21it/s]


epoch=24 / 30 train_loss=2.8903 valid_loss=3.2432 stale=3 time=0.29m eta=1.7m
---->Initiating a training run


100%|██████████████████████████████████████████████████████████████████████████████████████| 26/26 [00:15<00:00,  1.66it/s]


---->Initiating a validation run


100%|████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.16it/s]


epoch=25 / 30 train_loss=2.9265 valid_loss=3.25 stale=4 time=0.29m eta=1.4m
---->Initiating a training run


 12%|██████████                                                                             | 3/26 [00:01<00:13,  1.69it/s]

In [None]:
%matplotlib inline
plt.rcParams.update({"font.size": 15})

# make total loss plot
fig, ax = plt.subplots(2,2, figsize=(10,10))
ax[0,0].plot(range(len(losses_train_tot)), losses_train_tot, label="training ({:.2f})".format(best_train_loss_tot))
ax[0,0].plot(range(len(losses_valid_tot)), losses_valid_tot, label="validation ({:.2f})".format(best_val_loss_tot))
ax[0,0].set_xlabel("Epochs")
ax[0,0].set_ylabel("Total Loss")
ax[0,0].legend(title="VICReg", loc="best", title_fontsize=15, fontsize=10)

# make variance loss plot
ax[1,0].plot(range(len(losses_train_var)), losses_train_var, label="training ({:.2f})".format(best_train_loss_var))
ax[1,0].plot(range(len(losses_valid_var)), losses_valid_var, label="validation ({:.2f})".format(best_val_loss_var))
ax[1,0].set_xlabel("Epochs")
ax[1,0].set_ylabel("Variance Loss")
ax[1,0].legend(title="VICReg", loc="best", title_fontsize=15, fontsize=10)

# make invariance loss plot
ax[0,1].plot(range(len(losses_train_invar)), losses_train_invar, label="training ({:.2f})".format(best_train_loss_invar))
ax[0,1].plot(range(len(losses_valid_invar)), losses_valid_invar, label="validation ({:.2f})".format(best_val_loss_invar))
ax[0,1].set_xlabel("Epochs")
ax[0,1].set_ylabel("Invariance Loss")
ax[0,1].legend(title="VICReg", loc="best", title_fontsize=15, fontsize=10)

# make covariance loss plot
ax[1,1].plot(range(len(losses_train_cross)), losses_train_cross, label="training ({:.2f})".format(best_train_loss_cross))
ax[1,1].plot(range(len(losses_valid_cross)), losses_valid_cross, label="validation ({:.2f})".format(best_val_loss_cross))
ax[1,1].set_xlabel("Epochs")
ax[1,1].set_ylabel("Covariance Loss")
ax[1,1].legend(title="VICReg", loc="best", title_fontsize=15, fontsize=10)