# Deep Learning Project
How to run this Notebook:
- Run every cell up to the Training Chapter

For training:
- Simply decide to run a Generative training approach or a Discriminative one, running only the cell.

For testing:
- Run the "DSI Model with Foundation Model" chapter, then run the "Restore a Checkpoint and run a Test" one.

## Install required libraries

In [None]:
try:
    from google.colab import drive
    usingColab = True
    print("Using Colab. Downloading necessary libraries...")
    # -- WINDOWS --
    # !pip uninstall bitsandbytes bitsandbytes-windows
    # !pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.2.post2-py3-none-win_amd64.whl # 0.41.2
    # !pip install https://github.com/habibzadeh/bitsandbytes/releases/download/0.42.0_win_cuda_12.1/bitsandbytes-0.42.0-py39-none-any.whl # 0.42.0

    # -- LINUX --
    !pip install bitsandbytes -q

    # Other simple libraries
    !pip install accelerate sentencepiece peft pytorch_lightning scikit-learn wandb lightning-bolts langchain -q

    # Update transformers to latest version
    !pip install -U gdown transformers -q
except:
    usingColab = False
    print("Not using Colab. Assuming all libraries have been already downloaded...")
    pass

Using Colab. Downloading necessary libraries...
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m280.0/280.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.4/183.4 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m800.9/800.9 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m300.8/300.8 kB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m815.9/815.9 kB[0m [31m43.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m45.1 MB/s

## Imports

In [None]:
import gzip
import json
import random
import shutil
import tarfile
import warnings
from os.path import join as pathjoin

import bitsandbytes as bnb
import gdown
import numpy as np
import pandas as pd
import pl_bolts
import pytorch_lightning as pl
import torch
import torch.nn as nn
from peft import (LoraConfig, TaskType, get_peft_model,
                  prepare_model_for_kbit_training)
from sklearn.cluster import KMeans
from torch import nn
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm
from transformers import (AutoModelForSeq2SeqLM, AutoTokenizer,
                          BitsAndBytesConfig, EncoderDecoderModel,
                          SwitchTransformersForConditionalGeneration)
from transformers.trainer_pt_utils import get_parameter_names

import wandb

warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')

# Few fixes for Linux
if(usingColab):
    import locale
    import os

    def getpreferredencoding(do_setlocale = True):
        return "UTF-8"
    locale.getpreferredencoding = getpreferredencoding

    # Java and CUDA
    os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"

  if not hasattr(numpy, tp_name):
  if not hasattr(numpy, tp_name):
  "lr_options": generate_power_seq(LEARNING_RATE_CIFAR, 11),
  contrastive_task: Union[FeatureMapContrastiveTask] = FeatureMapContrastiveTask("01, 02, 11"),
  self.nce_loss = AmdimNCELoss(tclip)


## Check CUDA support

In [None]:
if(torch.cuda.is_available()):
    device = torch.device("cuda")
    print('Cuda available: {}'.format(torch.cuda.is_available()))
    gpu_name = torch.cuda.get_device_name(torch.cuda.current_device())
    print("GPU: " + gpu_name)
    if('RTX 3060' in gpu_name):
        print("\t- Using RTX 3060. Setting MatMul precision to High.")
        # Faster, but less precise
        torch.set_float32_matmul_precision("high")
    print("Total memory: {:.1f} GB".format((float(torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)))))
    print("===================================================")
else:
    device = torch.device("cpu")
    print('Cuda not available, so using CPU. Please consider switching to a GPU runtime before running the notebook!')

Cuda available: True
GPU: Tesla T4
Total memory: 14.7 GB


  and should_run_async(code)


## Dataset

First of all configure the next cell (select the right student): if you are not listed in the list of the keys, add your name and congifure the path

In [None]:
# Student
student = "Professor" # For Professor, please insert "Professor"

# Special case for professor's notebook setup
if(student == 'Professor'):
  print("Using Professor's Notebook Setup.")
  msmarco_v2_doc_dir = "/content"
  msmarco_v2_passage_dir = "/content"
else:
  # Project directory
  proj_dict = {"Emanuele": ["ColabNotebooks/DeepLearning/Progetto", ''],
              "NN": ["DeepLearningProject/Progetto", ''],
              "Gianmarco_Uni": ["Università/2. Magistrale/II ANNO/I° Semestre/Deep Learning/Progetto", 'a61cd0e8be88d9f5059ac69c00e7fbb3c5cb33b5'], # Shortcut to Emanuele Project directory + WandB key
              "Gianmarco_Personal": ["DeepLearningProj/Progetto", 'a61cd0e8be88d9f5059ac69c00e7fbb3c5cb33b5'],
              }

  msmarco_v2_doc_dir = f"/content/drive/MyDrive/{proj_dict[student][0]}/msmarco_v2_doc"
  msmarco_v2_passage_dir = f"/content/drive/MyDrive/{proj_dict[student][0]}/msmarco_v2_passage"

  # LOCAL VARIABLES!! Use only for local execution
  if (not usingColab):
    msmarco_v2_doc_dir = "msmarco_v2_doc"
    msmarco_v2_passage_dir = "msmarco_v2_passage"

Using Professor's Notebook Setup.


### Drive connection
We put the Drive connection down here, since we have a special case for the Professor, where we don't want to use Drive

In [None]:
# Import Drive if using Colab
if(usingColab) and (student != 'Professor'):
  drive.mount("/content/drive/", force_remount=True)

### Download
Official download of MSMarco Dataset from https://microsoft.github.io/msmarco/TREC-Deep-Learning.html
- msmarco_v2_doc.tar = 32.27GB (Unpacked: 112GB)
- msmarco_v2_passage.tar = 20.27GB

In [None]:
# Special case for the Professor. Only download the Chunk 00 and the Train bits.
#
# We download only the Chunk 00, since MSMarco is composed by 59 Chunks, and the total size is 112 GB.
# Other than that, we will later use a portion of Chunk 00 to train the model, due to the limited computational resources.
if(student == 'Professor'):
    Chunk_URL = 'https://drive.google.com/uc?id=' + '1B8t6dxkDZ2DLomMN5CikmgTeerLiJgGU' + '&export=download&confirm=t'
    Doc_Train_Queries_URL = 'https://drive.google.com/uc?id=' + '1Pn3oOQG1-6Y1zD01xgiPi0wgPKWT4WiT' + '&export=download&confirm=t'
    Doc_Train_Qrels_URL = 'https://drive.google.com/uc?id=' + '1PvLVMMkSmB91BO29GOS1Xa0mhN0RxzVV' + '&export=download&confirm=t'
    Doc_Train_Top100_URL = 'https://drive.google.com/uc?id=' + '1VgwtS-6s5lT4eKHE_uKEie_8zqmQJkca' + '&export=download&confirm=t'

    gdown.download(Chunk_URL, quiet=False)
    gdown.download(Doc_Train_Queries_URL, quiet=False)
    gdown.download(Doc_Train_Qrels_URL, quiet=False)
    gdown.download(Doc_Train_Top100_URL, quiet=False)

    print(f"Files downloaded to /content/")

In [None]:
# download_dataset = True
#   - Starts the whole download process and store it in Google Drive. This takes approximately 2 hours between download/extracting.
# download_dataset = False
#   - It assumes that the Dataset has already been downloaded and extracted to Drive, so it simply loads from it
download_dataset = False

if(download_dataset and student != 'Professor'):
      !wget --header "X-Ms-Version: 2019-12-12" https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco_v2_doc.tar
      !wget --header "X-Ms-Version: 2019-12-12" https://msmarco.z22.web.core.windows.net/msmarcoranking/msmarco_v2_passage.tar
      !wget --header "X-Ms-Version: 2019-12-12" https://msmarco.z22.web.core.windows.net/msmarcoranking/docv2_train_queries.tsv
      !wget --header "X-Ms-Version: 2019-12-12" https://msmarco.z22.web.core.windows.net/msmarcoranking/docv2_train_qrels.tsv

      # Copy the Dataset to Drive for easier access
      shutil.copy("/content/msmarco_v2_doc.tar",F"/content/drive/MyDrive/{proj_dict[student][0]}/msmarco_v2_doc.tar")
      shutil.copy("/content/msmarco_v2_passage.tar",F"/content/drive/MyDrive/{proj_dict[student][0]}/msmarco_v2_passage.tar")
      shutil.copy("/content/docv2_train_queries.tsv",F"/content/drive/MyDrive/{proj_dict[student][0]}/docv2_train_queries.tsv")
      shutil.copy("/content/docv2_train_qrels.tsv",F"/content/drive/MyDrive/{proj_dict[student][0]}/docv2_train_qrels.tsv")

      # Extract both Documents and Passages into .gz files
      tar = tarfile.open(F"/content/drive/MyDrive/{proj_dict[student][0]}/msmarco_v2_doc.tar")
      tar.extractall()
      tar.close()
      shutil.copytree("/content/msmarco_v2_doc",F"/content/drive/MyDrive/{proj_dict[student][0]}/msmarco_v2_doc")
      !rm -rf "/content/msmarco_v2_doc" # Remove it from Colab otherwise disk will get full. Further loading will be done by Drive
      # ------------------------- #
      tar = tarfile.open(F"/content/drive/MyDrive/{proj_dict[student][0]}/msmarco_v2_passage.tar")
      tar.extractall()
      tar.close()
      shutil.copytree("/content/msmarco_v2_passage",F"/content/drive/MyDrive/{proj_dict[student][0]}/msmarco_v2_passage")
      !rm -rf "/content/msmarco_v2_passage" # Remove it from Colab otherwise disk will get full. Further loading will be done by Drive

      # Extract Documents starting from .gz files
      for doc in tqdm(os.listdir(msmarco_v2_doc_dir), desc="Extracting docs .gz files"):
        with gzip.open(pathjoin(msmarco_v2_doc_dir, doc), 'rb') as f_in:
          with open(pathjoin(msmarco_v2_doc_dir, doc.split(".")[0]), 'wb') as f_out:
              shutil.copyfileobj(f_in, f_out)
              os.remove(pathjoin(msmarco_v2_doc_dir, doc))

      # Extract Passages starting from .gz files
      for doc in tqdm(os.listdir(msmarco_v2_passage_dir), desc="Extracting passages .gz files"):
        with gzip.open(pathjoin(msmarco_v2_passage_dir, doc), 'rb') as f_in:
          with open(pathjoin(msmarco_v2_passage_dir, doc.split(".")[0]), 'wb') as f_out:
              shutil.copyfileobj(f_in, f_out)
              os.remove(pathjoin(msmarco_v2_passage_dir, doc))
else:
  if(usingColab) and student != 'Professor':
    print(F"Not downloading Dataset.\nAssuming it is already present in: /content/drive/MyDrive/{proj_dict[student][0]}/")
  else:
    print(F"Not downloading Dataset.\nAssuming it is already present in: {msmarco_v2_doc_dir}/")

### Link queries with documents

In [None]:
# Obtain queries
if(student == 'Professor'):
  docv2_train_queries = "/content/docv2_train_queries.tsv"
  docv2_train_qrels = "/content/docv2_train_qrels.tsv"
else:
  docv2_train_queries = pathjoin(F"/content/drive/MyDrive/{proj_dict[student][0]}", "docv2_train_queries.tsv") # query id - query
  docv2_train_qrels = pathjoin(F"/content/drive/MyDrive/{proj_dict[student][0]}", "docv2_train_qrels.tsv") # query id - ? - doc id - ?

# LOCAL VARIABLES!! Use only for local execution
if (not usingColab):
  docv2_train_queries = "docv2_train_queries.tsv"
  docv2_train_qrels = "docv2_train_qrels.tsv"

In [None]:
df_docv2_train_queries = pd.read_csv(docv2_train_queries, sep='\t', header=None, names=["Query_ID", "Query"])
df_docv2_train_qrels = pd.read_csv(docv2_train_qrels, sep='\t', header=None, names=["Query_ID", "Iteration", "Doc_ID", "Relevance"])

In [None]:
# Merge the two dataframes on 'Query_ID'
query_doc_df = pd.merge(df_docv2_train_queries, df_docv2_train_qrels, on='Query_ID')

### Show a simple example of document retrieval

In [None]:
df_docv2_train_queries.head(2)

In [None]:
df_docv2_train_qrels.head(2)

In [None]:
query_doc_df.head(2)

In [None]:
def get_document(document_id):
    (string1, string2, bundlenum, position) = document_id.split('_')
    assert string1 == 'msmarco' and string2 == 'doc'

    with open(f'{msmarco_v2_doc_dir}/msmarco_doc_{bundlenum}', 'rt', encoding='utf8') as in_fh:
        in_fh.seek(int(position))
        json_string = in_fh.readline()
        document = json.loads(json_string)
        assert document['docid'] == document_id
        return document

document = get_document('msmarco_doc_00_2823093')
print(document.keys())
print("-------------------------")
print(F"Document ID: {document['docid']}")
print("-------------------------")
print(F"URL: {document['url']}")
print("-------------------------")
print(F"Title: {document['title']}")
print("-------------------------")
print(F"Headings: {document['headings']}")
print("-------------------------")
print(F"Body: {document['body']}")
print("## ---------------------------------- ##")

In [None]:
# Specify the Query_ID
target_query_id = 100015

# Access rows where 'Query_ID' is equal to the target value
matching_query = query_doc_df[query_doc_df['Query_ID'] == target_query_id]

# Retrieve the Doc_ID from the matching rows
matched_doc_ids = matching_query['Doc_ID'].tolist()

# Print the results
print(F"Query_ID: {target_query_id}")
print(F"Matched Doc_IDs: {matched_doc_ids}")

# Assuming it has found multiple documents related to this query
print("\n## ---------------------------------- ##")
for doc in matched_doc_ids:
  document = get_document(F'{doc}')
  print(F"Document ID: {document['docid']}")
  print(F"URL: {document['url']}")
  print("-------------------------")
  print(F"Title: {document['title']}")
  print("-------------------------")
  print(F"Headings: {document['headings']}")
  print("-------------------------")
  print(F"Body: {document['body']}")
  print("## ---------------------------------- ##")

### Documents/Doc_ID Representation Engines
- For Documents, the only representation needed is on the Document itself. For this reason, we work on the "Body" of the Document.
- For Doc_IDs, the only representation needed is the DocID, hence solutions are developed only for it.

In [None]:
class DocRepresentationEngine():
    def __init__(self, strategy):
        self.strategy = strategy
        self.columns = ['Doc_ID', self.strategy]
        self.doc_count = 0
        assert strategy in ['direct_indexing', 'set_indexing', 'summarization', 'default'], 'Strategy not recognized! Select the correct one (direct_indexing, set_indexing, summarization, default).'

        # If the strategy is set_indexing, download the stopwords from NLTK
        if(strategy == 'set_indexing'):
            import nltk
            nltk.download('stopwords')
            from nltk.corpus import stopwords
            self.stop_words = set(stopwords.words('english'))

    # This function will be called by the DataLoader, depending on the strategy chosen
    def forward(self,chunk):
      if(self.strategy == 'direct_indexing'):
          return self.direct_indexing(chunk)
      elif(self.strategy == 'set_indexing'):
          return self.set_indexing(chunk)
      elif(self.strategy == 'summarization'):
          return self.default(chunk) # The way this is implemented is tricky. More on DatasetDSI_DocQuery
        #   return self.summarization(chunk)
      else:
          return self.default(chunk)

    # Default strategy.
    # It will take the whole body and store them in the dataframe
    def default(self,chunk):
        full_dict = {}

        for index, row in chunk.iterrows():
            full_dict[index] = {}
            doc_id = row['Doc_ID']
            full_dict[index] = {'Doc_ID': doc_id, self.strategy: row['Body']}

        df = pd.DataFrame.from_dict(full_dict, orient='index', columns=self.columns)
        self.doc_count = len(df)
        return df

    # Direct Indexing strategy
    # It will take the first L tokens of the body and store them in the dataframe
    def direct_indexing(self,chunk):
        L = 32
        full_dict = {}

        for index, row in chunk.iterrows():
            full_dict[index] = {}
            body_tokens = row['Body'].split(" ")
            body = body_tokens[:L]
            body = " ".join([word for word in body])
            doc_id = row['Doc_ID']
            full_dict[index] = {'Doc_ID': doc_id, self.strategy: body}

        df = pd.DataFrame.from_dict(full_dict, orient='index', columns=self.columns)
        self.doc_count = len(df)
        return df

    # Summarization strategy
    # It will summarize the body in maximum 'max_length' tokens and store them in the dataframe
    def summarization(self, chunk):
        full_dict = {}

        # Falcon with 4-Bit for faster inference
        from transformers import pipeline, T5ForConditionalGeneration
        model_id = "Falconsai/text_summarization"

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )

        model_4bit = T5ForConditionalGeneration.from_pretrained(
            model_id,
            device_map="auto",
            quantization_config=bnb_config,
        )

        model_4bit.eval() # Set the model to inference mode

        tokenizer = AutoTokenizer.from_pretrained(model_id)

        max_length = 32
        limit_length = 20000

        hugging_pipe = pipeline(
            "summarization",
            model=model_4bit,
            tokenizer=tokenizer,
            use_cache=True,
            device_map="auto",
            max_length=max_length,
            min_length=16,
            do_sample=True,
            top_k=10,
            num_return_sequences=1,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )

        for index, row in tqdm(chunk.iterrows(), total=len(chunk), desc="Summarizing documents"):
            full_dict[index] = {}
            doc_id = row['Doc_ID']

            # Why summarize if the document has less tokens than the summary length?
            if len(row['Body'].split()) < max_length:
                summary = row['Body']
            else:
                if len(row['Body']) > limit_length:
                    row['Body'] = row['Body'][:limit_length]  # Limit the length to 20k characters
                with torch.no_grad():
                    summary = hugging_pipe(row['Body'])[0]['summary_text'] # Crashes because doc with index 60 has length = 745k. All others are < 20k

            full_dict[index] = {'Doc_ID': doc_id, self.strategy: summary, chunk.columns[4]: row[chunk.columns[4]]}

        df = pd.DataFrame.from_dict(full_dict, orient='index', columns=['Doc_ID', self.strategy, chunk.columns[4]])
        self.doc_count = len(df)
        del bnb_config, model_4bit, tokenizer, hugging_pipe
        return df

    # Set Indexing strategy
    # It will remove stopwords and duplicates from the body and store them in the dataframe
    def set_indexing(self,chunk):
        full_dict = {}

        for index, row in chunk.iterrows():
            full_dict[index] = {}
            body = set(row['Body'].split()) - self.stop_words
            doc_id = row['Doc_ID']
            full_dict[index] = {'Doc_ID': doc_id, self.strategy: body}

        df = pd.DataFrame.from_dict(full_dict, orient='index', columns=self.columns)
        self.doc_count = len(df)
        return df

In [None]:
class DocIDRepresentationEngine():
    def __init__(self, strategy, doc_count):
        self.strategy = strategy
        self.used_doc_ids = set()
        self.columns = ['Doc_ID',self.strategy] # Do not confuse! Doc_ID is the MsMarco Original identifier, while self.strategy returns the doc_id generated by the algorithm
        self.doc_count = doc_count # This got passed from the DocRepresentationEngine
        assert strategy in ['unstructured_atomic', 'naively_structured', 'semantically_structured'], 'Strategy not recognized! Select the correct one (unstructured_atomic, naively_structured, semantically_structured).'

    # This function will be called by the DataLoader, depending on the strategy chosen
    def forward(self, chunk):
      if(self.strategy == 'unstructured_atomic'):
          return self.unstructred_atomic(chunk)
      elif(self.strategy == 'naively_structured'):
          return self.naively_structured(chunk)
      else:
          raise Exception("To implement")
          return self.semantically_structured(chunk)

    # Needed for the Unstructured/Naively Structured Identifiers
    def generate_unique_doc_id(self):
        doc_id = random.randint(0, self.doc_count) # From 0 to the max number of possibile document

        # If the doc_id is already used, generate a new one
        while doc_id in self.used_doc_ids:
            doc_id = random.randint(0, self.doc_count)

        self.used_doc_ids.add(doc_id)
        return doc_id

    # The most naive way to represent documents is assign each an
    # arbitrary (and possibly random) unique integer identifier.
    def unstructred_atomic(self,chunk):
        full_dict = {}

        for index, row in chunk.iterrows():
            full_dict[index] = {}

            doc_id = str(index).zfill(len(str(self.doc_count)))  # Random unique indexing (as expressed in the paper) | Example: 0001728

            full_dict[index] = {'Doc_ID': row['Doc_ID'], self.strategy: doc_id}

        df = pd.DataFrame.from_dict(full_dict, orient='index', columns=self.columns)
        return df

    # Naively Structured Identifiers
    # It will take the original Doc_ID and split it into tokens
    def naively_structured(self,chunk):
        full_dict = {}

        for index, row in chunk.iterrows():
            doc_id = self.generate_unique_doc_id() #                            | Example: 1728
            doc_id_str = str(doc_id)  # Convert it into a string                | Example: '1728'
            tokens = list(doc_id_str) # Get a tokenizable version of the Doc_ID | Example: ['1', '7', '2', '8']
            full_dict[index] = {'Doc_ID': row['Doc_ID'], self.strategy: tokens}

        df = pd.DataFrame.from_dict(full_dict, orient='index')
        return df

    # Needed for Semantically Structured Identifiers
    def cluster_documents(self, document_embeddings, number_of_clusters=10):
        kmeans = KMeans(n_clusters=number_of_clusters, random_state=2047315) # For reproducibility
        clusters = kmeans.fit_predict(document_embeddings)
        clustered_documents = [document_embeddings[clusters == i] for i in range(number_of_clusters)]
        return clustered_documents

    def generate_semantic_ids(self, cluster, max_docs_per_cluster):
        # This function should be implemented in the future
        raise Exception("To implement")

    # Needs to be properly tested!!
    def semantically_structured(self):
      number_of_clusters = 10
      max_docs_per_cluster = 100

      # Il problema qua è che bisogna avere, per prima cosa, gli embeddings di tutti i documenti.
      # Significa che bisogna far passare un "small 8-layer BERT model" per tutti i body nel DataFrame
      document_embeddings = None

      clusters = self.cluster_documents(document_embeddings, number_of_clusters=number_of_clusters) # C_1:10 ← Cluster(X_1:N , k = 10)
      doc_ids = []                                                                                  # J ← empty list

      # As per pseudo-code
      for i in range(10):                                                                           # for i = 0 to 9 do
          current_cluster = [str(i)] * len(clusters[i])                                             # J_current ← [i] ∗ |C_i+1|

          if len(clusters[i]) > max_docs_per_cluster:                                               # if |C_i+1| > c then
              rest_cluster = self.generate_semantic_ids(clusters[i], max_docs_per_cluster)              # J_rest ← GENERATESEMANTICIDS(C_i+1)
          else:                                                                                     # else
              rest_cluster = [str(j) for j in range(len(clusters[i]))]                                  # J_rest ← [0, . . . , |C_i+1| − 1]

          cluster_ids = [current + rest for current, rest in zip(current_cluster, rest_cluster)]    # J_cluster ←elementwiseStrConcat(J_current, J_rest)
          doc_ids.extend(cluster_ids)                                                               # J ← J.appendElements(Jcluster)

          # Manca il riordinamento: J ← reorderToOriginal(J, X1:N , C1:10)
          # Significa che molto probabilmente bisogna riordinare i doc_ids basandosi
          # su come sono arrivati gli embeddings iniziali

      return doc_ids

### Datasets

This section is responsable to create the DataLoader in a simple manner, ready to be used for training.<br>
The `DatasetDSI` Class has to cover all the combination from Document Representation and DocID Representation.<br>

In general, we identify a single sample as:<br>
`[ms_marco_original_doc_id, document_representation, doc_id_representation]`<br>

More on `DatasetDSI_DocQuery` in the `DataLoaders` Chapter.

In [None]:
class DatasetDSI(Dataset):
  '''
    train_dataset_chunk: the single MSMarco file we want to process
    doc_full_iterator: the iterator that returns a full rapresentation of a document
    doc_rep_strategy: use to bootstrap the DocRepresentationEngine
    doc_ids_rep_strategy: use to bootstrap the DocIDRepresentationEngine
  '''
  def __init__(self, train_dataset_chunk='00', doc_rep_strategy='direct_indexing', doc_ids_rep_strategy="unstructured_atomic", MAX_DOCS = 1000):
    self.train_dataset_chunk = train_dataset_chunk
    self.MAX_DOCS = MAX_DOCS      # Max: 200000 | Default: 1000
    self.df = self.generate_df()  # Get MAX_DOCS documents from the chunk

    # Run the DocRepresentationEngine on the Chunk
    self.doc_rep_engine = DocRepresentationEngine(doc_rep_strategy)
    self.docs_rep = self.doc_rep_engine.forward(self.df)

    # Run the DocIDRepresentationEngine on the Chunk
    self.docID_rep_engine = DocIDRepresentationEngine(doc_ids_rep_strategy, self.doc_rep_engine.doc_count)
    self.docs_id_rep = self.docID_rep_engine.forward(self.df)
    print(f"-- Considering as dataset only the first {self.MAX_DOCS} rows")

  # Generate the DataFrame from the MSMarco chunk chosen.
  # Optionally we can limit the number of documents to be processed for faster testing
  def generate_df(self):
    dict_current_idx = {}
    columns = ["Doc_ID", "Title", "Body"]
    with open(f'{msmarco_v2_doc_dir}/msmarco_doc_{self.train_dataset_chunk}', 'rt', encoding='utf8') as in_fh:
      i = 0
      for line in tqdm(in_fh, desc=f"Generating chunks for file msmarco_doc_{self.train_dataset_chunk}"):
        # TODO // Remove this if if we want to load the whole file
        if i < self.MAX_DOCS:
          line_dict = json.loads(line)

          # Update the dictionary
          dict_current_idx[str(i)] = {
              "Doc_ID":  line_dict['docid'],
              "Title": line_dict['title'],
              "Body":  line_dict['body'],
          }
        else:
          break

        i += 1

    df = pd.DataFrame.from_dict(
        dict_current_idx, orient='index', columns=columns)
    print("Chunk correctly generated. Processing DocIDs and Documents now...")
    return df

  def __len__(self):
      return len(self.df)

  # This function will be called by the DataLoader when we call a batch
  def __getitem__(self, index):
    sample = self.df.iloc[index]        # Get the row from the DataFrame (Mainly for retrieving original MSMarco Doc_ID)
    doc_rep = self.docs_rep.iloc[index] # Get the Document representation based on the index
    doc_id_rep = self.docs_id_rep.iloc[index] # Get the Document ID representation based on the index

    # Return the sample as: Doc_ID, Doc_Representation, DocID_Representation
    return {
        "Doc_ID":  sample['Doc_ID'],
        # "Title": sample['Title'],
        # "Body":	 sample['Body'],
        self.doc_rep_engine.strategy: doc_rep[self.doc_rep_engine.strategy],
        self.docID_rep_engine.strategy: doc_id_rep[self.docID_rep_engine.strategy]
    }

  # Collate function for the DataLoader
  # This gets called after the __getitem__ and is used to prepare the batch
  # by applying some transformations to the doc_rep
  def collate_fn(self, batch):
    # body = [item["Body"] for item in batch]
    # title = [item["Title"] for item in batch]
    doc_id = [item["Doc_ID"] for item in batch]

    doc_rep = [item[self.doc_rep_engine.strategy] for item in batch]
    doc_id_rep = [item[self.docID_rep_engine.strategy] for item in batch]

    return {
        "Doc_ID":  doc_id,
        # "Title": title,
        # "Body":	 body,
        self.doc_rep_engine.strategy: doc_rep,
        self.docID_rep_engine.strategy: doc_id_rep
    }

In [None]:
class DatasetDSI_DocQuery(Dataset):
  '''This is the DSI Class that allows the creation of a dataframe that contains the queries (for retrieval task)'''
  def __init__(self, train_dataset: DatasetDSI, query_doc_df: pd.DataFrame, strategy: str):
    self.train_dataset_docs_rep = train_dataset.docs_rep      # columns ['Doc_ID', 'Doc Repres. Strategy']
    self.train_dataset_docs_id = train_dataset.docs_id_rep    # columns ['Doc_ID', 'Doc ID Repres. Strategy']
    self.query_doc_df = query_doc_df # Dataframe containing ALL the queries and ALL the documents (330k) # columns ['Query_ID', 'Query', 'Iteration', 'Doc_ID', 'Relevance']

    # Merge dataframes
    # We first merge the train_dataset_docs_rep with the train_dataset_docs_id on 'Doc_ID', hence having a Dataframe with columns ['Doc_ID', 'Doc Repres. Strategy', 'Doc ID Repres. Strategy']
    self.filtered_docs_with_queries = pd.merge(self.train_dataset_docs_rep, self.train_dataset_docs_id, on='Doc_ID')

    # Now we merge the filtered_docs_with_queries with the train_dataset.df on 'Doc_ID', hence having a Dataframe with columns ['Doc_ID', 'Title', 'Body', 'Doc Repres. Strategy', 'Doc ID Repres. Strategy']
    self.filtered_docs_with_queries = pd.merge(train_dataset.df, self.filtered_docs_with_queries, on='Doc_ID')

    # Finally, we want to filter the documents that have at least one query. We can do this by taking only the documents that are in the query_doc_df
    self.filtered_docs_with_queries = self.filtered_docs_with_queries[self.filtered_docs_with_queries['Doc_ID'].isin(self.query_doc_df['Doc_ID'])] # Take only the docs with at least one query

    # We use the summarization strategy in this brutal way due to the fact that if we did summarization
    # on the whole Dataset, we had to summarize up to MAX_DOCS documents, which is not what we want.
    #
    # We want to only summarize the documents that have queries, which can be found in self.filtered_docs_with_queries.
    if(strategy == 'summarization'):
      summarizations = 'https://drive.google.com/uc?id=' + '1S_M69b2SAN2WyvudEKgeHBkMcU2vlerv' + '&export=download&confirm=t'

      gdown.download(summarizations, "summarization.csv", quiet=False)
      # self.filtered_docs_with_queries = train_dataset.doc_rep_engine.summarization(self.filtered_docs_with_queries)
      # Read the summarization.csv file and specify the index column as the first one, then specify the type of the columns (all str)
      # Read the CSV file using the custom function

      self.filtered_docs_with_queries = pd.read_csv("summarization.csv", sep=",", dtype={
            'Doc_ID': 'string',
            'summarization': 'string',
            'unstructured_atomic': 'string',
      })
      self.filtered_docs_with_queries = self.filtered_docs_with_queries.apply(self.fix_row, column_name='idx', axis=1)
      self.filtered_docs_with_queries.set_index('idx', inplace=True)
    else:
      # Drop Title and Body Column
      self.filtered_docs_with_queries = self.filtered_docs_with_queries.drop(columns=['Title', 'Body'])

    self.columns = self.filtered_docs_with_queries.columns.values.tolist()

  # Define a function to read the CSV file
  def fix_row(self, row: pd.Series, column_name: str) -> pd.Series:
      value = row[column_name]
      formated_value = str(value).split(',')
      if len(formated_value) > 1:
          i = 0
          j = 0
          # We now need to check if in the list there is a string that needs to be concatenated. We can check this if it starts with a " or ends with a ".
          # If it starts with a " and ends with a ", then it means that the string has been split and we need to concatenate it.
          # We can do this by checking if the first element of the list starts with a " and the last element ends with a ".
          for i in range(len(formated_value)):
            if(formated_value[i].startswith('"')): # Starts with a comma. Let's find the end in positions i+1, i+2, ...
              for j in range(i+1, len(formated_value)):
                if(formated_value[j].endswith('"')): # Ends with a comma. We found the end of the string
                  formated_value[i] = ','.join(formated_value[i:j+1]) # Join the string
                  # We need to remove the elements from i+1 to j
                  formated_value = formated_value[:i+1] + formated_value[j+1:]
                  break
              break

          return pd.Series(dict(zip(row.keys(), formated_value)))
      return row

  def __len__(self):
      return len(self.filtered_docs_with_queries)

  def __getitem__(self, index):
    sample = self.filtered_docs_with_queries.iloc[index]

    return {
        self.columns[0]: sample['Doc_ID'],
        self.columns[1]: sample[self.columns[1]],
        self.columns[2]: sample[self.columns[2]],
    }

  def collate_fn(self, batch):
    '''
      Here we perform some preprocessing
    '''
    doc_id = [item["Doc_ID"] for item in batch]
    doc_rep = [item[self.columns[1]] for item in batch]
    doc_id_rep = [item[self.columns[2]] for item in batch]

    return {
        "Doc_ID":  doc_id,
        self.columns[1]: doc_rep,
        self.columns[2]: doc_id_rep
    }

### DataLoaders
We have two main Datasets:
- `DatasetDSI`,  which contains the full set of documents (from the chunk chosen)
- `DatasetDSI_DocQuery`, which includes documents associated with queries for retrieval tasks.

In MSMarco, not all documents have a query linked to them, so we can't properly perform the retrieval task, hence the creation of the second Dataset.

When training, we want a dataset that combines the complete set of documents along with retrieval examples.<br>
To achieve this, the model includes the full `DatasetDSI` and add a portion (80%) of retrieval examples from `DatasetDSI_DocQuery`.<br>

Train Dataset:
  - Full `DatasetDSI`
  - 0.8 from `DatasetDSI_DocQuery`
  
Validation Dataset:
  - 0.1 from `DatasetDSI_DocQuery`
  
Test Dataset:
  - 0.1 from `DatasetDSI_DocQuery`

In this way, we ensure that the model never performs retrieval on Validation and Test Documents, as it is limited to the 0.8 portion.

In [None]:
dsi_config = {
    # ---- GENERAL CONFIGURATIONS ----
    "seed": 42,                                                     # Seed for reproducibility
    "MAX_LENGTH": 256,                                              # Max length for the Tokenizer when tokenizing the input
    "MAX_DOCS": 10000,                                              # Max number of documents to consider
    "batch_size": 4,                                                # Batch size | Default 4 otherwise it will crash
    # ---- TRAINING CONFIGURATIONS ----
    "train_dataset_chunk": "00",                                    # The single MSMarco file we want to process
    "document_representation_strategy": "direct_indexing",            # Strategy for the document representation | "direct_indexing", "set_indexing", "inverted_indexing", "summarization", "default"
                                                                    #   "summarization" only works when inserting MAX_DOCS = 10000
    "document_id_representation_strategy": "unstructured_atomic",   # Strategy for the document ID representation | "unstructured_atomic", "naively_structured", "semantically_structured"
    "training_type": "generative",                                  # Training type | "discriminative" / "generative"
    "enable_multitask_prompting" : True,                            # If True, it will enable the multitask prompting | "Generate a doc_id etc.."
    "train_indexing_retrieval_ratio": 0.3,                          # Ratio between indexing and retrieval tasks | Default: 0.3 (70% for indexing, 30% for retrieval)
    # ---- W&B CONFIGURATIONS ----
    "wandb_configs": {"enable": False, "group_id": "v0.9.3"},        # WandB configurations
    "student": student,                                             # Student name
}

In [None]:
doc_rep_strategy = dsi_config["document_representation_strategy"]
doc_ids_rep_strategy = dsi_config["document_id_representation_strategy"]
train_dataset_chunk = dsi_config["train_dataset_chunk"]

print("Document Representation Strategy: {}".format(doc_rep_strategy))
print("Document ID Representation Strategy: {}".format(doc_ids_rep_strategy))
print("Train Dataset Chunk: {}".format(train_dataset_chunk))
print("-----------------------------------")

# This contains ALL the documents.
train_dataset = DatasetDSI(train_dataset_chunk=train_dataset_chunk, doc_rep_strategy=doc_rep_strategy, doc_ids_rep_strategy=doc_ids_rep_strategy, MAX_DOCS=dsi_config['MAX_DOCS'])

# This filters out the Documents which have a query(ies), suitable for Retrieval Task.
train_docs_with_queries_dataset = DatasetDSI_DocQuery(train_dataset, query_doc_df, dsi_config['document_representation_strategy'])

# This is the number of labels we need in order to generate a Linear Layer for the model
# when the training type is "discriminative".
dsi_config['labels_count'] = len(train_docs_with_queries_dataset.filtered_docs_with_queries[doc_ids_rep_strategy])

Document Representation Strategy: direct_indexing
Document ID Representation Strategy: unstructured_atomic
Train Dataset Chunk: 00
-----------------------------------


Generating chunks for file msmarco_doc_00: 0it [00:00, ?it/s]

Generating chunks for file msmarco_doc_00: 200000it [00:22, 8707.35it/s] 


Chunk correctly generated. Processing DocIDs and Documents now...
-- Considering as dataset only the first 200000 rows


In [None]:
# -------------- RETRIEVAL SPLIT -------------- #
np.random.seed(dsi_config["seed"])

# Define the ratios for validation and test datasets
train_ratio = 0.8       # 80% of retrieval Docs for training
validation_ratio = 0.1  # 10% of retrieval Docs for validation
test_ratio = 0.1        # 10% of retrieval Docs for testing

# Calculate the number of samples for validation and test datasets
num_samples = len(train_docs_with_queries_dataset)
num_train_samples = int(train_ratio * num_samples)
num_val_samples = int(validation_ratio * num_samples)
num_test_samples = int(test_ratio * num_samples)

# Create an array of indices for the dataset
indices = np.arange(num_samples)
np.random.shuffle(indices) # Shuffle the indices

# Split the shuffled indices into validation and test indices
train_indices = indices[:num_train_samples]
val_indices = indices[num_train_samples:num_train_samples + num_val_samples]
test_indices = indices[num_val_samples:num_val_samples + num_test_samples]

# Create subsets of the dataset for validation and test
train_subset_with_queries_filtered = Subset(train_docs_with_queries_dataset, train_indices)
val_subset = Subset(train_docs_with_queries_dataset, val_indices)
test_subset = Subset(train_docs_with_queries_dataset, test_indices)

In [None]:
# Define number of workers for dataloaders
#num_workers = int(os.cpu_count() / 2) if os.cpu_count() > 2 else 1 # Not setting maximum number of workers to avoid overloading the system
num_workers = 0

# Create dataloaders for training, validation, and test sets
train_dataloader = DataLoader(train_subset_with_queries_filtered, batch_size=dsi_config['batch_size'], shuffle=True, collate_fn=train_docs_with_queries_dataset.collate_fn, num_workers=num_workers)

# Filter the train dataloader to have a smaller subset for Retrieval Task
# [---------------------------------|---]
# [                I                | R ]
batches_to_keep = int(len(train_dataloader) * dsi_config["train_indexing_retrieval_ratio"])
np.random.seed(dsi_config["seed"])
np.random.shuffle(train_indices)
train_subset_with_queries_filtered = Subset(train_docs_with_queries_dataset, train_indices[:batches_to_keep*dsi_config['batch_size']]) # This will be passed to the model
train_dataloader_filtered = DataLoader(train_subset_with_queries_filtered, batch_size=dsi_config['batch_size'], shuffle=True, collate_fn=train_docs_with_queries_dataset.collate_fn, num_workers=num_workers)

val_dataloader = DataLoader(val_subset, batch_size=dsi_config['batch_size'], shuffle=False, collate_fn=train_docs_with_queries_dataset.collate_fn, num_workers=num_workers)
test_dataloader = DataLoader(test_subset, batch_size=dsi_config['batch_size'], shuffle=False, collate_fn=train_docs_with_queries_dataset.collate_fn, num_workers=num_workers)

print(f"Train Dataloader length: {len(train_dataloader)}")
print(f"Validation Dataloader length: {len(val_dataloader)}")
print(f"Test Dataloader length: {len(test_dataloader)}")
print("------------------------------------")
print(f"Train Docs with Queries Dataloader length: {len(train_dataloader_filtered)}")

# Update some configurations
dsi_config["train_begin_retrieval"] = len(train_dataloader) - len(train_dataloader_filtered)    # The index where the retrieval task begins | Dynamic
dsi_config["token_tokenization_max_length"] = train_dataset.docID_rep_engine.doc_count          # The max length for the tokenization. Used by tokenizing the labels and the generate() function

if(dsi_config["enable_multitask_prompting"]):
    dsi_config["multitask_prompting_indexing"] = f"Generate a document identifier with {len(str(train_dataset.MAX_DOCS))} digits between 0 and 9 for the following document:"
    dsi_config["multitask_prompting_retrieval"] = f"Generate a document identifier with {len(str(train_dataset.MAX_DOCS))} digits between 0 and 9 as response to the following query:"

Train Dataloader length: 117
Validation Dataloader length: 15
Test Dataloader length: 15
------------------------------------
Train Docs with Queries Dataloader length: 35


## Model
- Quantization: https://huggingface.co/docs/transformers/quantization#8-bit <br>
- Notebook on FineTuning using Lora & BitsAndBytes: https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb

BitsAndBytes is needed to load up large models into small RAM systems through model quantization to 8 or 4-bits precision.<br>
However, after a model is quantized it isn’t typically further trained for downstream tasks because training can be unstable due to the lower precision of the weights and activations.<br>
But since PEFT methods only add extra trainable parameters, this allows you to train a quantized model with a PEFT adapter on top!<br>
Combining quantization with PEFT can be a good strategy for training even the largest models on a single GPU. For example, QLoRA is a method that quantizes a model to 4/8-bits and then trains it with LoRA. <br>

- LoRa = One would keep the base model in 32 or 16 bits in memory, and then train the parameter weights.
- QLoRa = Apply LoRa to a quantized model (like a 4-bit model)

The various models tested can be found in:
- SwitchTransformers = https://huggingface.co/collections/google/switch-transformers-release-6548c35c6507968374b56d1f
- Flan-T5 = https://huggingface.co/collections/google/flan-t5-release-65005c39e3201fff885e22fb

| Model Name               | All Parameters | 8-Bit Parameters | QLoRa Trainable Params (Percentage) | Hidden Size Encoder  | Target modules                        |
|--------------------------|----------------|------------------|-------------------------------------|----------------------|---------------------------------------|
| 'google/switch-base-8'   | 619M           | 24.722.688	   | 1,327,104 (0.0214%)                 | 768                  | target_modules=["q", "k", "v"]        |
| 'flan-t5-base'           | 248M           | /                | 1,327,104 (0.0533%)         	     | 768                  | target_modules=["q", "k", "v"]        |
| 'bert-base-uncased'      | 139M           | /                | 1,327,104 (0.0929%)         	     | 768                  | target_modules=["q", "k", "v"]        |

### Foundation model

In [None]:
class FoundationModel(pl.LightningModule):
    def __init__(self, model_id="google/flan-t5-base", encoder_id='bert-base-uncased', decoder_id='bert-base-uncased', MAX_LENGTH=512, fine_tuning=False):
        super().__init__()
        # Misc variables
        self.MAX_LENGTH = MAX_LENGTH

        if(not fine_tuning):
            # We first define the quantization configuration we want to use
            # Here we use the Bits and Bytes configuration with 8-bit quantization
            quantization_config = BitsAndBytesConfig(
                load_in_8bit=True,                      # Quantize the model to 8-bits when you load it
                bnb_8bit_use_double_quant=True,         # To use a nested quantization scheme to quantize the already quantized weights
                bnb_8bit_quant_type="nf4",              # To use a special 8-bit data type for weights initialized from a normal distribution
                bnb_8bit_compute_dtype=torch.bfloat16,  # To use bfloat16 for faster computation
            )

            # Now that the quantized config is ready, let’s set up a configuration for further shrink of the parameters.
            # We will use the LoraConfig class from the PEFT library.
            lora_config = LoraConfig(
                task_type=TaskType.SEQ_2_SEQ_LM,        # We are using the model for a seq2seq task
                inference_mode=False,
                r=8,                                    # Lora attention dimension | Default: 8
                lora_alpha=32,                          # The alpha parameter for Lora scaling | Default: 8
                lora_dropout=0.1,                       # The dropout probability for Lora layers | Default: 0.0
                target_modules=["q", "k", "v"],         # The names of the modules to apply Lora to | Can also be ["q", "k", "v", "o", "wi", "wo"] if we want to include weights_in / weights_out of FFN
            )
        else:
            quantization_config = None
            lora_config = None

        # Load the model and tokenizer
        if(model_id != 'bert-base-uncased') and (encoder_id == None) and (decoder_id == None):
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id, device_map="auto", use_cache=False, quantization_config=quantization_config)
            self.tokenizer = AutoTokenizer.from_pretrained(model_id, model_max_length=self.MAX_LENGTH) # Set reasonable default for models without max length
            self.is_encoder_decoder = False
            print("Using AutoModelForSeq2SeqLM")
        else:
            self.model = EncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, quantization_config=quantization_config, tie_encoder_decoder=True)
            self.tokenizer = AutoTokenizer.from_pretrained(encoder_id, model_max_length=self.MAX_LENGTH) # Set reasonable default for models without max length
            self.is_encoder_decoder = True
            # Configurations of the EncoderDecoderModel
            self.model.config.decoder_start_token_id = self.tokenizer.cls_token_id
            self.model.config.eos_token_id = self.tokenizer.sep_token_id
            self.model.config.pad_token_id = self.tokenizer.pad_token_id
            self.model.config.vocab_size = self.model.config.encoder.vocab_size
            lora_config.target_modules = {'query', 'key', 'value'}
            print("Using EncoderDecoderModel, hence bert-base-uncased.")

        if(fine_tuning):
            # Set requires_grad=True for the last 4 layers of the encoder and decoder
            for name, param in self.model.named_parameters():
                if name.startswith("encoder.block") and int(name.split(".")[2]) >= 8:
                    param.requires_grad = True
                elif name.startswith("decoder.block") and int(name.split(".")[2]) >= 8:
                    param.requires_grad = True
                else:
                    param.requires_grad = False

        # Function to preprocess the quantized model for training.
        if(quantization_config is not None):
            self.model = prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=False)

        # Should fix: UserWarning: None of the inputs have requires_grad=True. Gradients will be None.
        # Should fix: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
        if hasattr(self.model, "enable_input_require_grads"):
            self.model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)

            self.model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        if(lora_config is not None):
            self.model = get_peft_model(self.model, lora_config)
            print("Encoder info:")
            self.model.print_trainable_parameters()
        else:
            print(f"Encoder info: {self.get_n_trainable_parameters()} trainable parameters")

    # This is no more useful, the forward if this model is called directly in the DSI model
    def forward(self, input):
        pass

    # Print number of trainable parameters
    def get_n_trainable_parameters(self):
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

### Custom Encoder-Decoder
Here we use SwitchTransformer as an Encoder and FlanT5 as Decoder, both instantiated with QLoRa.<br>

The mechanism follows this scheme:
- Forward pass in Encoder for document representation
- Forward pass in Decoder with loss computation
- Backpropagation of the loss

This is done for both indexing and retrival tasks, but in the latter, the Query is used for the Forward Pass of the Encoder

#### Encoder

In [None]:
class SwitchTransformer(pl.LightningModule):
    def __init__(self, model_id="google/switch-base-8", MAX_LENGTH=512):
        super().__init__()
        # Misc variables
        self.MAX_LENGTH = MAX_LENGTH

        # We first define the quantization configuration we want to use
        # Here we use the Bits and Bytes configuration with 8-bit quantization
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,                      # Quantize the model to 8-bits when you load it
            bnb_8bit_use_double_quant=True,         # To use a nested quantization scheme to quantize the already quantized weights
            bnb_8bit_quant_type="nf4",              # To use a special 8-bit data type for weights initialized from a normal distribution
            bnb_8bit_compute_dtype=torch.bfloat16,  # To use bfloat16 for faster computation
        )

        # Now that the quantized config is ready, let’s set up a configuration for further shrink of the parameters.
        # We will use the LoraConfig class from the PEFT library.
        lora_config = LoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM,        # We are using the model for a seq2seq task
            inference_mode=False,
            r=8,                                    # Lora attention dimension | Default: 8
            lora_alpha=32,                          # The alpha parameter for Lora scaling | Default: 8
            lora_dropout=0.1,                       # The dropout probability for Lora layers | Default: 0.0
            target_modules=["q", "k", "v"],         # The names of the modules to apply Lora to | Can also be ["q", "k", "v", "o", "wi", "wo"] if we want to include weights_in / weights_out of FFN
        )

        # Load up Switch Transformer
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, model_max_length=self.MAX_LENGTH) # Set reasonable default for models without max length
        self.model = SwitchTransformersForConditionalGeneration.from_pretrained(model_id, device_map="auto", use_cache=False, quantization_config=quantization_config)

        # Function to preprocess the quantized model for training.
        self.model = prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=False)

        # Should fix: UserWarning: None of the inputs have requires_grad=True. Gradients will be None.
        # Should fix: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
        if hasattr(self.model, "enable_input_require_grads"):
            self.model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)

            self.model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        self.model = get_peft_model(self.model, lora_config)
        print("Encoder info:")
        self.model.print_trainable_parameters()

    # Forward pass for Encoder.
    # input = Documents (or Queries) that needs to be tokenized and forwarded through the model
    def forward(self, input):
        docs_input = self.tokenizer(input, add_special_tokens=False, return_tensors='pt', padding='max_length', truncation=True) # [batch_size, max_length_tokenizer] -> [2, 256]
        # inputs = self.tokenizer(input, add_special_tokens=False, return_tensors='pt', max_length=64, padding="max_length", truncation=True)
        docs_input["input_ids"].to(device)
        docs_input["attention_mask"].to(device)

        # Calling the whole model, but effectively retain the Encoder last hidden states
        outputs = self.model(**docs_input, decoder_input_ids=torch.zeros_like(docs_input["input_ids"], device=device), output_hidden_states=True)

        last_hidden_states = outputs.encoder_last_hidden_state
        encoder_hidden_states = outputs.encoder_hidden_states

        return (last_hidden_states, encoder_hidden_states), docs_input, outputs

    # Print number of trainable parameters
    def get_n_trainable_parameters(self):
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

#### Decoder

In [None]:
class FlanT5(pl.LightningModule):
    def __init__(self, model_id="google/flan-t5-base", MAX_LENGTH=512):
        super().__init__()
        # Misc variables
        self.MAX_LENGTH = MAX_LENGTH

        # We first define the quantization configuration we want to use
        # Here we use the Bits and Bytes configuration with 8-bit quantization
        quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,                      # Quantize the model to 8-bits when you load it
            bnb_8bit_use_double_quant=True,         # To use a nested quantization scheme to quantize the already quantized weights
            bnb_8bit_quant_type="nf4",              # To use a special 8-bit data type for weights initialized from a normal distribution
            bnb_8bit_compute_dtype=torch.bfloat16,  # To use bfloat16 for faster computation
        )

        # Now that the quantized config is ready, let’s set up a configuration for further shrink of the parameters.
        # We will use the LoraConfig class from the PEFT library.
        lora_config = LoraConfig(
            task_type=TaskType.SEQ_2_SEQ_LM,        # We are using the model for a seq2seq task
            inference_mode=False,
            r=8,                                    # Lora attention dimension | Default: 8
            lora_alpha=32,                          # The alpha parameter for Lora scaling | Default: 8
            lora_dropout=0.1,                       # The dropout probability for Lora layers | Default: 0.0
            target_modules=["q", "k", "v"],         # The names of the modules to apply Lora to | Can also be ["q", "k", "v", "o", "wi", "wo"] if we want to include weights_in / weights_out of FFN
        )

        # Standard call for initializing FlanT5
        self.tokenizer = AutoTokenizer.from_pretrained(model_id, model_max_length=self.MAX_LENGTH)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_id, device_map="auto", use_cache=False, quantization_config=quantization_config)

        # Function to preprocess the quantized model for training.
        self.model = prepare_model_for_kbit_training(self.model, use_gradient_checkpointing=False)

        # Should fix: UserWarning: None of the inputs have requires_grad=True. Gradients will be None.
        # Should fix: `use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
        if hasattr(self.model, "enable_input_require_grads"):
            self.model.enable_input_require_grads()
        else:
            def make_inputs_require_grad(module, input, output):
                output.requires_grad_(True)

            self.model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

        self.model = get_peft_model(self.model, lora_config)
        print("Decoder info:")
        self.model.print_trainable_parameters()

    # Forward pass for Decoder.
    #   - embeddings        : embeddings arriving from the Encoder Module
    #   - decoder_input_ids : input_ids from the Encoder's Tokenizer (for Doc & Query)
    #   - labels            : Tensor of doc_ids tokenized to compute the loss
    def forward(self, embeddings, decoder_input_ids, decoder_attention_mask, labels):

      # https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L1709 | We can skip directly to Decoder by passing encoder_outputs.
      # encoder_outputs is a tuple composed of: (last_hidden_state, optional: hidden_states, optional: attentions)
      #     -> last_hidden_state : Shape (batch_size, sequence_length, hidden_size)
      #        is a sequence of hidden states at the output of the last layer of the encoder.
      #        Used in the cross-attention of the decoder.
      # In this case, we would pass the last_hidden_state of the encoder to the decoder.
      return self.model(
          decoder_input_ids=decoder_input_ids,              # Use decoder_input_ids for the decoder input | Indices of decoder input sequence tokens in the vocabulary.
          decoder_attention_mask=decoder_attention_mask,
          encoder_outputs=embeddings,                       # Pass embeddings coming from Encoder as inputs_embeds for the decoder
          labels=labels                                     # Labels for computing the sequence classification/regression loss | All labels set to -100 are ignored (masked), the loss is only computed for labels in [0, ..., model.vocab_size] # Default value = 32128
      )

    # Print number of trainable parameters
    def get_n_trainable_parameters(self):
        return sum(p.numel() for p in self.model.parameters() if p.requires_grad)

### Configs & WandB support

In [None]:
# Recall that:
# proj_dict[student][0] -> Project directory
# proj_dict[student][1] -> WandB key
if(dsi_config["wandb_configs"]["enable"]) and student != 'Professor':
    wandb.login(key = proj_dict[student][1])

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mslimshadys[0m ([33msapienza_ml_2022_23[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\Gianmarco\.netrc


In [None]:
"""
This method edits a configuration parameter given a run_id.

:param run_id: It's the ID of the run (13qvc9uo, 83cvs1ff, etc.)
:param project_name: The name of the WANDB project
:param parameter: The name of the parameter to change
:param new_value: The value to give to the parameter passed above

:return: Edit the needed parameter and returns

@ Gianmarco Scarano
"""
def configChanger(run_id='None', project_name='None', parameter='None', new_value = 'None'):

    if run_id == 'None' or project_name == 'None' or parameter == 'None' or new_value == 'None':
        raise NotImplementedError("Check parameters! You should provide:\n- run_ID\n- Project name\n- Parameter\n- New Value")

    run = wandb.Api().run(F"{project_name}/{run_id}")

    # EXAMPLE: We edit the model_name parameter
    run.config[parameter] = new_value

    run.update()
    wandb.finish()
    return

# configChanger(run_id='8gxx5lpy', project_name='DeepLearning-DSI', parameter='model_name', new_value='google/switch-base-8')

## Training

### Generative Approach

#### DSI Model with Foundation Model

In [None]:
'''
  -- DSI Model --

  Parameters:
    - config: The configuration dictionary populated throughout the whole Code
    - model: The model to use (populated when instantiating the Foundation Model)
    - query_doc_df: The dataframe containing the queries
        This is needed for retrieving DocIDs during the retrieval task
    - training_dataloader_with_queries_filtered: The dataloader containing the filtered documents with queries
        Due to Lightning's limitations, we need to pass the dataloader to the model in order to be able to iterate over it
'''
class DSI_Model(pl.LightningModule):
    def __init__(self,
                 config: dict,
                 model: pl.LightningModule,
                 query_doc_df: pd.DataFrame,
                 training_dataloader_with_queries_filtered: DataLoader,
                 ):
        super().__init__()

        # Variables
        self.config = config
        self.MAX_LENGTH = config["MAX_LENGTH"]

        self.token_tokenization_max_length = config['token_tokenization_max_length'] # Define the label max_length of token for the tokenization
        self.train_indexing_retrieval_ratio = config["train_indexing_retrieval_ratio"] # Docs with Query for Training
        self.train_begin_retrieval = config["train_begin_retrieval"] # At which index the retrieval task begins

        self.enable_multitask_prompting = config["enable_multitask_prompting"]
        if(self.enable_multitask_prompting):
          self.multitask_prompting_indexing = config["multitask_prompting_indexing"]
          self.multitask_prompting_retrieval = config["multitask_prompting_retrieval"]

        # Foundation Model
        self.model = model.to(device)

        # Dataloader
        # We need to keep a copy of the training dataloader, due to the fact that we
        # must reset the training dataloader every time we finish an epoch
        # If we don't do this, we will have a StopIteration error
        self.copy_training_dataloader = training_dataloader_with_queries_filtered
        self.training_dataloader_with_queries_filtered = iter(training_dataloader_with_queries_filtered)

        # Log losses
        self.training_indexing_losses = []
        self.training_retrieval_losses = []
        self.validation_retrieval_losses = []

        self.training_loss = []
        self.validation_loss = []
        self.test_accuracy = []

        self.best_training_loss = float('inf')
        self.best_validation_loss = float('inf')
        self.best_test_accuracy = 0.0

        # WandB
        self.enable_wandb = config["wandb_configs"]["enable"]
        if(self.enable_wandb):
          self.wandb_run = wandb.init(
              project='DeepLearning-DSI',
              group=config["wandb_configs"]["group_id"],
              config=config,)
        else:
          self.wandb_run = None

        # Dataframes
        self.query_doc_df = query_doc_df  # Query - Doc ID Dataframe
        self.top_100_df = None            # Populated only if the test_type is "top_k"

        # The current task
        self.current_task = "indexing" # "indexing" / "retrieval"
        self.test_type = "accuracy" # "accuracy" / "top_k"

        # Create a Dict DocID2Index and Index2DocID
        # Useful when we need to retrieve the DocID from the index and vice versa (Test phase)
        self.docID2label = {}
        self.label2DocID = {}
        filtered_docs = train_docs_with_queries_dataset.filtered_docs_with_queries
        for _, row in filtered_docs.iterrows():
          self.docID2label[row['Doc_ID']] = str(row[config['document_id_representation_strategy']])
          self.label2DocID[str(row[config['document_id_representation_strategy']])] = row['Doc_ID']

        # Generation variables
        self.num_return_sequences = 5
        self.num_beams = 5

    '''
      Generate method:
        - docs_input: The input for the model
        - labels_att_mask: The attention mask for the labels
    '''
    def generate(self, docs_input, labels_att_mask):
        docs_input_ids = docs_input['input_ids'].to(device)
        docs_input_att_mask = docs_input["attention_mask"].to(device)

        # Here we might want to force the model to generate a specific token at the beginning of the sequence
        force_words = ["0"]
        force_words_ids = self.model.tokenizer(force_words, add_special_tokens=False).input_ids

        # If we are using the BertUncased Model, we need to pass the decoder_start_token_id
        # If num_beams > 0, then it means that we are using the Constrainted Beam Search
        with torch.no_grad():
          generated_ids = self.model.model.generate(
              inputs=docs_input_ids, # [bs, 256]
              attention_mask=docs_input_att_mask, # [bs, 256]
              # decoder_attention_mask=labels_att_mask, # [bs, 4]
              max_length=len(str(self.token_tokenization_max_length)),
              decoder_start_token_id = self.model.tokenizer.cls_token_id if self.model.is_encoder_decoder else None,
              #force_words_ids=force_words_ids,
              num_beams=self.num_beams,
              num_return_sequences=self.num_return_sequences,
              do_sample=True,
              top_k=50,
              temperature=0.6,
          )

        return self.model.tokenizer.batch_decode(
            generated_ids,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        )
    '''
      The forward method of the DSI model that differentiate slightly the procedure to
      follow according to the task that has to be performed (indexing/retrieval)
    '''
    def forward(self, batch, training=True):
        torch.cuda.empty_cache() # Clear CUDA cache before forward pass

        if self.current_task == "indexing": # Indexing Task
          # T5 Prompt Setup as suggested in the paper
          if self.enable_multitask_prompting:
            docs_body = [f"{self.multitask_prompting_indexing} {item}" for item in batch[train_dataset.doc_rep_engine.strategy]]
          else:
            docs_body = batch[train_dataset.doc_rep_engine.strategy]

          # Labels are the Doc_ID(s) processed with the strategy chosen
          labels = [str(item) for item in batch[train_dataset.docID_rep_engine.strategy]]

          docs_input = self.model.tokenizer(docs_body, add_special_tokens=False, return_tensors='pt', max_length=self.MAX_LENGTH, padding="max_length", truncation=True)
          labels_inputs = self.model.tokenizer(labels, add_special_tokens=False, return_tensors='pt', max_length=len(str(self.token_tokenization_max_length)), padding="max_length", truncation=True)

          labels_input_ids = labels_inputs["input_ids"].to(device)
          labels_att_mask = labels_inputs['attention_mask'].to(device)
          input_ids = docs_input['input_ids'].to(device)
          attention_mask = docs_input['attention_mask'].to(device)

          labels_input_ids[labels_input_ids == self.model.tokenizer.pad_token_id] = -100

          # For training -> Includes loss computation as well
          if(training):
            # The model is called with the input_ids and the labels
            # The input_ids will flow through the Encoder-Decoder model
            return self.model.model(
                input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=None,
                decoder_attention_mask=labels_att_mask,
                labels=labels_input_ids
            )
          else: # For validation and test
            return self.generate(docs_input, labels_att_mask)
        else:
          # Retrieval Task. The process is the exact same as the indexing task,
          # but we need to consider the queries associated to the documents, not the documents themselves.
          queries = batch["Query"]

          queries = [str(item[0]) for item in queries]
          if self.enable_multitask_prompting:
            queries = [f"{self.multitask_prompting_retrieval} {query}" for query in queries]

          labels = [str(item) for item in batch[train_dataset.docID_rep_engine.strategy]]

          queries_input = self.model.tokenizer(queries, add_special_tokens=False, return_tensors='pt', max_length=self.MAX_LENGTH, padding="max_length", truncation=True)
          labels_inputs = self.model.tokenizer(labels, add_special_tokens=False, return_tensors='pt', max_length=len(str(self.token_tokenization_max_length)), padding="max_length", truncation=True)

          labels_input_ids = labels_inputs["input_ids"].to(device)
          labels_att_mask = labels_inputs['attention_mask'].to(device)
          labels_input_ids[labels_input_ids == self.model.tokenizer.pad_token_id] = -100
          input_ids = queries_input['input_ids'].to(device)
          attention_mask = queries_input['attention_mask'].to(device)

          if(training):
            return self.model.model(
                input_ids,
                attention_mask=attention_mask,
                decoder_input_ids=None,
                decoder_attention_mask=labels_att_mask,
                labels=labels_input_ids
            )
          else:
            return self.generate(queries_input, labels_att_mask)

    '''
      A retrival batch is created starting from a batch considering the questions associated to doc
      in the current batch.
    '''
    def prepare_sample_for_retrieval(self, batch):
        batch['Queries_ID'] = []
        batch['Query'] = []
        doc_ids = batch['Doc_ID']

        for doc_id in doc_ids:
            query_ids = self.query_doc_df[self.query_doc_df['Doc_ID'] == doc_id]['Query_ID'].values

            if(len(query_ids) == 0):
              raise Exception(f"Document {doc_id} without queries!")
            else:
              queries = self.query_doc_df[self.query_doc_df['Doc_ID'] == doc_id]['Query'].values

              batch['Queries_ID'].append(query_ids)
              batch['Query'].append(queries)

        return batch

    def training_step(self, batch, batch_idx):
        # ========================================================================
        # Retrieval step | Every self.train_indexing_retrieval_ratio batches
        #
        # Training step => Indexing all the documents and perform retrival
        #                  (at the end of each epoch considering the given ratio)
        #                  on just the document chosen as training docs.
        # ========================================================================

        # Indexing step
        self.current_task = 'indexing'

        output = self.forward(batch, training=True)
        loss_idx = output.loss

        self.training_indexing_losses.append(loss_idx.cpu().detach().item())

        total_step_loss = loss_idx

        # Switch to retrieval task if the batch index is greater than the train_begin_retrieval value
        # This is done to alternate between indexing and retrieval tasks as suggested in the paper
        if batch_idx >= self.train_begin_retrieval:
          self.current_task = 'retrieval'

          # If we are in the retrieval task, we need to prepare the sample for retrieval
          # Hence, we need to consider the queries associated to the documents in the current batch
          batch = next(self.training_dataloader_with_queries_filtered)
          batch = self.prepare_sample_for_retrieval(batch)

          output = self.forward(batch, training=True)
          loss_retrieval = output.loss

          self.training_retrieval_losses.append(loss_retrieval.cpu().detach().item())

          total_step_loss = (total_step_loss + loss_retrieval) / 2

        self.log("training_loss", total_step_loss.cpu().detach().item())
        self.training_loss.append(total_step_loss.cpu().detach().item())

        return total_step_loss

    # Validation => Retrieval step for all the documents chosen to be in validation split
    def validation_step(self, batch, batch_idx):
        self.current_task = 'retrieval'

        # Retrieval step (same as retrieval step in training phase)
        batch = self.prepare_sample_for_retrieval(batch)
        bs = len(batch['Doc_ID'])

        output = self.forward(batch, training=True)
        loss = output.loss

        self.log("validation_loss", loss.cpu().detach().item(), batch_size=bs)
        self.validation_loss.append(loss.cpu().detach().item())
        self.validation_retrieval_losses.append(loss.cpu().detach().item())

        return loss

    # Test => Retrieval step for all the documents chosen to be in test split
    def test_step(self, batch, batch_idx):
      self.current_task = 'retrieval'

      # Retrieval step (same as retrieval step in training/validation phase)
      batch = self.prepare_sample_for_retrieval(batch)
      bs = len(batch['Doc_ID'])

      output = self.forward(batch, training=False)

      if(self.test_type == "accuracy"):
        accuracy = self.run_accuracy(output, batch)
      else:
        accuracy = self.run_top_k(output, batch)

      self.log("test_accuracy", accuracy, batch_size=bs)
      self.test_accuracy.append(accuracy)

      return accuracy

    '''
      Checks if the output from the model are in the top-k relevant document of the queries from the batch
    '''
    def run_top_k(self, output, batch):
      # Before trying to find if the outputs are relevant w.r.t query, we first check if the query is present in the Top100 DF.
      list_df = []
      batch_queries = batch['Queries_ID']
      for query in batch_queries:
          q = self.top_100_df[self.top_100_df['Query_ID'] == query[0]]
          if len(q) > 0:
              list_df.append(q)
          else:
              pass
              #print(f"There are no queries associated with batch query {query[0]}!")
      batch_top_100 = pd.concat(list_df)

      # Check if num_return_sequences > 1, if so, we need to iterate over the output differently
      if(self.num_return_sequences > 1):
          batch_predictions = []
          final_returns = []

          # Iterate over each prediction sequence in the output
          for i in range(0, len(output), self.num_return_sequences):
              # Get the predictions for the current batch
              batch_predictions.append(output[i:i+self.num_return_sequences])

          # Extract DocIDs from the output
          for idx, b_pred in enumerate(batch_predictions):
              model_returns = []
              for p in b_pred:
                  try:
                      model_returns.append(self.label2DocID[p])
                      self.model_found_label += 1 # Increment the number of labels found | Instantiated before starting the test phase
                  except:
                      model_returns.append('')
                      #print(f"There is no DocID for model prediction: {p}!")
              final_returns.append(model_returns)

          # Now we need to:
          # - Check if the batch predictions are in the top 100 for their respective queries
          correct = 0
          for idx, predictions in enumerate(final_returns):
              for prediction in predictions:
                  if prediction in batch_top_100[batch_top_100['Query_ID'] == batch_queries[idx][0]]['Doc_ID'].values:
                      correct += 1
          return float(correct / len(batch_predictions))
      else: # num_return_sequences == 1 / 0
          # Extract DocIDs from the output
          model_returns = []
          for idx, b_pred in enumerate(output):
              try:
                  model_returns.append(self.label2DocID[b_pred])
                  self.model_found_label += 1 # Increment the number of labels found | Instantiated before starting the test phase
              except:
                  model_returns.append('')
                  #print(f"There is no DocID for model prediction: {b_pred}!")

          # Now we need to:
          # - Take the first query from the batch
          # - Check if the first model_returns is in the top 100 for the first query
          # - Loop this process for all queries/outputs
          correct = 0
          for idx, query in enumerate(batch_queries):
              if(model_returns[idx] in batch_top_100[batch_top_100['Query_ID'] == query[0]]['Doc_ID'].values):
                  correct += 1

          return float(correct / len(output))

    '''
      Checks if the output from the model is equal to the label present in the batch.
    '''
    def run_accuracy(self, output, batch):
      batch_predictions = []

      if(len(output) > self.config['batch_size']): # It means that we are in the generation phase with num_return_sequences > 1
        # Iterate over each prediction sequence in the output
        for i in range(0, len(output), self.num_return_sequences):
            # Get the predictions for the current batch
            batch_predictions.append(output[i:i+self.num_return_sequences])

        correct = 0
        for idx, pred in enumerate(batch_predictions): # Enters a single batch of predictions.
          # Compute the accuracy
          for b_pred in pred: # Enters a single prediction
            try:
              model_return = self.label2DocID[b_pred]
              self.model_found_label += 1 # Increment the number of labels found | Instantiated before starting the test phase
            except:
              model_return = ''
              #print(f"The model did not output a correct DocID for prediction {b_pred}!")

            if model_return == str(batch['Doc_ID'][idx]):
              correct += 1
        return correct / len(batch_predictions)

      else:
        for idx, b_pred in enumerate(output):
          try:
            model_return = self.label2DocID[b_pred]
            self.model_found_label += 1 # Increment the number of labels found | Instantiated before starting the test phase
          except:
            model_return = ''
            #print(f"The model did not output a correct DocID for prediction {b_pred}!")

          if model_return == str(batch['Doc_ID'][idx]):
            correct += 1
        return correct / len(batch['Doc_ID'])

    def on_train_epoch_end(self) -> None:
        epoch_loss = sum(loss for loss in self.training_loss) / len(self.training_loss) if len(self.training_loss) > 0 else float('nan')
        self.log(f"avg_training_loss", epoch_loss, batch_size=self.config['batch_size'])

        print(f"| Epoch {self.current_epoch} | {'TRAINING'}")

        best_loss = getattr(self, f"best_training_loss")

        # Update best loss if current loss is better
        if epoch_loss < best_loss:
            setattr(self, f"best_training_loss", epoch_loss)

        if(len(self.training_indexing_losses)!=0):
          loss_indexing = sum(loss for loss in self.training_indexing_losses) / len(self.training_indexing_losses) if len(self.training_indexing_losses) > 0 else float('nan')
          self.log(f"avg_training_indexing_loss", loss_indexing, batch_size=self.config['batch_size'])
          print(f"\t- Indexing Loss => {loss_indexing:.4f}")

        if(len(self.training_retrieval_losses)!=0):
          loss_retrieval = sum(loss for loss in self.training_retrieval_losses) / len(self.training_retrieval_losses) if len(self.training_retrieval_losses) > 0 else float('nan')
          self.log(f"avg_training_retrieval_loss", loss_retrieval, batch_size=self.config['batch_size'])
          print(f"\t- Retrieval Loss => {loss_retrieval:.4f}")

        print(f"\t- Total Loss => {epoch_loss:.4f}")

        if self.enable_wandb and not self.wandb_run._is_finished:
            self.wandb_run.log({f"avg_training_loss": epoch_loss})
            if(len(self.training_indexing_losses)!=0):
              self.wandb_run.log({f"avg_training_indexing_loss": loss_indexing})
            if(len(self.training_retrieval_losses)!=0):
              self.wandb_run.log({f"avg_training_retrieval_loss": loss_retrieval})
            self.wandb_run.log({f"epoch": self.current_epoch})

        # At the end of each training epoch we select a differet subset of training document
        # on which perform the training retrieval. This ensure to have stocasticity selection
        # while performing the retrieval on training docs
        self.training_dataloader_with_queries_filtered = iter(self.copy_training_dataloader)

        # Empty the list of the losses for the next epoch
        self.training_loss = []
        self.training_indexing_losses = []
        self.training_retrieval_losses = []

    def on_validation_epoch_end(self) -> None:
        epoch_loss = sum(loss for loss in self.validation_loss) / len(self.validation_loss) if len(self.validation_loss) > 0 else float('nan')
        self.log(f"avg_validation_loss", epoch_loss, batch_size=self.config['batch_size'])

        print(f"| Epoch {self.current_epoch} | {'VALIDATION'}")

        best_loss = getattr(self, f"best_validation_loss")

        # Update best loss if current loss is better
        if epoch_loss < best_loss:
            setattr(self, f"best_validation_loss", epoch_loss)

        if(len(self.validation_retrieval_losses)!=0):
          loss_retrieval = sum(loss for loss in self.validation_retrieval_losses) / len(self.validation_retrieval_losses) if len(self.validation_retrieval_losses) > 0 else float('nan')
          self.log(f"avg_validation_retrieval_loss", loss_retrieval, batch_size=self.config['batch_size'])
          print(f"\t- Retrieval Loss => {loss_retrieval:.4f}")

        print(f"\t- Total Loss => {epoch_loss:.4f}")

        if self.enable_wandb and not self.wandb_run._is_finished:
            self.wandb_run.log({f"avg_validation_loss": epoch_loss})
            if(len(self.validation_retrieval_losses)!=0):
              self.wandb_run.log({f"avg_validation_retrieval_loss": loss_retrieval})

        # Empty the list of the losses for the next epoch
        self.validation_loss = []
        self.validation_retrieval_losses = []

    def on_test_epoch_end(self) -> None:
      epoch_accuracy = sum(acc for acc in self.test_accuracy) / len(self.test_accuracy) if len(self.test_accuracy) > 0 else float('nan')
      self.log(f"avg_test_accuracy", epoch_accuracy, batch_size=self.config['batch_size'])

      print(f"| Epoch {self.current_epoch} | {'TEST'}")

      best_accuracy = getattr(self, f"best_test_accuracy")

      # Update best loss if current loss is better
      if epoch_accuracy < best_accuracy:
          setattr(self, f"best_test_accuracy", epoch_accuracy)

      print(f"\t- Total Accuracy => {epoch_accuracy:.4f}")

      if self.enable_wandb and not self.wandb_run._is_finished:
          self.wandb_run.log({f"avg_test_accuracy": epoch_accuracy})

      # Empty the list of the losses for the next epoch
      self.test_accuracy = []

    def save_model_pytorch_api(self, name='best', epoch=''):
        torch.save(self.state_dict(), f'{name}_ep-{epoch}.pt')

    # In this case, we configure the model in 3 ways:
    #   - Adam 8-Bit
    #   - AdamW 8-Bit
    #   - Normal Adam
    #   - Normal AdamW
    # There is an option to also set the weight decay for the parameters
    def configure_optimizers(self):

        # # ---- Grouped Parameters instead of self.parameters() directly to Adam ----
        # decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
        # decay_parameters = [name for name in decay_parameters if "bias" not in name]

        # optimizer_grouped_parameters = [
        #     {
        #         "params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
        #         "weight_decay": 0.0,
        #     },
        #     {
        #         "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
        #         "weight_decay": 0.0,
        #     },
        # ]
        # # ---------------------------------------------------------------------------

        # Configs
        epsilon = 1e-8  # Default
        lr = 3e-4       # Default: 1e-3 | Can be changed also to 3e-4

        # # Adam 8-Bit
        # adam_bnb_optim = bnb.optim.Adam8bit(
        #     self.parameters(), # Can also be "optimizer_grouped_parameters"
        #     eps=epsilon,
        #     lr=lr,
        # )

        # # AdamW 8-Bit
        # adam_bnb_optim = bnb.optim.AdamW8bit(
        #     self.parameters(), # Can also be "optimizer_grouped_parameters"
        #     eps=epsilon,
        #     lr=lr,
        # )

        # # Normal Adam
        # adam_bnb_optim = bnb.optim.Adam(
        #     self.parameters(), # Can also be "optimizer_grouped_parameters"
        #     eps=epsilon,
        #     lr=lr,
        # )

        # Standard PyTorch AdamW
        adam_bnb_optim = torch.optim.AdamW(self.parameters(), eps=epsilon, lr=lr)

        # # Standard PyTorch Adam
        # adam_bnb_optim = torch.optim.Adam(self.parameters(), eps=epsilon, lr=lr)

        # TODO: Add a Linear Warmup LR (maybe https://lightning-flash.readthedocs.io/en/stable/api/generated/flash.core.optimizers.LinearWarmupCosineAnnealingLR.html)
        scheduler = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(adam_bnb_optim, warmup_epochs=2, warmup_start_lr=0.0, eta_min=0.0, max_epochs=self.trainer.max_epochs)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(adam_bnb_optim, T_max=self.trainer.max_epochs)

        return {"optimizer": adam_bnb_optim, "lr_scheduler": scheduler}

    def get_n_trainable_parameters(self):
        return self.model.get_n_trainable_parameters()

In [None]:
# Define the model
#
# 'model_name' = 'bert-base-uncased'  : The system will load the Bert-Uncased model
# 'model_name' = 'google/flan-t5-base': The system will load the Flan-T5 model
#
# All the configs will be dynamically set according to the model chosen
dsi_config['model_name'] = 'bert-base-uncased' # 'google/flan-t5-base' | 'bert-base-uncased'
dsi_config['fine_tuning'] = False

# Load the model
if(dsi_config['model_name'] != 'bert-base-uncased'):
    model = FoundationModel(model_id=dsi_config['model_name'], encoder_id=None, decoder_id=None, MAX_LENGTH=dsi_config["MAX_LENGTH"], fine_tuning=dsi_config['fine_tuning']) # Flan-T5 Model
    dsi_config["hidden_size"] = model.model.config.hidden_size
else:
    dsi_config['encoder_model'] = 'bert-base-uncased'
    dsi_config['decoder_model'] = 'bert-base-uncased'
    model = FoundationModel(model_id=None, encoder_id=dsi_config['encoder_model'], decoder_id=dsi_config['decoder_model'], MAX_LENGTH=dsi_config["MAX_LENGTH"], fine_tuning=dsi_config['fine_tuning']) # Bert-Uncased Model
    dsi_config["hidden_size"] = model.model.encoder.config.hidden_size
    dsi_config.pop("model_name")

In [None]:
# Create actual DSI Model
dsi_model = DSI_Model(dsi_config,
                        model=model,
                        query_doc_df=query_doc_df,
                        training_dataloader_with_queries_filtered=train_dataloader_filtered,)

total_params = dsi_model.get_n_trainable_parameters()
total_params = "{:,}".format(total_params).replace(",", ".")
print("--------------------------------------------------------------")
print(f"Total number of trainable parameters for DSI Model: {total_params}")

#### DSI Model with Encoder-Decoder
SwitchTransformer + FlanT5

In [None]:
'''
  -- DSI Model for the custom Encoder Decoder model--

  Parameters:
    - config: The configuration dictionary populated throughout the whole Code
    - encoder: The encoder to use (SwitchTransformer)
    - decoder: The decoder to use (Flan-T5)
    - query_doc_df: The dataframe containing the queries
        This is needed for retrieving DocIDs during the retrieval task
    - training_dataloader_with_queries_filtered: The dataloader containing the filtered documents with queries
        Due to Lightning's limitations, we need to pass the dataloader to the model in order to be able to iterate over it
'''
class DSI_EncoderDecoder(pl.LightningModule):
    def __init__(self,
                 config: dict,
                 encoder: pl.LightningModule,
                 decoder: pl.LightningModule,
                 query_doc_df: pd.DataFrame,
                 training_dataloader_with_queries_filtered: DataLoader,
                 ):
        super().__init__()

        # Variables
        self.config = config
        self.MAX_LENGTH = config["MAX_LENGTH"]

        self.token_tokenization_max_length = config['token_tokenization_max_length'] # Define the label max_length of token for the tokenization
        self.train_indexing_retrieval_ratio = config["train_indexing_retrieval_ratio"] # Docs with Query for Training
        self.train_begin_retrieval = config["train_begin_retrieval"] # At which index the retrieval task begins

        self.enable_multitask_prompting = config["enable_multitask_prompting"]
        if(self.enable_multitask_prompting):
          self.multitask_prompting_indexing = config["multitask_prompting_indexing"]
          self.multitask_prompting_retrieval = config["multitask_prompting_retrieval"]

        # Encoder and Decoder
        self.encoder = encoder.to(device)
        self.decoder = decoder.to(device)

        # Dataloader
        # We need to keep a copy of the training dataloader, due to the fact that we
        # must reset the training dataloader every time we finish an epoch
        # If we don't do this, we will have a StopIteration error
        self.copy_training_dataloader = training_dataloader_with_queries_filtered
        self.training_dataloader_with_queries_filtered = iter(training_dataloader_with_queries_filtered)

        # Log losses
        self.training_indexing_losses = []
        self.training_retrieval_losses = []
        self.validation_retrieval_losses = []
        self.test_retrieval_losses = []

        self.training_loss = []
        self.validation_loss = []
        self.test_accuracy = []

        self.best_training_loss = float('inf')
        self.best_validation_loss = float('inf')
        self.best_test_accuracy = 0.0

        # WandB
        self.enable_wandb = config["wandb_configs"]["enable"]
        if(self.enable_wandb):
          self.wandb_run = wandb.init(
              project='DeepLearning-DSI',
              group=config["wandb_configs"]["group_id"],
              config=config,)
        else:
          self.wandb_run = None

        # Dataframes
        self.query_doc_df = query_doc_df  # Query - Doc ID Dataframe
        self.top_100_df = None            # Populated only if the test_type is "top_k"

        # The current task
        self.current_task = "indexing"  # "indexing" / "retrieval"
        self.test_type = "accuracy"     # "accuracy" / "top_k"

        # Create a Dict DocID2Index and Index2DocID
        # Useful when we need to retrieve the DocID from the index and vice versa (Test phase)
        self.docID2label = {}
        self.label2DocID = {}
        filtered_docs = train_docs_with_queries_dataset.filtered_docs_with_queries
        for index, row in filtered_docs.iterrows():
          self.docID2label[row['Doc_ID']] = str(row[config['document_id_representation_strategy']])
          self.label2DocID[str(row[config['document_id_representation_strategy']])] = row['Doc_ID']

        # Generation variables
        self.num_return_sequences = 5
        self.num_beams = 5

    '''
      Generate method:
        - outputs_encoder: The outputs coming from the Encoder model
        - enc_input_ids: The input_ids coming from the Encoder Tokenizer
    '''
    def generate(self, outputs_encoder, enc_input_ids):
        # Switch to CUDA first
        enc_input_ids = enc_input_ids.to(device)
        self.decoder = self.decoder.to(device)

        # Here we might want to force the model to generate a specific token at the beginning of the sequence
        force_words = ["0"]
        force_words_ids = self.encoder.tokenizer(force_words, add_special_tokens=False).input_ids

        # In this case, we have already computed the Encoder part from SwitchTransformer (Encoder Model)
        # Hence, we just want to generate the output from the Decoder part (Flan-T5)
        #
        # If num_beams > 0, then it means that we are using the Constrainted Beam Search
        with torch.no_grad():
          output = self.decoder.model.generate(
              input_ids=None,                                               # No input_ids are needed for the Encoder part the Flan-T5 model
              decoder_input_ids=enc_input_ids,                              # Use generated SwitchTransformer input_ids for the Decoder (Flan-T5) input ids
              encoder_hidden_states=outputs_encoder.encoder_hidden_states,  # Use generated SwitchTransformer hidden states as the Encoder (Flan-T5) hidden states (skipping Encoding of Flan-T5)
              bos_token_id=self.decoder.tokenizer.pad_token_id,             # The beginning of the sequence token. Needed for generation
              max_length=len(str(self.token_tokenization_max_length)),      # Because the number of possbile doc_ids is a certain values, this parameters ensure that the generation
                                                                            # create at maximum the number of token useful to format a doc_id | Example: Max Length = 5 -> '00012'
              #force_words_ids=force_words_ids,                              # Use the force_words_ids to force the model to generate a specific token at the beginning of the sequence
              do_sample=True,
              use_cache=True,
              top_k=50,
              temperature=0.6,

          )

        return self.decoder.tokenizer.batch_decode(
          output,
          skip_special_tokens=True,
          clean_up_tokenization_spaces=True
        )

    '''
      The forward method of the DSI model that differentiate slightly the procedure to
      follow according to the task that has to be performed (indexing/retrieval)
    '''
    def forward(self, batch, training=True):
        torch.cuda.empty_cache() # Clear CUDA cache before forward pass

        if self.current_task == "indexing": # Indexing Task
          # T5 Prompt Setup as suggested in the paper
          if self.enable_multitask_prompting:
            docs_body = [f"{self.multitask_prompting_indexing} {item}" for item in batch[train_dataset.doc_rep_engine.strategy]]
          else:
            docs_body = batch[train_dataset.doc_rep_engine.strategy]

          # Labels are the Doc_ID(s) processed with the strategy chosen
          labels = [str(item) for item in batch[train_dataset.docID_rep_engine.strategy]]

          # Calling the encoder
          # Returns:
          #   - embeddings: (last_hidden_states, encoder_hidden_states)
          #   - docs_input_tokenized: The tokenized input of the documents
          #   - outputs_encoder: The outputs of the Encoder model
          embeddings, docs_input_tokenized, outputs_encoder = self.encoder(docs_body)

          # Switch to CUDA
          embeddings[0].to(device)
          input_ids = docs_input_tokenized["input_ids"].to(device)
          attention_mask = docs_input_tokenized["attention_mask"].to(device)
          outputs_encoder['encoder_hidden_states'] = [item.to(device) for item in outputs_encoder['encoder_hidden_states']]

          # For training -> Includes loss computation as well
          if(training):
            # Treat the labels as tokenizable strings
            labels_inputs = self.encoder.tokenizer(labels, add_special_tokens=False, return_tensors='pt', max_length=self.MAX_LENGTH, padding="max_length", truncation=True)
            labels_input_ids = labels_inputs["input_ids"].to(device)
            #labels_attention_mask = labels_inputs["attention_mask"]

            # Mask the labels corresponding to the padding tokens of the Encoder Tokenizer with -100
            # Such that they are not taken into account in the loss computation
            labels_input_ids[labels_input_ids == self.encoder.tokenizer.pad_token_id] = -100

            return self.decoder(
              embeddings,                                   # Embeddings coming from Encoder Model's last 4 Hidden Layers
              decoder_input_ids=input_ids,                  # Input_ids coming from Encoder Tokenizer
              decoder_attention_mask=attention_mask,        # Attention mask coming from Encoder Tokenizer
              labels=labels_input_ids,                      # Input_ids of labels tokenized with Encoder Tokenizer,
            )
          else: # For validation and test
            return self.generate(outputs_encoder, input_ids)
        else:
          # Retrieval Task. The process is the exact same as the indexing task,
          # but we need to consider the queries associated to the documents, not the documents themselves.
          queries = batch["Query"]

          queries = [str(item[0]) for item in queries]
          if self.enable_multitask_prompting:
            queries = [f"{self.multitask_prompting_retrieval} {query}" for query in queries]

          labels = [str(item) for item in batch[train_dataset.docID_rep_engine.strategy]]

          # Calling the encoder
          # embeddings = (last_hidden_states, encoder_hidden_states)
          embeddings, docs_input_tokenized, outputs_encoder = self.encoder(queries)

          # Switch to CUDA
          embeddings[0].to(device)
          input_ids = docs_input_tokenized["input_ids"].to(device)
          attention_mask = docs_input_tokenized["attention_mask"].to(device)
          outputs_encoder['encoder_hidden_states'] = [item.to(device) for item in outputs_encoder['encoder_hidden_states']]

          if(training):
            # Treat the labels as tokenizable strings
            labels_inputs = self.encoder.tokenizer(labels, add_special_tokens=False, return_tensors='pt', max_length=self.MAX_LENGTH, padding="max_length", truncation=True)
            labels_input_ids = labels_inputs["input_ids"].to(device)
            #labels_attention_mask = labels_inputs["attention_mask"]

            # Mask the labels corresponding to the padding tokens of the Encoder Tokenizer with -100
            # Such that they are not taken into account in the loss computation
            labels_input_ids[labels_input_ids == self.encoder.tokenizer.pad_token_id] = -100

            return self.decoder(
                embeddings,                             # Embeddings coming from Encoder Model's last 4 Hidden Layers
                decoder_input_ids=input_ids,            # Input_ids coming from Encoder Tokenizer
                decoder_attention_mask=attention_mask,  # Attention mask coming from Encoder Tokenizer
                labels=labels_input_ids,                # Input_ids of labels tokenized with Encoder Tokenizer
            )
          else:
            return self.generate(outputs_encoder, input_ids)

    '''
      A retrival batch is created starting from a batch considering the questions associated to doc
      in the current batch.
    '''
    def prepare_sample_for_retrieval(self, batch):
        batch['Queries_ID'] = []
        batch['Query'] = []
        doc_ids = batch['Doc_ID']

        for doc_id in doc_ids:
            query_ids = self.query_doc_df[self.query_doc_df['Doc_ID'] == doc_id]['Query_ID'].values

            if(len(query_ids) == 0):
              raise Exception(f"Document {doc_id} without queries!")
            else:
              queries = self.query_doc_df[self.query_doc_df['Doc_ID'] == doc_id]['Query'].values

              batch['Queries_ID'].append(query_ids)
              batch['Query'].append(queries)

        return batch

    def training_step(self, batch, batch_idx):
        # ========================================================================
        # Retrieval step | Every self.train_indexing_retrieval_ratio batches
        #
        # Training step => Indexing all the documents and perform retrival
        #                  (at the end of each epoch considering the given ratio)
        #                  on just the document chosen as training docs.
        # ========================================================================

        # Indexing step
        self.current_task = 'indexing'

        output = self.forward(batch, training=True)
        loss_idx = output.loss

        self.training_indexing_losses.append(loss_idx.cpu().detach().item())

        total_step_loss = loss_idx

        # Switch to retrieval task if the batch index is greater than the train_begin_retrieval value
        # This is done to alternate between indexing and retrieval tasks as suggested in the paper
        if batch_idx >= self.train_begin_retrieval:
          self.current_task = 'retrieval'

          # If we are in the retrieval task, we need to prepare the sample for retrieval
          # Hence, we need to consider the queries associated to the documents in the current batch
          batch = next(self.training_dataloader_with_queries_filtered)
          batch = self.prepare_sample_for_retrieval(batch)

          output = self.forward(batch, training=True)
          loss_retrieval = output.loss

          self.training_retrieval_losses.append(loss_retrieval.cpu().detach().item())

          total_step_loss = (total_step_loss + loss_retrieval) / 2

        self.log("training_loss", total_step_loss.cpu().detach().item())
        self.training_loss.append(total_step_loss.cpu().detach().item())

        return total_step_loss

    # Validation => Retrieval step for all the documents chosen to be in validation split
    def validation_step(self, batch, batch_idx):
      self.current_task = 'retrieval'

      # Retrieval step (same as retrieval step in training phase)
      batch = self.prepare_sample_for_retrieval(batch)
      bs = len(batch['Doc_ID'])

      output = self.forward(batch, training=True)
      loss = output.loss

      self.log("validation_loss", loss.cpu().detach().item(), batch_size=bs)
      self.validation_loss.append(loss.cpu().detach().item())
      self.validation_retrieval_losses.append(loss.cpu().detach().item())

      return loss

    # Test => Retrieval step for all the documents chosen to be in test split
    def test_step(self, batch, batch_idx):
      self.current_task = 'retrieval'

      # Retrieval step (same as retrieval step in training/validation phase)
      batch = self.prepare_sample_for_retrieval(batch)
      bs = len(batch['Doc_ID'])

      output = self.forward(batch, training=False)

      if(self.test_type == "accuracy"):
        accuracy = self.run_accuracy(output, batch)
      else:
        accuracy = self.run_top_k(output, batch)

      self.log("accuracy", accuracy, batch_size=bs)
      # self.log("test_loss", loss.cpu().detach().item(), batch_size=bs)
      # self.test_loss.append(loss.cpu().detach().item())
      # self.test_retrieval_losses.append(loss.cpu().detach().item())

      return accuracy

    '''
      Checks if the output from the model are in the top-k relevant document of the queries from the batch
    '''
    def run_top_k(self, output, batch):
      # Before trying to find if the outputs are relevant w.r.t query, we first check if the query is present in the Top100 DF.
      list_df = []
      batch_queries = batch['Queries_ID']
      for query in batch_queries:
          q = self.top_100_df[self.top_100_df['Query_ID'] == query[0]]
          if len(q) > 0:
              list_df.append(q)
          else:
              pass
              #print(f"There are no queries associated with batch query {query[0]}!")
      batch_top_100 = pd.concat(list_df)

      # Check if num_return_sequences > 1, if so, we need to iterate over the output differently
      if(self.num_return_sequences > 1):
          batch_predictions = []
          final_returns = []

          # Iterate over each prediction sequence in the output
          for i in range(0, len(output), self.num_return_sequences):
              # Get the predictions for the current batch
              batch_predictions.append(output[i:i+self.num_return_sequences])

          # Extract DocIDs from the output
          for idx, b_pred in enumerate(batch_predictions):
              model_returns = []
              for p in b_pred:
                  try:
                      model_returns.append(self.label2DocID[p])
                      self.model_found_label += 1 # Increment the number of labels found | Instantiated before starting the test phase
                  except:
                      model_returns.append('')
                      #print(f"There is no DocID for model prediction: {p}!")
              final_returns.append(model_returns)

          # Now we need to:
          # - Check if the batch predictions are in the top 100 for their respective queries
          correct = 0
          for idx, predictions in enumerate(final_returns):
              for prediction in predictions:
                  if prediction in batch_top_100[batch_top_100['Query_ID'] == batch_queries[idx][0]]['Doc_ID'].values:
                      correct += 1
          return float(correct / len(batch_predictions))
      else: # num_return_sequences == 1 / 0
          # Extract DocIDs from the output
          model_returns = []
          for idx, b_pred in enumerate(output):
              try:
                  model_returns.append(self.label2DocID[b_pred])
                  self.model_found_label += 1 # Increment the number of labels found | Instantiated before starting the test phase
              except:
                  model_returns.append('')
                  #print(f"There is no DocID for model prediction: {b_pred}!")

          # Now we need to:
          # - Take the first query from the batch
          # - Check if the first model_returns is in the top 100 for the first query
          # - Loop this process for all queries/outputs
          correct = 0
          for idx, query in enumerate(batch_queries):
              if(model_returns[idx] in batch_top_100[batch_top_100['Query_ID'] == query[0]]['Doc_ID'].values):
                  correct += 1

          return float(correct / len(output))

    '''
      Checks if the output from the model is equal to the label present in the batch.
    '''
    def run_accuracy(self, output, batch):
      batch_predictions = []

      if(len(output) > self.config['batch_size']): # It means that we are in the generation phase with num_return_sequences > 1
        # Iterate over each prediction sequence in the output
        for i in range(0, len(output), self.num_return_sequences):
            # Get the predictions for the current batch
            batch_predictions.append(output[i:i+self.num_return_sequences])

        correct = 0
        for idx, pred in enumerate(batch_predictions): # Enters a single batch of predictions.
          # Compute the accuracy
          for b_pred in pred: # Enters a single prediction
            try:
              model_return = self.label2DocID[b_pred]
              self.model_found_label += 1 # Increment the number of labels found | Instantiated before starting the test phase
            except:
              model_return = ''
              #print(f"The model did not output a correct DocID for prediction {b_pred}!")

            if model_return == str(batch['Doc_ID'][idx]):
              correct += 1
        return correct / len(batch_predictions)

      else:
        correct = 0
        for idx, b_pred in enumerate(output):
          try:
            model_return = self.label2DocID[b_pred]
            self.model_found_label += 1 # Increment the number of labels found | Instantiated before starting the test phase
          except:
            model_return = ''
            #print(f"The model did not output a correct DocID for prediction {b_pred}!")

          if model_return == str(batch['Doc_ID'][idx]):
            correct += 1
        return correct / len(batch)

    def on_train_epoch_end(self) -> None:
        epoch_loss = sum(loss for loss in self.training_loss) / len(self.training_loss) if len(self.training_loss) > 0 else float('nan')
        self.log(f"avg_training_loss", epoch_loss, batch_size=self.config['batch_size'])

        print(f"| Epoch {self.current_epoch} | {'TRAINING'}")

        best_loss = getattr(self, f"best_training_loss")

        # Update best loss if current loss is better
        if epoch_loss < best_loss:
            setattr(self, f"best_training_loss", epoch_loss)

        if(len(self.training_indexing_losses)!=0):
          loss_indexing = sum(loss for loss in self.training_indexing_losses) / len(self.training_indexing_losses) if len(self.training_indexing_losses) > 0 else float('nan')
          self.log(f"avg_training_indexing_loss", loss_indexing, batch_size=self.config['batch_size'])
          print(f"\t- Indexing Loss => {loss_indexing:.4f}")

        if(len(self.training_retrieval_losses)!=0):
          loss_retrieval = sum(loss for loss in self.training_retrieval_losses) / len(self.training_retrieval_losses) if len(self.training_retrieval_losses) > 0 else float('nan')
          self.log(f"avg_training_retrieval_loss", loss_retrieval, batch_size=self.config['batch_size'])
          print(f"\t- Retrieval Loss => {loss_retrieval:.4f}")

        print(f"\t- Total Loss => {epoch_loss:.4f}")

        if self.enable_wandb and not self.wandb_run._is_finished:
            self.wandb_run.log({f"avg_training_loss": epoch_loss})
            if(len(self.training_indexing_losses)!=0):
              self.wandb_run.log({f"avg_training_indexing_loss": loss_indexing})
            if(len(self.training_retrieval_losses)!=0):
              self.wandb_run.log({f"avg_training_retrieval_loss": loss_retrieval})
            self.wandb_run.log({f"epoch": self.current_epoch})

        # At the end of each training epoch we select a differet subset of training document
        # on which perform the training retrieval. This ensure to have stocasticity selection
        # while performing the retrieval on training docs
        self.training_dataloader_with_queries_filtered = iter(self.copy_training_dataloader)

        # Empty the list of the losses for the next epoch
        self.training_loss = []
        self.training_indexing_losses = []
        self.training_retrieval_losses = []

    def on_validation_epoch_end(self) -> None:
        epoch_loss = sum(loss for loss in self.validation_loss) / len(self.validation_loss) if len(self.validation_loss) > 0 else float('nan')
        self.log(f"avg_validation_loss", epoch_loss, batch_size=self.config['batch_size'])

        print(f"| Epoch {self.current_epoch} | {'VALIDATION'}")

        best_loss = getattr(self, f"best_validation_loss")

        # Update best loss if current loss is better
        if epoch_loss < best_loss:
            setattr(self, f"best_validation_loss", epoch_loss)

        if(len(self.validation_retrieval_losses)!=0):
          loss_retrieval = sum(loss for loss in self.validation_retrieval_losses) / len(self.validation_retrieval_losses) if len(self.validation_retrieval_losses) > 0 else float('nan')
          self.log(f"avg_validation_retrieval_loss", loss_retrieval, batch_size=self.config['batch_size'])
          print(f"\t- Retrieval Loss => {loss_retrieval:.4f}")

        print(f"\t- Total Loss => {epoch_loss:.4f}")

        if self.enable_wandb and not self.wandb_run._is_finished:
            self.wandb_run.log({f"avg_validation_loss": epoch_loss})
            if(len(self.validation_retrieval_losses)!=0):
              self.wandb_run.log({f"avg_validation_retrieval_loss": loss_retrieval})

        # Empty the list of the losses for the next epoch
        self.validation_loss = []
        self.validation_retrieval_losses = []

    def on_test_epoch_end(self) -> None:
      epoch_accuracy = sum(acc for acc in self.test_accuracy) / len(self.test_accuracy) if len(self.test_accuracy) > 0 else float('nan')
      self.log(f"avg_test_accuracy", epoch_accuracy, batch_size=self.config['batch_size'])

      print(f"| Epoch {self.current_epoch} | {'TEST'}")

      best_accuracy = getattr(self, f"best_test_accuracy")

      # Update best loss if current loss is better
      if epoch_accuracy < best_accuracy:
          setattr(self, f"best_test_accuracy", epoch_accuracy)

      print(f"\t- Total Accuracy => {epoch_accuracy:.4f}")

      if self.enable_wandb and not self.wandb_run._is_finished:
          self.wandb_run.log({f"avg_test_accuracy": epoch_accuracy})

      # Empty the list of the losses for the next epoch
      self.test_accuracy = []

    def save_model_pytorch_api(self, name='best', epoch=''):
        torch.save(self.state_dict(), f'{name}_ep-{epoch}.pt')

    # In this case, we configure the model in 3 ways:
    #   - Adam 8-Bit
    #   - AdamW 8-Bit
    #   - Normal Adam
    #   - Normal AdamW
    # There is an option to also set the weight decay for the parameters
    def configure_optimizers(self):

        # # ---- Grouped Parameters instead of self.parameters() directly to Adam ----
        # decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
        # decay_parameters = [name for name in decay_parameters if "bias" not in name]

        # optimizer_grouped_parameters = [
        #     {
        #         "params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
        #         "weight_decay": 0.0,
        #     },
        #     {
        #         "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
        #         "weight_decay": 0.0,
        #     },
        # ]
        # # ---------------------------------------------------------------------------

        # Configs
        epsilon = 1e-8  # Default
        lr = 3e-4       # Default: 1e-3 | Can be changed also to 3e-4

        # # Adam 8-Bit
        # adam_bnb_optim = bnb.optim.Adam8bit(
        #     self.parameters(), # Can also be "optimizer_grouped_parameters"
        #     eps=epsilon,
        #     lr=lr,
        # )

        # # AdamW 8-Bit
        # adam_bnb_optim = bnb.optim.AdamW8bit(
        #     self.parameters(), # Can also be "optimizer_grouped_parameters"
        #     eps=epsilon,
        #     lr=lr,
        # )

        # # Normal Adam
        # adam_bnb_optim = bnb.optim.Adam(
        #     self.parameters(), # Can also be "optimizer_grouped_parameters"
        #     eps=epsilon,
        #     lr=lr,
        # )

        # Standard PyTorch AdamW
        adam_bnb_optim = torch.optim.AdamW(self.parameters(), eps=epsilon, lr=lr)

        # # Standard PyTorch Adam
        # adam_bnb_optim = torch.optim.Adam(self.parameters(), eps=epsilon, lr=lr)

        # TODO: Add a Linear Warmup LR (maybe https://lightning-flash.readthedocs.io/en/stable/api/generated/flash.core.optimizers.LinearWarmupCosineAnnealingLR.html)
        scheduler = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(adam_bnb_optim, warmup_epochs=2, warmup_start_lr=0.0, eta_min=0.0, max_epochs=self.trainer.max_epochs)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(adam_bnb_optim, T_max=self.trainer.max_epochs)

        return {"optimizer": adam_bnb_optim, "lr_scheduler": scheduler}

    # Print number of trainable parameters
    def get_n_trainable_parameters(self):
        return self.encoder.get_n_trainable_parameters() + self.decoder.get_n_trainable_parameters()

In [None]:
# Our Model
dsi_config['encoder_model'] = "google/switch-base-8"
dsi_config['decoder_model'] = "google/flan-t5-base"

encoder = SwitchTransformer(model_id=dsi_config['encoder_model'], MAX_LENGTH=dsi_config["MAX_LENGTH"])
decoder = FlanT5(model_id=dsi_config['decoder_model'], MAX_LENGTH=dsi_config["MAX_LENGTH"])
dsi_config["hidden_size"] = decoder.model.config.hidden_size

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Encoder info:
trainable params: 1,327,104 || all params: 620,666,112 || trainable%: 0.21381931030898624
Decoder info:
trainable params: 1,327,104 || all params: 248,904,960 || trainable%: 0.5331770005708203


In [None]:
# Create actual DSI Model
dsi_model = DSI_EncoderDecoder(dsi_config,
                        encoder=encoder,
                        decoder=decoder,
                        query_doc_df=query_doc_df,
                        training_dataloader_with_queries_filtered=train_dataloader_filtered,)

total_params = dsi_model.get_n_trainable_parameters()
total_params = "{:,}".format(total_params).replace(",", ".")
print("--------------------------------------------------------------")
print(f"Total number of trainable parameters for DSI Model: {total_params}")

### Discriminative Approach
The Discriminative approach aims at adding a Linear Projection of the Decoder embeddings (through a Linear Layer) in the Doc_ID space.<br>
By doing so, we simply use a SoftMax function to get the likelihood of each Doc_ID, given the query

In [None]:
'''
  -- DSI Model for Discriminative Training --

  Parameters:
    - config: The configuration dictionary populated throughout the whole Code
    - model: The model to use (populated when instantiating the Foundation Model)
    - query_doc_df: The dataframe containing the queries
        This is needed for retrieving DocIDs during the retrieval task
    - training_dataloader_with_queries_filtered: The dataloader containing the filtered documents with queries
        Due to Lightning's limitations, we need to pass the dataloader to the model in order to be able to iterate over it
'''
class DSI_Discriminative(pl.LightningModule):
    def __init__(self,
                 config: dict,
                 model: pl.LightningModule = None,
                 query_doc_df: pd.DataFrame = None,
                 training_dataloader_with_queries_filtered: DataLoader = None,
                 ):
        super().__init__()

        # Variables
        self.config = config
        self.MAX_LENGTH = config["MAX_LENGTH"]

        self.token_tokenization_max_length = config['token_tokenization_max_length'] # Define the label max_length of token for the tokenization
        self.train_indexing_retrieval_ratio = config["train_indexing_retrieval_ratio"] # Docs with Query for Training
        self.train_begin_retrieval = config["train_begin_retrieval"] # At which index the retrieval task begins

        self.enable_multitask_prompting = config["enable_multitask_prompting"]
        if(self.enable_multitask_prompting):
          self.multitask_prompting_indexing = config["multitask_prompting_indexing"]
          self.multitask_prompting_retrieval = config["multitask_prompting_retrieval"]

        # Model
        self.model = model.to(device)

        # Generative or Discriminative Training
        self.cross_entropy_loss = nn.CrossEntropyLoss()           # Cross Entropy Loss for Discriminative Training
        self.batch_norm = nn.BatchNorm1d(config["hidden_size"])   # Batch Normalization for Discriminative Training
        self.relu = nn.ReLU()                                     # ReLU for Discriminative Training
        self.linear_layer = nn.Linear(config["hidden_size"], config['labels_count'], device=device) # Linear Layer for Discriminative Training
        self.softmax = nn.Softmax(dim=1) # Softmax for Discriminative Training
        self.dropout = nn.Dropout(p=0.2) # Dropout for Discriminative Training

        # Dataloader
        # We need to keep a copy of the training dataloader, due to the fact that we
        # must reset the training dataloader every time we finish an epoch
        # If we don't do this, we will have a StopIteration error
        self.copy_training_dataloader = training_dataloader_with_queries_filtered
        self.training_dataloader_with_queries_filtered = iter(training_dataloader_with_queries_filtered)

        # Log losses
        self.training_indexing_losses = []
        self.training_retrieval_losses = []
        self.validation_retrieval_losses = []
        self.test_retrieval_losses = []

        self.training_loss = []
        self.validation_loss = []
        self.test_accuracy = []

        self.best_training_loss = float('inf')
        self.best_validation_loss = float('inf')
        self.best_test_accuracy = 0.0

        # WandB
        self.enable_wandb = config["wandb_configs"]["enable"]
        if(self.enable_wandb):
          self.wandb_run = wandb.init(
              project='DeepLearning-DSI',
              group=config["wandb_configs"]["group_id"],
              config=config,)
        else:
          self.wandb_run = None

        # Dataframes
        self.query_doc_df = query_doc_df  # Query - Doc ID Dataframe
        self.top_100_df = None            # Populated only if the test_type is "top_k"

        # The current task
        self.current_task = "indexing"  # "indexing" / "retrieval"
        self.test_type = "accuracy"     # "accuracy" / "top_k"

        # Create a Dict DocID2Index and Index2DocID
        # Useful when we need to retrieve the DocID from the index and vice versa (Test phase)
        self.docID2label = {}
        self.label2DocID = {}
        filtered_docs = train_docs_with_queries_dataset.filtered_docs_with_queries
        for index, row in filtered_docs.iterrows():
          self.docID2label[row['Doc_ID']] = str(row[config['document_id_representation_strategy']])
          self.label2DocID[str(row[config['document_id_representation_strategy']])] = row['Doc_ID']

        # Generation variables
        self.num_return_sequences = 5
        self.num_beams = 5

    '''
      The forward method of the DSI model that differentiate slightly the procedure to
      follow according to the task that has to be performed (indexing/retrieval)
    '''
    def forward(self, batch, training=True):

        torch.cuda.empty_cache() # Clear CUDA cache before forward pass

        if self.current_task == "indexing": # Indexing Task
          # T5 Prompting Setup as suggested in the paper
          if self.enable_multitask_prompting:
            docs_body = [f"{self.multitask_prompting_indexing} {item}" for item in batch[train_dataset.doc_rep_engine.strategy]]
          else:
            docs_body = batch[train_dataset.doc_rep_engine.strategy]

          # Labels are the Doc_ID(s) processed with the strategy chosen
          # labels = [str(item) for item in batch[train_dataset.docID_rep_engine.strategy]]

          docs_input = self.model.tokenizer(docs_body, add_special_tokens=False, return_tensors='pt', max_length=self.MAX_LENGTH, padding="max_length", truncation=True)
          # labels_inputs = self.model.tokenizer(labels, add_special_tokens=False, return_tensors='pt', max_length=len(str(self.token_tokenization_max_length)), padding="max_length", truncation=True)

          # labels_input_ids = labels_inputs["input_ids"].to(device)
          # labels_att_mask = labels_inputs['attention_mask'].to(device)
          input_ids = docs_input['input_ids'].to(device)
          attention_mask = docs_input['attention_mask'].to(device)

          # labels_input_ids[labels_input_ids == self.model.tokenizer.pad_token_id] = -100

          # In this case, we calculate the output of the model, regardless of the training phase
          # Then, we pass the output to the discriminative_step method to compute the loss
          output = self.model.model(
              input_ids,                     # Input_ids coming from Encoder Tokenizer
              attention_mask=attention_mask, # Attention mask coming from Encoder Tokenizer
              decoder_input_ids=torch.zeros_like(input_ids, device=device),
              decoder_attention_mask=None,
              labels=None,                   # Labels are not since we do not want to compute the loss
              output_hidden_states=True
          )
          # If training, run a discriminative step to compute the loss
          if(training):
              loss, _ = self.discriminative_step(batch['Doc_ID'], output)
              return loss
          else: # If not training, run a discriminative step to compute the prediction
              _, prediction = self.discriminative_step(batch['Doc_ID'], output)
              return prediction
        else:
          # Retrival Task. The process is the exact same as the indexing task,
          # but we need to consider the queries associated to the documents, not the documents themselves.
          queries = batch["Query"]

          queries = [str(item[0]) for item in queries]
          if self.enable_multitask_prompting:
            queries = [f"{self.multitask_prompting_retrieval} {query}" for query in queries]

          # labels = [str(item) for item in batch[train_dataset.docID_rep_engine.strategy]]

          queries_input = self.model.tokenizer(queries, add_special_tokens=False, return_tensors='pt', max_length=self.MAX_LENGTH, padding="max_length", truncation=True)
          # labels_inputs = self.model.tokenizer(labels, add_special_tokens=False, return_tensors='pt', max_length=len(str(self.token_tokenization_max_length)), padding="max_length", truncation=True)

          # labels_input_ids = labels_inputs["input_ids"].to(device)
          # labels_att_mask = labels_inputs['attention_mask'].to(device)
          # labels_input_ids[labels_input_ids == self.model.tokenizer.pad_token_id] = -100
          input_ids = queries_input['input_ids'].to(device)
          attention_mask = queries_input['attention_mask'].to(device)

          # For training. Also includes loss computation
          output = self.model.model(
              input_ids,
              attention_mask=attention_mask,
              decoder_input_ids=torch.zeros_like(input_ids, device=device),
              decoder_attention_mask=None,
              labels=None,
              output_hidden_states=True
          )
          if(training):
              loss, _ = self.discriminative_step(batch['Doc_ID'], output)
              return loss
          else:
              _, prediction = self.discriminative_step(batch['Doc_ID'], output)
              return prediction

    '''
      A retrieval batch is created starting from a batch considering the questions associated to doc
      in the current batch.
    '''
    def prepare_sample_for_retrieval(self, batch):
        batch['Queries_ID'] = []
        batch['Query'] = []
        doc_ids = batch['Doc_ID']

        for doc_id in doc_ids:
            query_ids = self.query_doc_df[self.query_doc_df['Doc_ID'] == doc_id]['Query_ID'].values

            if(len(query_ids) == 0):
              raise Exception(f"Document {doc_id} without queries!")
            else:
              queries = self.query_doc_df[self.query_doc_df['Doc_ID'] == doc_id]['Query'].values

              batch['Queries_ID'].append(query_ids)
              batch['Query'].append(queries)

        return batch

    # Generate the target tensor for the discriminative step
    # This is needed such that we can compute the loss
    # This will be a one hot encoded tensor with 1s at the index of the Doc_ID
    def generate_target_tensor(self, document_ids: list):
        x = torch.zeros((len(document_ids), self.config['labels_count']), device=device)

        for index, id in enumerate(document_ids):
            x[index, list(self.docID2label.keys()).index(id)] = 1

        return x

    def discriminative_step(self, list_doc_ids, output):
        # Stack the last 4 Decoder Hidden Layers and take the max value
        decoder_hidden_layer = torch.mean(torch.stack(output.decoder_hidden_states[-4:]), dim=0)
        decoder_hidden_layer = torch.max(decoder_hidden_layer, dim=1)[0] # Dim = [batch_size, hidden_size]
        decoder_hidden_layer = decoder_hidden_layer.to(device)

        # Pass the max_decoder to the Linear Layer + Softmax + BatchNorm
        x = self.batch_norm(decoder_hidden_layer)   # [batch_size, hidden_size]
        x = self.relu(x)                            # [batch_size, hidden_size]
        x = self.linear_layer(x)                    # [batch_size, labels_count]
        x = self.dropout(x)                         # [batch_size, labels_count]
        prediction = self.softmax(x)                # Apply the SoftMax on top of the Linear Layer

        # Generate a target tensor starting from the DocIDs in the batch
        # It will be a one hot encoded tensor with 1s at the index of the Doc_ID
        target_tensor = self.generate_target_tensor(list_doc_ids)

        # Get the indices of the max value in the target tensor and the prediction tensor
        target_indices = torch.argmax(target_tensor, dim=1)   # Indices in target tensor
        # prediction_indices = torch.argmax(prediction, dim=1)  # Indices in prediction tensor

        # CrossEntropyLoss between the indices and the target tensor
        loss = self.cross_entropy_loss(x, target_indices)

        return loss, prediction  # [batch_size, labels_count]

    def training_step(self, batch, batch_idx):
        # ========================================================================
        # Retrieval step | Every self.train_indexing_retrieval_ratio batches
        #
        # Training step => Indexing all the documents and perform retrieval
        #                  (at the end of each epoch considering the given ratio)
        #                  on just the document chosen as training docs.
        # ========================================================================

        # Indexing step
        self.current_task = 'indexing'

        loss_idx = self.forward(batch, training=True)

        self.training_indexing_losses.append(loss_idx.cpu().detach().item())
        total_step_loss = loss_idx

        # Switch to retrieval task if the batch index is greater than the train_begin_retrieval value
        # This is done to alternate between indexing and retrieval tasks as suggested in the paper
        if batch_idx >= self.train_begin_retrieval:
          self.current_task = 'retrieval'

          # If we are in the retrieval task, we need to prepare the sample for retrieval
          # Hence, we need to consider the queries associated to the documents in the current batch
          batch = next(self.training_dataloader_with_queries_filtered)
          batch = self.prepare_sample_for_retrieval(batch)

          loss_retrieval = self.forward(batch, training=True)

          self.training_retrieval_losses.append(loss_retrieval.cpu().detach().item())
          total_step_loss = (total_step_loss + loss_retrieval) / 2

        self.log("training_loss", total_step_loss.cpu().detach().item())
        self.training_loss.append(total_step_loss.cpu().detach().item())

        return total_step_loss

    # Validation => Retrieval step for all the documents chosen to be in validation split
    def validation_step(self, batch, batch_idx):
      self.current_task = 'retrieval'

      # Retrieval step (same as retrieval step in training phase)
      batch = self.prepare_sample_for_retrieval(batch)
      bs = len(batch['Doc_ID'])

      loss = self.forward(batch, training=True)

      self.log("validation_loss", loss.cpu().detach().item(), batch_size=bs)
      self.validation_loss.append(loss.cpu().detach().item())
      self.validation_retrieval_losses.append(loss.cpu().detach().item())

      return loss

    # Test => Retrieval step for all the documents chosen to be in test split
    def test_step(self, batch, batch_idx):
      self.current_task = 'retrieval'

      # Retrieval step (same as retrieval step in training/validation phase)
      batch = self.prepare_sample_for_retrieval(batch)
      bs = len(batch['Doc_ID'])

      predictions = self.forward(batch, training=False)

      # Get the indices of the max value in the prediction tensor
      prediction_indices = torch.argmax(predictions, dim=1)  # Indices in prediction tensor

      if(self.test_type == "accuracy"):
        accuracy = self.run_accuracy(prediction_indices, batch)
      else:
        accuracy = self.run_top_k(prediction_indices, batch)

      # # Retrieve index from index2doc dictionary
      # docs = [self.label2DocID[index.item()] for index in predictions]
      # print(docs)

      # self.log("test_loss", loss.cpu().detach().item(), batch_size=bs)
      # self.test_loss.append(loss.cpu().detach().item())
      # self.test_retrieval_losses.append(loss.cpu().detach().item())
      self.log("accuracy", accuracy, batch_size=bs)
      self.test_accuracy.append(accuracy)

      return accuracy

    '''
      Checks if the output from the model are in the top-k relevant document of the queries from the batch
    '''
    def run_top_k(self, prediction_indices, batch):
      # Before trying to find if the outputs are relevant w.r.t query, we first check if the query is present in the Top100 DF.
      list_df = []
      batch_queries = batch['Queries_ID']
      for query in batch_queries:
          q = self.top_100_df[self.top_100_df['Query_ID'] == query[0]]
          if len(q) > 0:
              list_df.append(q)
          else:
              print(f"There are no queries associated with batch query {query[0]}!")
      batch_top_100 = pd.concat(list_df)

      # Extract DocIDs from the output
      docs = [self.label2DocID[list(self.label2DocID.keys())[index.item()]] for index in prediction_indices]

      # Now we need to:
      # - Take the first query from the batch
      # - Check if the first model_returns is in the top 100 for the first query
      # - Loop this process for all queries/outputs
      correct = 0
      for idx, query in enumerate(batch_queries):
          if(docs[idx] in batch_top_100[batch_top_100['Query_ID'] == query[0]]['Doc_ID'].values):
              correct += 1

      final = float(correct / len(prediction_indices))
      return final

    '''
      Checks if the output from the model is equal to the label present in the batch.
    '''
    def run_accuracy(self, prediction_indices, batch):
      correct = 0

      # Extract DocIDs from the output
      docs = [self.label2DocID[list(self.label2DocID.keys())[index.item()]] for index in prediction_indices]

      for idx, doc_id_pred in enumerate(docs):
        if doc_id_pred == str(batch['Doc_ID'][idx]):
          correct += 1

      return correct / len(batch['Doc_ID'])

    def on_train_epoch_end(self) -> None:
        epoch_loss = sum(loss for loss in self.training_loss) / len(self.training_loss) if len(self.training_loss) > 0 else float('nan')
        self.log(f"avg_training_loss", epoch_loss, batch_size=self.config['batch_size'])

        print(f"| Epoch {self.current_epoch} | {'TRAINING'}")
        best_loss = getattr(self, f"best_training_loss")

        # Update best loss if current loss is better
        if epoch_loss < best_loss:
            setattr(self, f"best_training_loss", epoch_loss)

        if(len(self.training_indexing_losses)!=0):
          loss_indexing = sum(loss for loss in self.training_indexing_losses) / len(self.training_indexing_losses) if len(self.training_indexing_losses) > 0 else float('nan')
          self.log(f"avg_training_indexing_loss", loss_indexing, batch_size=self.config['batch_size'])
          print(f"\t- Indexing Loss => {loss_indexing:.4f}")

        if(len(self.training_retrieval_losses)!=0):
          loss_retrieval = sum(loss for loss in self.training_retrieval_losses) / len(self.training_retrieval_losses) if len(self.training_retrieval_losses) > 0 else float('nan')
          self.log(f"avg_training_retrieval_loss", loss_retrieval, batch_size=self.config['batch_size'])
          print(f"\t- Retrieval Loss => {loss_retrieval:.4f}")

        print(f"\t- Total Loss => {epoch_loss:.4f}")

        if self.enable_wandb and self.wandb_run is not None:
            self.wandb_run.log({f"avg_training_loss": epoch_loss})
            if(len(self.training_indexing_losses)!=0):
              self.wandb_run.log({f"avg_training_indexing_loss": loss_indexing})
            if(len(self.training_retrieval_losses)!=0):
              self.wandb_run.log({f"avg_training_retrieval_loss": loss_retrieval})
            self.wandb_run.log({f"epoch": self.current_epoch})

        # At the end of each training epoch we select a differet subset of training document
        # on which perform the training retrieval. This ensure to have stocasticity selection
        # while performing the retrieval on training docs
        self.training_dataloader_with_queries_filtered = iter(self.copy_training_dataloader)

        # Empty the list of the losses for the next epoch
        self.training_loss = []
        self.training_indexing_losses = []
        self.training_retrieval_losses = []

    def on_validation_epoch_end(self) -> None:
        epoch_loss = sum(loss for loss in self.validation_loss) / len(self.validation_loss) if len(self.validation_loss) > 0 else float('nan')
        self.log(f"avg_validation_loss", epoch_loss, batch_size=self.config['batch_size'])

        print(f"| Epoch {self.current_epoch} | {'VALIDATION'}")

        best_loss = getattr(self, f"best_validation_loss")

        # Update best loss if current loss is better
        if epoch_loss < best_loss:
            setattr(self, f"best_validation_loss", epoch_loss)

        if(len(self.validation_retrieval_losses)!=0):
          loss_retrieval = sum(loss for loss in self.validation_retrieval_losses) / len(self.validation_retrieval_losses) if len(self.validation_retrieval_losses) > 0 else float('nan')
          self.log(f"avg_validation_retrieval_loss", loss_retrieval, batch_size=self.config['batch_size'])
          print(f"\t- Retrieval Loss => {loss_retrieval:.4f}")

        print(f"\t- Total Loss => {epoch_loss:.4f}")

        if self.enable_wandb and self.wandb_run is not None:
            self.wandb_run.log({f"avg_validation_loss": epoch_loss})
            if(len(self.validation_retrieval_losses)!=0):
              self.wandb_run.log({f"avg_validation_retrieval_loss": loss_retrieval})

        # Empty the list of the losses for the next epoch
        self.validation_loss = []
        self.validation_retrieval_losses = []

    def on_test_epoch_end(self) -> None:
      epoch_accuracy = sum(acc for acc in self.test_accuracy) / len(self.test_accuracy) if len(self.test_accuracy) > 0 else float('nan')
      self.log(f"avg_test_accuracy", epoch_accuracy, batch_size=self.config['batch_size'])

      print(f"| Epoch {self.current_epoch} | {'TEST'}")

      best_accuracy = getattr(self, f"best_test_accuracy")

      # Update best loss if current loss is better
      if epoch_accuracy < best_accuracy:
          setattr(self, f"best_test_accuracy", epoch_accuracy)

      print(f"\t- Total Accuracy => {epoch_accuracy:.4f}")

      if self.enable_wandb and not self.wandb_run._is_finished:
          self.wandb_run.log({f"avg_test_accuracy": epoch_accuracy})

      # Empty the list of the losses for the next epoch
      self.test_accuracy = []

    def save_model_pytorch_api(self, name='best', epoch=''):
        torch.save(self.state_dict(), f'{name}_ep-{epoch}.pt')

    # In this case, we configure the model in 3 ways:
    #   - Adam 8-Bit
    #   - AdamW 8-Bit
    #   - Normal Adam
    # There is an option to also set the weight decay for the parameters
    def configure_optimizers(self):

        # # ---- Grouped Parameters instead of self.parameters() directly to Adam ----
        # decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
        # decay_parameters = [name for name in decay_parameters if "bias" not in name]

        # optimizer_grouped_parameters = [
        #     {
        #         "params": [p for n, p in self.model.named_parameters() if n in decay_parameters],
        #         "weight_decay": 0.0,
        #     },
        #     {
        #         "params": [p for n, p in self.model.named_parameters() if n not in decay_parameters],
        #         "weight_decay": 0.0,
        #     },
        # ]
        # # ---------------------------------------------------------------------------

        # Configs
        epsilon = 1e-8  # Default
        lr = 3e-4       # Default: 1e-3 | Can be changed also to 3e-4

        # # Adam 8-Bit
        # adam_bnb_optim = bnb.optim.Adam8bit(
        #     self.parameters(), # Can also be "optimizer_grouped_parameters"
        #     eps=epsilon,
        #     lr=lr,
        # )

        # # AdamW 8-Bit
        # adam_bnb_optim = bnb.optim.AdamW8bit(
        #     self.parameters(), # Can also be "optimizer_grouped_parameters"
        #     eps=epsilon,
        #     lr=lr,
        # )

        # # Normal Adam
        # adam_bnb_optim = bnb.optim.Adam(
        #     self.parameters(), # Can also be "optimizer_grouped_parameters"
        #     eps=epsilon,
        #     lr=lr,
        # )

        # Standard PyTorch AdamW
        adam_bnb_optim = torch.optim.AdamW(self.parameters(), eps=epsilon, lr=lr)

        # # Standard PyTorch Adam
        # adam_bnb_optim = torch.optim.Adam(self.parameters(), eps=epsilon, lr=lr)

        # TODO: Add a Linear Warmup LR (maybe https://lightning-flash.readthedocs.io/en/stable/api/generated/flash.core.optimizers.LinearWarmupCosineAnnealingLR.html)
        scheduler = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(adam_bnb_optim, warmup_epochs=2, warmup_start_lr=0.0, eta_min=0.0, max_epochs=self.trainer.max_epochs)
        # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(adam_bnb_optim, T_max=self.trainer.max_epochs)

        return {"optimizer": adam_bnb_optim, "lr_scheduler": scheduler}

    # Print number of trainable parameters
    def get_n_trainable_parameters(self):
        return self.model.get_n_trainable_parameters()

In [None]:
# ---- DISCRIMINATIVE TRAINING ---- #
dsi_config['model_name'] = 'google/flan-t5-base' # Only Flan-T5 Model if we want to perform discriminative training

model = FoundationModel(model_id=dsi_config['model_name'], encoder_id=None, decoder_id=None, MAX_LENGTH=dsi_config["MAX_LENGTH"])
dsi_config["hidden_size"] = model.model.config.hidden_size

Using AutoModelForSeq2SeqLM
Encoder info:
trainable params: 1,327,104 || all params: 248,904,960 || trainable%: 0.5331770005708203


In [None]:
dsi_model = DSI_Discriminative(dsi_config,
                        model=model,
                        query_doc_df=query_doc_df,
                        training_dataloader_with_queries_filtered=train_dataloader_filtered,)

total_params = dsi_model.get_n_trainable_parameters()
total_params = "{:,}".format(total_params).replace(",", ".")
print("--------------------------------------------------------------")
print(f"Total number of trainable parameters for DSI Model: {total_params}")

--------------------------------------------------------------
Total number of trainable parameters for DSI Model: 1.327.104


### Show a simple example of training/inference
- If you would like to try a forward pass for training, set _training=`False`_, otherwise set it to `True` for inference purposes

In [None]:
'''
    Obtain the top 100 documents for each query ID in the training set.
'''
def obtain_top_100_df():
    # Obtain queries
    if(student == 'Professor'):
        docv2_train_top = "/content/top100.train.txt"
    else:
        docv2_train_top = pathjoin(F"/content/drive/MyDrive/{proj_dict[student][0]}", "top100.train.txt") # qid, “Q0”, docid, rank, score, runstring

    # LOCAL VARIABLES!! Use only for local execution
    if (not usingColab):
        docv2_train_top = "top100.train.txt"
        docv2_train_top = "top100.train.txt"

    # Here we merge our filtered dataset with the query_doc_df to obtain the dataset with the query IDs
    # columns = ['Doc_ID', 'Doc Repres. Strategy', 'Doc ID Repres. Strategy', 'Query_ID']
    dataset_with_queryIDs = pd.merge(train_docs_with_queries_dataset.filtered_docs_with_queries, query_doc_df[['Query_ID','Doc_ID']], on='Doc_ID')

    # Read the top100.train.txt file. This file contains the top 100 documents for each query ID.
    df_docv2_train_top100 = pd.read_csv(docv2_train_top, header=None, sep=" ", names=["Query_ID", "Q0", "Doc_ID", "Rank", "Score","Relevance"])

    # Filter out df_docv2_train_top100.
    #
    # After this line, we have a dataframe with the top 100 documents for each query ID, but only for the queries that we have in our dataset_with_queryIDs
    df_docv2_train_top100_filtered = df_docv2_train_top100[df_docv2_train_top100['Query_ID'].isin(dataset_with_queryIDs['Query_ID'])]

    # Here we might have two options:
    #   - Keep all the 100 documents for each query ID (but in practice we have just documents from the chunk '00'.)
    #   - Filter the query IDs to have just documents from the chunk '00'

    # This is the second option
    df_docv2_train_top100_filtered = df_docv2_train_top100_filtered[df_docv2_train_top100_filtered['Doc_ID'].str.startswith(f'msmarco_doc_{dsi_config["train_dataset_chunk"]}_')]
    return df_docv2_train_top100_filtered

In [None]:
# Initialize the Dataloaders such that in the cell below we can have different samples each time
sample_doc_dataloader = iter(train_dataloader)
sample_doc_query_dataloader = iter(val_dataloader)

dsi_model = dsi_model.to(device) # Move the model to the GPU
training = True

In [None]:
task = 'retrieval'

if(task == 'retrieval'):
  dsi_model.current_task = "retrieval"
  batch = next(sample_doc_query_dataloader)
  batch = dsi_model.prepare_sample_for_retrieval(batch)
else:
  dsi_model.current_task = "indexing"
  batch = next(sample_doc_dataloader)

# UserWarning: Input length of decoder_input_ids is 65, but `max_length` is set to 20. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.
# UserWarning: Using the model-agnostic default `max_length`(=20) to control the generation length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
out = dsi_model.forward(batch, training=training)

print(out)

if(dsi_model.config['training_type'] == 'discriminative' and training == False):
    dsi_model.test_type = 'top_k' # 'accuracy' / 'top_k'

    # Get the indices of the max value in the prediction tensor
    #
    # If it gives error in this line, most probably you have set dsi_config['training_type'] = 'generative'
    # It should be set to 'discriminative' in order to perform this test phase
    prediction_indices = torch.argmax(out, dim=1)  # Indices in prediction tensor

    if(dsi_model.test_type == "accuracy"):
        accuracy = dsi_model.run_accuracy(prediction_indices, batch)
    else:
        if(dsi_model.top_100_df is None) and (dsi_model.test_type == "top_k"):
            dsi_model.top_100_df = obtain_top_100_df()
        accuracy = dsi_model.run_top_k(prediction_indices, batch)

    print(accuracy)

tensor(0.1489, device='cuda:0',
       grad_fn=<BinaryCrossEntropyWithLogitsBackward0>)


### Actual Training

In [None]:
# Output of Model Name
if(usingColab):
    if(dsi_config['student'] == "Professor"):
        parent_path = f"/content/saved_models/"
    else:
        parent_path = f"/content/drive/MyDrive/{proj_dict[student][0]}/saved_models/"
else:
    parent_path = f"saved_models/"

model_name = f"{dsi_config['document_representation_strategy']}-{dsi_config['document_id_representation_strategy']}-{str(len(train_dataset))}_rows-checkpoint"

# Add the runID to the model name if WANDB is enabled
if(dsi_config['wandb_configs']['enable']):
    model_name += f"-runID_{dsi_model.wandb_run.name}"

# Checkpoint Callback
checkpoint_callback = pl.callbacks.ModelCheckpoint(dirpath=f"{parent_path+model_name}", monitor="avg_validation_loss", mode="min", save_top_k=1)

# Early Stopping Callback
early_stop_callback = pl.callbacks.EarlyStopping(monitor="avg_validation_loss", min_delta=0.0, patience=5, verbose=True, mode="min")

# Train parameters
train_parameters = dict(
    accelerator = 'auto',
    max_epochs = 50,
    callbacks=[checkpoint_callback], # [checkpoint_callback, early_stop_callback]
    fast_dev_run=False,
    gradient_clip_val=0.5,
)

trainer = pl.Trainer(**train_parameters)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [None]:
# Fit the model! :)
trainer.fit(
    model=dsi_model,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

# Close the WandB run if it's enabled
if(dsi_model.enable_wandb and dsi_model.wandb_run is not None):
    dsi_model.wandb_run.finish()

## Test

In [None]:
'''
    Obtain the top 100 documents for each query ID in the training set.
'''
def obtain_top_100_df():
    # Obtain queries
    if(student == 'Professor'):
        docv2_train_top = "/content/top100.train.txt"
    else:
        docv2_train_top = pathjoin(F"/content/drive/MyDrive/{proj_dict[student][0]}", "top100.train.txt") # qid, “Q0”, docid, rank, score, runstring

    # LOCAL VARIABLES!! Use only for local execution
    if (not usingColab):
        docv2_train_top = "top100.train.txt"
        docv2_train_top = "top100.train.txt"

    # Here we merge our filtered dataset with the query_doc_df to obtain the dataset with the query IDs
    # columns = ['Doc_ID', 'Doc Repres. Strategy', 'Doc ID Repres. Strategy', 'Query_ID']
    dataset_with_queryIDs = pd.merge(train_docs_with_queries_dataset.filtered_docs_with_queries, query_doc_df[['Query_ID','Doc_ID']], on='Doc_ID')

    # Read the top100.train.txt file. This file contains the top 100 documents for each query ID.
    df_docv2_train_top100 = pd.read_csv(docv2_train_top, header=None, sep=" ", names=["Query_ID", "Q0", "Doc_ID", "Rank", "Score","Relevance"])

    # Filter out df_docv2_train_top100.
    #
    # After this line, we have a dataframe with the top 100 documents for each query ID, but only for the queries that we have in our dataset_with_queryIDs
    df_docv2_train_top100_filtered = df_docv2_train_top100[df_docv2_train_top100['Query_ID'].isin(dataset_with_queryIDs['Query_ID'])]

    # Here we might have two options:
    #   - Keep all the 100 documents for each query ID (but in practice we have just documents from the chunk '00'.)
    #   - Filter the query IDs to have just documents from the chunk '00'

    # This is the second option
    df_docv2_train_top100_filtered = df_docv2_train_top100_filtered[df_docv2_train_top100_filtered['Doc_ID'].str.startswith(f'msmarco_doc_{dsi_config["train_dataset_chunk"]}_')]
    return df_docv2_train_top100_filtered

In [None]:
dsi_model.test_type = "accuracy"     # "accuracy" or "top_k"
dsi_model.model_found_label = 0      # Counter to simply keep track of at least how many Document can our Model retrieve
if(dsi_model.top_100_df is None) and (dsi_model.test_type == "top_k"):
    print("Obtaining Top 100 Documents for each Query ID in the Training Set...")
    dsi_model.top_100_df = obtain_top_100_df()

trainer.test(dsi_model, dataloaders=test_dataloader)
print(f"In total, the Model was able to find {dsi_model.model_found_label/((len(test_dataloader)*dsi_config['batch_size'])*dsi_model.num_return_sequences)*100:.2f}% of the Documents, starting from its predictions.")

## Restore a Checkpoint and run a Test
In this case we are restoring a Foundation Model with BERT

In [None]:
# Define the model
#
# 'model_name' = 'bert-base-uncased'  : The system will load the Bert-Uncased model
# 'model_name' = 'google/flan-t5-base': The system will load the Flan-T5 model
#
# All the configs will be dynamically set according to the model chosen
dsi_config['model_name'] = 'bert-base-uncased' # 'google/flan-t5-base' | 'bert-base-uncased'
dsi_config['fine_tuning'] = False

# ------------------------------------------------- #

if(dsi_config['model_name'] != 'bert-base-uncased'):
    model = FoundationModel(model_id=dsi_config['model_name'], encoder_id=None, decoder_id=None, MAX_LENGTH=dsi_config["MAX_LENGTH"], fine_tuning=dsi_config['fine_tuning']) # Flan-T5 Model
    dsi_config["hidden_size"] = model.model.config.hidden_size
else:
    dsi_config['encoder_model'] = 'bert-base-uncased'
    dsi_config['decoder_model'] = 'bert-base-uncased'
    model = FoundationModel(model_id=None, encoder_id=dsi_config['encoder_model'], decoder_id=dsi_config['decoder_model'], MAX_LENGTH=dsi_config["MAX_LENGTH"], fine_tuning=dsi_config['fine_tuning']) # Bert-Uncased Model
    dsi_config["hidden_size"] = model.model.encoder.config.hidden_size
    dsi_config.pop("model_name")

# Create actual DSI Model
dsi_config['wandb_configs']['enable'] = False # Disable WandB for the checkpoint model
checkpoint_model = DSI_Model(dsi_config,
                        model=model,
                        query_doc_df=query_doc_df,
                        training_dataloader_with_queries_filtered=train_dataloader_filtered,)

# Convert the model to CUDA as expected by QLoRa
checkpoint_model = checkpoint_model.to(device)

Some weights of BertLMHeadModel were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.0.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.0.crossattention.output.dense.bias', 'bert.encoder.layer.0.crossattention.output.dense.weight', 'bert.encoder.layer.0.crossattention.self.key.bias', 'bert.encoder.layer.0.crossattention.self.key.weight', 'bert.encoder.layer.0.crossattention.self.query.bias', 'bert.encoder.layer.0.crossattention.self.query.weight', 'bert.encoder.layer.0.crossattention.self.value.bias', 'bert.encoder.layer.0.crossattention.self.value.weight', 'bert.encoder.layer.1.crossattention.output.LayerNorm.bias', 'bert.encoder.layer.1.crossattention.output.LayerNorm.weight', 'bert.encoder.layer.1.crossattention.output.dense.bias', 'bert.encoder.layer.1.crossattention.output.dense.weight', 'bert.encoder.layer.1.crossattention.self.key.bias', 'bert.e

Using EncoderDecoderModel, hence bert-base-uncased.
Encoder info:
trainable params: 1,327,104 || all params: 139,798,842 || trainable%: 0.9492954169105349


In [None]:
# Check the directory where the model is saved based on the student
model_dir = 'summarization-unstructured_atomic-10000_rows-checkpoint-runID_beautiful-caress-72'

if(usingColab):
    if(dsi_config['student'] == "Professor"):
        parent_path = f"/content/saved_models/"
        DSI_URL = 'https://drive.google.com/uc?id=' + '1ehIC2ZsFalCwP1GfhE3vZa09rUBK_AO3' + '&export=download&confirm=t'
        model_path = pathjoin(parent_path, model_dir)
        if(not os.path.exists(model_path)):
            os.makedirs(model_path)
        gdown.download(DSI_URL, output=pathjoin(model_path, 'epoch=6-step=1064.ckpt'), quiet=False)
    else:
        parent_path = f"/content/drive/MyDrive/{proj_dict[student][0]}/saved_models/"
else:
    parent_path = f"saved_models/"

# Retrieve the checkpoint path
model_path = pathjoin(parent_path, model_dir, "epoch=6-step=1064.ckpt")

# Load the state dict from PyTorch/Lightning API
print(f"Loading checkpoint model from {model_path}...")
#checkpoint_model.load_state_dict(torch.load(model_path))                     # ------> PyTorch API (Maybe expects .pt and not .ckpt)
checkpoint_model.load_state_dict(torch.load(model_path)['state_dict']) # ------> PyTorch Lightning API

Loading checkpoint model from saved_models\summarization-unstructured_atomic-10000_rows-checkpoint-runID_beautiful-caress-72\epoch=6-step=1064.ckpt...


<All keys matched successfully>

Make sure that test_dataloader is populated.<br>
Also, make sure that the `obtain_top_100_df()` function is defined and loaded in the notebook.

In [None]:
checkpoint_model.test_type = "accuracy"     # "accuracy" or "top_k"
checkpoint_model.model_found_label = 0      # Counter to simply keep track of at least how many Document can our Model retrieve
if(checkpoint_model.top_100_df is None) and (checkpoint_model.test_type == "top_k"):
    print("Obtaining Top 100 Documents for each Query ID in the Training Set...")
    checkpoint_model.top_100_df = obtain_top_100_df()

trainer = pl.Trainer(accelerator = 'auto', fast_dev_run=False)
trainer.test(checkpoint_model, dataloaders=test_dataloader)

print(f"In total, the Model was able to find {checkpoint_model.model_found_label/((len(test_dataloader)*dsi_config['batch_size'])*checkpoint_model.num_return_sequences)*100:.2f}% of the Documents, starting from its predictions.")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

| Epoch 0 | TEST
	- Total Accuracy => 0.0000


In total, the Model was able to find 15.79% of the Documents, starting from its predictions.
