In [None]:
from re import escape
import pandas as pd
import csv

processed_train_df = pd.read_csv(
    'processed_train_updated.csv',
    quoting=csv.QUOTE_ALL,
    escapechar='\\',
    engine='python',
    encoding='utf-8-sig',
    on_bad_lines = 'skip'
)

processed_test_df = pd.read_csv(
    'processed_test_updated.csv',
    quoting=csv.QUOTE_ALL,
    escapechar='\\',
    engine='python',
    encoding='utf-8-sig'
)

In [None]:
print(processed_train_df.info())
print('\n')
print('=' * 65)
print('\n')
processed_test_df.info()
print('\n')
print('=' * 65)

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 50000 entries, 0 to 49999
Data columns (total 3 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   query          50000 non-null  object 
 1   product_input  50000 non-null  object 
 2   esci_label     50000 non-null  float64
dtypes: float64(1), object(2)
memory usage: 1.1+ MB
None




<class 'pandas.core.frame.DataFrame'>
RangeIndex: 10000 entries, 0 to 9999
Data columns (total 3 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   query          10000 non-null  object 
 1   product_input  10000 non-null  object 
 2   esci_label     10000 non-null  float64
dtypes: float64(1), object(2)
memory usage: 234.5+ KB




## Training Bi-Encoder

In [None]:
from transformers import AutoModel, AutoTokenizer

model_name = "microsoft/deberta-v3-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
encoder = AutoModel.from_pretrained(model_name)



In [None]:
from collections import defaultdict

# Creating dict for product info as prod_groups
# and esci_label as label_groups
prod_groups_train = defaultdict(list)
prod_groups_test  = defaultdict(list)
label_groups_train = defaultdict(list)
label_groups_test  = defaultdict(list)

def get_dicts(df, prod_groups, label_groups):
    for _, row in df.iterrows():
        query = row["query"]
        product = row["product_input"]
        relevance = float(row["esci_label"])

        prod_groups[query].append(product)
        label_groups[query].append(relevance)

get_dicts(processed_train_df, prod_groups_train, label_groups_train)
get_dicts(processed_test_df, prod_groups_test, label_groups_test)

In [None]:
from torch.utils.data import Dataset
import random

class ESCI_Dataset(Dataset):
    def __init__(self, tokenizer, prod_groups, label_groups, max_len=128):
        self.tokenizer = tokenizer
        self.pairs = []
        self.labels = []

        ## Labels are 0.0(I), 0.01(C), 0.1(S) and 1.0(E),
        ## Models would prefer to promote with labels 1.0 and 0.1
        ## over 0.01 and 0.0
        for query in prod_groups:
            product_info = prod_groups[query]
            labels = label_groups[query]

            E_indices = [i for i, l in enumerate(labels) if l > 0.1]
            S_indices = [i for i, l in enumerate(labels) if l > 0.01]
            C_indices = [i for i, l in enumerate(labels) if l > 0]
            I_indices = [i for i, l in enumerate(labels) if l == 0]

            for idx in E_indices:
                pos_product = product_info[idx]
                self.pairs.append((query, pos_product))
                self.labels.append(1.0)

            for idx in S_indices:
                pos_product = product_info[idx]
                self.pairs.append((query, pos_product))
                self.labels.append(0.1)

            for idx in C_indices:
                pos_product = product_info[idx]
                self.pairs.append((query, pos_product))
                self.labels.append(0.01)

            for idx in I_indices:
                pos_product = product_info[idx]
                self.pairs.append((query, pos_product))
                self.labels.append(0.0)


        self.max_len = max_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        query, product = self.pairs[idx]
        label = self.labels[idx]

        query_encoded = self.tokenizer(
            query,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )

        product_encoded = self.tokenizer(
            product,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )

        return {
            "query_input_ids": query_encoded["input_ids"].squeeze(0),
            "query_attention_mask": query_encoded["attention_mask"].squeeze(0),
            "product_input_ids": product_encoded["input_ids"].squeeze(0),
            "product_attention_mask": product_encoded["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.float)
        }

In [None]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

train_dataset = ESCI_Dataset(tokenizer, prod_groups_train, label_groups_train, max_len=128)
test_dataset = ESCI_Dataset(tokenizer, prod_groups_test, label_groups_test, max_len=128)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [None]:
import torch
import torch.nn as nn

class BiEncoder(nn.Module):
    def __init__(self, model_name, dropout_rate=0.1):
        super(BiEncoder, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        hidden_size = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(dropout_rate)
        self.scorer = nn.Linear(hidden_size, 1)

    def encode(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # CLS token
        pooled_output = self.dropout(pooled_output)
        return pooled_output

    def forward(self, query_input_ids, query_attention_mask, product_input_ids, product_attention_masks):
        query_vec = self.encode(query_input_ids, query_attention_mask)
        product_vec = self.encode(product_input_ids, product_attention_masks)

        ## Cosine Similarity
        logits = F.cosine_similarity(query_vec, product_vec, dim=1)
        return logits

In [None]:
## Loss function as RCR Loss with Softplus function

def list_ce_loss(logits, labels):
  """
  logits: torch.Tensor,
  labels: torch.Tensor
  """
  true_dist = F.softmax(labels, dim = 0)
  log_pred_dist = F.log_softmax(logits, dim = 0)
  return -torch.sum(true_dist * log_pred_dist)

def rcr_loss_function(logits, labels, alpha):
  """
  logits: torch.Tensor,
  labels: torch.Tensor
  """
  reg_preds = F.softplus(logits)
  reg_loss = F.mse_loss(reg_preds, labels)
  listwise_loss = list_ce_loss(logits, labels)
  return (1 - alpha) * reg_loss + alpha * listwise_loss

In [None]:
import torch
from tqdm import tqdm
from transformers import get_scheduler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BiEncoder(model_name).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=8e-6, weight_decay=0.01)
alpha = 0.5
num_epochs = 3

lr_scheduler = get_scheduler(
    name="linear",
    optimizer = optimizer,
    num_warmup_steps = 100,
    num_training_steps = num_epochs * len(train_loader)
)

global_step = 0.0
losses = []

model.train()
for epoch in range(1):
    loop = tqdm(train_loader, desc = f"Epoch {epoch + 1} / {num_epochs}")

    for batch in loop:
        query_input_ids = batch["query_input_ids"].to(device)
        query_attention_mask = batch["query_attention_mask"].to(device)

        product_input_ids = batch["product_input_ids"].to(device)
        product_attention_mask = batch["product_attention_mask"].to(device)

        labels = batch["label"].to(device)

        logits = model(
            query_input_ids,
            query_attention_mask,
            product_input_ids,
            product_attention_mask
        )

        optimizer.zero_grad()

        loss = rcr_loss_function(logits, labels, alpha)
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        global_step += 1

        loop.set_postfix(loss = loss.item())

        losses.append(loss.item())

Epoch 1 / 3: 100%|██████████| 13894/13894 [1:59:55<00:00,  1.93it/s, loss=0.736]


In [None]:
from sklearn.metrics import ndcg_score
import torch
import numpy as np

model.eval()
query_to_scores = defaultdict(list)
query_to_labels = defaultdict(list)

test_pairs = test_dataset.pairs
test_labels = test_dataset.labels

batch_size = 8
with torch.no_grad():
    for i in tqdm(range(0, len(test_pairs), batch_size), desc="Evaluating"):
        batch_pairs = test_pairs[i:i+batch_size]
        batch_labels = test_labels[i:i+batch_size]

        queries = [q for q, _ in batch_pairs]
        products = [p for _, p in batch_pairs]

        query_enc = tokenizer(
            queries,
            padding="max_length",
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )
        product_enc = tokenizer(
            products,
            padding="max_length",
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )

        query_input_ids = query_enc["input_ids"].to(device)
        query_attention_mask = query_enc["attention_mask"].to(device)
        product_input_ids = product_enc["input_ids"].to(device)
        product_attention_mask = product_enc["attention_mask"].to(device)

        scores = model(
            query_input_ids=query_input_ids,
            query_attention_mask=query_attention_mask,
            product_input_ids=product_input_ids,
            product_attention_masks=product_attention_mask
        ).cpu().tolist()

        for q, s, l in zip(queries, scores, batch_labels[i:i+batch_size]):
            query_to_scores[q].append(s)
            query_to_labels[q].append(l)

Evaluating: 100%|██████████| 2671/2671 [06:58<00:00,  6.38it/s]


In [None]:
ndcg_total = 0
count = 0

for q in query_to_labels:
    if sum(query_to_labels[q]) > 0:
        y_true = [query_to_labels[q]]
        y_score = [query_to_scores[q]]
        try:
            ndcg = ndcg_score(y_true, y_score, k=10)
            ndcg_total += ndcg
            count += 1
        except:
            pass  # In case of malformed input, skip

avg_ndcg_10 = ndcg_total / count if count > 0 else 0
print(f"Average NDCG@10: {avg_ndcg_10:.4f}")

Average NDCG@10: 0.7140
