<a href="https://colab.research.google.com/github/matdjohnson-at-umass-dot-edu/cs646-final-project/blob/main/CS646_Final_Project_Classifier_Trainer_Continued_Effort_Instance_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install datasets
! pip install transformers

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl 

In [None]:
from datasets import concatenate_datasets, Dataset, disable_caching, disable_progress_bars, load_dataset
from tqdm import tqdm
from google.colab import drive
import os
import torch
import torch.nn.functional as torch_func
import gc
import time
from threading import Lock
from concurrent.futures import ThreadPoolExecutor
from transformers import AutoTokenizer, AutoModel
import logging
import psutil
import numpy as np
from collections import Counter
import random
import math
import psutil

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:

class DatasetHolder:
    def __init__(self, batch_size=100):
        self.batch_size = batch_size
        datasets_base_dir = "/content/drive/MyDrive/CS646-FinalProject/datasets"
        final_dataset_train_file_path_and_name = f"{datasets_base_dir}/ms_marco_final_dataset_avg/ms_marco_final_dataset_avg_train.parquet"
        final_dataset_train = Dataset.from_parquet(final_dataset_train_file_path_and_name).to_dict(batch_size=10000)
        self.query_ids = list()
        self.doc_ids = list()
        self.query_embs = list()
        self.doc_embs = list()
        self.labels = list()
        for i in tqdm(range(0, len(final_dataset_train['query_id']))):
            self.query_ids.append(final_dataset_train['query_id'][i])
            self.query_ids.append(final_dataset_train['query_id'][i])
            self.doc_ids.append(final_dataset_train['pos_doc_id'][i])
            self.doc_ids.append(final_dataset_train['neg_doc_id'][i])
            self.query_embs.append(final_dataset_train['query_emb'][i])
            self.query_embs.append(final_dataset_train['query_emb'][i])
            self.doc_embs.append(final_dataset_train['pos_doc_emb'][i])
            self.doc_embs.append(final_dataset_train['neg_doc_emb'][i])
            # positive example to provide indicator on logit index 0
            self.labels.append(0)
            # negative example to provide indicator on logit index 1
            self.labels.append(1)
        del final_dataset_train
        gc.collect()
        assert len(self.query_ids) == len(self.doc_ids) == len(self.query_embs) == len(self.doc_embs) == len(self.labels)
        self.total_elements = len(self.query_ids)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def shuffle(self):
        assert len(self.query_ids) == len(self.doc_ids) == len(self.query_embs) == len(self.doc_embs) == len(self.labels)
        zipped_list = list(zip(self.query_ids, self.doc_ids, self.query_embs, self.doc_embs, self.labels))
        random.shuffle(zipped_list)
        self.query_ids, self.doc_ids, self.query_embs, self.doc_embs, self.labels = zip(*zipped_list)

    def get_batch_count(self):
        batches = (self.total_elements // self.batch_size)
        if self.total_elements % self.batch_size != 0:
            batches = batches + 1
        return batches

    def get_query_ids_for_batch_idx(self, batch_idx):
        return self.query_ids[self.batch_size*batch_idx:self.batch_size*(batch_idx+1)]

    def get_doc_ids_for_batch_idx(self, batch_idx):
        return self.doc_ids[self.batch_size*batch_idx:self.batch_size*(batch_idx+1)]

    def get_query_embs_for_batch_idx(self, batch_idx):
        return torch.tensor(self.query_embs[self.batch_size*batch_idx:self.batch_size*(batch_idx+1)], device=self.device)

    def get_doc_embs_for_batch_idx(self, batch_idx):
        return torch.tensor(self.doc_embs[self.batch_size*batch_idx:self.batch_size*(batch_idx+1)], device=self.device)

    def get_labels_for_batch_idx(self, batch_idx):
        return torch.tensor(self.labels[self.batch_size*batch_idx:self.batch_size*(batch_idx+1)], device=self.device)


In [None]:

class SimpleAttentionModel_1AttnModule(torch.nn.Module):
    def __init__(self, num_attn_heads=1, linear_layer_dim=1024):
        super().__init__()
        inner_dim=1024
        self.layer_1 = torch.nn.MultiheadAttention(embed_dim=inner_dim, num_heads=num_attn_heads, batch_first=True)
        self.layer_2 = torch.nn.ReLU()
        self.layer_3 = torch.nn.Linear(in_features=inner_dim, out_features=linear_layer_dim)
        self.layer_4 = torch.nn.Linear(in_features=linear_layer_dim, out_features=2)
        self.layer_5 = torch.nn.LogSoftmax(dim=1)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(device=device)

    def forward(self, query_embs, doc_embs):
        assert len(query_embs.shape) == 2 and len(doc_embs.shape) == 2 and query_embs.shape[0] == doc_embs.shape[0] and query_embs.shape[1] == doc_embs.shape[1]
        query_embs = query_embs.unsqueeze(-2)
        doc_embs = doc_embs.unsqueeze(-2)
        layer_1_output = self.layer_1(query_embs, doc_embs, doc_embs)
        layer_1_output_squeezed = layer_1_output[0].squeeze(-2)
        layer_2_output = self.layer_2(layer_1_output_squeezed)
        layer_3_output = self.layer_3(layer_2_output)
        layer_4_output = self.layer_4(layer_3_output)
        layer_5_output = self.layer_5(layer_4_output)
        return layer_5_output


In [None]:
hyper_parameters_2024_12_08_01 = {
    "epochs": 10,
    "max_lr": 0.002,
    "fraction_of_epochs_as_warmup": 0.1,
    "fraction_of_max_lr_at_init": 0.25,
    "fraction_of_max_lr_at_end": 0.0625,
    "batch_size": 2000,
    "num_attn_modules": 1,
    "num_attn_heads": 4,
    "linear_layer_dim": 4096,
    "parameter_set_name": "hyper_parameters_2024_12_08_01"
}

hyper_parameters = hyper_parameters_2024_12_08_01

dataset_holder = DatasetHolder(batch_size=hyper_parameters['batch_size'])


Generating train split: 0 examples [00:00, ? examples/s]

100%|██████████| 201396/201396 [00:00<00:00, 746893.87it/s]


In [None]:
model = SimpleAttentionModel_1AttnModule(
    num_attn_heads=hyper_parameters['num_attn_heads'],
    linear_layer_dim=hyper_parameters['linear_layer_dim']
)

iters_warmup = math.floor(hyper_parameters["epochs"] * hyper_parameters["fraction_of_epochs_as_warmup"])
iters_cooldown = hyper_parameters['epochs'] - iters_warmup

optimizer = torch.optim.Adam(model.parameters(), lr=hyper_parameters['max_lr'])
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=(hyper_parameters['fraction_of_max_lr_at_init']), end_factor=(1.0), total_iters=iters_warmup)
cooldown_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=(1.0), end_factor=(hyper_parameters['fraction_of_max_lr_at_end']), total_iters=iters_cooldown)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[warmup_scheduler, cooldown_scheduler], milestones=[iters_warmup])
loss_fcn = torch.nn.NLLLoss(reduction='mean')


In [None]:
for i in range(0, hyper_parameters["epochs"]):
    dataset_holder.shuffle()
    for j in range(0, dataset_holder.get_batch_count()):
        query_embs = dataset_holder.get_query_embs_for_batch_idx(j)
        doc_embs = dataset_holder.get_doc_embs_for_batch_idx(j)
        labels = dataset_holder.get_labels_for_batch_idx(j)
        model.zero_grad()
        logits = model.forward(query_embs, doc_embs)
        loss = loss_fcn(logits, labels)
        loss.backward()
        optimizer.step()
        timestamp = time.strftime("%Y-%m-%dT%H:%M:%S", time.localtime())
        print(f"{timestamp} epoch:{i}/{hyper_parameters['epochs']} batch:{j}/{dataset_holder.get_batch_count()} lr:{scheduler.get_last_lr()} loss:{loss} memory:{torch.cuda.mem_get_info()}")
        if j == dataset_holder.get_batch_count() - 1:
            print(f"{timestamp} logits_and_labels:{torch.concatenate([logits, labels.unsqueeze(-1)], dim=1)}")
        del query_embs, doc_embs, labels, logits, loss
        gc.collect()
        torch.cuda.empty_cache()
        gc.collect()
    scheduler.step()


2024-12-08T10:00:16 epoch:0/10 batch:0/202 lr:[0.0005] loss:0.6935302019119263 memory:(41669165056, 42481811456)
2024-12-08T10:00:29 epoch:0/10 batch:1/202 lr:[0.0005] loss:1.6596875190734863 memory:(41581084672, 42481811456)
2024-12-08T10:00:42 epoch:0/10 batch:2/202 lr:[0.0005] loss:0.7593177556991577 memory:(41589473280, 42481811456)
2024-12-08T10:00:54 epoch:0/10 batch:3/202 lr:[0.0005] loss:0.9450475573539734 memory:(41589473280, 42481811456)
2024-12-08T10:01:07 epoch:0/10 batch:4/202 lr:[0.0005] loss:0.7361346483230591 memory:(41589473280, 42481811456)
2024-12-08T10:01:20 epoch:0/10 batch:5/202 lr:[0.0005] loss:0.6815588474273682 memory:(41589473280, 42481811456)
2024-12-08T10:01:33 epoch:0/10 batch:6/202 lr:[0.0005] loss:0.7125224471092224 memory:(41589473280, 42481811456)
2024-12-08T10:01:46 epoch:0/10 batch:7/202 lr:[0.0005] loss:0.7115239500999451 memory:(41589473280, 42481811456)
2024-12-08T10:01:59 epoch:0/10 batch:8/202 lr:[0.0005] loss:0.6958856582641602 memory:(415894732

In [None]:
models_dir = "/content/drive/MyDrive/CS646-FinalProject/models"
model.zero_grad()
model.eval()
torch.save(model, f"{models_dir}/{hyper_parameters['parameter_set_name']}.pt")
