In [1]:
import pandas as pd
import numpy as np
import os
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, average_precision_score
import gc
from tqdm import tqdm
import time
from collections import defaultdict
import copy
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import autocast, GradScaler
from transformers import AutoModel, AutoTokenizer


import pdb


# Defining global parameters

In [2]:
def get_available_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"
    
CONFIG = {
    "seed": 42,
    "epochs": 12, # 42, ~MAX 20 hours of training
    "train_batch_size": 16,
    "valid_batch_size": 64,
    "learning_rate": 5e-5,
    "scheduler": 'CosineAnnealingLR',
    "min_lr": 5e-7,
    "T_max": 12,
    "weight_decay": 1e-6,
    "fold" : 0,
    "n_fold": 5,
    "n_accumulate": 1,
    "device": get_available_device(),
}

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything(seed=CONFIG['seed'])

# Read the already preprocessed and split data

In [4]:
train_df = pd.read_pickle('../data/train_split_remove_duplicates_all_embeddings.pkl')
valid_df = pd.read_pickle('../data/valid_split_remove_duplicates_all_embeddings.pkl')

## Wrapping in Pytorch Dataset

In [5]:
class EnvedaDataset(Dataset):
    def __init__(self, dataframe, labels = ['unable_to_assess', 'close_match', 
                                            'near_exact_match', 'exact_match']):
        """
        Args:
            dataframe (pd.DataFrame): A DataFrame containing 'ground_truth_embeddings', 
                                       'predicted_embeddings', and output columns.
        """
        self.dataframe = dataframe
        
        # Convert embeddings to tensors
        self.ground_truth_embeddings = torch.tensor(dataframe['ground_truth_embeddings'].tolist(), dtype=torch.float32)
        self.predicted_embeddings = torch.tensor(dataframe['predicted_embeddings'].tolist(), dtype=torch.float32)
        
        # Get molecular smiles
        self.ground_truth_smiles = dataframe['ground_truth_smiles'].to_list()
        self.predicted_smiles = dataframe['predicted_smiles'].to_list()

        # Convert labels to tensor
        self.labels = torch.tensor(dataframe[labels].values, dtype=torch.float32)

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.dataframe)

    def __getitem__(self, idx):
        """Generates one sample of data."""
        return self.ground_truth_embeddings[idx].squeeze(0), self.predicted_embeddings[idx].squeeze(0), self.labels[idx]
        # return self.ground_truth_smiles[idx], self.predicted_smiles[idx], self.labels[idx]

In [6]:
#---
trainset_unable = EnvedaDataset(dataframe=train_df, labels=['unable_to_assess'])
validset_unable = EnvedaDataset(dataframe=valid_df, labels=['unable_to_assess'])

#---
trainset_not_close = EnvedaDataset(dataframe=train_df, labels=['not_close_match'])
validset_not_close = EnvedaDataset(dataframe=valid_df, labels=['not_close_match'])

#---
trainset_close = EnvedaDataset(dataframe=train_df, labels=['close_match'])
validset_close = EnvedaDataset(dataframe=valid_df, labels=['close_match'])

#--
trainset_near = EnvedaDataset(dataframe=train_df, labels=['near_exact_match'])
validset_near = EnvedaDataset(dataframe=valid_df, labels=['near_exact_match'])

#--
trainset_exact = EnvedaDataset(dataframe=train_df, labels=['exact_match'])
validset_exact = EnvedaDataset(dataframe=valid_df, labels=['exact_match'])

#--
trainset_prioritization = EnvedaDataset(dataframe=train_df, labels=['good_enough_for_prioritization'])
validset_prioritization = EnvedaDataset(dataframe=valid_df, labels=['good_enough_for_prioritization'])


  self.ground_truth_embeddings = torch.tensor(dataframe['ground_truth_embeddings'].tolist(), dtype=torch.float32)


## Wrapping in Pytorch DataLoader

In [7]:
trainloader_unable = DataLoader(trainset_unable, batch_size=CONFIG['train_batch_size'], shuffle=True)
validloader_unable = DataLoader(validset_unable, batch_size=CONFIG['valid_batch_size'], shuffle=False)

trainloader_not_close = DataLoader(trainset_not_close, batch_size=CONFIG['train_batch_size'], shuffle=True)
validloader_not_close = DataLoader(validset_not_close, batch_size=CONFIG['valid_batch_size'], shuffle=False)

trainloader_close = DataLoader(trainset_close, batch_size=CONFIG['train_batch_size'], shuffle=True)
validloader_close = DataLoader(validset_close, batch_size=CONFIG['valid_batch_size'], shuffle=False)

trainloader_near = DataLoader(trainset_near, batch_size=CONFIG['train_batch_size'], shuffle=True)
validloader_near = DataLoader(validset_near, batch_size=CONFIG['valid_batch_size'], shuffle=False)

trainloader_exact = DataLoader(trainset_exact, batch_size=CONFIG['train_batch_size'], shuffle=True)
validloader_exact = DataLoader(validset_exact, batch_size=CONFIG['valid_batch_size'], shuffle=False)

trainloader_prioritization = DataLoader(trainset_prioritization, batch_size=CONFIG['train_batch_size'], shuffle=True)
validloader_prioritization = DataLoader(validset_prioritization, batch_size=CONFIG['valid_batch_size'], shuffle=False)

In [8]:
gt, pred, labels = next(iter(trainloader_unable))
print(gt.shape, pred.shape, labels.shape)


torch.Size([16, 768]) torch.Size([16, 768]) torch.Size([16, 1])


### Use larger Siamese network as in notebook 2.2_feature_aggregation_siamese_multiclass_FP

In [9]:
# Define the Siamese Network
class SiameseNetwork(nn.Module):
    def __init__(self, input_dim=768, output_dim=4):
        super(SiameseNetwork, self).__init__()
        
        self.fc = nn.Sequential(
        nn.Linear(input_dim, 1024),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Linear(1024, 512),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Linear(512, 256),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Linear(256, 128),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2))

        self.fc5 = nn.Linear(128, output_dim)

        
    def forward_one(self, x):
        return self.fc(x)

    def forward(self, input1, input2):
        # pdb.set_trace()
        # input1, input2 = self.smiler_embedder(**input1, output_hidden_states=True).pooler_output, \
        #                  self.smiler_embedder(**input2, output_hidden_states=True).pooler_output
        
        out1 = self.forward_one(input1)
        out2 = self.forward_one(input2)
        # pdb.set_trace()
        # Combine both outputs by subtraction
        combined = torch.sub(out1, out2)  # maybe torch.abs(out1 - out2)
        output = self.fc5(combined)                  # Outputs raw logits
        return output
    
    
# model(gt, pred)

In [None]:
model = SiameseNetwork(output_dim=1)
model.to(CONFIG['device'])
model(gt.to(CONFIG['device']), pred.to(CONFIG['device']))

# Training and validation regime

## Training with mixed precision, gradient accumulation, learning with scheduler
## Validation logging loss, AUROC, and F1 metrics

In [10]:
def train_one_epoch(model, optimizer, criterion, scheduler, dataloader, epoch=CONFIG['epochs']):
    model.train()
    
    scaler = GradScaler()
    dataset_size = 0
    running_loss = 0.0
    running_auroc  = 0.0
    running_f1 = 0.0
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        
        gt, preds, targets = data
        # gt, preds = tokenizer(gt, return_tensors='pt', padding=True, truncation=True)['input_ids'], \
        #             tokenizer(preds, return_tensors='pt', padding=True, truncation=True)['input_ids']
        gt, preds, targets = gt.to(CONFIG['device']), preds.to(CONFIG['device']), targets.to(CONFIG['device'])
        batch_size = targets.shape[0]

        with autocast():
            outputs = model(gt, preds)
            loss = criterion(outputs, targets)
            loss = loss / CONFIG['n_accumulate']
            
        # Backward pass with scaling
        scaler.scale(loss).backward()
    
        if (step + 1) % CONFIG['n_accumulate'] == 0:
            # Step the optimizer
            scaler.step(optimizer)

            # Update the scale for next iteration
            scaler.update()
            # optimizer.step()

            # zero the parameter gradients
            optimizer.zero_grad()

            if scheduler is not None:
                scheduler.step()
        
        # pdb.set_trace()
        probabilities = torch.softmax(outputs, dim=1).detach().cpu().numpy() if outputs.shape[1] > 1 else torch.sigmoid(outputs).detach().cpu().numpy()
        preds = np.eye(outputs.shape[1])[np.argmax(probabilities, axis=1)] if outputs.shape[1] > 1 else (probabilities > 0.5).astype(float)
        # pdb.set_trace()
        auroc = average_precision_score(targets.cpu().numpy(), probabilities)
        f1 = f1_score(targets.cpu().numpy(), preds, average='weighted')
        
        running_loss += (loss.item() * batch_size)
        running_auroc  += (auroc * batch_size)
        running_f1 += (f1 * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        epoch_auroc = running_auroc / dataset_size
        epoch_f1 = running_f1 / dataset_size
        
        bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss, Train_Auroc=epoch_auroc, Train_F1=epoch_f1,
                        LR=optimizer.param_groups[0]['lr'])
    gc.collect()
    
    return epoch_loss, epoch_auroc, epoch_f1

@torch.no_grad()
def valid_one_epoch(model, dataloader, criterion, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    running_auroc = 0.0
    running_f1 = 0.0 
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, data in bar:
        
        gt, preds, targets = data
        # gt, preds = tokenizer(gt, return_tensors='pt', padding=True, truncation=True)['input_ids'], \
        #             tokenizer(preds, return_tensors='pt', padding=True, truncation=True)['input_ids']
        gt, preds, targets = gt.to(CONFIG['device']), preds.to(CONFIG['device']), targets.to(CONFIG['device'])
        batch_size = targets.shape[0]
        
        with autocast():
            outputs = model(gt, preds)
            loss = criterion(outputs, targets)
            loss = loss / CONFIG['n_accumulate']
        
        probabilities = torch.softmax(outputs, dim=1).detach().cpu().numpy() if outputs.shape[1] > 1 else torch.sigmoid(outputs).detach().cpu().numpy()
        preds = np.eye(outputs.shape[1])[np.argmax(probabilities, axis=1)] if outputs.shape[1] > 1 else (probabilities > 0.5).astype(float)
        
        auroc = average_precision_score(targets.cpu().numpy(), probabilities)
        f1 = f1_score(targets.cpu().numpy(), preds, average='weighted')
        
        running_loss += (loss.item() * batch_size)
        running_auroc  += (auroc * batch_size)
        running_f1 += (f1 * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss / dataset_size
        epoch_auroc = running_auroc / dataset_size
        epoch_f1 = running_f1 / dataset_size
        
        bar.set_postfix(Epoch=epoch, Valid_Loss=epoch_loss, Valid_Auroc=epoch_auroc, 
                        Valid_F1=epoch_f1,
                        )   
    gc.collect()
    
    return epoch_loss, epoch_auroc, epoch_f1

## Initializing components
1. Model
2. AdamW optimizer
3. Cosine annealing scheduler
4. Weighted cross entropy loss to handle class imbalance

In [11]:
# Define models, optimizers, and schedulers for each dataset
model_unable = SiameseNetwork(output_dim=1).to(CONFIG['device'])
optimizer_unable = optim.AdamW(model_unable.parameters(), lr=CONFIG['learning_rate'], 
                               weight_decay=CONFIG['weight_decay'])
scheduler_unable = lr_scheduler.CosineAnnealingLR(optimizer_unable, T_max=CONFIG['T_max'], 
                                                   eta_min=CONFIG['min_lr'])

model_close = SiameseNetwork(output_dim=1).to(CONFIG['device'])
optimizer_close = optim.AdamW(model_close.parameters(), lr=CONFIG['learning_rate'], 
                               weight_decay=CONFIG['weight_decay'])
scheduler_close = lr_scheduler.CosineAnnealingLR(optimizer_close, T_max=CONFIG['T_max'], 
                                                 eta_min=CONFIG['min_lr'])

model_not_close = SiameseNetwork(output_dim=1).to(CONFIG['device'])
optimizer_not_close = optim.AdamW(model_not_close.parameters(), lr=CONFIG['learning_rate'], 
                                   weight_decay=CONFIG['weight_decay'])
scheduler_not_close = lr_scheduler.CosineAnnealingLR(optimizer_not_close, T_max=CONFIG['T_max'], 
                                                     eta_min=CONFIG['min_lr'])

model_near = SiameseNetwork(output_dim=1).to(CONFIG['device'])
optimizer_near = optim.AdamW(model_near.parameters(), lr=CONFIG['learning_rate'], 
                             weight_decay=CONFIG['weight_decay'])
scheduler_near = lr_scheduler.CosineAnnealingLR(optimizer_near, T_max=CONFIG['T_max'], 
                                                eta_min=CONFIG['min_lr'])

model_exact = SiameseNetwork(output_dim=1).to(CONFIG['device'])
optimizer_exact = optim.AdamW(model_exact.parameters(), lr=CONFIG['learning_rate'], 
                               weight_decay=CONFIG['weight_decay'])
scheduler_exact = lr_scheduler.CosineAnnealingLR(optimizer_exact, T_max=CONFIG['T_max'], 
                                                 eta_min=CONFIG['min_lr'])

model_prioritization = SiameseNetwork(output_dim=1).to(CONFIG['device'])
optimizer_prioritization = optim.AdamW(model_prioritization.parameters(), lr=CONFIG['learning_rate'], 
                                       weight_decay=CONFIG['weight_decay'])
scheduler_prioritization = lr_scheduler.CosineAnnealingLR(optimizer_prioritization, T_max=CONFIG['T_max'],
                                                          eta_min=CONFIG['min_lr'])


In [13]:

print(train_df[['unable_to_assess','not_close_match','close_match', 'near_exact_match','exact_match', 'good_enough_for_prioritization']].sum())
class_distribution = {'unable_to_assess': 10,
                      'not_close_match': 109, 
                      'close_match': 125, 
                      'near_exact_match': 66, 
                      'exact_match': 12,
                      'good_enough_for_prioritization': 224}

# Calculate class weights
total_samples = sum(class_distribution.values())
class_weights = {label: total_samples / (len(class_distribution) * count) for label, count in class_distribution.items()}
print(class_weights)

# Convert class weights to tensors for each dataset
pos_weights_unable = torch.tensor([class_weights['unable_to_assess']]).to(CONFIG['device'])
pos_weights_not_close = torch.tensor([class_weights['not_close_match']]).to(CONFIG['device'])
pos_weights_close = torch.tensor([class_weights['close_match']]).to(CONFIG['device'])
pos_weights_near = torch.tensor([class_weights['near_exact_match']]).to(CONFIG['device'])
pos_weights_exact = torch.tensor([class_weights['exact_match']]).to(CONFIG['device'])
pos_weights_prioritization = torch.tensor([class_weights['good_enough_for_prioritization']]).to(CONFIG['device'])

# Define the BCEWithLogitsLoss for each dataset
criterion_unable = nn.BCEWithLogitsLoss(pos_weight=pos_weights_unable)
criterion_not_close = nn.BCEWithLogitsLoss(pos_weight=pos_weights_not_close)
criterion_close = nn.BCEWithLogitsLoss(pos_weight=pos_weights_close)
criterion_near = nn.BCEWithLogitsLoss(pos_weight=pos_weights_near)
criterion_exact = nn.BCEWithLogitsLoss(pos_weight=pos_weights_exact)
criterion_prioritization = nn.BCEWithLogitsLoss(pos_weight=pos_weights_prioritization)



unable_to_assess                    7
not_close_match                    87
close_match                        99
near_exact_match                   48
exact_match                        11
good_enough_for_prioritization    176
dtype: int64
{'unable_to_assess': 9.1, 'not_close_match': 0.8348623853211009, 'close_match': 0.728, 'near_exact_match': 1.378787878787879, 'exact_match': 7.583333333333333, 'good_enough_for_prioritization': 0.40625}


In [None]:
# train_one_epoch(model=model_exact, optimizer=optimizer_exact, criterion=criterion_exact, scheduler=scheduler_exact, dataloader=trainloader_exact)
# valid_one_epoch(model=model_prioritization, dataloader=trainloader_prioritization, criterion=criterion_prioritization, epoch=CONFIG['epochs'])

## Putting all together into training code
Training code includes:
1. Early stopping
2. Saving best model weights according to supplied name
3. Original code adapted from Kaggle

In [None]:

def run_training(model, optimizer, scheduler, criterion, num_epochs, train_loader, valid_loader, name):
    if torch.cuda.is_available():
        print("[INFO] Using GPU: {}\n".format(torch.cuda.get_device_name()))
    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_epoch_f1 = -np.inf
    best_valid_loss = np.inf
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        train_epoch_loss, train_epoch_auroc, train_epoch_f1 = train_one_epoch(model=model, optimizer=optimizer, scheduler=scheduler, 
                                           criterion=criterion, dataloader=train_loader, 
                                           epoch=epoch)
        
        val_epoch_loss, val_epoch_auroc, val_epoch_f1 = valid_one_epoch(model=model, dataloader=valid_loader, criterion=criterion, 
                                         epoch=epoch)
    
        history['Train Loss'].append(train_epoch_loss)
        history['Valid Loss'].append(val_epoch_loss)
        history['Train AUROC'].append(train_epoch_auroc)
        history['Valid AUROC'].append(val_epoch_auroc)
        history['Valid F1'].append(val_epoch_f1)
        history['lr'].append( scheduler.get_lr()[0] )
        if val_epoch_loss <= best_valid_loss:
            print(f"Validation Loss Improved ({best_valid_loss} ---> {val_epoch_loss})")
            best_valid_loss = val_epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"best_VAL_LOSS_model_{name}.bin"
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            print(f"Model Saved")
        
        if best_epoch_f1 <= val_epoch_f1:
            print(f"Validation F1 Improved ({best_epoch_f1} ---> {val_epoch_f1})")
            best_epoch_f1 = val_epoch_f1
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"best_F1_model_{name}.bin"
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            print(f"Model Saved")
            
        print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best F1: {:.4f}".format(best_epoch_f1))
    print("Best Loss: {:.4f}".format(best_valid_loss))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

In [None]:
import concurrent.futures

# Define a function to run the training for a single model
def train_model(model, optimizer, scheduler, criterion, train_loader, valid_loader, name):
    return run_training(
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        criterion=criterion,
        num_epochs=200,
        train_loader=train_loader,
        valid_loader=valid_loader,
        name=name
    )

# Prepare model-specific parameters
model_params = [
    (model_unable, optimizer_unable, scheduler_unable, criterion_unable, trainloader_unable, validloader_unable, 'siamese_unable'),
    (model_not_close, optimizer_not_close, scheduler_not_close, criterion_not_close, trainloader_not_close, validloader_not_close, 'siamese_not_close'),
    (model_close, optimizer_close, scheduler_close, criterion_close, trainloader_close, validloader_close, 'siamese_close'),
    (model_near, optimizer_near, scheduler_near, criterion_near, trainloader_near, validloader_near, 'siamese_near'),
    (model_exact, optimizer_exact, scheduler_exact, criterion_exact, trainloader_exact, validloader_exact, 'siamese_exact'),
    (model_prioritization, optimizer_prioritization, scheduler_prioritization, criterion_prioritization, trainloader_prioritization, validloader_prioritization, 'siamese_prioritization')
]

# Use ThreadPoolExecutor to parallelize the training
with concurrent.futures.ThreadPoolExecutor() as executor:
    futures = {executor.submit(train_model, *params): name for params, name in zip(model_params, ['siamese_unable', 'siamese_close', 'siamese_near', 'siamese_exact', 'siamese_prioritization'])}
    
    for future in concurrent.futures.as_completed(futures):
        model_name = futures[future]
        try:
            trained_model, history = future.result()
            print(f"Training completed for model: {model_name}")
        except Exception as e:
            print(f"Model {model_name} generated an exception: {e}")


In [14]:
model_unable.load_state_dict(torch.load('../results/one_vs_all/best_F1_model_siamese_unable.bin', map_location=torch.device(CONFIG['device'])))
model_near.load_state_dict(torch.load('../results/one_vs_all/best_F1_model_siamese_near.bin', map_location=torch.device(CONFIG['device'])))
model_not_close.load_state_dict(torch.load('../results/one_vs_all/best_F1_model_siamese_not_close.bin', map_location=torch.device(CONFIG['device'])))

model_close.load_state_dict(torch.load('../results/one_vs_all/best_F1_model_siamese_close.bin', map_location=torch.device(CONFIG['device'])))
model_exact.load_state_dict(torch.load('../results/one_vs_all/best_F1_model_siamese_exact.bin', map_location=torch.device(CONFIG['device'])))
# model_prioritization.load_state_dict(torch.load('../results/best_F1_model_siamese_prioritization.bin', map_location=torch.device(CONFIG['device'])))

<All keys matched successfully>

In [15]:
models = [model_unable.eval(), model_not_close.eval(), model_close.eval(), model_near.eval(), model_exact.eval()]
# SAME ORDER OF PUTTING MODELS TOGETHER AS ORDER OF LABELS

In [16]:
trainset = EnvedaDataset(dataframe=train_df, labels=['unable_to_assess', 'not_close_match', 'close_match', 
                                                     'near_exact_match', 'exact_match'])
validset = EnvedaDataset(dataframe=valid_df,  labels=['unable_to_assess', 'not_close_match', 'close_match', 
                                                     'near_exact_match', 'exact_match'])

trainloader = DataLoader(trainset, batch_size=CONFIG['train_batch_size'], shuffle=True)
validloader = DataLoader(validset, batch_size=CONFIG['valid_batch_size'], shuffle=False)

criterion = nn.BCEWithLogitsLoss()


# One Vs. All evaluation

In [20]:
    
# Initialize metrics
dataset_size = 0
running_loss = 0.0
running_auroc = 0.0
running_f1 = 0.0 

bar = tqdm(enumerate(validloader), total=len(validloader))
for step, data in bar:
    
    gt, preds, targets = data
    gt, preds, targets = gt.to(CONFIG['device']), preds.to(CONFIG['device']), targets.to(CONFIG['device'])
    batch_size = gt.shape[0]
    
    # Initialize outputs for all models
    all_outputs = []

    with autocast():
        for model in models:
            outputs = model(gt, preds)  # Forward pass through each model
            
            # Store outputs for each model
            all_outputs.append(outputs)

        # Stack outputs from all models
        combined_outputs = torch.cat(all_outputs, dim=1)  # Shape: (num_models, batch_size, num_classes)
        # pdb.set_trace()
        # Apply sigmoid or softmax based on the output shape
        # if combined_outputs.shape[1] > 1:
        #     probabilities = torch.softmax(combined_outputs, dim=1)  # Average over models
        # else:
        #     probabilities = torch.sigmoid(combined_outputs).mean(dim=0)  # Average over models
        probabilities = torch.sigmoid(combined_outputs)/torch.sigmoid(combined_outputs).sum(dim=1, keepdim=True)
        # Convert probabilities to predictions
        preds = np.eye(probabilities.shape[1])[np.argmax(probabilities.detach().cpu().numpy(), axis=1)] if probabilities.shape[1] > 1 else (probabilities > 0.5).float()
        # pdb.set_trace()
        # Calculate loss (assuming you have a criterion defined)
        loss = criterion(probabilities, targets)
        loss = loss / CONFIG['n_accumulate']
    
    # Compute metrics
    auroc = average_precision_score(targets.cpu().numpy(), probabilities.detach().cpu().numpy(), average='macro')
    f1 = f1_score(targets.cpu().numpy(), preds, average='weighted')
    
    # Update running totals
    running_loss += (loss.item() * batch_size)
    running_auroc += (auroc * batch_size)
    running_f1 += (f1 * batch_size)
    dataset_size += batch_size
    
    # Calculate epoch metrics
    epoch_loss = running_loss / dataset_size
    epoch_auroc = running_auroc / dataset_size
    epoch_f1 = running_f1 / dataset_size
    
    bar.set_postfix(Epoch=1, Valid_Loss=epoch_loss, Valid_Auroc=epoch_auroc, 
                    Valid_F1=epoch_f1,
                    )   

gc.collect()

  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
100%|██████████| 2/2 [00:00<00:00, 55.34it/s, Epoch=1, Valid_Auroc=0.308, Valid_F1=0.29, Valid_Loss=0.768]


0

# Plotting training metrics

In [None]:
def plot_metrics(metrics_dict):
    epochs = range(1, len(metrics_dict['Train Loss']) + 1)
    
    # Create a figure with subplots
    fig, axs = plt.subplots(3, 1, figsize=(10, 15), sharex=True)

    # Plot Train and Validation Loss
    axs[0].plot(epochs, metrics_dict['Train Loss'], label='Train Loss', color='blue', marker='o')
    axs[0].plot(epochs, metrics_dict['Valid Loss'], label='Valid Loss', color='orange', marker='o')
    axs[0].set_title('Loss Over Epochs')
    axs[0].set_ylabel('Loss')
    axs[0].legend()
    axs[0].grid()

    # Plot Train and Validation AUROC
    axs[1].plot(epochs, metrics_dict['Train AUROC'], label='Train AUROC', color='green', marker='o')
    axs[1].plot(epochs, metrics_dict['Valid AUROC'], label='Valid AUROC', color='red', marker='o')
    axs[1].set_title('AUROC Over Epochs')
    axs[1].set_ylabel('AUROC Score')
    axs[1].legend()
    axs[1].grid()

    
    # Plot Validation F1 Score
    axs[2].plot(epochs, metrics_dict['Valid F1'], label='Valid F1 Score', color='purple', marker='o')
    axs[2].set_title('Validation F1 Score Over Epochs')
    axs[2].set_ylabel('F1 Score')
    
    
    # Create a separate plot for Learning Rate
    fig_lr, ax_lr = plt.subplots(figsize=(10, 5))
    ax_lr.plot(epochs, metrics_dict['lr'], label='Learning Rate', color='cyan', linestyle='--', marker='o')
    ax_lr.set_title('Learning Rate Over Epochs')
    ax_lr.set_ylabel('Learning Rate')
    ax_lr.grid()
    fontsize = 12
    # Label points for each plot
    for i in range(0, len(metrics_dict['lr']), 10):
        axs[0].text(epochs[i], metrics_dict['Train Loss'][i], f"{metrics_dict['Train Loss'][i]:.2f}", 
                    fontsize=fontsize, ha='right', color='k')
        axs[0].text(epochs[i], metrics_dict['Valid Loss'][i], f"{metrics_dict['Valid Loss'][i]:.2f}", 
                    fontsize=fontsize, ha='right', color='k')
        
        axs[1].text(epochs[i], metrics_dict['Train AUROC'][i], f"{metrics_dict['Train AUROC'][i]:.2f}", 
                    fontsize=fontsize, ha='right', color='k')
        axs[1].text(epochs[i], metrics_dict['Valid AUROC'][i], f"{metrics_dict['Valid AUROC'][i]:.2f}", 
                    fontsize=fontsize, ha='right', color='k')

        axs[2].text(epochs[i], metrics_dict['Valid F1'][i], f"{metrics_dict['Valid F1'][i]:.2f}", 
                    fontsize=fontsize, ha='right', color='k')

        ax_lr.text(epochs[i], metrics_dict['lr'][i], f"{metrics_dict['lr'][i]:.4f}", 
                   fontsize=fontsize, ha='right', color='k')

    # Set common x-label
    axs[-1].set_xlabel('Epochs')
    
    plt.tight_layout()
    
    # Show plots
    plt.show()

In [None]:
plot_metrics(history)