In [4]:
import argparse
import random
import torch
import pandas as pd
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.optim import Adam
from tqdm import tqdm

# Set seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Helper functions to load and process the data into desired format needed for MF
# For MF we need a "model ID" either in the form of name or index and so we use the tabular data instead of tensors
def load_and_process_data(train_data, test_data, batch_size=64):
    # NOTE: Due to the nature of the embedding layer we need to take max prompt ID from both train and test data
    # But during training we won't be using test question
    num_prompts = int(max(max(train_data["prompt_id"]), max(test_data["prompt_id"]))) + 1
    class CustomDataset(Dataset):
        def __init__(self, data):
            model_ids = torch.tensor(data["model_id"], dtype=torch.int64)
            unique_ids, inverse_indices = torch.unique(model_ids, sorted=True, return_inverse=True)
            id_to_rank = {id.item(): rank for rank, id in enumerate(unique_ids)}
            ranked_model_ids = torch.tensor([id_to_rank[id.item()] for id in model_ids])
            self.models = ranked_model_ids
            self.prompts = torch.tensor(data["prompt_id"], dtype=torch.int64)
            self.labels = torch.tensor(data["label"], dtype=torch.int64)
            self.num_models = len(data["model_id"].unique())
            self.num_prompts = num_prompts
            self.num_classes = len(data["label"].unique())

        def get_num_models(self):
            return self.num_models

        def get_num_prompts(self):
            return self.num_prompts

        def get_num_classes(self):
            return self.num_classes

        def __len__(self):
            return len(self.models)

        def __getitem__(self, index):
            return self.models[index], self.prompts[index], self.labels[index]

        def get_dataloaders(self, batch_size):
            return DataLoader(self, batch_size, shuffle=False)

    train_dataset = CustomDataset(train_data)
    test_dataset = CustomDataset(test_data)

    train_loader = train_dataset.get_dataloaders(batch_size)
    test_loader = test_dataset.get_dataloaders(batch_size)

    return train_loader, test_loader

class TextMF(nn.Module):
    def __init__(self, question_embeddings, model_embedding_dim, alpha, num_models, num_prompts, text_dim=768, num_classes=2):
        super(TextMF, self).__init__()
        # Model embedding network
        self.P = nn.Embedding(num_models, model_embedding_dim)

        # Question embedding network
        self.Q = nn.Embedding(num_prompts, text_dim).requires_grad_(False)
        self.Q.weight.data.copy_(question_embeddings)
        self.text_proj = nn.Linear(text_dim, model_embedding_dim)

        # Noise/Regularization level
        self.alpha = alpha
        self.classifier = nn.Linear(model_embedding_dim, num_classes)

    def forward(self, model, prompt, test_mode=False):
        p = self.P(model)
        q = self.Q(prompt)
        if not test_mode:
            # Adding a small amount of noise in question embedding to reduce overfitting
            q += torch.randn_like(q) * self.alpha
        q = self.text_proj(q)
        return self.classifier(p * q)
    
    @torch.no_grad()
    def predict(self, model, prompt):
        logits = self.forward(model, prompt, test_mode=True) # During inference no noise is applied
        return torch.argmax(logits, dim=1)
    
def evaluate(net, test_loader, device):
    net.eval()
    loss_fn = nn.CrossEntropyLoss(reduction="sum")
    total_loss = 0
    correct = 0
    num_samples = 0

    with torch.no_grad():
        for models, prompts, labels in test_loader:
            models, prompts, labels = models.to(device), prompts.to(device), labels.to(device)
            logits = net(models, prompts)
            loss = loss_fn(logits, labels)
            pred_labels = net.predict(models, prompts)
            correct += (pred_labels == labels).sum().item()
            total_loss += loss.item()
            num_samples += labels.shape[0]

    mean_loss = total_loss / num_samples
    accuracy = correct / num_samples
    net.train()
    return mean_loss, accuracy

# Main training loop
def train(net, train_loader, test_loader, num_epochs, lr, device, weight_decay=1e-5, save_path=None):
    optimizer = Adam(net.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = nn.CrossEntropyLoss()
    progress_bar = tqdm(total=num_epochs)

    for epoch in range(num_epochs):
        net.train()
        total_loss = 0
        for models, prompts, labels in train_loader:
            models, prompts, labels = models.to(device), prompts.to(device), labels.to(device)

            optimizer.zero_grad()
            logits = net(models, prompts)
            loss = loss_fn(logits, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        train_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss}")

        test_loss, test_accuracy = evaluate(net, test_loader, device)
        print(f"Test Loss: {test_loss}, Test Accuracy: {test_accuracy}")

        progress_bar.set_postfix(train_loss=train_loss, test_loss=test_loss, test_acc=test_accuracy)
        progress_bar.update(1)
    
    if save_path:
        torch.save(net.state_dict(), save_path)
        print(f"Model saved to {save_path}")

def load_model(net, path, device):
    net.load_state_dict(torch.load(path, map_location=device))
    print(f"Model loaded from {path}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--embedding_dim", type=int, default=232)
    parser.add_argument("--alpha", type=float, default=0.05)
    parser.add_argument("--batch_size", type=int, default=2048)
    parser.add_argument("--num_epochs", type=int, default=50)
    parser.add_argument("--learning_rate", type=float, default=1e-4)

    parser.add_argument("--train_data_path", type=str, default="data/train.csv")
    parser.add_argument("--test_data_path", type=str, default="data/test.csv")
    parser.add_argument("--question_embedding_path", type=str, default="data/question_embeddings.pth")

    parser.add_argument("--embedding_save_path", type=str, default="data/model_embeddings.pth")
    parser.add_argument("--model_save_path", type=str, default="data/saved_model.pth")
    parser.add_argument("--model_load_path", type=str, default=None)
    args, unknown = parser.parse_known_args()

    print("Loading dataset...")
    train_data = pd.read_csv(args.train_data_path)
    test_data = pd.read_csv(args.test_data_path)
    question_embeddings = torch.load(args.question_embedding_path)
    num_prompts = question_embeddings.shape[0]
    num_models = len(test_data["model_id"].unique())
    model_names = list(np.unique(list(test_data["model_name"])))

    train_loader, test_loader = load_and_process_data(train_data, test_data, batch_size=args.batch_size)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("Initializing model...")
    model = TextMF(question_embeddings=question_embeddings, 
                   model_embedding_dim=args.embedding_dim, alpha=args.alpha,
                   num_models=num_models, num_prompts=num_prompts)
    model.to(device)

    if args.model_load_path:
        model.load_state_dict(torch.load(args.model_load_path, map_location=device))
        print(f"Model loaded from {args.model_load_path}")

    print("Training model...")
    train(model, train_loader, test_loader, num_epochs=args.num_epochs, lr=args.learning_rate,
          device=device, save_path=args.model_save_path)
    # torch.save(model.P.weight.detach().to("cpu"), args.embedding_save_path) # Save model embeddings if needed

Loading dataset...


  question_embeddings = torch.load(args.question_embedding_path)


Initializing model...
Training model...




Epoch 1/50, Train Loss: 0.6762879089968727




Test Loss: 0.6474284955286161, Test Accuracy: 0.6350559033068401
Epoch 2/50, Train Loss: 0.6434537410032218




Test Loss: 0.6255477968424955, Test Accuracy: 0.6510187167977598
Epoch 3/50, Train Loss: 0.6247365225388424




Test Loss: 0.6079742108970931, Test Accuracy: 0.6711279495943664
Epoch 4/50, Train Loss: 0.609402252416168




Test Loss: 0.5937827536445451, Test Accuracy: 0.6865862949388462
Epoch 5/50, Train Loss: 0.5970634579242587




Test Loss: 0.5827348188577706, Test Accuracy: 0.6986986780875509
Epoch 6/50, Train Loss: 0.5876314163216007




Test Loss: 0.5747202365150886, Test Accuracy: 0.7073930321624181
Epoch 7/50, Train Loss: 0.5805213826114108




Test Loss: 0.5690599276658873, Test Accuracy: 0.7138121319441585
Epoch 8/50, Train Loss: 0.575341832222666




Test Loss: 0.5647261700117475, Test Accuracy: 0.7183883787011489
Epoch 9/50, Train Loss: 0.5714738889568862




Test Loss: 0.5614229027975461, Test Accuracy: 0.7216957336408186
Epoch 10/50, Train Loss: 0.5686679466643035




Test Loss: 0.5593264401904029, Test Accuracy: 0.72428241979986
Epoch 11/50, Train Loss: 0.566294658291231




Test Loss: 0.5579921871807864, Test Accuracy: 0.7261175513733887
Epoch 12/50, Train Loss: 0.5649243101789564




Test Loss: 0.5565771379310983, Test Accuracy: 0.7277570728493185
Epoch 13/50, Train Loss: 0.5636633304250189




Test Loss: 0.5555766776086903, Test Accuracy: 0.7289744883251658
Epoch 14/50, Train Loss: 0.5628745317423113




Test Loss: 0.5549774452689864, Test Accuracy: 0.729641106947247
Epoch 15/50, Train Loss: 0.5621980420137309




Test Loss: 0.5550406726626181, Test Accuracy: 0.73016873944735
Epoch 16/50, Train Loss: 0.5616287498010529




Test Loss: 0.5544893709108984, Test Accuracy: 0.7305110571181485
Epoch 17/50, Train Loss: 0.5612606944509303




Test Loss: 0.5540888555035267, Test Accuracy: 0.730979491825557
Epoch 18/50, Train Loss: 0.5609369792436089




Test Loss: 0.5536369839310769, Test Accuracy: 0.7311982662768192
Epoch 19/50, Train Loss: 0.5606305065154389




Test Loss: 0.5532011798641383, Test Accuracy: 0.7314839599719969
Epoch 20/50, Train Loss: 0.5605250936655415




Test Loss: 0.5536361672028685, Test Accuracy: 0.7317748013013219
Epoch 21/50, Train Loss: 0.5603049891154516




Test Loss: 0.55338346483444, Test Accuracy: 0.7319601161306264
Epoch 22/50, Train Loss: 0.5600986965040622




Test Loss: 0.5527947087345438, Test Accuracy: 0.7321866120331095
Epoch 23/50, Train Loss: 0.5600409645625198




Test Loss: 0.5533968988274132, Test Accuracy: 0.7324414199234032
Epoch 24/50, Train Loss: 0.5598665665649177




Test Loss: 0.5535178170174512, Test Accuracy: 0.7324517151916979
Epoch 25/50, Train Loss: 0.5597555725342858




Test Loss: 0.5528820510535618, Test Accuracy: 0.7325932751307499
Epoch 26/50, Train Loss: 0.5595726559675482




Test Loss: 0.5528167368321709, Test Accuracy: 0.7325932751307499
Epoch 27/50, Train Loss: 0.5595055115931596




Test Loss: 0.5529616576258661, Test Accuracy: 0.7326962278136968
Epoch 28/50, Train Loss: 0.559521394347351




Test Loss: 0.5526350156698207, Test Accuracy: 0.7328789688259276
Epoch 29/50, Train Loss: 0.559272213316594




Test Loss: 0.5527886876614969, Test Accuracy: 0.7329381666186221
Epoch 30/50, Train Loss: 0.5592880370040383




Test Loss: 0.5531709628055482, Test Accuracy: 0.7330668574723057
Epoch 31/50, Train Loss: 0.5592431642167106




Test Loss: 0.5526194224290644, Test Accuracy: 0.7331106123625581
Epoch 32/50, Train Loss: 0.5592394346132339




Test Loss: 0.5524753232334707, Test Accuracy: 0.7331209076308529
Epoch 33/50, Train Loss: 0.5590785770180625




Test Loss: 0.5528595395797669, Test Accuracy: 0.7331106123625581
Epoch 34/50, Train Loss: 0.5590265284997765




Test Loss: 0.5524744730443836, Test Accuracy: 0.733182679240621
Epoch 35/50, Train Loss: 0.5591408219738347




Test Loss: 0.552412824595784, Test Accuracy: 0.7331801054235474
Epoch 36/50, Train Loss: 0.5591138136251698




Test Loss: 0.5523948298693342, Test Accuracy: 0.7331517934357369
Epoch 37/50, Train Loss: 0.5590246870325927




Test Loss: 0.5526748017542534, Test Accuracy: 0.7332856319235679
Epoch 38/50, Train Loss: 0.5589386517810873




Test Loss: 0.552393742940603, Test Accuracy: 0.7332959271918627
Epoch 39/50, Train Loss: 0.559018915132192




Test Loss: 0.5523358532640444, Test Accuracy: 0.7333911584235885
Epoch 40/50, Train Loss: 0.5589705215915998




Test Loss: 0.5523577852205552, Test Accuracy: 0.7333628464357781
Epoch 41/50, Train Loss: 0.5589518952120938




Test Loss: 0.5525391213471604, Test Accuracy: 0.733450356216283
Epoch 42/50, Train Loss: 0.5589167899979549




Test Loss: 0.5524609224881617, Test Accuracy: 0.7335121278260511
Epoch 43/50, Train Loss: 0.5588229959885631




Test Loss: 0.5525541160623397, Test Accuracy: 0.7335198492772722
Epoch 44/50, Train Loss: 0.5587898035540225




Test Loss: 0.5523667381263685, Test Accuracy: 0.7334838158382407
Epoch 45/50, Train Loss: 0.5587628871858344




Test Loss: 0.5525770944175034, Test Accuracy: 0.7335301445455669
Epoch 46/50, Train Loss: 0.5587521786239125




Test Loss: 0.5522969494157375, Test Accuracy: 0.7335404398138615
Epoch 47/50, Train Loss: 0.5587277928976141




Test Loss: 0.5526786153518657, Test Accuracy: 0.73342719186262
Epoch 48/50, Train Loss: 0.558720946115988




Test Loss: 0.5526795536703094, Test Accuracy: 0.7335430136309352
Epoch 49/50, Train Loss: 0.5586703493411318




Test Loss: 0.5524839638377494, Test Accuracy: 0.7335404398138615
Epoch 50/50, Train Loss: 0.5587358072188133




Test Loss: 0.5525112900649314, Test Accuracy: 0.7335404398138615


100%|██████████| 50/50 [4:35:57<00:00, 331.14s/it, test_acc=0.734, test_loss=0.553, train_loss=0.559]

Model saved to data/saved_model.pth



