In [2]:
import argparse
import random
import torch
import pandas as pd
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.optim import Adam
from tqdm import tqdm

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Helper functions to load and process the data into desired format needed for MF
# For MF we need a "model ID" either in the form of name or index and so we use the tabular data instead of tensors
def load_and_process_data(train_data, test_data, batch_size=64 , num_model_use = 112):
    # NOTE: Due to the nature of the embedding layer we need to take max prompt ID from both train and test data
    # But during training we won't be using test question
    num_prompts = int(max(max(train_data["prompt_id"]), max(test_data["prompt_id"]))) + 1

    model_ids = test_data["model_id"].unique()
    selected_model_ids = model_ids[:num_model_use]

    train_data = train_data[train_data["model_id"].isin(selected_model_ids)]
    test_data = test_data[test_data["model_id"].isin(selected_model_ids)]

    class CustomDataset(Dataset):
        def __init__(self, data):
            model_ids = torch.tensor(data["model_id"], dtype=torch.int64)
            unique_ids, inverse_indices = torch.unique(model_ids, sorted=True, return_inverse=True)
            id_to_rank = {id.item(): rank for rank, id in enumerate(unique_ids)}
            ranked_model_ids = torch.tensor([id_to_rank[id.item()] for id in model_ids])
            self.models = ranked_model_ids
            self.prompts = torch.tensor(data["prompt_id"], dtype=torch.int64)
            self.labels = torch.tensor(data["label"], dtype=torch.int64)
            self.num_models = len(data["model_id"].unique())
            self.num_prompts = num_prompts
            self.num_classes = len(data["label"].unique())

        def get_num_models(self):
            return self.num_models

        def get_num_prompts(self):
            return self.num_prompts

        def get_num_classes(self):
            return self.num_classes

        def __len__(self):
            return len(self.models)

        def __getitem__(self, index):
            return self.models[index], self.prompts[index], self.labels[index]

        def get_dataloaders(self, batch_size):
            return DataLoader(self, batch_size, shuffle=False)

    train_dataset = CustomDataset(train_data)
    test_dataset = CustomDataset(test_data)

    train_loader = train_dataset.get_dataloaders(batch_size)
    test_loader = test_dataset.get_dataloaders(batch_size)

    return train_loader, test_loader

class TextMF(nn.Module):
    def __init__(self, question_embeddings, model_embedding_dim, alpha, num_models, num_prompts, text_dim=768, num_classes=2):
        super(TextMF, self).__init__()
        # Model embedding network
        self.P = nn.Embedding(num_models, model_embedding_dim)

        # Question embedding network
        self.Q = nn.Embedding(num_prompts, text_dim).requires_grad_(False)
        self.Q.weight.data.copy_(question_embeddings)
        self.text_proj = nn.Linear(text_dim, model_embedding_dim)

        # Noise/Regularization level
        self.alpha = alpha
        self.classifier = nn.Linear(model_embedding_dim, num_classes)

    def forward(self, model, prompt, test_mode=False):
        p = self.P(model)
        q = self.Q(prompt)
        if not test_mode:
            # Adding a small amount of noise in question embedding to reduce overfitting
            q += torch.randn_like(q) * self.alpha
        q = self.text_proj(q)
        return self.classifier(p * q)
    
    @torch.no_grad()
    def predict(self, model, prompt):
        logits = self.forward(model, prompt, test_mode=True) # During inference no noise is applied
        return torch.argmax(logits, dim=1)
    
def evaluate(net, test_loader, device):
    net.eval()
    loss_fn = nn.CrossEntropyLoss(reduction="sum")
    total_loss = 0
    correct = 0
    num_samples = 0

    with torch.no_grad():
        for models, prompts, labels in test_loader:
            models, prompts, labels = models.to(device), prompts.to(device), labels.to(device)
            logits = net(models, prompts)
            loss = loss_fn(logits, labels)
            pred_labels = net.predict(models, prompts)
            correct += (pred_labels == labels).sum().item()
            total_loss += loss.item()
            num_samples += labels.shape[0]

    mean_loss = total_loss / num_samples
    accuracy = correct / num_samples
    net.train()
    return mean_loss, accuracy

# Main training loop
def train(net, train_loader, test_loader, num_epochs, lr, device, weight_decay=1e-5, save_path=None):
    optimizer = Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.CrossEntropyLoss()
    progress_bar = tqdm(total=num_epochs)

    for epoch in range(num_epochs):
        net.train()
        total_loss = 0
        for models, prompts, labels in train_loader:
            models, prompts, labels = models.to(device), prompts.to(device), labels.to(device)

            optimizer.zero_grad()
            logits = net(models, prompts)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss}")

        test_loss, test_accuracy = evaluate(net, test_loader, device)
        print(f"Test Loss: {test_loss}, Test Accuracy: {test_accuracy}")

        progress_bar.set_postfix(train_loss=train_loss, test_loss=test_loss, test_acc=test_accuracy)
        progress_bar.update(1)
    
    if save_path:
        torch.save(net.state_dict(), save_path)
        print(f"Model saved to {save_path}")

def load_model(net, path, device):
    net.load_state_dict(torch.load(path, map_location=device))
    print(f"Model loaded from {path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--embedding_dim", type=int, default=232)
    parser.add_argument("--alpha", type=float, default=0.05)
    parser.add_argument("--batch_size", type=int, default=2048)
    parser.add_argument("--num_epochs", type=int, default=50)
    parser.add_argument("--learning_rate", type=float, default=1e-4)

    parser.add_argument("--train_data_path", type=str, default="data/train.csv")
    parser.add_argument("--test_data_path", type=str, default="data/test.csv")
    parser.add_argument("--question_embedding_path", type=str, default="data/question_embeddings.pth")

    parser.add_argument("--embedding_save_path", type=str, default="data/model_embeddings.pth")
    parser.add_argument("--model_save_path", type=str, default="data/model_with_fir90.pth")
    parser.add_argument("--model_load_path", type=str, default=None)
    args,unknown = parser.parse_known_args()

    print("Loading dataset...")
    train_data = pd.read_csv(args.train_data_path)
    test_data = pd.read_csv(args.test_data_path)
    question_embeddings = torch.load(args.question_embedding_path)
    num_prompts = question_embeddings.shape[0]
    num_models = len(test_data["model_id"].unique())
    model_names = list(np.unique(list(test_data["model_name"])))

    train_loader, test_loader = load_and_process_data(train_data, test_data, batch_size=args.batch_size , num_model_use=90)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("Initializing model...")
    model = TextMF(question_embeddings=question_embeddings, 
                   model_embedding_dim=args.embedding_dim, alpha=args.alpha,
                   num_models=num_models, num_prompts=num_prompts)
    model.to(device)

    if args.model_load_path:
        model.load_state_dict(torch.load(args.model_load_path, map_location=device))
        print(f"Model loaded from {args.model_load_path}")

    print("Training model...")
    train(model, train_loader, test_loader, num_epochs=args.num_epochs, lr=args.learning_rate,
          device=device, save_path=args.model_save_path)
    # torch.save(model.P.weight.detach().to("cpu"), args.embedding_save_path) # Save model embeddings if needed

Loading dataset...


  question_embeddings = torch.load(args.question_embedding_path)


Initializing model...
Training model...


  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1/50, Train Loss: 0.6807110990057329


  2%|▏         | 1/50 [01:22<1:07:21, 82.47s/it, test_acc=0.628, test_loss=0.654, train_loss=0.681]

Test Loss: 0.6544452969152248, Test Accuracy: 0.6278978892412158
Epoch 2/50, Train Loss: 0.6504853306090584


  4%|▍         | 2/50 [02:47<1:07:22, 84.21s/it, test_acc=0.648, test_loss=0.633, train_loss=0.65] 

Test Loss: 0.6326087845849073, Test Accuracy: 0.6477435059735435
Epoch 3/50, Train Loss: 0.6315780720893112


  6%|▌         | 3/50 [04:48<1:19:00, 100.86s/it, test_acc=0.668, test_loss=0.615, train_loss=0.632]

Test Loss: 0.614596012873384, Test Accuracy: 0.6680727715319816
Epoch 4/50, Train Loss: 0.6152618961905191


  8%|▊         | 4/50 [08:24<1:52:08, 146.26s/it, test_acc=0.683, test_loss=0.6, train_loss=0.615]  

Test Loss: 0.5996233367846612, Test Accuracy: 0.683440632907338
Epoch 5/50, Train Loss: 0.6022179891904673


 10%|█         | 5/50 [12:00<2:08:27, 171.28s/it, test_acc=0.694, test_loss=0.588, train_loss=0.602]

Test Loss: 0.5879719311161383, Test Accuracy: 0.6937445949841453
Epoch 6/50, Train Loss: 0.5922690026845308


 12%|█▏        | 6/50 [15:39<2:17:29, 187.50s/it, test_acc=0.702, test_loss=0.58, train_loss=0.592] 

Test Loss: 0.5798487531585561, Test Accuracy: 0.7022645014573524
Epoch 7/50, Train Loss: 0.5850809007116613


 14%|█▍        | 7/50 [19:20<2:22:18, 198.58s/it, test_acc=0.709, test_loss=0.574, train_loss=0.585]

Test Loss: 0.5738293990338659, Test Accuracy: 0.708692866980558
Epoch 8/50, Train Loss: 0.5797644727673623


 16%|█▌        | 8/50 [23:00<2:23:53, 205.56s/it, test_acc=0.714, test_loss=0.569, train_loss=0.58] 

Test Loss: 0.5693673394926116, Test Accuracy: 0.7141827616027674
Epoch 9/50, Train Loss: 0.575835395910458


 18%|█▊        | 9/50 [26:42<2:23:47, 210.42s/it, test_acc=0.718, test_loss=0.566, train_loss=0.576]

Test Loss: 0.5664414430106764, Test Accuracy: 0.7182313186637199
Epoch 10/50, Train Loss: 0.572728900971736


 20%|██        | 10/50 [30:22<2:22:23, 213.59s/it, test_acc=0.721, test_loss=0.564, train_loss=0.573]

Test Loss: 0.5638742909670174, Test Accuracy: 0.7208673649146408
Epoch 11/50, Train Loss: 0.5703616764896619


 22%|██▏       | 11/50 [34:03<2:20:16, 215.81s/it, test_acc=0.723, test_loss=0.562, train_loss=0.57] 

Test Loss: 0.561941974828422, Test Accuracy: 0.7230453861183178
Epoch 12/50, Train Loss: 0.5688361604918857


 24%|██▍       | 12/50 [37:44<2:17:43, 217.47s/it, test_acc=0.725, test_loss=0.561, train_loss=0.569]

Test Loss: 0.5608106185821693, Test Accuracy: 0.7245667979885334
Epoch 13/50, Train Loss: 0.567524060945393


 26%|██▌       | 13/50 [41:27<2:15:00, 218.93s/it, test_acc=0.726, test_loss=0.56, train_loss=0.568] 

Test Loss: 0.5597873613823002, Test Accuracy: 0.7258640017936645
Epoch 14/50, Train Loss: 0.566496181486544


 28%|██▊       | 14/50 [45:08<2:11:51, 219.75s/it, test_acc=0.727, test_loss=0.56, train_loss=0.566]

Test Loss: 0.5596813886255195, Test Accuracy: 0.7266102943531597
Epoch 15/50, Train Loss: 0.5656624946994588


 30%|███       | 15/50 [48:51<2:08:41, 220.62s/it, test_acc=0.727, test_loss=0.558, train_loss=0.566]

Test Loss: 0.5584514787144345, Test Accuracy: 0.7273565869126549
Epoch 16/50, Train Loss: 0.5654482879292073


 32%|███▏      | 16/50 [52:34<2:05:27, 221.39s/it, test_acc=0.728, test_loss=0.558, train_loss=0.565]

Test Loss: 0.5583079181464774, Test Accuracy: 0.7281541270298837
Epoch 17/50, Train Loss: 0.5648581151842434


 34%|███▍      | 17/50 [56:18<2:02:06, 222.01s/it, test_acc=0.728, test_loss=0.558, train_loss=0.565]

Test Loss: 0.557717568641981, Test Accuracy: 0.7284295826527017
Epoch 18/50, Train Loss: 0.5645519622595969


 36%|███▌      | 18/50 [1:00:02<1:58:44, 222.63s/it, test_acc=0.729, test_loss=0.558, train_loss=0.565]

Test Loss: 0.5578546424284861, Test Accuracy: 0.7288651868934372
Epoch 19/50, Train Loss: 0.5641841499175958


 38%|███▊      | 19/50 [1:03:44<1:55:01, 222.62s/it, test_acc=0.729, test_loss=0.557, train_loss=0.564]

Test Loss: 0.557464870172958, Test Accuracy: 0.729057365234938
Epoch 20/50, Train Loss: 0.5638193437430209


 40%|████      | 20/50 [1:07:28<1:51:28, 222.96s/it, test_acc=0.729, test_loss=0.557, train_loss=0.564]

Test Loss: 0.5571592817040464, Test Accuracy: 0.7294128951667147
Epoch 21/50, Train Loss: 0.5637638591299594


 42%|████▏     | 21/50 [1:11:27<1:50:09, 227.90s/it, test_acc=0.73, test_loss=0.557, train_loss=0.564] 

Test Loss: 0.5567995712325726, Test Accuracy: 0.729617885397649
Epoch 22/50, Train Loss: 0.5638358660773906


 44%|████▍     | 22/50 [1:15:38<1:49:30, 234.66s/it, test_acc=0.73, test_loss=0.557, train_loss=0.564]

Test Loss: 0.5566405293057269, Test Accuracy: 0.7298484994074501
Epoch 23/50, Train Loss: 0.5632933183469291


 46%|████▌     | 23/50 [1:19:58<1:49:03, 242.36s/it, test_acc=0.73, test_loss=0.556, train_loss=0.563]

Test Loss: 0.5563967492692594, Test Accuracy: 0.7300118509977259
Epoch 24/50, Train Loss: 0.5632681323883528


 48%|████▊     | 24/50 [1:24:30<1:48:47, 251.07s/it, test_acc=0.73, test_loss=0.557, train_loss=0.563]

Test Loss: 0.5567031293907934, Test Accuracy: 0.7300022420806509
Epoch 25/50, Train Loss: 0.5630081539019275


 50%|█████     | 25/50 [1:29:29<1:50:37, 265.52s/it, test_acc=0.73, test_loss=0.557, train_loss=0.563]

Test Loss: 0.5565068861442481, Test Accuracy: 0.7301271580026264
Epoch 26/50, Train Loss: 0.5630562276558909


 52%|█████▏    | 26/50 [1:35:01<1:54:15, 285.63s/it, test_acc=0.73, test_loss=0.557, train_loss=0.563]

Test Loss: 0.5569067992558031, Test Accuracy: 0.7303033214823356
Epoch 27/50, Train Loss: 0.5630477965326013


 54%|█████▍    | 27/50 [1:40:22<1:53:34, 296.27s/it, test_acc=0.73, test_loss=0.557, train_loss=0.563]

Test Loss: 0.5568529366736706, Test Accuracy: 0.7303994106530861
Epoch 28/50, Train Loss: 0.5629619428354179


 56%|█████▌    | 28/50 [1:45:49<1:51:56, 305.30s/it, test_acc=0.73, test_loss=0.557, train_loss=0.563]

Test Loss: 0.5565942162099866, Test Accuracy: 0.730389801736011
Epoch 29/50, Train Loss: 0.5628376626717543


 58%|█████▊    | 29/50 [1:51:24<1:50:01, 314.34s/it, test_acc=0.73, test_loss=0.557, train_loss=0.563]

Test Loss: 0.5565483781644679, Test Accuracy: 0.7304154255148778
Epoch 30/50, Train Loss: 0.5627488585819821


 60%|██████    | 30/50 [1:56:52<1:46:06, 318.31s/it, test_acc=0.731, test_loss=0.557, train_loss=0.563]

Test Loss: 0.5565201823698642, Test Accuracy: 0.7306204157458122
Epoch 31/50, Train Loss: 0.5625591369653115


 62%|██████▏   | 31/50 [2:02:20<1:41:43, 321.26s/it, test_acc=0.731, test_loss=0.557, train_loss=0.563]

Test Loss: 0.5567251446451682, Test Accuracy: 0.730729316805996
Epoch 32/50, Train Loss: 0.5625625770058764


 64%|██████▍   | 32/50 [2:08:04<1:38:27, 328.22s/it, test_acc=0.731, test_loss=0.556, train_loss=0.563]

Test Loss: 0.5561183796140944, Test Accuracy: 0.7307453316677877
Epoch 33/50, Train Loss: 0.5624753180804177


 66%|██████▌   | 33/50 [2:13:52<1:34:36, 333.93s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.5563099460596308, Test Accuracy: 0.7306940841100541
Epoch 34/50, Train Loss: 0.5625047735974401


 68%|██████▊   | 34/50 [2:19:28<1:29:15, 334.73s/it, test_acc=0.731, test_loss=0.556, train_loss=0.563]

Test Loss: 0.5564064124358967, Test Accuracy: 0.7308574357003299
Epoch 35/50, Train Loss: 0.5621841845905773


 70%|███████   | 35/50 [2:25:07<1:23:57, 335.84s/it, test_acc=0.731, test_loss=0.557, train_loss=0.562]

Test Loss: 0.5565407528264354, Test Accuracy: 0.7309535248710803
Epoch 36/50, Train Loss: 0.5623430688887258


 72%|███████▏  | 36/50 [2:31:00<1:19:37, 341.22s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.556368538519371, Test Accuracy: 0.7309150892027801
Epoch 37/50, Train Loss: 0.5623356988609992


 74%|███████▍  | 37/50 [2:36:43<1:14:02, 341.71s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.5562099158292333, Test Accuracy: 0.7308830594791967
Epoch 38/50, Train Loss: 0.5622418082767753


 76%|███████▌  | 38/50 [2:42:26<1:08:24, 342.08s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.5560800403242959, Test Accuracy: 0.7310079754011722
Epoch 39/50, Train Loss: 0.5622603011005866


 78%|███████▊  | 39/50 [2:48:08<1:02:42, 342.07s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.5562964686111525, Test Accuracy: 0.7309695397328722
Epoch 40/50, Train Loss: 0.5622362530775961


 80%|████████  | 40/50 [2:53:53<57:10, 343.01s/it, test_acc=0.731, test_loss=0.557, train_loss=0.562]  

Test Loss: 0.5567399949496883, Test Accuracy: 0.7311072675442811
Epoch 41/50, Train Loss: 0.5624198709815521


 82%|████████▏ | 41/50 [2:59:36<51:26, 342.94s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.556025513356082, Test Accuracy: 0.7311200794337145
Epoch 42/50, Train Loss: 0.5623677641616197


 84%|████████▍ | 42/50 [3:05:21<45:48, 343.60s/it, test_acc=0.731, test_loss=0.557, train_loss=0.562]

Test Loss: 0.5566726987097151, Test Accuracy: 0.7311328913231478
Epoch 43/50, Train Loss: 0.5621560989650751


 86%|████████▌ | 43/50 [3:11:11<40:18, 345.43s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.5562287329985977, Test Accuracy: 0.7311008615995644
Epoch 44/50, Train Loss: 0.5622403922434163


 88%|████████▊ | 44/50 [3:17:07<34:50, 348.45s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.5558239844427191, Test Accuracy: 0.7310496140418308
Epoch 45/50, Train Loss: 0.5621033234916174


 90%|█████████ | 45/50 [3:23:18<29:36, 355.37s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.556477982759407, Test Accuracy: 0.7311328913231478
Epoch 46/50, Train Loss: 0.5621574140672057


 92%|█████████▏| 46/50 [3:29:26<23:56, 359.14s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.5561946300620101, Test Accuracy: 0.7311136734889978
Epoch 47/50, Train Loss: 0.5621021206931783


 94%|█████████▍| 47/50 [3:35:24<17:56, 358.80s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.556087314457504, Test Accuracy: 0.731152109157298
Epoch 48/50, Train Loss: 0.5621113390073867


 96%|█████████▌| 48/50 [3:41:22<11:57, 358.60s/it, test_acc=0.731, test_loss=0.557, train_loss=0.562]

Test Loss: 0.5569461842425708, Test Accuracy: 0.7312578072451235
Epoch 49/50, Train Loss: 0.5619463625033657


 98%|█████████▊| 49/50 [3:47:18<05:57, 357.77s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.5562493885927703, Test Accuracy: 0.7311296883507895
Epoch 50/50, Train Loss: 0.5618920626545557


100%|██████████| 50/50 [3:53:15<00:00, 357.41s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Test Loss: 0.5562164059340818, Test Accuracy: 0.7312353864386151


100%|██████████| 50/50 [3:53:15<00:00, 279.91s/it, test_acc=0.731, test_loss=0.556, train_loss=0.562]

Model saved to data/model_with_fir90.pth



