# MODEL TRAINING AND VALIDATION

Google published two contrastive learning papers which I based my contrastive loss based off of. The first paper was the "Unsupervised Contrastive Learning" paper: http://proceedings.mlr.press/v119/chen20j/chen20j.pdf

The second one was the "Supervised Contrastive Learning" paper: https://proceedings.neurips.cc/paper_files/paper/2020/file/d89a66c7c80a29b1bdbab0f2a1a94af8-Paper.pdf

In this notebook, I implemented the supervised contrastive loss and modified it to learn the positive class samples, whereas the supervised contrastive loss from the second paper was designed to train the model for all classes.

As this is a small example, I'll only train the encoder module where the contrastive loss is used. If you check out the papers, you'll see that the encoder module is trained first. Then, the module is frozen and a classifier network is trained with a typical loss function, such as cross entropy loss or NLLLoss. In the papers, the authors used a simple multilayer perceptron (MLP) network as the classifier network. Attaching an MLP downstream of the encoder network is pretty trivial. So, I skipped that part in this example. What we would like to see in this example is a declining training loss as an indicator of the successfull training.

In [1]:
%cd ..

/Users/mk/Projects/contrastive-learning


In [2]:
import os
import sys
sys.path.append(os.getcwd() + "src")
import pandas as pd
import warnings
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import trange
from torch.utils.data import DataLoader
from torch.nn import functional as F
from hydra import initialize, compose
from tqdm.notebook import tqdm
from src.models import supervisedCL, LSTM, Dataset
warnings.filterwarnings("ignore", category=UserWarning, message='TypedStorage is depreciated') # Silence TypedStorage warnings



Import parameters

In [3]:
with initialize(config_path="../", version_base=None):
    cfg = compose(config_name="params.yaml")

Some reproducibility parameters.

In [4]:
torch.manual_seed(cfg.common.randomseed)
np.random.seed(cfg.common.randomseed)
torch.cuda.manual_seed(cfg.common.randomseed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.backends.cudnn.is_available
torch.backends.cudnn.benchmark = False #selects fasterst conv algo
torch.backends.cudnn.deterministic = True # for reproducibility

Load data

In [5]:
df_train = pd.read_parquet("data/03_processed/train.pq")

In [6]:
df_train

Unnamed: 0,ecfp,y
0,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0
1,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0
2,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0
3,"[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0
4,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0
...,...,...
6834,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0
6835,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0
6836,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1
6837,"[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",0


Create PyTorch dataset.

In [7]:
training_set = Dataset(torch.tensor(df_train['ecfp'], dtype=torch.float),
                       torch.tensor(df_train['y'], dtype=torch.int64))

  training_set = Dataset(torch.tensor(df_train['ecfp'], dtype=torch.float),


Create PyTorch dataloader.

In [8]:
train_dataloader = DataLoader(training_set,
                              batch_size=cfg.train.bsz,
                              shuffle=True,
                              pin_memory=True,
                              drop_last=False)

Initialize the encoder network, the loss criterion and optimizer.

In [9]:
encoder = LSTM(cfg.train.insz,
               cfg.train.outsz,
               cfg.train.hiddensz,
               cfg.train.nlayer)
criterion = supervisedCL(temp=cfg.train.temp)

if torch.cuda.is_available():
    encoder = encoder.to(cfg.train.device)
    criterion = criterion.to(cfg.train.device)

encoder_optimizer = torch.optim.Adam(encoder.parameters(), lr=cfg.train.lr)

Training loop

In [10]:
pbar = trange(cfg.train.epoch, desc=f"Train")

batch_num = 0
running_loss = 0.0

for epoch_num in pbar:
    epoch_loss = 0.0

    encoder.train()
    encoder_optimizer.zero_grad()

    for x, y in train_dataloader:
        elem, ct = torch.unique(y, return_counts=True)
        batch_num += 1
        x = x.to(cfg.train.device)
        x = torch.unsqueeze(x,1)
        y = y.to(cfg.train.device)
        learned_embeddings = encoder(x)
        learned_embeddings = F.normalize(learned_embeddings, dim=1)
        loss = criterion(learned_embeddings, y)
        loss.backward()
        encoder_optimizer.step()
        epoch_loss += loss.item()
        running_loss += loss.item()

    pbar.set_description(f"Epoch {epoch_num} - Loss {loss.item():.4f}")

final_loss = running_loss / batch_num
print(f"Loss: {final_loss:.5f}")


Epoch 19 - Loss 0.1627: 100%|██████████| 20/20 [00:16<00:00,  1.20it/s]

Loss: 0.51561





The loss is declining, indicating the contrastive loss function was coded successfully and the encoder module is training properly. That said, users should be careful with overtraining the encoder module. A proper optimization is necessary for acceptable performance on the validation set.