# Importing libraries

In [57]:
# %% Importing Libraries
import os
import sys
import pickle
import argparse
import time
import datetime
import random
from pathlib import Path
from tqdm import tqdm

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler, random_split, TensorDataset

import torch.nn as nn

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict

import lightning as L
import lightning.pytorch as pl
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
# from lightning.pytorch.strategies import DeepSpeedStrategy
# from lightning.pytorch.plugins.precision import DeepSpeedPrecisionPlugin

# Custom library
sys.path.append('../process/')
from loadData import HTContraDataModule
from contraUtilities import compute_sim_matrix, compute_target_matrix, kl_contrastive_loss, supervised_infoNCE_loss, supervised_infoNCE_loss_with_negatives
from contraUtilities import supervised_contrastive_loss, supervised_contrastive_loss_with_negatives, create_triplets, compute_triplet_loss


import warnings
warnings.filterwarnings('ignore')

In [58]:
# Create the parser
parser = argparse.ArgumentParser(description="Train a transformers-based classifier.")

# Add arguments to the parser
parser.add_argument('--model_name_or_path', type=str, default="johngiorgi/declutr-small")
parser.add_argument('--tokenizer_name_or_path', type=str, default="johngiorgi/declutr-small")
parser.add_argument('--logged_entry_name', type=str, default="declutr-small-contra-temp:0.5-seed:1111")
parser.add_argument('--data_dir', type=str, default='../data/processed/')
parser.add_argument('--demography', type=str, default='south')
parser.add_argument('--save_dir', type=str, default=os.path.join(os.getcwd(), "../models/text-baselines/contra-learn"))
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--nb_epochs', type=int, default=100)
parser.add_argument('--max_seq_length', type=int, default=512)
parser.add_argument('--sample_unit_size', type=int, default=2)
parser.add_argument('--emb_len', type=int, default=768)
parser.add_argument('--hidden_dim', type=int, default=512)
parser.add_argument('--patience', type=int, default=5)
parser.add_argument('--seed', type=int, default=1111)
parser.add_argument('--warmup_steps', type=int, default=0)
parser.add_argument('--grad_steps', type=int, default=4)
parser.add_argument('--learning_rate', type=float, default=0.0001)
parser.add_argument('--dropout', type=float, default=0.3)
parser.add_argument('--train_data_percentage', type=float, default=1.0)
parser.add_argument('--adam_epsilon', type=float, default=1e-6)
parser.add_argument('--min_delta_change', type=float, default=0.5)
parser.add_argument('--temp', type=float, default=0.5)
parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--nb_triplets', type=int, default=1)

# Simulate the command line inputs (change these strings to simulate different command line arguments)
input_args = [
    '--model_name_or_path', 'johngiorgi/declutr-small',
    '--tokenizer_name_or_path', 'johngiorgi/declutr-small',
    '--logged_entry_name', 'declutr-small-contra-temp:0.5-seed:1111',
    '--data_dir', '../data/processed/',
    '--demography', 'south',
    '--save_dir', os.path.join(os.getcwd(), "../models/text-baselines/contra-learn"),
    '--batch_size', '32',
    '--nb_epochs', '2',
    '--max_seq_length', '512',
    '--sample_unit_size', '2',
    '--emb_len', '768',
    '--hidden_dim', '512',
    '--patience', '5',
    '--seed', '1111',
    '--warmup_steps', '0',
    '--grad_steps', '4',
    '--learning_rate', '0.0001',
    '--dropout', '0.3',
    '--train_data_percentage', '1.0',
    '--adam_epsilon', '1e-6',
    '--min_delta_change', '0.5',
    '--temp', '0.5',
    '--weight_decay', '0.01',
    '--nb_triplets', '1'
]

# Parse the arguments
args = parser.parse_args(input_args)

In [59]:
args.pooling = True

In [60]:
# Setting seed value for reproducibility    
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(args.seed)
random.seed(args.seed)
os.environ['PYTHONHASHSEED'] = str(args.seed)
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
seed_everything(args.seed)

Seed set to 1111


1111

In [61]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [62]:
# Creating directories
directory = os.path.join(args.save_dir, "seed:" + str(args.seed), "lr-" + str(args.learning_rate), args.demography, args.model_name_or_path.split("/")[-1])
Path(directory).mkdir(parents=True, exist_ok=True)

In [63]:
# Loading data
dm = HTContraDataModule(file_dir=os.path.join(args.data_dir, args.demography + '.csv'), tokenizer_name_or_path=args.tokenizer_name_or_path, seed=args.seed, train_batch_size=args.batch_size, 
                        eval_batch_size=args.batch_size)
dm.setup(stage="fit")

In [64]:
args.num_classes = pd.read_csv(os.path.join(args.data_dir, args.demography + '.csv')).VENDOR.nunique()
args.num_training_steps = len(dm.train_dataloader()) * args.nb_epochs
# Setting the warmup steps to 1/10th the size of training data
args.warmup_steps = int(len(dm.train_dataloader()) * 10/100)

# Model Architecture

In [65]:
import sys
from sklearn.metrics import f1_score, balanced_accuracy_score

import torch
from torch import nn
import lightning.pytorch as pl

from transformers import RobertaModel
from transformers import AutoModel
from transformers import AdamW, get_linear_schedule_with_warmup

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions

In [66]:
class MLPLayer(nn.Module):
    """
    Head for getting sentence representations over RoBERTa/BERT's CLS representation.
    """
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, features, **kwargs):
        x = self.dense(features)
        x = self.activation(x)

        return x

class Similarity(nn.Module):
    """
    Dot product or cosine similarity
    """

    def __init__(self, temp):
        super().__init__()
        self.temp = temp
        self.cos = nn.CosineSimilarity(dim=-1)

    def forward(self, x, y):
        return self.cos(x, y) / self.temp

class Pooler(nn.Module):
    """
    Parameter-free poolers to get the sentence embedding
    'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
    'cls_before_pooler': [CLS] representation without the original MLP pooler.
    'avg': average of the last layers' hidden states at each token.
    'avg_top2': average of the last two layers.
    'avg_first_last': average of the first and the last layers.
    """
    def __init__(self, pooler_type):
        super().__init__()
        self.pooler_type = pooler_type
        assert self.pooler_type in ["cls", "cls_before_pooler", "avg", "avg_top2", "avg_first_last"], "unrecognized pooling type %s" % self.pooler_type

    def forward(self, attention_mask, outputs):
        last_hidden = outputs.last_hidden_state
        pooler_output = outputs.pooler_output
        hidden_states = outputs.hidden_states

        if self.pooler_type in ['cls_before_pooler', 'cls']:
            return last_hidden[:, 0]
        else:
            raise NotImplementedError

def cl_init(cls, pooler_type_, config, temp_):
    """
    Contrastive learning class init function.
    """
    cls.pooler_type = pooler_type_
    cls.pooler = Pooler(pooler_type_)
    cls.mlp = MLPLayer(config)
    cls.sim = Similarity(temp=temp_)
    cls.init_weights()

def cl_forward(cls, encoder, pooler_type, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, 
                labels=None, output_attentions=None, output_hidden_states=None, return_dict=None):
    #import ipdb; ipdb.set_trace();
    return_dict = return_dict if return_dict is not None else cls.config.use_return_dict
    batch_size = int(input_ids.size(0))
    
    # mlm_outputs = None
    # Flatten input for encoding
    # input_ids = input_ids.view((-1, input_ids.size(-1))) # (bs * num_sent, len)
    # attention_mask = attention_mask.view((-1, attention_mask.size(-1))) # (bs * num_sent len)
    # if token_type_ids is not None:
    #     token_type_ids = token_type_ids.view((-1, token_type_ids.size(-1))) # (bs * num_sent, len)

    # Get raw embeddings
    outputs = encoder(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds,
                        output_attentions=output_attentions, output_hidden_states=False if pooler_type == 'cls' else True, return_dict=True)

    # Pooling
    pooler_output = cls.pooler(attention_mask, outputs)
    pooler_output = pooler_output.view((batch_size, pooler_output.size(-1))) # (bs, num_sent, hidden)

    # If using "cls", we add an extra MLP layer
    # (same as BERT's original implementation) over the representation.
    if cls.pooler_type == "cls":
        pooler_output = cls.mlp(pooler_output)

    return pooler_output

def sentemb_forward(cls, encoder, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, 
                    inputs_embeds=None, labels=None, output_attentions=None, output_hidden_states=None, return_dict=None):

    return_dict = return_dict if return_dict is not None else cls.config.use_return_dict

    outputs = encoder(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask,
                        inputs_embeds=inputs_embeds, output_attentions=output_attentions, output_hidden_states=True if cls.pooler_type in ['avg_top2', 'avg_first_last'] else False,
                        return_dict=True)

    pooler_output = cls.pooler(attention_mask, outputs)
    if cls.pooler_type == "cls":
        pooler_output = cls.mlp(pooler_output)

    if not return_dict:
        return (outputs[0], pooler_output) + outputs[2:]

    return BaseModelOutputWithPoolingAndCrossAttentions(pooler_output=pooler_output, last_hidden_state=outputs.last_hidden_state, 
                                                        hidden_states=outputs.hidden_states)

def compute_sim_matrix(feats):
    """
    Takes in a batch of features of size (bs, feat_len).
    """
    sim_matrix = F.cosine_similarity(feats.unsqueeze(2).expand(-1, -1, feats.size(0)),
                                     feats.unsqueeze(2).expand(-1, -1, feats.size(0)).transpose(0, 2),
                                     dim=1)

    return sim_matrix

def compute_target_matrix(labels):
    """
    Computes a target matrix for contrastive learning based on class labels.

    This function generates a square matrix where each element (i, j) indicates whether 
    the labels for samples i and j are the same. This binary matrix serves as the target 
    for similarity in contrastive learning tasks, facilitating the training to learn 
    embeddings that are closer for similar (same label) instances and further apart for 
    dissimilar (different label) ones.

    Parameters:
        labels (torch.Tensor): A 1D tensor containing the class labels for a batch of samples.
                               The tensor should have a shape of (bs,), where bs is the batch size.

    Returns:
        torch.Tensor: A 2D square tensor of shape (bs, bs) where each element is 1.0 if the 
                      corresponding labels are the same, and 0.0 otherwise. This tensor is of
                      type float.
    """
    label_matrix = labels.unsqueeze(-1).expand((labels.shape[0], labels.shape[0]))
    trans_label_matrix = torch.transpose(label_matrix, 0, 1)
    target_matrix = (label_matrix == trans_label_matrix).type(torch.float)

    return target_matrix


def contrastive_loss(pred_sim_matrix, target_matrix, temperature, labels):
    return F.kl_div(F.softmax(pred_sim_matrix / temperature).log(), F.softmax(target_matrix / temperature),
                    reduction="batchmean", log_target=False)

In [67]:
def supervised_info_nce_loss(features, labels, temperature=0.1):
    """
    Computes the Supervised InfoNCE loss using class labels to determine positive and negative pairs.

    Args:
        features (torch.Tensor): Embeddings of shape (N, D) where N is batch size and D is embedding dimension.
        labels (torch.Tensor): Corresponding labels of shape (N,).
        temperature (float): A temperature scaling factor to soften the softmax output.

    Returns:
        torch.Tensor: The computed loss.
    """
    device = features.device
    n_samples = features.size(0)
    labels = labels.contiguous().view(-1, 1)

    # Normalize the features to simplify the cosine similarity calculation
    features = F.normalize(features, p=2, dim=1)

    # Compute cosine similarity matrix
    similarity_matrix = torch.mm(features, features.T) / temperature

    # Create a mask to identify positive and negative samples
    positive_mask = torch.eq(labels, labels.T).float().to(device)

    # Subtract large value from diagonal to ignore self-comparison
    logits_mask = torch.scatter(
        torch.ones_like(similarity_matrix),
        1,
        torch.arange(n_samples).view(-1, 1).to(device),
        0
    )
    masked_logits = similarity_matrix * logits_mask

    # Compute log-sum-exp across all samples (for denominator of softmax)
    exp_logits = torch.exp(masked_logits)
    log_prob = similarity_matrix - torch.log(exp_logits.sum(1, keepdim=True))

    # Compute mean of log-likelihood over positive samples
    mean_log_prob_pos = (positive_mask * log_prob).sum(1) / positive_mask.sum(1)

    # Loss is negative log of mean positive log-probabilities
    loss = -mean_log_prob_pos.mean()
    return loss

In [68]:
def supervised_info_nce_loss_with_hard_negatives(features, labels, temperature=0.1, num_hard_negatives=5, epsilon=1e-6):
    device = features.device
    n_samples = features.size(0)
    labels = labels.view(-1, 1)

    # Normalize features to the unit sphere to simplify cosine similarity
    features = F.normalize(features, p=2, dim=1)

    # Compute cosine similarity matrix and scale by temperature
    similarity_matrix = torch.mm(features, features.T) / temperature

    # Mask to select positive samples
    positive_mask = torch.eq(labels, labels.T).float().to(device)

    # Mask to exclude self-similarity
    diagonal_mask = torch.eye(n_samples, device=device).bool()
    positive_mask[diagonal_mask] = 0

    # Hard negatives: select the top-k negative examples (excluding self)
    negative_mask = 1 - positive_mask
    negative_mask[diagonal_mask] = 0
    top_negatives = torch.topk(similarity_matrix * negative_mask, k=num_hard_negatives, dim=1).values

    # Compute log-sum-exp for normalization over chosen hard negatives and all positives
    positive_counts = positive_mask.sum(1).int().tolist()  # Convert to int here
    positives = similarity_matrix[positive_mask.bool()].split(positive_counts)  # Split into lists of positives per sample
    positives_padded = torch.nn.utils.rnn.pad_sequence(positives, batch_first=True, padding_value=-float('Inf'))  # Pad sequences to allow concatenation

    max_sim, _ = torch.max(torch.cat([positives_padded, top_negatives], dim=1), dim=1, keepdim=True)
    exp_sim_sum = torch.exp(similarity_matrix - max_sim).clamp(min=epsilon).sum(1, keepdim=True)

    # Compute log probability of positive samples
    log_prob_pos = torch.log(torch.exp(positives_padded - max_sim).clamp(min=epsilon).sum(1, keepdim=True) / exp_sim_sum + epsilon)

    # Negative log likelihood
    loss = -log_prob_pos.mean()
    return loss

In [69]:
def supervised_contrastive_loss(features, labels, temperature=0.1):
    """
    Computes the supervised contrastive loss.

    Parameters:
        features (torch.Tensor): The embeddings for the batch, shape (N, D)
                                 where N is the batch size and D is the dimension of the embeddings.
        labels (torch.Tensor): The class labels for the batch, shape (N,)
                               with each value in range [0, C-1] where C is the number of classes.
        temperature (float): A temperature scaling factor to adjust the sharpness of
                             the softmax distribution.

    Returns:
        torch.Tensor: The computed supervised contrastive loss.
    """
    device = features.device
    labels = labels.to(device)
    batch_size = features.shape[0]

    # Normalize the features to the unit sphere
    features = F.normalize(features, p=2, dim=1)

    # Compute the cosine similarity matrix
    sim_matrix = torch.mm(features, features.t()) / temperature

    # Mask for identifying positive pairs (excluding the diagonal)
    pos_mask = labels.unsqueeze(1) == labels.unsqueeze(0)
    pos_mask.fill_diagonal_(0)

    # Compute the logits
    exp_sim = torch.exp(sim_matrix) * pos_mask

    # Sum of exps for all positive pairs
    sum_pos = exp_sim.sum(dim=1, keepdim=True)

    # Log-sum-exp trick for numerical stability
    logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)
    log_prob_denom = torch.log(torch.exp(sim_matrix - logits_max).sum(dim=1, keepdim=True) + 1e-12) + logits_max

    # Compute the loss
    loss = -torch.log(sum_pos / torch.exp(log_prob_denom) + 1e-12)
    loss = loss.mean()

    return loss

In [70]:
def supervised_contrastive_loss_with_hard_negatives(features, labels, temperature=0.1, num_hard_negatives=0, eps=1e-8):
    """
    Computes the supervised contrastive loss with hard negatives, adding safeguards against numerical instability.

    Parameters:
        features (torch.Tensor): Embeddings of shape (N, D).
        labels (torch.Tensor): Class labels of shape (N,).
        temperature (float): Temperature scaling factor.
        num_hard_negatives (int): Number of hard negatives to consider.
        eps (float): Small epsilon value for numerical stability in log operations.

    Returns:
        torch.Tensor: The computed supervised contrastive loss.
    """
    device = features.device
    labels = labels.to(device)
    batch_size = features.shape[0]

    # Normalize the features to the unit sphere
    features = F.normalize(features, p=2, dim=1)

    # Compute the cosine similarity matrix and apply temperature scaling
    sim_matrix = torch.mm(features, features.t()) / temperature

    # Mask for identifying positive pairs (excluding the diagonal)
    pos_mask = labels.unsqueeze(1) == labels.unsqueeze(0)
    pos_mask.fill_diagonal_(0)

    # Mask for identifying negatives
    neg_mask = labels.unsqueeze(1) != labels.unsqueeze(0)

    # Select hard negatives: top 'num_hard_negatives' from each row, considering only negatives
    negative_scores = sim_matrix * neg_mask.float()
    top_negatives, _ = torch.topk(negative_scores, k=num_hard_negatives, dim=1)

    # Sum of exps for all positive pairs and hard negatives
    exp_pos = torch.exp(sim_matrix * pos_mask.float()).sum(dim=1, keepdim=True)
    exp_hard_neg = torch.exp(top_negatives)

    # Log-sum-exp trick for numerical stability: max_sim for each row
    max_sim, _ = torch.max(sim_matrix, dim=1, keepdim=True)
    exp_sum = torch.exp(sim_matrix - max_sim).sum(dim=1, keepdim=True)

    # Combine positives and hard negatives in the denominator
    denom = exp_sum + exp_hard_neg.sum(dim=1, keepdim=True) - torch.exp(top_negatives - max_sim)

    # Log probability of positive pairs
    log_prob_pos = torch.log(exp_pos / (denom + eps) + eps)

    # Compute the mean of negative log probabilities across the batch
    loss = -log_prob_pos.mean()
    return loss

In [71]:
class DeclutrClassifier(nn.Module):
    def __init__(self, model, classifier):
        super().__init__()
        self.model = model
        self.fc = classifier

    def forward(self, x, pooling, return_feat=False):
        # x is a tokenized input
        # features = self.model(input_ids=x[0], token_type_ids=x[1], attention_mask=x[2])
        features = self.model(input_ids=x[0], attention_mask=x[2])

        # Get the last hidden state
        last_hidden_state = features.last_hidden_state

        if pooling == True:            
            # [CLS] token representation - generally the first token in the sequence
            cls_representation = last_hidden_state[:, 0, :]
            
            # Mean pooling - mean over the sequence dimension, ignoring padding (use attention_mask)
            expanded_attention_mask = x[2].unsqueeze(-1).expand_as(last_hidden_state)
            sum_hidden_state = torch.sum(last_hidden_state * expanded_attention_mask, dim=1)
            sum_mask = torch.clamp(expanded_attention_mask.sum(1), min=1e-9)
            mean_pooled = sum_hidden_state / sum_mask
            last_layer_representation = mean_pooled
        else:
            # Flattenning the output of last layer
            last_layer_representation = last_hidden_state.view(last_hidden_state.size(0), -1)  # Reshape to [batch_size, features]
        
        out = self.fc(last_layer_representation)

        if return_feat:
            return out, cls_representation, last_layer_representation

        return out

"""class DeclutrClassifier(nn.Module):
    def __init__(self, model, classifier):
        super().__init__()
        self.model = model
        self.fc = classifier

    def forward(self, x, pooling, return_feat=False):
        # x is a tokenized input
        # feature = self.model(input_ids=x[0], token_type_ids=x[1], attention_mask=x[2])
        feature = self.model(input_ids=x[0], attention_mask=x[2])
        # out = self.fc(feature.pooler_output.flatten(1))       # not good for our task     # (BS, E)
        hidden_states = feature["hidden_states"]
        print("Output to Logistic Regression;", feature.last_hidden_state.flatten(1).shape)
        out = self.fc(feature.last_hidden_state.flatten(1))  # (BS, T, E)
        if return_feat:
            return out, hidden_states, feature.last_hidden_state.flatten(1)
        return out
"""

class LogisticRegression(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, dropout=0):
        super().__init__()
        # print(f'Logistic Regression classifier of dim ({in_dim} {hid_dim} {out_dim})')

        self.nn = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(in_dim, hid_dim, bias=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(hid_dim, out_dim, bias=True),
        )

    def forward(self, x, return_feat=False):
        out = self.nn(x)
        if return_feat:
            return out, x
        return out

class HTContraClassifierModel(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        
        self.save_hyperparameters()
        if isinstance(args, tuple) and len(args) > 0: 
            self.args = args[0]
            self.hparams.learning_rate = self.args.learning_rate
            self.hparams.eps = self.args.adam_epsilon
            self.hparams.weight_decay = self.args.weight_decay
            self.hparams.model_name_or_path = self.args.model_name_or_path
            self.hparams.num_classes = self.args.num_classes
            self.hparams.num_training_steps = self.args.num_training_steps
            self.hparams.warmup_steps = self.args.warmup_steps
            self.hparams.emb_len = self.args.emb_len
            self.hparams.max_seq_length = self.args.max_seq_length
            self.hparams.hidden_dim = self.args.hidden_dim
            self.hparams.dropout = self.args.dropout
            self.hparams.nb_epochs = self.args.nb_epochs
            self.hparams.temp = self.args.temp
            self.hparams.coefficient = self.args.coefficient
            self.hparams.num_hard_negatives = self.args.num_hard_negatives
            self.hparams.loss1_type = self.args.loss1_type
            self.hparams.loss2_type = self.args.loss2_type
            self.hparams.pooling = self.args.pooling
            self.hparams.num_triplets_per_sample = self.args.nb_triplets
        
        # freeze
        self._frozen = False
        self.criterion = nn.CrossEntropyLoss()
        
        # Loading model
        self.model = AutoModel.from_pretrained(self.hparams.model_name_or_path)
        
        if self.hparams.pooling == True:
            self.model = DeclutrClassifier(self.model, LogisticRegression(self.hparams.emb_len, 
                                                                          self.hparams.hidden_dim, self.hparams.num_classes, 
                                                                          dropout=self.hparams.dropout))
        else:
             self.model = DeclutrClassifier(self.model, LogisticRegression(self.hparams.emb_len * self.hparams.max_seq_length, 
                                                                      self.hparams.hidden_dim, self.hparams.num_classes, 
                                                                      dropout=self.hparams.dropout))
        # self.model = nn.DataParallel(self.model).cuda()

    def forward(self, batch):
        input_ids, token_ids, attention_mask, y = batch
        x, y = (input_ids, token_ids, attention_mask), y
        
        pred, _, feats = self.model(x, self.hparams.pooling, return_feat=True)
        
        if self.hparams.loss1_type == "CE":
            loss_1 = self.criterion(pred, y.long())
        else:
            loss_1 = 0
            self.hparams.coefficient = 1

        if self.hparams.loss2_type == "KL":
            mask = y.clone().cpu().apply_(lambda x: x not in []).type(torch.bool)
            feats, pred, y = feats[mask], pred[mask], y[mask]

            sim_matrix = compute_sim_matrix(feats)
            target_matrix = compute_target_matrix(y)
            loss_contrastive = kl_contrastive_loss(sim_matrix, target_matrix, self.hparams.temp, y)

        elif self.hparams.loss2_type == "infoNCE-negatives":
            loss_contrastive = supervised_infoNCE_loss_with_negatives(feats, y, self.hparams.temp)

        elif self.hparams.loss2_type == "SupCon-negatives":
            loss_contrastive = supervised_contrastive_loss_with_negatives(feats, y, self.hparams.temp)

        elif self.hparams.loss2_type == "triplet":
            triplet_loss, _ = compute_triplet_loss(feats, y, self.hparams.num_triplets_per_sample)
            loss_contrastive = triplet_loss

        else:
            raise Exception("Loss2 type can only be between KL, infoNCE, infoNCE-negatives (with in-batch negatives), SupCon, triplet, and SupCon-negatives (with in-batch negatives). Other losses have not been implemented.")
        
        total_loss = loss_1 + self.hparams.coefficient * loss_contrastive
        
        return pred, feats, y, total_loss

    def training_step(self, batch, batch_nb):
        # the training step is a (virtual) method,specified in the interface, that the pl.LightningModule
        # class stipulates you to overwrite. This we do here, by virtue of this definition
        _, _, _, train_loss = self(batch)  # self refers to the model, which in turn acceses the forward method
        self.log_dict({"train_loss": train_loss}, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return train_loss
        # the training_step method expects a dictionary, which should at least contain the loss

    def validation_step(self, batch, batch_nb):
        # the training step is a (virtual) method,specified in the interface, that the pl.LightningModule
        # class  wants you to overwrite, in case you want to do validation. This we do here, by virtue of this definition.

        pred, _, y, val_loss = self(batch)
        # self refers to the model, which in turn accesses the forward method

        # Evaluating the performance
        predictions = torch.argmax(pred, dim=1)
        balanced_accuracy = balanced_accuracy_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), adjusted=True)
        macro_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='macro')
        micro_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='micro')
        weighted_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='weighted')        
        
        self.log_dict({"val_loss": val_loss, 'accuracy': balanced_accuracy, 'macro-F1': macro_accuracy, 'micro-F1': micro_accuracy, 'weighted-F1':weighted_accuracy}, 
                       on_step=False, on_epoch=True, prog_bar=True, logger=True)

        
        return val_loss
    
    def test_step(self, batch, batch_nb):
        # the training step is a (virtual) method,specified in the interface, that the pl.LightningModule
        # class  wants you to overwrite, in case you want to do test. This we do here, by virtue of this definition.

        pred, _, y, test_loss = self(batch)
        # self refers to the model, which in turn accesses the forward method

        # Evaluating the performance
        predictions = torch.argmax(pred, dim=1)
        balanced_accuracy = balanced_accuracy_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), adjusted=True)
        macro_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='macro')
        micro_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='micro')
        weighted_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='weighted')
        
        self.log_dict({"test_loss": test_loss, 'accuracy': balanced_accuracy, 'macro-F1': macro_accuracy, 'micro-F1': micro_accuracy, 'weighted-F1':weighted_accuracy}, 
                      on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return test_loss
    
    def predict_step(self, batch, batch_nb):
        # the training step is a (virtual) method,specified in the interface, that the pl.LightningModule
        # class  wants you to overwrite, in case you want to do validation. This we do here, by virtue of this definition.
        return None

    def configure_optimizers(self):
        # The configure_optimizers is a (virtual) method, specified in the interface, that the
        # pl.LightningModule class wants you to overwrite.

        # In this case we define that some parameters are optimized in a different way than others. In
        # particular we single out parameters that have 'bias', 'LayerNorm.weight' in their names. For those
        # we do not use an optimization technique called weight decay.

        no_decay = ['bias', 'LayerNorm.weight']

        optimizer_grouped_parameters = [{'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay':self.hparams.weight_decay}, 
                                        {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.eps)
        # optimizer = DeepSpeedCPUAdam(optimizer_grouped_parameters, adamw_mode=True, lr=self.hparams.learning_rate, betas=(0.9, 0.999), eps=self.hparams.eps)

        # We also use a scheduler that is supplied by transformers.
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.hparams.num_training_steps)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.hparams.nb_epochs)
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}

        return [optimizer], [scheduler]

    def freeze(self) -> None:
        # freeze all layers, except the final classifier layers
        for name, param in self.model.named_parameters():
            if 'classifier' not in name:  # classifier layer
                param.requires_grad = False

        self._frozen = True

    def unfreeze(self) -> None:
        if self._frozen:
            for name, param in self.model.named_parameters():
                if 'classifier' not in name:  # classifier layer
                    param.requires_grad = True

        self._frozen = False

    def train_epoch_start(self):
        """pytorch lightning hook"""
        if self.current_epoch < self.hparams.nr_frozen_epochs:
            self.freeze()

        if self.current_epoch >= self.hparams.nr_frozen_epochs:
            self.unfreeze() 

In [72]:
""" 
Python version: 3.9
Description: Contains helper classes and functions to load a Declutr-small model with supervised contrastive loss into the LightningModule.
"""

import sys
from sklearn.metrics import f1_score, balanced_accuracy_score

import torch
from torch import nn
import lightning.pytorch as pl

from transformers import RobertaModel
from transformers import AutoModel
from transformers import AdamW, get_linear_schedule_with_warmup

# from deepspeed.ops.adam import DeepSpeedCPUAdam

# Custom library
sys.path.append('../process/')
from contraUtilities import compute_sim_matrix, compute_target_matrix, kl_contrastive_loss, supervised_infoNCE_loss, supervised_infoNCE_loss_with_negatives

def supervised_contrastive_loss_with_negatives(features, labels, temperature=0.5, num_hard_negatives=5, eps=1e-8):
    device = features.device
    labels = labels.to(device)
    batch_size = features.shape[0]

    # Normalize the features to the unit sphere
    features = F.normalize(features, p=2, dim=1)

    # Compute the cosine similarity matrix and apply temperature scaling
    sim_matrix = torch.mm(features, features.t())
    sim_matrix = torch.clamp(sim_matrix, min=-10, max=10) / temperature  # Clamping extreme values

    # Mask for identifying positive pairs (excluding the diagonal)
    pos_mask = labels.unsqueeze(1) == labels.unsqueeze(0)
    pos_mask.fill_diagonal_(0)

    # Mask for identifying negatives
    neg_mask = labels.unsqueeze(1) != labels.unsqueeze(0)

    # Select hard negatives: top 'num_hard_negatives' from each row, considering only negatives
    negative_scores = sim_matrix * neg_mask.float()
    
    # Adjust num_hard_negatives if it exceeds the number of valid negative samples
    k = min(num_hard_negatives, neg_mask.sum(dim=1).min().item())
    top_negatives, _ = torch.topk(negative_scores, k=k, dim=1)

    # Log-sum-exp trick for numerical stability: max_sim for each row
    max_sim, _ = torch.max(sim_matrix, dim=1, keepdim=True)
    exp_sim_matrix = torch.exp(sim_matrix - max_sim)

    # Sum of exps for positive and hard negative pairs
    exp_pos = (exp_sim_matrix * pos_mask.float()).sum(dim=1, keepdim=True)
    exp_hard_neg = torch.exp(top_negatives - max_sim)

    # Combine positives and hard negatives in the denominator
    denom = exp_pos + exp_hard_neg.sum(dim=1, keepdim=True) + torch.exp(negative_scores - max_sim).sum(dim=1, keepdim=True) - exp_hard_neg

    # Log probability of positive pairs
    log_prob_pos = torch.log(exp_pos / (denom + eps) + eps)

    # Compute the mean of negative log probabilities across the batch
    loss = -log_prob_pos.mean()
    return loss
    
def create_triplets(embeddings, labels, num_triplets_per_sample):
    triplets = []
    for i in range(len(labels)):
        anchor = embeddings[i]
        pos_indices = torch.where(labels == labels[i])[0]
        neg_indices = torch.where(labels != labels[i])[0]
        
        # Ensure there are enough positive samples
        if len(pos_indices) <= 1:
            continue
        
        # Ensure there are enough negative samples
        if len(neg_indices) == 0:
            continue
        
        # Select a single positive sample
        pos_index = pos_indices[pos_indices != i][torch.randint(len(pos_indices)-1, (1,)).item()]
        positive = embeddings[pos_index]

        # Create multiple triplets with different negative samples
        for _ in range(num_triplets_per_sample):
            neg_index = neg_indices[torch.randint(len(neg_indices), (1,)).item()]
            negative = embeddings[neg_index]

            triplets.append((anchor, positive, negative))
    
    if len(triplets) == 0:
        return None
    
    anchor, positive, negative = zip(*triplets)
    return torch.stack(anchor), torch.stack(positive), torch.stack(negative)

def compute_triplet_loss(embeddings, labels, num_triplets_per_sample, margin=1.0):
    triplets = create_triplets(embeddings, labels, num_triplets_per_sample)
    if triplets is None:
        return torch.tensor(0.0, requires_grad=True), None
    
    anchor, positive, negative = triplets

    pos_dist = 1 - F.cosine_similarity(anchor, positive)
    neg_dist = 1 - F.cosine_similarity(anchor, negative)
    triplet_loss = F.relu(pos_dist - neg_dist + margin)

    return triplet_loss.mean(), triplet_loss

class DeclutrClassifier(nn.Module):
    def __init__(self, model, classifier):
        super().__init__()
        self.model = model
        self.fc = classifier

    def forward(self, x, pooling, return_feat=False):
        # x is a tokenized input
        # features = self.model(input_ids=x[0], token_type_ids=x[1], attention_mask=x[2])
        features = self.model(input_ids=x[0], attention_mask=x[2])

        # Get the last hidden state
        last_hidden_state = features.last_hidden_state

        if pooling == True:            
            # [CLS] token representation - generally the first token in the sequence
            cls_representation = last_hidden_state[:, 0, :]
            
            # Mean pooling - mean over the sequence dimension, ignoring padding (use attention_mask)
            expanded_attention_mask = x[2].unsqueeze(-1).expand_as(last_hidden_state)
            sum_hidden_state = torch.sum(last_hidden_state * expanded_attention_mask, dim=1)
            sum_mask = torch.clamp(expanded_attention_mask.sum(1), min=1e-9)
            mean_pooled = sum_hidden_state / sum_mask
            last_layer_representation = mean_pooled
        else:
            # Flattenning the output of last layer
            last_layer_representation = last_hidden_state.view(last_hidden_state.size(0), -1)  # Reshape to [batch_size, features]
        
        out = self.fc(last_layer_representation)

        if return_feat:
            return out, cls_representation, last_layer_representation

        return out

class LogisticRegression(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, dropout=0):
        super().__init__()
        # print(f'Logistic Regression classifier of dim ({in_dim} {hid_dim} {out_dim})')

        self.nn = nn.Sequential(
            nn.Dropout(p=dropout),
            nn.Linear(in_dim, hid_dim, bias=True),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(hid_dim, out_dim, bias=True),
        )

    def forward(self, x, return_feat=False):
        out = self.nn(x)
        if return_feat:
            return out, x
        return out
        

class HTContraClassifierModel(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        
        self.save_hyperparameters()
        if isinstance(args, tuple) and len(args) > 0: 
            self.args = args[0]
            self.hparams.learning_rate = self.args.learning_rate
            self.hparams.eps = self.args.adam_epsilon
            self.hparams.weight_decay = self.args.weight_decay
            self.hparams.model_name_or_path = self.args.model_name_or_path
            self.hparams.num_classes = self.args.num_classes
            self.hparams.num_training_steps = self.args.num_training_steps
            self.hparams.warmup_steps = self.args.warmup_steps
            self.hparams.emb_len = self.args.emb_len
            self.hparams.max_seq_length = self.args.max_seq_length
            self.hparams.hidden_dim = self.args.hidden_dim
            self.hparams.dropout = self.args.dropout
            self.hparams.nb_epochs = self.args.nb_epochs
            self.hparams.temp = self.args.temp
            self.hparams.coefficient = self.args.coefficient
            self.hparams.num_hard_negatives = self.args.num_hard_negatives
            self.hparams.loss1_type = self.args.loss1_type
            self.hparams.loss2_type = self.args.loss2_type
            self.hparams.pooling = self.args.pooling
            self.hparams.num_triplets_per_sample = self.args.nb_triplets
        
        # freeze
        self._frozen = False
        self.criterion = nn.CrossEntropyLoss()
        
        # Loading model
        self.model = AutoModel.from_pretrained(self.hparams.model_name_or_path)
        
        if self.hparams.pooling == True:
            self.model = DeclutrClassifier(self.model, LogisticRegression(self.hparams.emb_len, 
                                                                          self.hparams.hidden_dim, self.hparams.num_classes, 
                                                                          dropout=self.hparams.dropout))
        else:
             self.model = DeclutrClassifier(self.model, LogisticRegression(self.hparams.emb_len * self.hparams.max_seq_length, 
                                                                      self.hparams.hidden_dim, self.hparams.num_classes, 
                                                                      dropout=self.hparams.dropout))
        # self.model = nn.DataParallel(self.model).cuda()

    def forward(self, batch):
        input_ids, token_ids, attention_mask, y = batch
        x, y = (input_ids, token_ids, attention_mask), y
        
        pred, _, feats = self.model(x, self.hparams.pooling, return_feat=True)
        
        if self.hparams.loss1_type == "CE":
            loss_1 = self.criterion(pred, y.long())
        else:
            loss_1 = 0
            self.hparams.coefficient = 1

        if self.hparams.loss2_type == "KL":
            mask = y.clone().cpu().apply_(lambda x: x not in []).type(torch.bool)
            feats, pred, y = feats[mask], pred[mask], y[mask]

            sim_matrix = compute_sim_matrix(feats)
            target_matrix = compute_target_matrix(y)
            loss_contrastive = kl_contrastive_loss(sim_matrix, target_matrix, self.hparams.temp, y)

        elif self.hparams.loss2_type == "infoNCE-negatives":
            loss_contrastive = supervised_infoNCE_loss_with_negatives(feats, y, self.hparams.temp)

        elif self.hparams.loss2_type == "SupCon-negatives":
            loss_contrastive = supervised_contrastive_loss_with_negatives(feats, y, self.hparams.temp)

        elif self.hparams.loss2_type == "triplet":
            triplet_loss, _ = compute_triplet_loss(feats, y, self.hparams.num_triplets_per_sample)
            loss_contrastive = triplet_loss

        else:
            raise Exception("Loss2 type can only be between KL, infoNCE, infoNCE-negatives (with in-batch negatives), SupCon, triplet, and SupCon-negatives (with in-batch negatives). Other losses have not been implemented.")
        
        total_loss = loss_1 + self.hparams.coefficient * loss_contrastive
        
        return pred, feats, y, total_loss

    def training_step(self, batch, batch_nb):
        # the training step is a (virtual) method,specified in the interface, that the pl.LightningModule
        # class stipulates you to overwrite. This we do here, by virtue of this definition
        _, _, _, train_loss = self(batch)  # self refers to the model, which in turn acceses the forward method
        self.log_dict({"train_loss": train_loss}, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return train_loss
        # the training_step method expects a dictionary, which should at least contain the loss

    def validation_step(self, batch, batch_nb):
        # the training step is a (virtual) method,specified in the interface, that the pl.LightningModule
        # class  wants you to overwrite, in case you want to do validation. This we do here, by virtue of this definition.

        pred, _, y, val_loss = self(batch)
        # self refers to the model, which in turn accesses the forward method

        # Evaluating the performance
        predictions = torch.argmax(pred, dim=1)
        balanced_accuracy = balanced_accuracy_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), adjusted=True)
        macro_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='macro')
        micro_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='micro')
        weighted_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='weighted')        
        
        self.log_dict({"val_loss": val_loss, 'accuracy': balanced_accuracy, 'macro-F1': macro_accuracy, 'micro-F1': micro_accuracy, 'weighted-F1':weighted_accuracy}, 
                       on_step=False, on_epoch=True, prog_bar=True, logger=True)

        
        return val_loss
    
    def test_step(self, batch, batch_nb):
        # the training step is a (virtual) method,specified in the interface, that the pl.LightningModule
        # class  wants you to overwrite, in case you want to do test. This we do here, by virtue of this definition.

        pred, _, y, test_loss = self(batch)
        # self refers to the model, which in turn accesses the forward method

        # Evaluating the performance
        predictions = torch.argmax(pred, dim=1)
        balanced_accuracy = balanced_accuracy_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), adjusted=True)
        macro_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='macro')
        micro_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='micro')
        weighted_accuracy = f1_score(y.detach().cpu().numpy(), predictions.detach().cpu().numpy(), average='weighted')
        
        self.log_dict({"test_loss": test_loss, 'accuracy': balanced_accuracy, 'macro-F1': macro_accuracy, 'micro-F1': micro_accuracy, 'weighted-F1':weighted_accuracy}, 
                      on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return test_loss
    
    def predict_step(self, batch, batch_nb):
        # the training step is a (virtual) method,specified in the interface, that the pl.LightningModule
        # class  wants you to overwrite, in case you want to do validation. This we do here, by virtue of this definition.
        return None

    def configure_optimizers(self):
        # The configure_optimizers is a (virtual) method, specified in the interface, that the
        # pl.LightningModule class wants you to overwrite.

        # In this case we define that some parameters are optimized in a different way than others. In
        # particular we single out parameters that have 'bias', 'LayerNorm.weight' in their names. For those
        # we do not use an optimization technique called weight decay.

        no_decay = ['bias', 'LayerNorm.weight']

        optimizer_grouped_parameters = [{'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay':self.hparams.weight_decay}, 
                                        {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.eps)
        # optimizer = DeepSpeedCPUAdam(optimizer_grouped_parameters, adamw_mode=True, lr=self.hparams.learning_rate, betas=(0.9, 0.999), eps=self.hparams.eps)

        # We also use a scheduler that is supplied by transformers.
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.hparams.num_training_steps)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.hparams.nb_epochs)
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}

        return [optimizer], [scheduler]

    def freeze(self) -> None:
        # freeze all layers, except the final classifier layers
        for name, param in self.model.named_parameters():
            if 'classifier' not in name:  # classifier layer
                param.requires_grad = False

        self._frozen = True

    def unfreeze(self) -> None:
        if self._frozen:
            for name, param in self.model.named_parameters():
                if 'classifier' not in name:  # classifier layer
                    param.requires_grad = True

        self._frozen = False

    def train_epoch_start(self):
        """pytorch lightning hook"""
        if self.current_epoch < self.hparams.nr_frozen_epochs:
            self.freeze()

        if self.current_epoch >= self.hparams.nr_frozen_epochs:
            self.unfreeze() 


class SemiConstrativeModel(pl.LightningModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        
        self.save_hyperparameters()
        if isinstance(args, tuple) and len(args) > 0: 
            self.args = args[0]
            self.hparams.learning_rate = self.args.learning_rate
            self.hparams.eps = self.args.adam_epsilon
            self.hparams.weight_decay = self.args.weight_decay
            self.hparams.model_name_or_path = self.args.model_name_or_path
            self.hparams.num_classes = self.args.num_classes
            self.hparams.num_training_steps = self.args.num_training_steps
            self.hparams.warmup_steps = self.args.warmup_steps
            self.hparams.emb_len = self.args.emb_len
            self.hparams.max_seq_length = self.args.max_seq_length
            self.hparams.hidden_dim = self.args.hidden_dim
            self.hparams.dropout = self.args.dropout
            self.hparams.nb_epochs = self.args.nb_epochs
            self.hparams.temp = self.args.temp
            self.hparams.coefficient = self.args.coefficient
            self.hparams.num_hard_negatives = self.args.num_hard_negatives
            self.hparams.loss_type = self.args.loss_type
            self.hparams.pooling = self.args.pooling
            self.hparams.num_triplets_per_sample = self.args.nb_triplets
        
        # freeze
        self._frozen = False
        
        # Loading model
        self.model = AutoModel.from_pretrained(self.hparams.model_name_or_path)
        
        # No classification layer
        if self.hparams.pooling == True:
            self.model = DeclutrClassifier(self.model, nn.Identity())
        else:
            self.model = DeclutrClassifier(self.model, nn.Identity())

    def forward(self, batch, compute_loss=True):
        input_ids, token_ids, attention_mask, y = batch
        x, y = (input_ids, token_ids, attention_mask), y

        # Extract features (embeddings) from the model
        feats, _, _ = self.model(x, self.hparams.pooling, return_feat=True)

        # During inference, skip loss computation
        if not self.training or not compute_loss:
            return feats, y

        # Compute the loss only during training
        if self.hparams.loss_type == "SupCon":
            loss_contrastive = supervised_contrastive_loss_with_negatives(feats, y, self.hparams.temp)
        elif self.hparams.loss_type == "triplet":
            loss_contrastive, _ = compute_triplet_loss(feats, y, self.hparams.num_triplets_per_sample)
        else:
            raise Exception("Loss type can only be between SupCon and triplet, both with in-batch negatives. Other losses have not been implemented.")

        total_loss = self.hparams.coefficient * loss_contrastive

        return feats, y, total_loss

    def training_step(self, batch, batch_nb):
        _, _, train_loss = self(batch)
        self.log_dict({"train_loss": train_loss}, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return train_loss

    def validation_step(self, batch, batch_nb):
        _, _, val_loss = self(batch)
        # No classification accuracy metrics needed for contrastive learning
        self.log_dict({"val_loss": val_loss}, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return val_loss
    
    def test_step(self, batch, batch_nb):
        _, _, test_loss = self(batch)
        self.log_dict({"test_loss": test_loss}, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return test_loss

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay':self.hparams.weight_decay}, 
                                        {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, eps=self.hparams.eps)
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=self.hparams.num_training_steps)
        scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1}

        return [optimizer], [scheduler]

    def generate_embeddings(self, batch):
        """
        Function to generate embeddings for inference.
        Args:
            batch: Input batch containing tokenized inputs (input_ids, token_type_ids, attention_mask)
        Returns:
            feats: Generated embeddings
        """
        self.model.eval()
        with torch.no_grad():
            input_ids, token_ids, attention_mask, _ = batch
            x = (input_ids, token_ids, attention_mask)
            feats, _, _  = self.model(x, self.hparams.pooling, return_feat=True)
        return feats

In [None]:
# %% Loading the Classifier model
model = HTContraClassifierModel.load_from_checkpoint("/workspace/persistent/HTClipper/models/grouped-and-masked/text-baselines/contra-learn/triplet-loss/declutr-small/south/pooled/seed:1111/lr-0.0001/coeff-1.0/temp:0.1/triplet/final_model.ckpt").eval()
model = model.to(device)

In [79]:
model = None

In [None]:
# %% Loading the Metric learning model
model = SemiConstrativeModel.load_from_checkpoint("/workspace/persistent/HTClipper/models/grouped-and-masked/text-baselines/contra-learn/semi-supervised/declutr-small/south/pooled/seed:1111/lr-0.0001/coeff-1.0/temp:0.1/SupCon/final_model.ckpt").eval()
model = model.to(device)

# Extracting Embeddings

In [18]:
chicago_df = pd.read_csv("../data/processed/chicago.csv")
atlanta_df = pd.read_csv("../data/processed/atlanta.csv")
dallas_df = pd.read_csv("../data/processed/dallas.csv")
detroit_df = pd.read_csv("../data/processed/detroit.csv")
houston_df = pd.read_csv("../data/processed/houston.csv")
ny_df = pd.read_csv("../data/processed/ny.csv")
sf_df = pd.read_csv("../data/processed/sf.csv")
canada_df = pd.read_csv("../data/processed/canada.csv")

In [74]:
south_df = pd.read_csv("../data/processed/south.csv")
midwest_df = pd.read_csv("../data/processed/midwest.csv")
west_df = pd.read_csv("../data/processed/west.csv")
northeast_df = pd.read_csv("../data/processed/northeast.csv")

In [75]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)

In [76]:
def extract_embedding_of_trained_checkpoints(df, model, tokenizer, city, model_name):
    df = df[["TEXT", "VENDOR"]].drop_duplicates()
    
    # Since the vendor IDs are not the current representations of the class labels, we remap these label IDs to avoid falling into out-of-bounds problem
    vendors_dict = {}
    i = 0
    for vendor in df.VENDOR.values.tolist():
        if vendor not in vendors_dict.keys():
            vendors_dict[vendor] = i
            i += 1

    df.replace({"VENDOR": vendors_dict}, inplace=True)
    train_df, test_df = train_test_split(df, test_size=0.20, random_state=1111)
    
    embeddings, labels = extract_embeddings(train_df, model, vendors_dict)
    
    directory = os.path.join(os.getcwd(), "../models/pickled/embeddings/grouped-and-masked", "trained_" + model_name + "_all")
    Path(directory).mkdir(parents=True, exist_ok=True)
    
    label_filename = city + "_labels_train.pt"
    data_filename = city + "_data_train.pt"
    torch.save(embeddings, os.path.join(directory, data_filename))
    torch.save(labels, os.path.join(directory, label_filename))

    embeddings, labels = extract_embeddings(test_df, model, vendors_dict)
    label_filename = city + "_labels_test.pt"
    data_filename = city + "_data_test.pt"
    torch.save(embeddings, os.path.join(directory, data_filename))
    torch.save(labels, os.path.join(directory, label_filename))

def extract_embeddings(df, model, vendors_dict, device="cuda", pooling_type="mean"):
    text = df.TEXT.values.tolist()
    vendors = df.VENDOR.values.tolist()

    # Tokenizing the data with padding and truncation
    encodings = tokenizer(text, add_special_tokens=True, max_length=512, padding='max_length', return_token_type_ids=True, truncation=True, 
                               return_attention_mask=True, return_tensors='pt') 


    # Move the encodings to the device
    input_ids = encodings['input_ids'].to(device)
    token_ids = encodings['token_type_ids'].to(device)
    attention_mask = encodings['attention_mask'].to(device)
    labels = torch.tensor(vendors).to(device)

    # Combine the inputs into a TensorDataset.
    dataset = TensorDataset(input_ids, token_ids, attention_mask, labels)
    test_dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

    pooled_output_list, labels_list = [], []
    pbar = tqdm(total=len(test_dataloader))
    with torch.no_grad():
        for batch in test_dataloader:
            
            # Uncomment this line if you want to generate representation for the classification model
            # _, feats, y, _ = model(batch)
            # And comment the one below
            feats, y= model(batch, compute_loss=False)
            
            pooled_output_list.append(feats)
            labels_list.append(y)
            pbar.update(1)
        pbar.close()

    # Concatenate the pooled outputs and labels into tensors
    pooled_outputs = torch.cat(pooled_output_list)
    labels = torch.cat(labels_list)
    return pooled_outputs, labels

In [82]:
extract_embedding_of_trained_checkpoints(south_df, model, tokenizer, city="south", model_name="declutr_SupCononly_all")
extract_embedding_of_trained_checkpoints(midwest_df, model, tokenizer, city="midwest", model_name="declutr_SupCononly_all")
extract_embedding_of_trained_checkpoints(west_df, model, tokenizer, city="west", model_name="declutr_SupCononly_all")
extract_embedding_of_trained_checkpoints(northeast_df, model, tokenizer, city="northeast", model_name="declutr_SupCononly_all")


  0%|          | 0/353 [00:00<?, ?it/s][A
  3%|▎         | 10/353 [00:00<00:03, 91.85it/s][A
  6%|▌         | 20/353 [00:00<00:06, 54.10it/s][A
  8%|▊         | 27/353 [00:00<00:06, 48.88it/s][A
  9%|▉         | 33/353 [00:00<00:06, 46.50it/s][A
 11%|█         | 38/353 [00:00<00:06, 45.17it/s][A
 12%|█▏        | 43/353 [00:00<00:07, 44.24it/s][A
 14%|█▎        | 48/353 [00:01<00:07, 43.57it/s][A
 15%|█▌        | 53/353 [00:01<00:06, 43.09it/s][A
 16%|█▋        | 58/353 [00:01<00:06, 42.76it/s][A
 18%|█▊        | 63/353 [00:01<00:06, 42.51it/s][A
 19%|█▉        | 68/353 [00:01<00:06, 42.33it/s][A
 21%|██        | 73/353 [00:01<00:06, 42.20it/s][A
 22%|██▏       | 78/353 [00:01<00:06, 42.12it/s][A
 24%|██▎       | 83/353 [00:01<00:06, 42.07it/s][A
 25%|██▍       | 88/353 [00:01<00:06, 42.03it/s][A
 26%|██▋       | 93/353 [00:02<00:06, 42.01it/s][A
 28%|██▊       | 98/353 [00:02<00:06, 41.99it/s][A
 29%|██▉       | 103/353 [00:02<00:05, 41.98it/s][A
 31%|███       | 10