In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Now check the folder contents
!ls -lh /content/drive/MyDrive/saved_crossencoder

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
total 712M
-rw------- 1 root root   23 Sep  1 16:04 added_tokens.json
-rw------- 1 root root  783 Sep  1 16:04 config.json
-rw------- 1 root root 702M Sep  1 16:06 pytorch_model.bin
-rw------- 1 root root  286 Sep  1 16:04 special_tokens_map.json
-rw------- 1 root root 2.4M Sep  1 16:04 spm.model
-rw------- 1 root root 1.3K Sep  1 16:04 tokenizer_config.json
-rw------- 1 root root 8.3M Sep  1 16:04 tokenizer.json


In [None]:
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import torch, os

# Path on Google Drive
model_path = "/content/drive/MyDrive/saved_crossencoder"

class CrossEncoder(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        hidden_size = encoder.config.hidden_size
        self.reg_head = nn.Linear(hidden_size, 1)
        self.cls_head = nn.Linear(hidden_size, 4)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        reg_logits = self.reg_head(pooled_output).squeeze(-1)
        cls_logits = self.cls_head(pooled_output)
        return reg_logits, cls_logits

# Load encoder from Drive
encoder = AutoModel.from_pretrained(model_path, local_files_only=True)

# Load tokenizer from Drive
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)

# Build model and load CrossEncoder head weights
model = CrossEncoder(encoder)
state_dict = torch.load(os.path.join(model_path, "pytorch_model.bin"), map_location="cpu")
model.load_state_dict(state_dict)

model.eval()

Some weights of DebertaV2Model were not initialized from the model checkpoint at /content/drive/MyDrive/saved_crossencoder and are newly initialized: ['embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.word_embeddings.weight', 'encoder.LayerNorm.bias', 'encoder.LayerNorm.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key_proj.bias', 'encoder.layer.0.attention.self.key_proj.weight', 'encoder.layer.0.attention.self.query_proj.bias', 'encoder.layer.0.attention.self.query_proj.weight', 'encoder.layer.0.attention.self.value_proj.bias', 'encoder.layer.0.attention.self.value_proj.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.0.output.

CrossEncoder(
  (encoder): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 768, padding_idx=0)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-11): 12 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): Dropout(p=0.1, inplace=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-07

In [None]:
import pandas as pd
test_A = pd.read_csv("filtered_test.csv")

In [None]:
from collections import defaultdict

# Creating dict for product info as prod_groups
# and esci_label as label_groups
prod_groups_train_A = defaultdict(list)
prod_groups_test_A  = defaultdict(list)
label_groups_train_A = defaultdict(list)
label_groups_test_A  = defaultdict(list)

def get_dicts(df, prod_groups, label_groups):
  for _, row in df.iterrows():
    query = row["query"]
    product = row["product_input"]
    relevance = float(row["esci_label"])

    prod_groups[query].append(product)
    label_groups[query].append(relevance)

# get_dicts(train_A, prod_groups_train, label_groups_train)
get_dicts(test_A, prod_groups_test_A, label_groups_test_A)

In [None]:
import torch
from torch.utils.data import Dataset

class ESCIDataset(Dataset):
    def __init__(self, tokenizer, prod_groups, label_groups, max_len=128):
        self.tokenizer = tokenizer
        self.pairs = []
        self.reg_labels = []
        self.cls_labels = []

        ## Labels are 0.0(I), 0.01(C), 0.1(S) and 1.0(E),
        ## Models would prefer to promote with labels 1.0 and 0.1
        ## over 0.01 and 0.0

        score_to_index = {0.0: 0, 0.01: 1, 0.1: 2, 1.0: 3}

        for query in prod_groups:
            product_info = prod_groups[query]
            labels = label_groups[query]

            for idx, label in enumerate(labels):
                self.pairs.append((query, product_info[idx]))
                self.reg_labels.append(label)
                self.cls_labels.append(score_to_index[label])

        self.max_len = max_len

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        query, product = self.pairs[idx]
        reg_label = self.reg_labels[idx]
        cls_label = self.cls_labels[idx]

        encoded = self.tokenizer(
            query,
            product,
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt"
        )

        encoded = {k: v.squeeze(0) for k, v in encoded.items()}
        encoded["reg_label"] = torch.tensor(reg_label, dtype=torch.float)
        encoded["cls_label"] = torch.tensor(cls_label, dtype=torch.long)

        return encoded

In [None]:
from torch.utils.data import DataLoader

test_dataset = ESCIDataset(tokenizer, prod_groups_test_A, label_groups_test_A, max_len=128)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [None]:
# from sklearn.metrics import ndcg_score
# from collections import defaultdict
# import torch.nn.functional as F
# import torch
# from tqdm import tqdm

# model.eval()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)

# query_to_scores = defaultdict(list)
# query_to_labels = defaultdict(list)

# all_cls_preds = []
# all_cls_trues = []

# test_pairs = test_dataset.pairs
# reg_labels = test_dataset.reg_labels
# cls_labels = test_dataset.cls_labels

# batch_size = 16
# with torch.no_grad():
#     for i in tqdm(range(0, len(test_pairs), batch_size), desc="Evaluating"):
#         batch_pairs = test_pairs[i:i+batch_size]
#         batch_reg_labels = reg_labels[i:i+batch_size]
#         batch_cls_labels = cls_labels[i:i+batch_size]

#         queries = [q for q, _ in batch_pairs]
#         products = [p for _, p in batch_pairs]

#         encoded = tokenizer(
#             queries,
#             products,
#             padding="max_length",
#             truncation=True,
#             max_length=128,
#             return_tensors="pt"
#         )

#         input_ids = encoded["input_ids"].to(device)
#         attention_mask = encoded["attention_mask"].to(device)

#         reg_logits, cls_logits = model(input_ids=input_ids, attention_mask=attention_mask)

#         reg_scores = F.softplus(reg_logits).cpu().tolist()
#         cls_preds = torch.argmax(F.softmax(cls_logits, dim=-1), dim=-1).cpu().tolist()
#         cls_trues = batch_cls_labels

#         for q, s, l in zip(queries, reg_scores, batch_reg_labels):
#             query_to_scores[q].append(s)
#             query_to_labels[q].append(l)

#         all_cls_preds.extend(cls_preds)
#         all_cls_trues.extend(cls_trues)

# ndcg_total = 0
# qualifiable_count = 0

# for q in query_to_labels:
#     labels = query_to_labels[q]
#     scores = query_to_scores[q]
#     if len(labels) > 1 and sum(labels) > 0:
#         try:
#             ndcg = ndcg_score([labels], [scores], k=10)
#             ndcg_total += ndcg
#             qualifiable_count += 1
#         except ValueError:
#             continue

# avg_ndcg_10 = ndcg_total / qualifiable_count if qualifiable_count > 0 else 0

# from sklearn.metrics import accuracy_score
# accuracy = accuracy_score(all_cls_trues, all_cls_preds)

# print(f"Average NDCG@10 (for {qualifiable_count} qualifiable queries): {avg_ndcg_10:.4f}")
# print(f"Classification Accuracy: {accuracy:.4f}")

### Retraining the Sub-Sample

In [None]:
import pandas as pd

train_B = pd.read_csv("filtered_train_B.csv")
test_B = pd.read_csv("filtered_test_B.csv")

train_B.shape, test_B.shape

((7227, 3), (1831, 3))

In [None]:
from collections import defaultdict

# Creating dict for product info as prod_groups
# and esci_label as label_groups
prod_groups_train_B = defaultdict(list)
prod_groups_test_B  = defaultdict(list)
label_groups_train_B = defaultdict(list)
label_groups_test_B  = defaultdict(list)

def get_dicts(df, prod_groups, label_groups):
  for _, row in df.iterrows():
    query = row["query"]
    product = row["product_input"]
    relevance = float(row["esci_label"])

    prod_groups[query].append(product)
    label_groups[query].append(relevance)

get_dicts(train_B, prod_groups_train_B, label_groups_train_B)
get_dicts(test_B, prod_groups_test_B, label_groups_test_B)

In [None]:
import torch.nn.functional as F

def list_ce_loss(logits, labels):
    true_dist = F.softmax(labels, dim=0)
    log_pred_dist = F.log_softmax(logits, dim=0)
    return -torch.sum(true_dist * log_pred_dist)

def rcr_loss_function(logits, reg_labels, alpha=0.3):
    softplus_logits = F.softplus(logits)

    mse_loss = F.mse_loss(softplus_logits, reg_labels)
    listwise_loss = list_ce_loss(softplus_logits, reg_labels)

    # return (1 - alpha) * mse_loss + alpha * listwise_loss
    return {
        "mse" : mse_loss,
        "listwise" : listwise_loss
    }

def multitask_loss(reg_logits, cls_logits, reg_labels, cls_labels, x=1/3, alpha=0.5):
    """
    x: weight for classification vs regression
    alpha: weight inside RCR loss
    """
    rcr = rcr_loss_function(reg_logits, reg_labels, alpha)
    ce = F.cross_entropy(cls_logits, cls_labels)
    # return (1 - x) * rcr + x * ce
    return {
        "mse" : rcr["mse"],
        "listwise" : rcr["listwise"],
        "ce" : ce
    }

In [None]:
from torch.utils.data import DataLoader
import torch.nn.functional as F

train_dataset = ESCIDataset(tokenizer, prod_groups_train_B, label_groups_train_B, max_len=128)
test_dataset = ESCIDataset(tokenizer, prod_groups_test_B, label_groups_test_B, max_len=128)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

In [None]:
from collections import OrderedDict
from copy import deepcopy

def get_grads(loss, model):
    """
    Returns a 1D tensor of gradients for all model parameters, filling zeros if any parameter is unused.
    """
    param_list = [p for p in model.parameters() if p.requires_grad]

    grads = torch.autograd.grad(
        outputs=loss,
        inputs=param_list,
        retain_graph=True,
        create_graph=False,
        allow_unused=True
    )

    flat_grads = []
    for p, g in zip(param_list, grads):
        if g is None:
            flat_grads.append(torch.zeros_like(p).view(-1))  # zero for unused params
        else:
            flat_grads.append(g.contiguous().view(-1))

    return torch.cat(flat_grads)

In [None]:
import torch.optim as optim

def solve_nash_weights(task_grads, lr=8e-6, weight_decay=0.01, steps=20):
  """
  task_grads: list of grad vectors for each task
  """

  T = len(task_grads)
  G = torch.stack(task_grads)

  w = torch.ones(T, device = G.device, requires_grad=True)

  optimizer = optim.AdamW([w], lr=lr, weight_decay=weight_decay)

  for _ in range(steps):
    optimizer.zero_grad()

    agg_grad = torch.matmul(w, G)
    grad_norm_sq = torch.sum(agg_grad ** 2)

    loss = -torch.sum(torch.log(w + 1e-8) + 0.5 * grad_norm_sq)

    loss.backward()
    optimizer.step()

    with torch.no_grad():
      w.clamp_(min=1e-4)
      w /= w.sum()

  return w.detach()

In [None]:
def save_old_params(model):
    return {n: p.clone().detach() for n, p in model.named_parameters()}

def compute_fisher(model, dataloader, num_samples=200, device="cuda"):
    model.eval()
    fisher = {n: torch.zeros_like(p, device=device) for n, p in model.named_parameters()}

    data_iter = iter(dataloader)
    for i in range(num_samples):
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(dataloader)
            batch = next(data_iter)

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        cls_labels = batch["cls_label"].to(device)

        reg_logits, cls_logits = model(input_ids, attention_mask)

        # Fisher usually uses CE loss on classification head
        log_probs = torch.log_softmax(cls_logits, dim=-1)
        class_ind = torch.multinomial(torch.exp(log_probs[0]), 1).item()
        log_prob = log_probs[0, class_ind]

        model.zero_grad()
        log_prob.backward(retain_graph=True)

        for n, p in model.named_parameters():
            if p.grad is not None:
                fisher[n] += p.grad.detach() ** 2

    for n in fisher:
        fisher[n] /= num_samples

    return fisher

In [None]:
def ewc_loss(model, multitask_losses, fisher, old_params, lambda_ewc=1.0, weights=None):
    # combine multitask losses
    if weights is None:
        combined = multitask_losses["mse"] + multitask_losses["listwise"] + multitask_losses["ce"]
    else:
        combined = weights[0]*multitask_losses["mse"] + weights[1]*multitask_losses["listwise"] + weights[2]*multitask_losses["ce"]

    # EWC penalty
    ewc_penalty = 0.0
    for n, p in model.named_parameters():
        if n in fisher:
            ewc_penalty += (fisher[n] * (p - old_params[n])**2).sum()

    return combined + (lambda_ewc / 2) * ewc_penalty

In [None]:
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# after finishing Task A
old_params = save_old_params(model)
fisher = compute_fisher(model, train_loader, num_samples=200, device=device)
optimizer = optim.AdamW(model.parameters(), lr=1e-5)

lambda_ewc = 0.4  # tune this
num_epochs = 1

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    total_batches = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", ncols=100)

    for batch in progress_bar:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        reg_labels = batch["reg_label"].to(device)
        cls_labels = batch["cls_label"].to(device)

        reg_logits, cls_logits = model(input_ids, attention_mask)
        multitask_losses = multitask_loss(reg_logits, cls_logits, reg_labels, cls_labels)

        # Compute total EWC loss directly
        loss = ewc_loss(model, multitask_losses, fisher, old_params, lambda_ewc)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_batches += 1

        progress_bar.set_postfix({"Loss": f"{loss.item():.4f}"})

    print(f"Epoch [{epoch+1}], Avg Loss={total_loss/total_batches:.4f}")

Epoch 1/1: 100%|███████████████████████████████████| 3614/3614 [12:36<00:00,  4.78it/s, Loss=0.0518]

Epoch [1], Avg Loss=1.2814





In [None]:
import torch
import torch.nn.functional as F
from sklearn.metrics import ndcg_score, accuracy_score, f1_score, classification_report
from collections import defaultdict
from tqdm import tqdm

def evaluate_model(model, tokenizer, dataset, device, batch_size=16, max_len=128):
    model.eval()

    query_to_scores = defaultdict(list)
    query_to_labels = defaultdict(list)

    all_cls_preds = []
    all_cls_trues = []

    test_pairs = dataset.pairs
    reg_labels = dataset.reg_labels
    cls_labels = dataset.cls_labels

    with torch.no_grad():
        for i in tqdm(range(0, len(test_pairs), batch_size), desc="Evaluating"):
            batch_pairs = test_pairs[i:i+batch_size]
            batch_reg_labels = reg_labels[i:i+batch_size]
            batch_cls_labels = cls_labels[i:i+batch_size]

            queries = [q for q, _ in batch_pairs]
            products = [p for _, p in batch_pairs]

            encoded = tokenizer(
                queries,
                products,
                padding="max_length",
                truncation=True,
                max_length=max_len,
                return_tensors="pt"
            )

            input_ids = encoded["input_ids"].to(device)
            attention_mask = encoded["attention_mask"].to(device)

            reg_logits, cls_logits = model(input_ids=input_ids, attention_mask=attention_mask)

            # Regression scores
            reg_scores = F.softplus(reg_logits).cpu().tolist()
            cls_preds = torch.argmax(F.softmax(cls_logits, dim=-1), dim=-1).cpu().tolist()

            for q, s, l in zip(queries, reg_scores, batch_reg_labels):
                query_to_scores[q].append(float(s))
                query_to_labels[q].append(float(l))

            all_cls_preds.extend(cls_preds)
            all_cls_trues.extend(batch_cls_labels)

    # ----- Ranking Metric -----
    ndcg_total = 0
    qualifiable_count = 0

    for q in query_to_labels:
        labels = query_to_labels[q]
        scores = query_to_scores[q]
        if len(labels) > 1 and sum(labels) > 0:
            try:
                ndcg = ndcg_score([labels], [scores], k=10)
                ndcg_total += ndcg
                qualifiable_count += 1
            except ValueError:
                continue

    avg_ndcg_10 = ndcg_total / qualifiable_count if qualifiable_count > 0 else 0

    # ----- Classification Metrics -----
    accuracy = accuracy_score(all_cls_trues, all_cls_preds)

    print(f"Average NDCG@10 (for {qualifiable_count} queries): {avg_ndcg_10:.4f}")
    print(f"Classification Accuracy: {accuracy:.4f}")

    return avg_ndcg_10, accuracy


avg_ndcg, acc, f1 = evaluate_model(
    model,
    tokenizer,
    test_dataset,
    device,
    batch_size=16,
    max_len=128
)

Evaluating: 100%|██████████| 115/115 [00:17<00:00,  6.68it/s]

Average NDCG@10 (for 100 queries): 0.8963
Classification Accuracy: 0.7231





ValueError: not enough values to unpack (expected 3, got 2)