<a href="https://colab.research.google.com/github/matdjohnson-at-umass-dot-edu/cs646-final-project/blob/main/CS646_Final_Project_Classifier_Trainer_Instance_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
! pip install datasets
! pip install transformers



In [4]:
from datasets import concatenate_datasets, Dataset, disable_caching, disable_progress_bars, load_dataset
from tqdm import tqdm
from google.colab import drive
import os
import torch
import torch.nn.functional as torch_func
import gc
import time
from threading import Lock
from concurrent.futures import ThreadPoolExecutor
from transformers import AutoTokenizer, AutoModel
import logging
import psutil
import numpy as np
from collections import Counter
import random
import math

drive.mount('/content/drive')

Mounted at /content/drive


In [5]:

class DatasetHolder:
    def __init__(self, batch_size=100):
        self.batch_size = batch_size
        datasets_base_dir = "/content/drive/MyDrive/CS646-FinalProject/datasets"
        final_dataset_train_file_path_and_name = f"{datasets_base_dir}/ms_marco_final_dataset_avg/ms_marco_final_dataset_avg_train.parquet"
        final_dataset_train = Dataset.from_parquet(final_dataset_train_file_path_and_name).to_dict(batch_size=10000)
        self.query_ids = list()
        self.doc_ids = list()
        self.query_embs = list()
        self.doc_embs = list()
        self.labels = list()
        for i in range(0, len(final_dataset_train['query_id'])):
            self.query_ids.append(final_dataset_train['query_id'][i])
            self.query_ids.append(final_dataset_train['query_id'][i])
            self.doc_ids.append(final_dataset_train['pos_doc_id'][i])
            self.doc_ids.append(final_dataset_train['neg_doc_id'][i])
            self.query_embs.append(final_dataset_train['query_emb'][i])
            self.query_embs.append(final_dataset_train['query_emb'][i])
            self.doc_embs.append(final_dataset_train['pos_doc_emb'][i])
            self.doc_embs.append(final_dataset_train['neg_doc_emb'][i])
            self.labels.append(1)
            self.labels.append(0)
        del final_dataset_train
        gc.collect()
        assert len(self.query_ids) == len(self.doc_ids) == len(self.query_embs) == len(self.doc_embs) == len(self.labels)
        self.total_elements = len(self.query_ids)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def shuffle(self):
        assert len(self.query_ids) == len(self.doc_ids) == len(self.query_embs) == len(self.doc_embs) == len(self.labels)
        zipped_list = list(zip(self.query_ids, self.doc_ids, self.query_embs, self.doc_embs, self.labels))
        random.shuffle(zipped_list)
        self.query_ids, self.doc_ids, self.query_embs, self.doc_embs, self.labels = zip(*zipped_list)

    def get_batch_count(self):
        batches = (self.total_elements // self.batch_size)
        if self.total_elements % self.batch_size != 0:
            batches = batches + 1
        return batches

    def get_query_ids_for_batch_idx(self, batch_idx):
        return self.query_ids[self.batch_size*batch_idx:self.batch_size*(batch_idx+1)]

    def get_doc_ids_for_batch_idx(self, batch_idx):
        return self.doc_ids[self.batch_size*batch_idx:self.batch_size*(batch_idx+1)]

    def get_query_embs_for_batch_idx(self, batch_idx):
        return torch.tensor(self.query_embs[self.batch_size*batch_idx:self.batch_size*(batch_idx+1)], device=self.device)

    def get_doc_embs_for_batch_idx(self, batch_idx):
        return torch.tensor(self.doc_embs[self.batch_size*batch_idx:self.batch_size*(batch_idx+1)], device=self.device)

    def get_labels_for_batch_idx(self, batch_idx):
        return torch.tensor(self.labels[self.batch_size*batch_idx:self.batch_size*(batch_idx+1)], device=self.device)


In [6]:

class SimpleAttentionModel(torch.nn.Module):
    def __init__(self, inner_dim=1024, number_heads=1):
        super().__init__()
        self.layer_1 = torch.nn.MultiheadAttention(embed_dim=inner_dim, num_heads=number_heads, batch_first=True)
        self.layer_2 = torch.nn.ReLU()
        self.layer_3 = torch.nn.Linear(in_features=inner_dim, out_features=2)
        self.layer_4 = torch.nn.LogSoftmax(dim=1)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device=device)

    def forward(self, query_embs, doc_embs):
        assert len(query_embs.shape) == 2 and len(doc_embs.shape) == 2 and query_embs.shape[0] == doc_embs.shape[0] and query_embs.shape[1] == doc_embs.shape[1]
        return self.layer_4(self.layer_3(self.layer_2(self.layer_1(query_embs, doc_embs, doc_embs)[0])))


In [None]:
hyper_parameters_2024_12_05_01 = {
    "epochs": 100,
    "max_lr": 0.002,
    "warmup": 0.1,
    "warmup_gamma": 2.0,
    "cooldown_gamma": 0.5,
    "batch_size": 20000,
    "attention_inner_dim": 1024,
    "attention_heads": 1,
    "parameter_set_name": "hyper_parameters_2024_12_05_01"
}

hyper_parameters = hyper_parameters_2024_12_05_01

dataset_holder = DatasetHolder(batch_size=hyper_parameters['batch_size'])

model = SimpleAttentionModel(inner_dim=hyper_parameters["attention_inner_dim"], number_heads=hyper_parameters['attention_heads'])

milestone = math.floor(hyper_parameters["epochs"] * hyper_parameters["warmup"])

initial_lr = (hyper_parameters["max_lr"]) / (hyper_parameters["warmup_gamma"] ** (milestone))

optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
warmup_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=hyper_parameters["warmup_gamma"])
cooldown_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=hyper_parameters["cooldown_gamma"])
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, cooldown_scheduler], milestones=[milestone])
loss_fcn = torch.nn.NLLLoss()

for i in range(0, hyper_parameters["epochs"]):
    dataset_holder.shuffle()
    for j in range(0, dataset_holder.get_batch_count()):
        query_embs = dataset_holder.get_query_embs_for_batch_idx(j)
        doc_embs = dataset_holder.get_doc_embs_for_batch_idx(j)
        labels = dataset_holder.get_labels_for_batch_idx(j)
        model.zero_grad()
        logits = model.forward(query_embs, doc_embs)
        loss = loss_fcn(logits, labels)
        loss.backward()
        optimizer.step()
        print(f"epoch:{i}/{hyper_parameters['epochs']} batch:{j}/{dataset_holder.get_batch_count()} lr:{scheduler.get_last_lr()} loss:{loss}")
        del query_embs, doc_embs, labels, logits, loss
        gc.collect()
        torch.cuda.empty_cache()
        gc.collect()
    scheduler.step()

models_dir = "/content/drive/MyDrive/CS646-FinalProject/models"
model.zero_grad()
model.eval()
torch.save(model, f"{models_dir}/{hyper_parameters['parameter_set_name']}.pt")

Generating train split: 0 examples [00:00, ? examples/s]

epoch:0/100 batch:0/21 lr:[1.953125e-06] loss:0.6938859820365906
epoch:0/100 batch:1/21 lr:[1.953125e-06] loss:0.6938720941543579
epoch:0/100 batch:2/21 lr:[1.953125e-06] loss:0.6930404901504517
epoch:0/100 batch:3/21 lr:[1.953125e-06] loss:0.6933808922767639
epoch:0/100 batch:4/21 lr:[1.953125e-06] loss:0.6932299733161926
epoch:0/100 batch:5/21 lr:[1.953125e-06] loss:0.6931371688842773
epoch:0/100 batch:6/21 lr:[1.953125e-06] loss:0.693168044090271
epoch:0/100 batch:7/21 lr:[1.953125e-06] loss:0.6931893825531006
epoch:0/100 batch:8/21 lr:[1.953125e-06] loss:0.6931453347206116
epoch:0/100 batch:9/21 lr:[1.953125e-06] loss:0.6931268572807312
epoch:0/100 batch:10/21 lr:[1.953125e-06] loss:0.6932315230369568
epoch:0/100 batch:11/21 lr:[1.953125e-06] loss:0.6932040452957153
epoch:0/100 batch:12/21 lr:[1.953125e-06] loss:0.6933361887931824
epoch:0/100 batch:13/21 lr:[1.953125e-06] loss:0.6931359171867371
epoch:0/100 batch:14/21 lr:[1.953125e-06] loss:0.6930447220802307
epoch:0/100 batch:15/