# Package Installation

## Import necessary packages

In [None]:
!pip install ftfy
!pip install timm

In [None]:
import os # file management
import torch
from torch import nn, Tensor # neural network
from torch.nn import functional as F
import nltk
from nltk.tokenize import word_tokenize
import numpy as np 
from functools import partial
import timm

from numpy import ndarray
from ftfy import fix_encoding
# data/parameter loading
import pandas as pd 
from math import sqrt
from transformers import AutoConfig
import pickle
# visualization
from tqdm.notebook import trange, tqdm
from accelerate.utils import set_seed
# transfomers
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup, DataCollatorWithPadding
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers import AdamW
from nltk.corpus import stopwords
import re
from typing import Union, List, Dict
import gc
from accelerate import Accelerator, notebook_launcher
from accelerate.utils import set_seed
from accelerate import DistributedDataParallelKwargs
# filter out warnings
import warnings
warnings.filterwarnings('ignore')
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)

## Some useful functions

In [None]:
# Utils
def save_parameter(save_object, save_file):
    with open(save_file, 'wb') as f:
        pickle.dump(save_object, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_parameter(load_file):
    with open(load_file, 'rb') as f:
        output = pickle.load(f)
    return output

def sim_matrix(a, b, eps=1e-8):
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.clamp(a_n, min=eps)
    b_norm = b / torch.clamp(b_n, min=eps)
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def compute_metrics(logits, labels):
    topks = (1, 3, 5, 10)
    scores = {k: 0 for k in topks}
    sorted_logits = torch.argsort(torch.exp(logits), axis=1, descending=True)
    num_cnt = {k: 0 for k in topks}
    for k in topks:
        batch_num_correct = 0
        n_points = len(labels)
        for idx in range(n_points):
            if labels[idx] in sorted_logits[idx, 0:k]:
                batch_num_correct += 1
        scores[k] = batch_num_correct/n_points
        num_cnt[k]+= batch_num_correct
    return scores, num_cnt

def cosine_distance(x):
    dist = 1-torch.mm(x, x.T)
    dist = dist.triu(diagonal=1).mean()
    return dist

def preprocess_keywords(keyword):
    keyword=  keyword.split(",")
    keyword = [ele.strip() for ele in keyword if len(ele.strip())>3]
    keyword = ", ".join(keyword)
    return keyword

stop_words = set(stopwords.words('english'))

# Data preparation

## Build set of paired samples

In [None]:
data_train = pd.read_csv("/kaggle/input/new-preprocessed-data/news_preprocessed_train (1).csv")
data_test = pd.read_csv("/kaggle/input/new-preprocessed-data/news_preprocessed_test (1).csv")
data_valid = pd.read_csv("/kaggle/input/new-preprocessed-data/news_preprocessed_valid (1).csv")
data_aims = pd.read_csv("/kaggle/input/preprocessed-aims/preprocessed_aims (1) (1).csv")

data_train = data_train.merge(data_aims[['itr', 'Aims', 'name']], on="itr")
data_valid = data_valid.merge(data_aims[['itr', 'Aims', 'name']], on="itr")
data_test = data_test.merge(data_aims[['itr', 'Aims', 'name']], on="itr")

data_train.fillna("", inplace=True)
data_valid.fillna("", inplace=True)
data_test.fillna("", inplace=True)
data_aims.fillna("", inplace=True)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("malteos/scincl")
journal = sorted(list(set(data_aims['name'].tolist())))
label_dict = {journal: idx for idx, journal in enumerate(journal)}
data_train['Label'] = data_train['name'].map(label_dict)
data_valid['Label'] = data_valid['name'].map(label_dict)
data_test['Label'] = data_test['name'].map(label_dict)
n_classes = len(journal)

In [None]:
aims = journal
for idx in range(len(aims)):
    aims[idx] = re.sub("[\(\[].*?[\)\]]", "", aims[idx])
    aims[idx] = aims[idx].replace(",", "").replace(":", "").replace(".", "").strip()
    aims[idx] = word_tokenize(aims[idx])
    aims[idx] = [w for w in aims[idx] if not w.lower().strip() in stop_words]
    aims[idx] = " ".join(aims[idx])
    
aims_ids = tokenizer(aims, max_length=16, return_tensors='pt', padding="max_length", truncation=True)

## Load saved pairs for training

In [None]:
class DualPaperDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, n_classes) -> None:
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.n_classes = n_classes
    
    def __getitem__(self, index):
        label = torch.tensor(self.data['Label'][index])
        title = self.data['title'][index].strip() if isinstance(self.data['title'][index], str) else ""
        title = f"main idea : {title} . "
        abstract = self.data['abstract'][index].strip() if isinstance(self.data['abstract'][index], str) else ""
        abstract = f"concise summary : {abstract} . "
        keywords = self.data['keywords'][index].strip() if isinstance(self.data['keywords'][index], str) else ""
        keywords = f"important words : {keywords} ."
        context = title + abstract + keywords
        context = self.tokenizer(context, max_length=320, truncation=True, padding=False)
        context['label'] = label
        return context
    
    def __len__(self):
        return len(self.data)

# Model definition

# Model for contrastive leanring training

In [None]:
class StackAttentionLayer(nn.Module):
    def __init__(self, embeds_dim, n_classes):
        super().__init__()
        self.linear_query = nn.Linear(embeds_dim, embeds_dim, bias=False)
        self.linear_key = nn.Linear(embeds_dim, embeds_dim, bias=False)
        self.attn_linear = nn.Linear(n_classes, 1, bias=True)
        self.scale = sqrt(embeds_dim)
        self.layer_norm1 = nn.LayerNorm(embeds_dim, eps=1e-12)
        self.mlp = nn.Sequential(
            nn.Linear(embeds_dim, embeds_dim),
            nn.ReLU(),
            nn.Linear(embeds_dim, embeds_dim),
            nn.ReLU(),
            nn.Linear(embeds_dim, embeds_dim)
        )
        self.layer_norm2 = nn.LayerNorm(embeds_dim, eps=1e-12)

    def forward(self, input_feats, label_feats, attn_mask=None):
        '''
        input_feats size: BxSxD
        label_feats size: CxD
        '''
        residual = input_feats[:, 0, :]
        input_feats = self.linear_query(input_feats)
        label_feats = self.linear_key(label_feats)
        dot_product = torch.div(torch.matmul(input_feats, label_feats.T), self.scale)
        # dot product: BxM
        attn = self.attn_linear(dot_product).squeeze()
        attn = attn.masked_fill_(attn_mask.eq(0), value=float('-inf'))
        attn = F.softmax(attn, dim=-1)
        out = torch.einsum('bc, bcd->bd', attn, input_feats)
        residual = self.layer_norm1(out+residual)
        out = self.layer_norm2(self.mlp(out) + residual)
        return out

class PaperModel(nn.Module):
    def __init__(self, hidden_size, model_name_or_path, num_classes) -> None:
        super(PaperModel, self).__init__()
        if isinstance(model_name_or_path, str):
            self.encoder = AutoModel.from_pretrained(model_name_or_path)
        else:
            self.encoder = model_name_or_path
        self.n_classes = num_classes
        self.temperature =  nn.Parameter(torch.ones([]) * 0.07)
        self.attn = StackAttentionLayer(hidden_size, num_classes)
        for param in self.encoder.parameters():
            param.requires_grad_(True)
        
    def get_label_feats(self, aims_ids):
        label_feats = self.encoder(**aims_ids).last_hidden_state[:, 0, :]
        return label_feats
        
    def forward(self, inputs, aims_ids):
        with torch.no_grad():
            self.temperature.clamp_(0.01,0.5)
        hiddens = self.encoder(**inputs).last_hidden_state
        label_feats = self.get_label_feats(aims_ids)
#         hiddens = hiddens[:, 0, :]
        hiddens = self.attn(hiddens, label_feats, inputs['attention_mask'])
        hiddens = F.normalize(hiddens, dim=-1)
        label_feats = F.normalize(label_feats, dim=-1)
        logits = torch.einsum("bd, cd->bc", hiddens, label_feats)/self.temperature
        outputs = {
            "logits": logits, "cls_feats": hiddens, "label_feats": label_feats
        }
        return outputs

# Pooler

## Contrastive Loss

In [None]:
def mean_pooling(model_output, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
    return torch.sum(model_output * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def sum_pooling(model_output, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(model_output.size()).float()
    return torch.sum(model_output * input_mask_expanded, 1)

class CELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.xent_loss = nn.CrossEntropyLoss()
        
    def forward(self, outputs, targets):
        return self.xent_loss(outputs['logits'], targets)

class SupConLoss(nn.Module):

    def __init__(self, alpha, temp):
        super().__init__()
        self.xent_loss = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.alpha = alpha
        self.temp = temp

    def nt_xent_loss(self, anchor, target, labels):
        with torch.no_grad():
            labels = labels.unsqueeze(-1)
            mask = torch.eq(labels, labels.transpose(0, 1))
            # delete diag elem
            mask = mask ^ torch.diag_embed(torch.diag(mask))
        # compute logits
        anchor_dot_target = torch.einsum('bd,cd->bc', anchor, target) / self.temp
        # delete diag elem
        anchor_dot_target = anchor_dot_target - torch.diag_embed(torch.diag(anchor_dot_target))
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_target, dim=1, keepdim=True)
        logits = anchor_dot_target - logits_max.detach()
        
        exp_logits = torch.exp(logits)
        # mask out positives
        logits = logits * mask
        log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True) + 1e-12)
        # in case that mask.sum(1) is zero
        mask_sum = mask.sum(dim=1)
        mask_sum = torch.where(mask_sum == 0, torch.ones_like(mask_sum), mask_sum)
        # compute log-likelihood
        pos_logits = (mask * log_prob).sum(dim=1) / mask_sum.detach()
        loss = -1 * pos_logits.mean()
        return loss

    def forward(self, outputs, targets):
        normed_cls_feats = F.normalize(outputs['cls_feats'], dim=-1)
        ce_loss = (1 - self.alpha) * self.xent_loss(outputs['logits'], targets)
        cl_loss = self.alpha * self.nt_xent_loss(normed_cls_feats, normed_cls_feats, targets)
        return ce_loss + cl_loss

class DualLoss(SupConLoss):
    def __init__(self, alpha=0.0, temp=0.1):
        super().__init__(alpha, temp)

    def forward(self, outputs, targets):
        cls_feats = outputs['cls_feats']
        label_feats = outputs['label_feats']
        ce_loss = self.xent_loss(outputs['logits'], targets)
        return ce_loss

class LabelRegLoss(nn.Module):
    def __init__(self, threshold=0.5, is_normalize=True):
        super().__init__()
        self.is_normalize = is_normalize
        self.threshold = threshold
    
    def forward(self, x):
        if self.is_normalize == False:
            x = F.normalize(x, dim=-1)
        sim_matrix = torch.mm(x, x.T)
        mask = torch.diag_embed(torch.ones_like(sim_matrix.diagonal(dim1=0, dim2=1)))
        sim_matrix = sim_matrix.masked_fill(mask == 1, float('-inf'))
        sim_matrix, _ = torch.max(sim_matrix, dim=-1)
        sim_matrix = F.relu(sim_matrix-self.threshold)
        loss = sim_matrix.mean()
        return loss

## Model declaration

In [None]:
model = AutoModel.from_pretrained("malteos/scincl")

# layers_to_keep = [0, 2, 4, 6, 9, 11]
# new_layers = torch.nn.ModuleList([layer_module for i, layer_module in enumerate(model.encoder.layer) if i in layers_to_keep])
# model.encoder.layer = new_layers
# model.config.num_hidden_layers = len(layers_to_keep)
# state = torch.load("/kaggle/input/distilscincl-state-dict/Epoch_1_steps8801_student_embedding.pth")
# model.load_state_dict(state['model_state_dict'])
model_args = {
    "hidden_size": 768,
    'model_name_or_path': model,
    "num_classes": n_classes
}

# Training

## Optimizer and configuration

In [None]:
def training_loop(aims_ids, mixed_precision="fp16", seed=42, batch_size=32, state=None):
    set_seed(seed)
    torch.cuda.empty_cache()
    gc.collect()
    history = {"cl_loss": [], "uniform_loss": [], "ce_loss": [], "accuracy": []}
    accelerator = Accelerator(mixed_precision=mixed_precision)
    data_collator = DataCollatorWithPadding(tokenizer)
    model = PaperModel(**model_args)
    dataset = DualPaperDataset(data_train, tokenizer, n_classes)
    data_loader = torch.utils.data.DataLoader(dataset, collate_fn=data_collator, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    min_loss = np.inf
    valid_dataset = DualPaperDataset(data_valid, tokenizer, n_classes)
    valid_loader = torch.utils.data.DataLoader(valid_dataset, collate_fn=data_collator, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
    max_epochs = 3
    print("Model summary:\n")
    print(">> Total params: ", count_parameters(model))
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    
    optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, weight_decay=1e-2, correct_bias=True)
    num_training_steps = len(data_loader)*max_epochs
    num_warmup_steps = int(num_training_steps*0.1)
    lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
    criterion = DualLoss()
    saved_epochs = -1
    uniform_criterion = LabelRegLoss()
    if state != None:
        model.load_state_dict(state['model_state_dict'])
        optimizer.load_state_dict(state['optimizer_state_dict'])
        lr_scheduler.load_state_dict(state['scheduler_state_dict'])
        saved_epochs = state['epoch']
    aims_ids = {k:v.to(accelerator.device) for k, v in aims_ids.items()}
    topks = (1,3,5,10)
    print(f"Saved epochs: {saved_epochs+1}")
    model, uniform_criterion, optimizer, data_loader, valid_loader, lr_scheduler = accelerator.prepare(
    model, uniform_criterion, optimizer, data_loader, valid_loader, lr_scheduler)
    for epoch in range(saved_epochs+1, max_epochs):
        loop = tqdm(data_loader, leave=True, disable=not accelerator.is_local_main_process)
        train_loss = 0.0
        train_score = {k:0 for k in (1,3,5,10)}
        for idx, batch in enumerate(loop):
            inputs = {k:v.squeeze() for k, v in batch.items() if k != 'labels'}
            labels = batch['labels']
            optimizer.zero_grad()
            outputs = model(inputs, aims_ids)
            label_feats = outputs['label_feats']
            ce_loss =  criterion(outputs, labels)
            uniform_loss = uniform_criterion(label_feats)
            loss = ce_loss + 0.1*uniform_loss
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            history['ce_loss'].append(ce_loss.item())
            history['uniform_loss'].append(uniform_loss.item())
            train_loss += loss.item()
            
            with torch.no_grad():
                cosine_dist = cosine_distance(label_feats.detach().clone())
                score, num_cnt = compute_metrics(outputs['logits'], labels)
                history['accuracy'].append(score[1])
                for k in topks:
                    train_score[k] += num_cnt[k]
            if idx % 300 == 0:
                accelerator.print(f"Loss: {loss.item()} || Temp: {model.temperature.item()}|| Regularize loss: {uniform_loss.item():.5f} || Uniformity: {cosine_dist} || Top 1 acc: {score[1]} || Top 3 acc: {score[3]} || Top 5 acc: {score[5]} || Top 10 acc: {score[10]}")
            loop.set_description('Epoch: {} - lr: {}'.format(epoch+1, optimizer.param_groups[0]['lr']))
            loop.set_postfix(loss=round(loss.item(), 3), temp = model.temperature.item(), top01=score[1], top03=score[3], top05=score[5], top10=score[10])
        train_loss = train_loss / (len(dataset))
        for k in topks:
            accelerator.print(f"Train top {k} acc: {train_score[k]/(len(dataset))}", end=" || ")
        accelerator.print("")
        valid_loss = 0.0
        valid_loop = tqdm(valid_loader, leave=True, disable=not accelerator.is_local_main_process)
        valid_score = {k:0 for k in(1,3,5,10)}
        for batch in valid_loop:
            with torch.no_grad():
                inputs = {k:v.squeeze() for k, v in batch.items() if k != 'labels'}
                labels = batch['labels']
                inputs = {k:v.squeeze() for k, v in inputs.items()}
                outputs = model(inputs, aims_ids)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()
                score, num_cnt = compute_metrics(outputs['logits'], labels)
                for k in topks:
                    valid_score[k] += num_cnt[k]
            valid_loop.set_description('Epoch: {} - lr: {} '.format(epoch+1, optimizer.param_groups[0]['lr']))
            valid_loop.set_postfix(loss=loss.item(), top01=score[1], top03=score[3], top05=score[5], top10=score[10])
        valid_loss /= len(valid_loader)
        for k in topks:
            accelerator.print(f"Valid top {k} acc: {valid_score[k]/len(valid_dataset)}", end=" || ") 
        accelerator.print("")
        print(f'Validation Loss ({min_loss:.6f}--->{valid_loss:.6f})')
        min_loss = valid_loss
        accelerator.save({
            "history": history,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": lr_scheduler.state_dict(),
            "epoch": epoch
                }, "./Epoch:{:0>2} SupCL-SciNCL.pth".format(epoch+1))

In [None]:
args = (aims_ids, "fp16", 42, 32)

training_loop(*args)

In [None]:
# topks = (1, 3, 5, 10) 
# train_score = {k:0 for k in (1,3,5,10)}
# valid_score = {k:0 for k in (1,3,5,10)}
# test_score = {k:0 for k in (1,3,5,10)}
# model = PaperModel(**model_args)
# state = torch.load("/kaggle/input/scincl-base-best-checkpoint/Epoch_03 SupCL-DistilRoberta (4).pth")
# model.load_state_dict(state['model_state_dict'])
# model.eval()
# train_loss = 0.0
# valid_loss = 0.0
# test_loss = 0.0
# criterion_ce = nn.CrossEntropyLoss()
# accelerator = Accelerator()
# dataset = DualPaperDataset(data_train, tokenizer, n_classes)
# data_collator = DataCollatorWithPadding(tokenizer)
# valid_dataset = DualPaperDataset(data_valid, tokenizer, n_classes)
# valid_loader = torch.utils.data.DataLoader(valid_dataset, collate_fn=data_collator, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
# test_dataset = DualPaperDataset(data_test, tokenizer, n_classes)
# test_loader = torch.utils.data.DataLoader(test_dataset, collate_fn=data_collator, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
# model, valid_loader, test_loader = accelerator.prepare(
# model, valid_loader, test_loader)
# aims_ids = {k: v.to(accelerator.device) for k, v in aims_ids.items()}

# valid_loop = tqdm(valid_loader, leave=True)
# for batch in valid_loop:
#     with torch.no_grad():
#         inputs = {k:v.squeeze() for k, v in batch.items() if k != 'labels'}
#         labels = batch['labels']
#         logits = model(inputs, aims_ids)
#         logits = logits['logits']
#         loss = criterion_ce(logits, labels)
#         test_loss += loss.item()
#         score, num_cnt = compute_metrics(logits, labels)
#         valid_loop.set_postfix(Top_01=score[1], Top_03=score[3], Top_05=score[5], Top_10=score[10])
#         for k in (1, 3, 5, 10):
#             valid_score[k] += num_cnt[k]

# test_loop = tqdm(test_loader, leave=True)
# for batch in test_loop:
#     with torch.no_grad():
#         inputs = {k:v.squeeze() for k, v in batch.items() if k != 'labels'}
#         labels = batch['labels']
#         logits = model(inputs, aims_ids)
#         logits = logits['logits']
#         loss = criterion_ce(logits, labels)
#         test_loss += loss.item()
#         score, num_cnt = compute_metrics(logits, labels)
#         test_loop.set_postfix(Top_01=score[1], Top_03=score[3], Top_05=score[5], Top_10=score[10])
#         for k in (1, 3, 5, 10):
#             test_score[k] += num_cnt[k]

# test_loss /= len(test_loader)

# print(f"Test loss: {test_loss:.3f}")
# print("")

# for k in topks:
#     print(f"Test top{k} acc: {test_score[k]/len(test_dataset)}")