In [None]:
import os
import random
import time
import pickle
import logger
import numpy as np
import torch
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_transformers.tokenization_bert import BertTokenizer
from DataModule import process_mention_dataset, process_ontology
from datetime import datetime


from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler) 
from pytorch_transformers.optimization import WarmupLinearSchedule 
from scipy.sparse.csgraph import minimum_spanning_tree 
# csgraph = compressed sparse graph
from scipy.sparse import csr_matrix
# csr_matrix = compressed sparse row matrices
from collections import Counter 

import blink.biencoder.data_process_mult as data_process
import blink.biencoder.eval_cluster_linking as eval_cluster_linking
import blink.candidate_ranking.utils as utils
from blink.biencoder.biencoder import BiEncoderRanker
from blink.common.optimizer import get_bert_optimizer
from blink.common.params import BlinkParser

from IPython import embed 

def evaluate(
    reranker,
    valid_dict_vecs, 
    valid_men_vecs, 
    # device, #no longer used
    logger, 
    # knn,  #not even used
    n_gpu, 
    entity_data, 
    query_data,
    silent=False, 
    use_types=False, 
    embed_batch_size=768, 
    force_exact_search=False, 
    probe_mult_factor=1,
    within_doc=False, 
    context_doc_ids=None ):
    '''
    Description 
    -----------
    1) Computes embeddings and indexes for entities and mentions. 
    2) Performs k-nearest neighbors (k-NN) search to establish relationships between them.
    3) Constructs graphs based on these relationships.
    4) Evaluates the model's accuracy by analyzing how effectively the model can link mentions to the correct entities.
    
    Parameters 
    ----------
    reranker : BiEncoderRanker
        NN-based ranking model
    valid_dict_vec : list or ndarray
        Ground truth dataset containing the entities
    valid_men_vecs : list or ndarray
        Dataset containing mentions
    device : str
        cpu or gpu
    logger : 'Logger' object
        Logging object used to record messages
    knn : int
        Number of neighbors
    n_gpu : int
        Number of gpu
    entity_data : list or dict
        Entities from the data
    query_data : list or dict
        Queries / mentions against which the entities are evaluated
    silent=False : bool 
        When set to "True", likely suppresses the output or logging of progress updates to keep the console output clean.
    use_types=False : bool
        A boolean flag that indicates whether or not to use type-specific indexes for entities and mentions
    embed_batch_size=768 : int
        The batch size to use when processing embeddings.
    force_exact_search=False : bool
        force the embedding process to use exact search methods rather than approximate methods.
    probe_mult_factor=1 : int
        A multiplier factor used in index building for probing in case of approximate search (bigger = better but slower)
    within_doc=False : bool
        Boolean flag that indicates whether the evaluation should be constrained to within-document contexts
    context_doc_ids=None : bool
        This would be used in conjunction with within_doc to limit evaluations within the same document.
    '''
    torch.cuda.empty_cache() # Empty the CUDA cache to free up GPU memory
    
    n_entities = len(valid_dict_vecs) # total number of entities
    n_mentions = len(valid_men_vecs) # total number of mentions
    max_knn = 8 # max number of neighbors
    
    joint_graphs = {} # Store results of the NN search and distance between entities and mentions
    
    for k in [0, 1, 2, 4, 8]:
        joint_graphs[k] = { #DD3
            "rows": np.array([]),
            "cols": np.array([]),
            "data": np.array([]),
            "shape": (n_entities + n_mentions, n_entities + n_mentions),
        }
        
    
    '1) Computes embeddings and indexes for entities and mentions. '
    '''
    This block is preparing the data for evaluation by transforming raw vectors into a format that can be efficiently used for retrieval and comparison operations
    '''
    if use_types: # corpus = entity data
                  # corpus is a collection of entities, which is used to build type-specific search indexes if provided.
        '''
        With a Corpus : Multiple type-specific indexes are created, allowing for more targeted and efficient searches within specific categories of entities.
        'dict_embeds' and 'men_embeds': The resulting entity and mention embeddings.
        'dict_indexes' and 'men_indexes': Dictionary that will store search indexes (!= indices)for each unique entity type found in the corpus
        'dict_idxs_by_type' and 'men_idxs_by_type': Dictionary to store indices of the corpus elements, grouped by their entity type.
        !!! idxs = indices / indexes = indexes !!!
        '''
        logger.info("Eval: Dictionary: Embedding and building index") # For entities
        dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
            reranker,
            valid_dict_vecs,
            encoder_type="candidate",
            n_gpu=n_gpu,
            corpus=entity_data, 
            force_exact_search=force_exact_search, 
            batch_size=embed_batch_size, 
            probe_mult_factor=probe_mult_factor, 
            )
        logger.info("Eval: Queries: Embedding and building index") # For mentions
        men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
            reranker,
            valid_men_vecs,
            encoder_type="context",
            n_gpu=n_gpu,
            corpus=query_data,
            force_exact_search=force_exact_search,
            batch_size=embed_batch_size,
            probe_mult_factor=probe_mult_factor,
        )
    else: # corpus = None
        '''
        Without a Corpus: A single, general index is created for all embeddings, suitable for broad searches across the entire dataset.
        'dict_embeds' and 'men_embeds': The resulting entity and mention embeddings.
        'dict_index' and 'men_index': Dictionary that will store search index
        '''
        logger.info("Eval: Dictionary: Embedding and building index")
        dict_embeds, dict_index = data_process.embed_and_index(
            reranker,
            valid_dict_vecs,
            "candidate",
            n_gpu=n_gpu,
            force_exact_search=force_exact_search,
            batch_size=embed_batch_size,
            probe_mult_factor=probe_mult_factor,
        )
        logger.info("Eval: Queries: Embedding and building index")
        men_embeds, men_index = data_process.embed_and_index(
            reranker,
            valid_men_vecs,
            "context",
            n_gpu=n_gpu,
            force_exact_search=force_exact_search,
            batch_size=embed_batch_size,
            probe_mult_factor=probe_mult_factor,
        )


    '2) Performs k-nearest neighbors (k-NN) search to establish relationships between mentions and entities.'
    logger.info("Eval: Starting KNN search...") # An informational message is logged to indicate that the k-NN search is starting.
    # Fetch recall_k (default 16) knn entities for all mentions
    # Fetch (k+1) NN mention candidates; fetching all mentions for within_doc to filter down later
    n_men_to_fetch = len(men_embeds) if within_doc else max_knn + 1 # Number of mentions to fetch
    if not use_types: # Only one index so only need one search
        nn_ent_dists, nn_ent_idxs = dict_index.search(men_embeds, 1) #DD4/DD5 #return the distance and the indice of the closest entity for all mentions in men_embeds
        nn_men_dists, nn_men_idxs = men_index.search(men_embeds, n_men_to_fetch) # return the distances and the indices of the k closest mentions for all mentions in men_embeds
    else: #C Several indexes corresponding to the different entities in entity_data so we can use the specific search index
        # DD6
        # DD7
        nn_ent_idxs = -1 * np.ones((len(men_embeds), 1), dtype=int) # Indice of the closest entity for all mentions in men_embeds
        nn_ent_dists = -1 * np.ones((len(men_embeds), 1), dtype="float64") # Distance of the closest entity for all mentions in men_embeds
        nn_men_idxs = -1 * np.ones((len(men_embeds), n_men_to_fetch), dtype=int) # Indice of k closest mentions for all mentions in men_embeds
        nn_men_dists = -1 * np.ones((len(men_embeds), n_men_to_fetch), dtype="float64") # Distance of the k closest mentions for all mentions in men_embeds
        for entity_type in men_indexes:
            #CC3 Creates a new list only containing the mentions for which type = entity_types
            men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type]] # Only want to search the mentions that belongs to a specific type of entity.
            # Returns the distance and the indice of the closest entity for all mentions in men_embeds by entity type
            nn_ent_dists_by_type, nn_ent_idxs_by_type = dict_indexes[entity_type].search(men_embeds_by_type, 1) 
            nn_ent_idxs_by_type = np.array( #CC4 DD8
                list( #DD9
                    map( # lambda x : acts as a function
                        lambda x: dict_idxs_by_type[entity_type][x], nn_ent_idxs_by_type
                    ) # nn_ent_idxs_by_type is the iterable being processed by the map function
                    # Each element within nn_ent_idxs_by_type is passed to the lambda function as x.
                ) # map alone would return an object, that's why need a list
            )
            # Returns the distance and the indice of the k closest mentions for all mention in men_embeds by entity type
            # Note that here we may not necessarily have k mentions in each entity type which is why we use min(k,len(men_embeds_by_type))
            nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[entity_type].search(
                men_embeds_by_type, min(n_men_to_fetch, len(men_embeds_by_type))
            )
            nn_men_idxs_by_type = np.array(
                list(
                    map(lambda x: men_idxs_by_type[entity_type][x], nn_men_idxs_by_type)
                )
            )
            for i, idx in enumerate(men_idxs_by_type[entity_type]): #CC5
                nn_ent_idxs[idx] = nn_ent_idxs_by_type[i]
                nn_ent_dists[idx] = nn_ent_dists_by_type[i]
                nn_men_idxs[idx][: len(nn_men_idxs_by_type[i])] = nn_men_idxs_by_type[i]
                nn_men_dists[idx][: len(nn_men_dists_by_type[i])] = nn_men_dists_by_type[i]
    logger.info("Eval: Search finished") # An informational message is logged to indicate that the k-NN search is finished

    '3) Constructs graphs based on these relationships.'
    '''
    nn_ent_dists contain information about distance of the closest entity
    nn_ent_idxs contain information about indice of the closest entity
    nn_men_dists contain information about distance of the k nearest mentions
    nn_men_idxs contain information about indice of the k nearest mentions
    - We can fill in the "rows" part (=start nodes) of the graph in the order of the mentions
    - We can fill in the "cols" part (=end nodes) of the graph with nn_ent_idxs and nn_men_idxs
    - We can fill in the "data" part (=weights) of the graph with nn_ent_dists and nn_men_dists
    '''
    logger.info("Eval: Building graphs")
    for men_query_idx, men_embed in enumerate(
        tqdm(men_embeds, total=len(men_embeds), desc="Eval: Building graphs")
    ):
        # Get nearest entity candidate
        dict_cand_idx = nn_ent_idxs[men_query_idx][0] # Use of [0] to retrieve a scalar and not an 1D array
        dict_cand_score = nn_ent_dists[men_query_idx][0]

        # Filter candidates to remove -1s, mention query, within doc (if reqd.), and keep only the top k candidates
        filter_mask_neg1 = nn_men_idxs[men_query_idx] != -1 # bool ndarray. Ex : np.array([True, False, True, False])
        men_cand_idxs = nn_men_idxs[men_query_idx][filter_mask_neg1] # Only keep the elements != -1
        men_cand_scores = nn_men_dists[men_query_idx][filter_mask_neg1]

        if within_doc:
            men_cand_idxs, wd_mask = filter_by_context_doc_id(
                men_cand_idxs,
                context_doc_ids[men_query_idx],
                context_doc_ids,
                return_numpy=True,
            )
            men_cand_scores = men_cand_scores[wd_mask]
        
        # Filter self-reference + limits the number of candidate to 'max_knn'
        filter_mask = men_cand_idxs != men_query_idx
        men_cand_idxs, men_cand_scores = (
            men_cand_idxs[filter_mask][:max_knn],
            men_cand_scores[filter_mask][:max_knn],
        )

        # Add edges to the graphs
        for k in joint_graphs:
            joint_graph = joint_graphs[k] # There is no "s" in "joint_graph", it's not the same ! 
            # Add mention-entity edge
            joint_graph["rows"] = np.append( # Mentions are offset by the total number of entities to differentiate mention nodes from entity nodes
                joint_graph["rows"], [n_entities + men_query_idx]
            )  
            joint_graph["cols"] = np.append(joint_graph["cols"], dict_cand_idx)
            joint_graph["data"] = np.append(joint_graph["data"], dict_cand_score)
            if k > 0:
                # Add mention-mention edges
                joint_graph["rows"] = np.append(
                    joint_graph["rows"],
                    [n_entities + men_query_idx] * len(men_cand_idxs[:k]), # creates an array where the starting node (current mention) is repeated len(men_cand_idxs[:k]) times
                ) 
                joint_graph["cols"] = np.append(
                    joint_graph["cols"], n_entities + men_cand_idxs[:k]
                )
                joint_graph["data"] = np.append(
                    joint_graph["data"], men_cand_scores[:k]
                )
    
    "4) Evaluates the model's accuracy by analyzing how effectively the model can link mentions to the correct entities."    
    
    best_result = {'accuracy': 0}
    
    dict_acc = {}
    max_eval_acc = -1.
    for k in joint_graphs:
        logger.info(f"\nEval: Graph (k={k}):")
        # Partition graph based on cluster-linking constraints (inference procedure)
        partitioned_graph, clusters = eval_cluster_linking.partition_graph(
            joint_graphs[k], n_entities, directed=True, return_clusters=True)
        # Infer predictions from clusters
        result = eval_cluster_linking.analyzeClusters(clusters, entity_data, query_data, k)
        best_result = result if result['accuracy'] >= best_result['accuracy'] else best_result
        acc = float(result['accuracy'].split(' ')[0])
        dict_acc[f'k{k}'] = acc
        max_eval_acc = max(acc, max_eval_acc)
        logger.info(f"Eval: accuracy for graph@k={k}: {acc}%")
    logger.info(f"Eval: Best accuracy: {max_eval_acc}%")
    return max_eval_acc, dict_acc, {'dict_embeds': dict_embeds, 'dict_indexes': dict_indexes, 'dict_idxs_by_type': dict_idxs_by_type} if use_types else {'dict_embeds': dict_embeds, 'dict_index': dict_index}

def read_data(split, params, logger):
    '''
    Description 
    -----------
    Loads dataset samples from a specified path
    Optionally filters out samples without labels
    Checks if the dataset supports multiple labels per sample
    "has_mult_labels" : bool
    
    Parameters 
    ----------
    split : str
        Indicates the portion of the dataset to load ("train", "test", "valid"), used by utils.read_dataset to determine which data to read.
    params : dict(str)
        Contains configuration options
    logger : 
        An object used for logging messages about the process, such as the number of samples read.
    '''
    samples = utils.read_dataset(split, params["data_path"]) #DD21
    # Check if dataset has multiple ground-truth labels
    has_mult_labels = "labels" in samples[0].keys()
    if params["filter_unlabeled"]:
        # Filter samples without gold entities
        samples = list(
            filter(lambda sample: (len(sample["labels"]) > 0) if has_mult_labels else (sample["label"] is not None),
                   samples))
    logger.info("Read %d train samples." % len(samples))
    return samples, has_mult_labels

# Utility function
def filter_by_context_doc_id(mention_idxs, doc_id, doc_id_list, return_numpy=False):
    '''
    Description 
    -----------
    Filters and returns mention indices that belong to a specific document identified by "doc_id".
    Ensures that the analysis are constrained within the context of that particular document.
    
    Parameters 
    ----------
    - mention_idxs : ndarray(int) of dim = (number of mentions)
    Represents the indices of mentions
    - doc_id : int 
    Indice of the target document
    - doc_id_list : ndarray(int) of dim = (number of mentions)
    Array of integers, where each element is a document ID associated with the corresponding mention in mention_idxs. 
    The length of doc_id_list should match the total number of mentions referenced in mention_idxs.
    - return_numpy : bool
    A flag indicating whether to return the filtered list of mention indices as a NumPy array. 
    If True, the function returns a NumPy array; otherwise, it returns a list
    -------
    Outputs: 
    - mask : ndarray(bool) of dim = (number of mentions)
    Mask indicating where each mention's document ID (from doc_id_list) matches the target doc_id
    - mention_idxs : 
    Only contains mention indices that belong to the target document (=doc_id).
    '''
    mask = [doc_id_list[i] == doc_id for i in mention_idxs]
    if isinstance(mention_idxs, list): # Test if mention_idxs = list. Return a bool
        mention_idxs = np.array(mention_idxs) 
    mention_idxs = mention_idxs[mask] # possible only if mention_idxs is an array, not a list
    if not return_numpy:
        mention_idxs = list(mention_idxs)
    return mention_idxs, mask



def loss_function(reranker, 
    params, 
    forward_output, 
    data_module, 
    n_entities, 
    knn_dict, 
    batch_context_inputs, 
    accumulate_grad_batches
    ):
    '''
    Compute the loss function during the training.
    
    Parameters
    ----------
    - reranker : BiEncoderRanker
    NN-based ranking model
    - params : dict
    Contains most of the relevant keys for training (embed_batch_size, train_batch_size, n_gpu, force_exact_search etc...)
    - forward_output : dict
    Output of the forward() method
    - data_module : Instance of ArboelDataModule class
    - n_entities : int
    Total number of entities
    - knn_dict : int (self.knn_dict = self.hparams["knn"]//2)
    number of negative entities to fetch. It divides the k-nn evenly between entities and mentions 
    - accumulate_grad_batches : int
    Number of steps to accumulate gradients
    '''
    
    # Compute the loss
    loss_dual_negs = loss_ent_negs = 0
    # loss of a batch includes both negative mention and entity inputs (alongside positive examples ofc)
    loss_dual_negs, _ = reranker(forward_output['context_inputs'], label_input=forward_output['label_inputs'], mst_data={
        'positive_embeds': forward_output['positive_embeds'],
        'negative_dict_inputs': forward_output['negative_dict_inputs'],
        'negative_men_inputs': forward_output['negative_men_inputs']
    }, pos_neg_loss=params["pos_neg_loss"]) #A27
    skipped_context_inputs = []
    if forward_output['skipped'] > 0 and not params["within_doc_skip_strategy"]: #A28
        skipped_negative_dict_inputs = torch.tensor(
            list(map(lambda x: data_module.entity_dict_vecs[x].numpy(), skipped_negative_dict_inputs)))
        skipped_positive_embeds = []
        for pos_idx in forward_output['skipped_positive_idxs']:
            if pos_idx < n_entities:
                pos_embed = reranker.encode_candidate(data_module.entity_dict_vecs[pos_idx:pos_idx + 1],
                                                        requires_grad=True)
            else:
                pos_embed = reranker.encode_context(
                    data_module.train_men_vecs[pos_idx - n_entities:pos_idx - n_entities + 1], requires_grad=True)
            skipped_positive_embeds.append(pos_embed)
        skipped_positive_embeds = torch.cat(skipped_positive_embeds)
        skipped_context_inputs = batch_context_inputs[~np.array(forward_output['context_inputs_mask'])]
        skipped_context_inputs = skipped_context_inputs
        skipped_label_inputs = torch.tensor([[1] + [0] * (knn_dict)] * len(skipped_context_inputs),
                                    dtype=torch.float32)
        #DD18 loss of a batch that only includes negative entity inputs.
        loss_ent_negs, _ = reranker(skipped_context_inputs, label_input=skipped_label_inputs, mst_data={
            'positive_embeds': skipped_positive_embeds,
            'negative_dict_inputs': skipped_negative_dict_inputs,
            'negative_men_inputs': None
        }, pos_neg_loss=params["pos_neg_loss"])
            
    # len(context_input) = Number of mentions in the batch that successfully found negative entities and mentions.
    # len(skipped_context_inputs): Number of mentions in the batch that only found negative entities.
    loss = ((loss_dual_negs * len(forward_output['context_inputs']) + loss_ent_negs * len(skipped_context_inputs)) / (len(forward_output['context_inputs']) + len(skipped_context_inputs))) / accumulate_grad_batches
    return loss




"Data module"
class ArboelDataModule(L.LightningDataModule):
    '''
    Attributes
    ----------
    
    - entity_dictionary : list of dict
    Stores the initial and raw entity dictionary
    - train_tensor_data : TensorDataset(context_vecs, label_idxs, n_labels, mention_idx) with :
        - “context_vecs” : tensor containing IDs of (mention + surrounding context) tokens 
        - “label_idxs” : tensor with indices pointing to the entities in the entity dictionary that are considered correct labels for the mention.
        - “n_labels” : Number of labels (=entities) associated with the mention
        - “mention_idx” : tensor containing a sequence of integers from 0 to N-1 (N = number of mentions in the dataset) serving as a unique identifier for each mention.
    - train_processed_data : list of dict
    Contains information about mentions (mention_id, mention_name, context, etc…)
    - valid_tensor_data : TensorDataset
    Same as "train_tensor_dataset" but for validation set
    - max_gold_cluster_len : int
    Maximum length of clusters inside gold_cluster
    - train_context_doc_ids : list
    # Store the context_doc_id (=context document indice) for every mention in the train set
    '''
    def __init__(self, params, dataset, ontology):
        super().__init__()
        self.save_hyperparameters(params)
        
        # # First try to load the tokens from a local file. If local file not found, uses a pre-trained model specified by params["bert_model"]
        # vocab_path = os.path.join(self.hparams.params["bert_model"], 'vocab.txt') #DD3
        # if os.path.isfile(vocab_path): 
        #     print(f"Found tokenizer vocabulary at {vocab_path}")
        # self.tokenizer = BertTokenizer.from_pretrained(
        #     vocab_path if os.path.isfile(vocab_path) else self.hparams.params["bert_model"], do_lower_case=self.hparams.params["lowercase"]
        # )
        
        self.dataset = dataset
        self.ontology = ontology
        
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            
        self.batch_size = self.hparams.params['batch_size']
        
        self.train_processed_data = None
        self.valid_processed_data = None
        self.test_processed_data = None
        self.train_tensor_data = None
        self.valid_tensor_data = None
        self.test_tensor_data = None
        self.entity_dict_vecs = None


    def prepare_data(self):
        'Use this to download and prepare data.'
        
        # prepare the entity data : dictionary.pickle
        process_ontology(self.ontology, self.dataset)
        
        # prepare the mentions data :  train.jsonl, valid.jsonl, test.jsonl
        process_mention_dataset(self.ontology, self.dataset)
        
        # Path to pickle files
        pickle_src_path = self.hparams.params["pickle_src_path"]
        
        'entity dictionary'
        # if entity dictionary already tokenized, load it
        pickle_src_path = self.hparams.params["pickle_src_path"]
        entity_dictionary_pkl_path = os.path.join(pickle_src_path, 'entity_dictionary.pickle')
        self.entity_dictionary_loaded = False
        if os.path.isfile(entity_dictionary_pkl_path): 
            print("Loading stored processed entity dictionary...")
            with open(entity_dictionary_pkl_path, 'rb') as read_handle:
                self.entity_dictionary = pickle.load(read_handle) # DD12B
            self.entity_dictionary_loaded = True
        
        else : # else load the not processed one
            with open(os.path.join(self.hparams.params["data_path"], 'dictionary.pickle'), 'rb') as read_handle: #A11
                self.entity_dictionary = pickle.load(read_handle)
        
        'training mention data'
        # path to a file where the training data, already processed into tensors is saved
        self.train_tensor_data_pkl_path = os.path.join(pickle_src_path, 'train_tensor_data.pickle')
        # path to a file where metadata / additional information about the training data is stored
        self.train_processed_data_pkl_path = os.path.join(pickle_src_path, 'train_processed_data.pickle')
        
        # if the full path to file exist, load the file
        if os.path.isfile(self.train_tensor_data_pkl_path) and os.path.isfile(self.train_processed_data_pkl_path):
            print("Loading stored processed train data...")
            with open(self.train_tensor_data_pkl_path, 'rb') as read_handle:
                self.train_tensor_data = pickle.load(read_handle)
            with open(self.train_processed_data_pkl_path, 'rb') as read_handle:
                self.train_processed_data = pickle.load(read_handle)
                
                
        'validation mention data'
        self.valid_tensor_data_pkl_path = os.path.join(pickle_src_path, 'valid_tensor_data.pickle')
        self.valid_processed_data_pkl_path = os.path.join(pickle_src_path, 'valid_processed_data.pickle')
        
        # Same as training data : 
        # if the full path to file exist, load the file
        if os.path.isfile(self.valid_tensor_data_pkl_path) and os.path.isfile(self.valid_processed_data_pkl_path):
            print("Loading stored processed valid data...")
            with open(self.valid_tensor_data_pkl_path, 'rb') as read_handle:
                self.valid_tensor_data = pickle.load(read_handle)
            with open(self.valid_processed_data_pkl_path, 'rb') as read_handle:
                self.valid_processed_data = pickle.load(read_handle)
                
        'test mention data'
        self.test_tensor_data_pkl_path = os.path.join(pickle_src_path, 'test_tensor_data.pickle')
        self.test_processed_data_pkl_path = os.path.join(pickle_src_path, 'test_processed_data.pickle')
        
        # Same as training data : 
        # if the full path to file exist, load the file
        if os.path.isfile(self.test_tensor_data_pkl_path) and os.path.isfile(self.test_processed_data_pkl_path):
            print("Loading stored processed valid data...")
            with open(self.test_tensor_data_pkl_path, 'rb') as read_handle: #CC7 'rb' = binary read mode
                self.test_tensor_data = pickle.load(read_handle)
            with open(self.test_processed_data_pkl_path, 'rb') as read_handle:
                self.test_processed_data = pickle.load(read_handle)



    def setup(self, stage=None):
        '''
        For processing and splitting. Called at the beginning of fit (train + validate), validate, test, or predict.
        '''
        
        
        'Entity dict : drop entity for discovery'
        # For discovery experiment: Drop entities used in training that were dropped randomly from dev/test set
        if self.hparams.params["drop_entities"]: #A12
            assert self.entity_dictionary 
            drop_set_path = self.hparams.params["drop_set"] if self.hparams.params["drop_set"] is not None else os.path.join(self.hparams.params["pickle_src_path"], 'drop_set_mention_data.pickle') #A12
            if not os.path.isfile(drop_set_path):
                raise ValueError("Invalid or no --drop_set path provided to dev/test mention data")
            with open(drop_set_path, 'rb') as read_handle:
                drop_set_data = pickle.load(read_handle)
            # gold cuis indices for each mention in drop_set_data
            drop_set_mention_gold_cui_idxs = list(map(lambda x: x['label_idxs'][0], drop_set_data))
            # Make the set unique
            ents_in_data = np.unique(drop_set_mention_gold_cui_idxs)
            # % of drop
            ent_drop_prop = 0.1
            logger.info(f"Dropping {ent_drop_prop*100}% of {len(ents_in_data)} entities found in drop set")
            # Number of entity indices to drop
            n_ents_dropped = int(ent_drop_prop*len(ents_in_data))
            # Random selection drop
            rng = np.random.default_rng(seed=17)
            # Indices of all entities that are dropped
            dropped_ent_idxs = rng.choice(ents_in_data, size=n_ents_dropped, replace=False)

            # Drop entities from dictionary (subsequent processing will automatically drop corresponding mentions)
            keep_mask = np.ones(len(self.entity_dictionary), dtype='bool')
            keep_mask[dropped_ent_idxs] = False
            self.entity_dictionary = np.array(self.entity_dictionary)[keep_mask]
        
        
        'Train mention data'
        # train_samples = list of dict. Each dict contains information about a mention (id, name, context, etc…). 
        # Each key can have a dictionary itself. Ex : mention["context"]["tokens"] or mention["context"]["ids"]
        train_samples, train_mult_labels = read_data("train", self.hparams.params, logger)
        
        if not os.path.isfile(self.train_tensor_data_pkl_path) : # Load and Process train data if not done yet
            # train_processed_data = (mention + surrounding context) tokens
            # entity_dictionary = tokenized entities
            # tensor_train_dataset = Dataset containing several tensors (IDs of mention + context / indices of correct entities etc..) # Go check "process_mention_data" for more info
            self.train_processed_data, self.entity_dictionary, self.train_tensor_data = data_process.process_mention_data(
                train_samples,
                self.entity_dictionary,
                self.tokenizer,
                self.hparams.params["max_context_length"],
                self.hparams.params["max_cand_length"],
                context_key=self.hparams.params["context_key"],
                multi_label_key="labels" if train_mult_labels else None,
                # silent=self.hparams.params["silent"], 
                logger=logger,
                debug=self.hparams.params["debug"], 
                knn=self.hparams.params['knn'],
                dictionary_processed=self.entity_dictionary_loaded
            )
            
            print("Saving processed train data...")
            with open(self.train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.train_tensor_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(self.train_processed_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.train_processed_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        
        # Prepare tensor containing only ID of (mention + surrounding context) tokens of training set'
        self.train_men_vecs = self.train_tensor_data[:][0] 

        # Store the IDs of the entity in entity_dictionary # It's the equivalent of train_men_vecs for entities
        # (Done here because data_process.process_mention_data will tokenize the entities in entity_dict)
        self.entity_dict_vecs = torch.tensor(list(map(lambda x: x['ids'], self.entity_dictionary)), dtype=torch.long)


        'Validation mention data'
        if not os.path.isfile(self.valid_tensor_data_pkl_path) : 
            # Load and Process validation data if not done yet
            valid_samples, valid_mult_labels = read_data("valid", self.hparams.params, logger)
            self.valid_processed_data, _, self.valid_tensor_data = data_process.process_mention_data(
                valid_samples,
                self.entity_dictionary,
                self.tokenizer,
                self.hparams.params["max_context_length"],
                self.hparams.params["max_cand_length"],
                context_key=self.hparams.params["context_key"],
                multi_label_key="labels" if valid_mult_labels else None,
                # silent=self.hparams.params["silent"],
                logger=logger,
                debug=self.hparams.params["debug"],
                knn=self.hparams.params["knn"],
                dictionary_processed=self.entity_dictionary_loaded
            )
            
            print("Saving processed valid data...")
            with open(self.valid_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.valid_tensor_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(self.valid_processed_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.valid_processed_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        # Prepare tensor containing only ID of (mention + surrounding context) tokens of validation data'
        self.valid_men_vecs = self.valid_tensor_data[:][0]
        
        
        'Test mention data'
        if not os.path.isfile(self.train_tensor_data_pkl_path) :
            # Load and Process test data if not done yet
            test_samples, test_mult_labels = read_data("test", self.hparams.params, logger)
            self.test_processed_data, _, self.test_tensor_data = data_process.process_mention_data(
                test_samples,
                self.entity_dictionary,
                self.tokenizer,
                self.hparams.params["max_context_length"],
                self.hparams.params["max_cand_length"],
                context_key=self.hparams.params["context_key"],
                multi_label_key="labels" if test_mult_labels else None,
                # silent=self.hparams.params["silent"],
                logger=logger,
                debug=self.hparams.params["debug"],
                knn=self.hparams.params["knn"],
                dictionary_processed=self.entity_dictionary_loaded
            )
            
            print("Saving processed test data...")
            with open(self.test_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.test_tensor_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(self.test_processed_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.test_processed_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        # Prepare tensor containing only ID of (mention + surrounding context) tokens of validation data'
        self.test_men_vecs = self.test_tensor_data[:][0]
        
        
        'Within_doc search'
        # Consider if it’s within_doc (=search only within the document)'
        self.train_context_doc_ids = self.valid_context_doc_ids = None
        if self.hparams.params["within_doc"]: 
            # RR9 : No need of those if conditions
            # Store the context_doc_id for every mention in the train and valid sets
            # if train_samples is None:
            #     train_samples, _ = read_data("train", self.hparams.params, logger)
            self.train_context_doc_ids = [s['context_doc_id'] for s in train_samples]
            # if valid_samples is None:
            #     valid_samples, _ = read_data("valid", self.hparams.params, logger)
            self.valid_context_doc_ids = [s['context_doc_id'] for s in valid_samples]
            # if test_samples is None:
            #     test_samples, _ = read_data("valid", self.hparams.params, logger)
            self.test_context_doc_ids = [s['context_doc_id'] for s in test_samples]
        
        # Get clusters of mentions that map to a gold entity
        self.train_gold_clusters = data_process.compute_gold_clusters(self.train_processed_data)
        # Maximum length of clusters inside gold_cluster
        self.max_gold_cluster_len = 0
        for ent in self.train_gold_clusters:
            if len(self.train_gold_clusters[ent]) > self.max_gold_cluster_len:
                self.max_gold_cluster_len = len(self.train_gold_clusters[ent])


    def train_dataloader(self): #RR5
        # Return the training DataLoader
        # train_sampler = RandomSampler(self.train_tensor_data) if self.params["shuffle"] else SequentialSampler(self.train_tensor_data)
        # return DataLoader(self.train_tensor_data, sampler=train_sampler, batch_size=self.batch_size) #DD4
        return DataLoader(self.train_tensor_data, batch_size=self.batch_size)
    
    def val_dataloader(self):
        # Return the validation DataLoader
        return DataLoader(self.valid_tensor_data, batch_size=self.batch_size)
    
    def test_dataloader(self):
        # Return the validation DataLoader
        return DataLoader(self.test_tensor_data, batch_size=self.batch_size)


'Model Training'
class LitArboel(L.LightningModule):
    def __init__(
        self, 
        params,  
        data_module
        ):
        '''
        - params : dict
        Contains most of the relevant keys for training (embed_batch_size, train_batch_size, n_gpu, force_exact_search etc...)
        - data_module : Instance of ArboelDataModule class
        '''
        super(LitArboel, self).__init__()
        self.save_hyperparameters(params) #DD1
        
        self.data_module = data_module
        
        self.reranker = BiEncoderRanker(params)
        # self.tokenizer = self.reranker.tokenizer
        self.model = self.reranker.model
        
        
    def forward(self, batch_context_inputs, candidate_idxs, n_gold, mention_idxs):
        """
        Description
        -----------
        Processes a batch of input data to generate embeddings, and identifies positive and negative examples for training. 
        It handles the construction of mention-entity graphs, computes nearest neighbors, and organizes the data for subsequent loss calculation.
        
        Parameters
        ----------
        - “batch_context_inputs” : Tensor
            Tensor containing IDs of (mention + surrounding context) tokens. Shape: (batch_size, context_length) 
        - “candidate_idxs” : Tensor
            Tensor with indices pointing to the entities in the entity dictionary that are considered correct labels for the mention. Shape: (batch_size, candidate_count)
        - “n_gold” : Tensor
            Number of labels (=entities) associated with the mention. Shape: (batch_size,)
        - “mention_idx” : Tensor
            Tensor containing a sequence of integers from 0 to N-1 (N = number of mentions in the dataset) serbing as a unique identifier for each mention.
        
        Return
        ------
        - label_inputs : Tensor
            Tensor of binary labels indicating the correct candidates. Shape: (batch_size, 1 + knn_dict + knn_men), where 1 represents the positive example and the rest are negative examples.
        - context_inputs : Tensor
            Processed batch context inputs, filtered to remove mentions with no negative examples. Shape: (filtered_batch_size, context_length).
        - negative_men_inputs : Tensor
            Tensor of negative mention inputs. Shape: (filtered_batch_size * knn_men,)
        - negative_dict_inputs : Tensor
            Tensor of negative dictionary (entity) inputs. Shape: (filtered_batch_size * knn_dict,)
        - positive_embeds : Tensor
            Tensor of embeddings for the positive examples. Shape: (filtered_batch_size, embedding_dim)
        - skipped : int 
            The number of mentions skipped due to lack of valid negative examples.
        - skipped_positive_idxs : list(int)
            List of indices for positive examples that were skipped.
        - skipped_negative_dict_inputs :
            Tensor of negative dictionary inputs for skipped examples. Shape may vary based on the number of skipped examples and available negative dictionary entries.
        - context_inputs_mask : list(bool)
            Mask indicating which entries in batch_context_inputs were retained after filtering out mentions with no negative examples.
        """
        
        # mentions within the batch
        mention_embeddings = self.train_men_embeddings[mention_idxs.cpu()]
        if len(mention_embeddings.shape) == 1:
            mention_embeddings = np.expand_dims(mention_embeddings, axis=0)

        positive_idxs = []
        negative_dict_inputs = []
        negative_men_inputs = []

        skipped_positive_idxs = []
        skipped_negative_dict_inputs = []

        min_neg_mens = float('inf')
        skipped = 0
        context_inputs_mask = [True]*len(batch_context_inputs)
        
        'IV.4.B) For each mention within the batch'
        # For each mention within the batch
        for m_embed_idx, m_embed in enumerate(mention_embeddings):
            mention_idx = int(mention_idxs[m_embed_idx])
            #CC11 ground truth entities of the mention "mention_idx"
            gold_idxs = set(self.data_module.train_processed_data[mention_idx]['label_idxs'][:n_gold[m_embed_idx]])
            
            # TEMPORARY: Assuming that there is only 1 gold label, TODO: Incorporate multiple case
            assert n_gold[m_embed_idx] == 1

            if mention_idx in self.gold_links:
                gold_link_idx = self.gold_links[mention_idx]
            else:
                'IV.4.B.a) Create the graph with positive edges'
                # This block creates all the positive edges of the mention in this iteration
                # Run MST on mention clusters of all the gold entities of the current query mention to find its positive edge
                rows, cols, data, shape = [], [], [], (self.n_entities+self.n_mentions, self.n_entities+self.n_mentions)
                seen = set()

                # Set whether the gold edge should be the nearest or the farthest neighbor
                sim_order = 1 if self.hparams.params["farthest_neighbor"] else -1 #A26

                for cluster_ent in gold_idxs:
                    #CC12 IDs of all the mentions inside the gold cluster with entity id = "cluster_ent"
                    cluster_mens = self.hparams.train_gold_clusters[cluster_ent]

                    if self.hparams.param["within_doc"]:
                        # Filter the gold cluster to within-doc
                        cluster_mens, _ = filter_by_context_doc_id(cluster_mens,
                                                                    self.data_module.train_context_doc_ids[mention_idx],
                                                                    self.data_module.train_context_doc_ids)
                    
                    # weights for all the mention-entity links inside the cluster of the current mention
                    to_ent_data = self.train_men_embeddings[cluster_mens] @ self.train_dict_embeddings[cluster_ent].T

                    # weights for all the mention-mention links inside the cluster of the current mention
                    to_men_data = self.train_men_embeddings[cluster_mens] @ self.train_men_embeddings[cluster_mens].T

                    if self.hparams.params['gold_arbo_knn'] is not None:
                        # Descending order of similarity if nearest-neighbor, else ascending order
                        sorti = np.argsort(sim_order * to_men_data, axis=1)
                        sortv = np.take_along_axis(to_men_data, sorti, axis=1)
                        if self.hparams.params["rand_gold_arbo"]:
                            randperm = np.random.permutation(sorti.shape[1])
                            sortv, sorti = sortv[:, randperm], sorti[:, randperm]

                    for i in range(len(cluster_mens)):
                        from_node = self.n_entities + cluster_mens[i]
                        to_node = cluster_ent
                        # Add mention-entity link
                        rows.append(from_node)
                        cols.append(to_node)
                        data.append(-1 * to_ent_data[i])
                        if self.hparams.params['gold_arbo_knn'] is None:
                            # Add forward and reverse mention-mention links over the entire MST
                            for j in range(i+1, len(cluster_mens)):
                                to_node = self.n_entities + cluster_mens[j]
                                if (from_node, to_node) not in seen:
                                    score = to_men_data[i,j]
                                    rows.append(from_node)
                                    cols.append(to_node)
                                    data.append(-1 * score) # Negatives needed for SciPy's Minimum Spanning Tree computation
                                    seen.add((from_node, to_node))
                                    seen.add((to_node, from_node))
                        else:
                            # Approximate the MST using <gold_arbo_knn> nearest mentions from the gold cluster
                            added = 0
                            approx_k = min(self.hparams.params['gold_arbo_knn']+1, len(cluster_mens))
                            for j in range(approx_k):
                                if added == approx_k - 1:
                                    break
                                to_node = self.n_entities + cluster_mens[sorti[i, j]]
                                if to_node == from_node:
                                    continue
                                added += 1
                                if (from_node, to_node) not in seen:
                                    score = sortv[i, j]
                                    rows.append(from_node)
                                    cols.append(to_node)
                                    data.append(
                                        -1 * score)  # Negatives needed for SciPy's Minimum Spanning Tree computation
                                    seen.add((from_node, to_node))

                'IV.4.B.b) Fine tuning with inference procedure to get a mst'
                # Creates MST with entity constraint (inference procedure)
                csr = csr_matrix((-sim_order * np.array(data), (rows, cols)), shape=shape)
                # Note: minimum_spanning_tree expects distances as edge weights
                mst = minimum_spanning_tree(csr).tocoo()
                # Note: cluster_linking_partition expects similarities as edge weights # Convert directed to undirected graph
                rows, cols, data = cluster_linking_partition(np.concatenate((mst.row, mst.col)), # cluster_linking_partition is imported from eval_cluster_linking
                                                                np.concatenate((mst.col, mst.row)),
                                                                np.concatenate((sim_order * mst.data, sim_order * mst.data)),
                                                                self.n_entities,
                                                                directed=True,
                                                                silent=True)
                assert np.array_equal(rows - self.n_entities, cluster_mens)
                
                for i in range(len(rows)):
                    men_idx = rows[i] - self.n_entities
                    if men_idx in self.gold_links:
                        continue
                    assert men_idx >= 0
                    add_link = True
                    # Store the computed positive edges for the mentions in the clusters only if they have the same gold entities as the query mention
                    for l in self.data_module.train_processed_data[men_idx]['label_idxs'][:self.data_module.train_processed_data[men_idx]['n_labels']]:
                        if l not in gold_idxs:
                            add_link = False
                            break
                    if add_link:
                        self.gold_links[men_idx] = cols[i]
                gold_link_idx = self.gold_links[mention_idx]
                
            'IV.4.B.c) Retrieve the pre-computed nearest neighbors'
            knn_dict_idxs = self.dict_nns[mention_idx]
            knn_dict_idxs = knn_dict_idxs.astype(np.int64).flatten()
            knn_men_idxs = self.men_nns[mention_idx][self.men_nns[mention_idx] != -1]
            knn_men_idxs = knn_men_idxs.astype(np.int64).flatten()
            if self.hparams.params['within_doc']:
                knn_men_idxs, _ = filter_by_context_doc_id(knn_men_idxs,
                                                        self.data_module.train_context_doc_ids[mention_idx],
                                                        self.data_module.train_context_doc_ids, return_numpy=True)
            'IV.4.B.d) Add negative examples'
            neg_mens = list(knn_men_idxs[~np.isin(knn_men_idxs, np.concatenate([self.data_module.train_gold_clusters[gi] for gi in gold_idxs]))][:self.knn_men])
            # Track queries with no valid mention negatives
            if len(neg_mens) == 0:
                context_inputs_mask[m_embed_idx] = False
                skipped_negative_dict_inputs += list(knn_dict_idxs[~np.isin(knn_dict_idxs, list(gold_idxs))][:self.knn_dict])
                skipped_positive_idxs.append(gold_link_idx)
                skipped += 1
                continue
            else:
                min_neg_mens = min(min_neg_mens, len(neg_mens))
            negative_men_inputs.append(knn_men_idxs[~np.isin(knn_men_idxs, np.concatenate([self.data_module.train_gold_clusters[gi] for gi in gold_idxs]))][:self.knn_men])
            negative_dict_inputs += list(knn_dict_idxs[~np.isin(knn_dict_idxs, list(gold_idxs))][:self.knn_dict])
            # Add the positive example
            positive_idxs.append(gold_link_idx)

        
        'IV.4.C) Skip this iteration if no suitable negative examples found'
        if len(negative_men_inputs) == 0 :
            return None #DD8 instead of continue
        
        # Sets the minimum number of negative mentions found across all processed mentions in the current batch
        self.knn_men = min_neg_mens
        
        # This step ensures that each mention is compared against a uniform number of negative mentions
        filtered_negative_men_inputs = []
        for row in negative_men_inputs:
            filtered_negative_men_inputs += list(row[:self.knn_men])
        negative_men_inputs = filtered_negative_men_inputs

        # Assertions for Data Integrity
        assert len(negative_dict_inputs) == (len(mention_embeddings) - skipped) * self.knn_dict
        assert len(negative_men_inputs) == (len(mention_embeddings) - skipped) * self.knn_men

        self.total_skipped += skipped
        self.total_knn_men_negs += self.knn_men

        # Convert to tensors
        negative_dict_inputs = torch.tensor(list(map(lambda x: self.data_module.entity_dict_vecs[x].numpy(), negative_dict_inputs)))
        negative_men_inputs = torch.tensor(list(map(lambda x: self.data_module.train_men_vecs[x].numpy(), negative_men_inputs)))
        
        # Labels indicating the correct candidates. Used for computing loss.
        positive_embeds = []
        for pos_idx in positive_idxs:
            if pos_idx < self.n_entities:
                pos_embed = self.reranker.encode_candidate(self.data_module.entity_dict_vecs[pos_idx:pos_idx + 1], requires_grad=True)
            else:
                pos_embed = self.reranker.encode_context(self.data_module.train_men_vecs[pos_idx - self.n_entities:pos_idx - self.n_entities + 1], requires_grad=True)
            positive_embeds.append(pos_embed)
        positive_embeds = torch.cat(positive_embeds)
        
        # Remove mentions with no negative examples
        context_inputs = batch_context_inputs[context_inputs_mask]
        context_inputs = context_inputs
        
        # Tensor containing binary values that act as indicator variables in the paper:
        # Contains Indicator variable such that I_{u,m_i} = 1 if(u,mi) ∈ E'_{m_i} and I{u,m_i} = 0 otherwise.
        label_inputs = torch.tensor([[1]+[0]*(self.knn_dict+self.knn_men)]*len(context_inputs), dtype=torch.float32)
        
        return {'label_inputs':label_inputs, 'context_inputs' : context_inputs, "negative_men_inputs" : negative_men_inputs,
                'negative_dict_inputs' : negative_dict_inputs, 'positive_embeds' : positive_embeds, 'skipped' : skipped,
                'skipped_positive_idxs' : skipped_positive_idxs, 'skipped_negative_dict_inputs' : skipped_negative_dict_inputs,
                'context_inputs_mask' : context_inputs_mask
                }



    def training_step(self, batch, batch_idx):
        
        # batch = tuple(t.to(device) for t in batch) : automated in pytorch #DD5
        
        # Initialize the parameters
        # batch is a subsample from tensor_dataset
        batch_context_inputs, candidate_idxs, n_gold, mention_idxs = batch
        
        f = self.forward(batch_context_inputs, candidate_idxs, n_gold, mention_idxs)
        
        # Compute the loss
        loss = loss_function(self.reranker, 
            self.hparams.params, 
            f, 
            self.data_module, 
            self.n_entities, 
            self.knn_dict, 
            batch_context_inputs, 
            self.trainer.accumulate_grad_batches
        )

        return loss


    def configure_optimizers(self): #DD2
        
        # Define optimizer
        optimizer = get_bert_optimizer(
        [self.model],
        self.hparams.params["type_optimization"],
        self.hparams.params["learning_rate"],
        fp16=self.hparams.params.get("fp16"),
        )
        
        # Define scheduler
        batch_size = self.hparams.params["train_batch_size"]
        epochs = self.trainer.max_epochs

        num_train_steps = int(len(self.train_tensor_data) / batch_size / self.trainer.accumulate_grad_batches) * epochs
        num_warmup_steps = int(num_train_steps * self.hparams.params["warmup_proportion"])

        scheduler = WarmupLinearSchedule(
            optimizer, warmup_steps=num_warmup_steps, t_total=num_train_steps,
        )
        logger.info(" Num optimization steps = %d" % num_train_steps)
        logger.info(" Num warmup steps = %d", num_warmup_steps)
        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]



    def on_train_start(self):
        self.start_time = time.time()
    
    
    
    def on_fit_start(self):
        # Compute n_gpu once, at the start of training, and store it as an instance attribute
        if self.trainer.devices is None:
            self.n_gpu = 0
        elif isinstance(self.trainer.devices, list):
            self.n_gpu = len(self.trainer.devices)
        elif isinstance(self.trainer.devices, int):
            self.n_gpu = self.trainer.devices
        else:
            # For other configurations, such as auto-select or when specific GPUs are selected as a string
            # It's safer to rely on the actual allocated devices by the trainer
            self.n_gpu = len(self.trainer.accelerator_connector.parallel_devices)
            
            
    def on_train_epoch_start(self):
        # To do at the start of each epoch
        self.tr_loss = 0
        
        'IV.1) Compute mention and entity embeddings and indexes at the start of each epoch'
        # Compute mention and entity embeddings and indexes at the start of each epoch
        if self.hparams.params['use_types']: # type-specific indexes 
            self.train_dict_embeddings, self.train_dict_indexes, self.dict_idxs_by_type = data_process.embed_and_index(self.reranker, self.data_module.entity_dict_vecs, encoder_type="candidate", n_gpu=self.n_gpu, corpus= self.data_module.entity_dictionary, force_exact_search= self.hparams.params['force_exact_search'], batch_size= self.hparams.params['embed_batch_size'], probe_mult_factor= self.hparams.params['probe_mult_factor']) #D11
            self.train_men_embeddings, self.train_men_indexes, self.men_idxs_by_type = data_process.embed_and_index(self.reranker, self.data_module.train_men_vecs, encoder_type="context", n_gpu=self.n_gpu, corpus= self.data_module.train_processed_data, force_exact_search= self.hparams.params['force_exact_search'], batch_size= self.hparams.params['embed_batch_size'], probe_mult_factor= self.hparams.params['probe_mult_factor'])
        else: # general indexes
            self.train_dict_embeddings, self.train_dict_index = data_process.embed_and_index(self.reranker, self.data_module.entity_dict_vecs, encoder_type="candidate", n_gpu=self.n_gpu, force_exact_search= self.hparams.params['force_exact_search'], batch_size= self.hparams.params['embed_batch_size'], probe_mult_factor= self.hparams.params['probe_mult_factor'])
            self.train_men_embeddings, self.train_men_index = data_process.embed_and_index(self.reranker, self.train_men_vecs, encoder_type="context", n_gpu=self.n_gpu, force_exact_search= self.hparams.params['force_exact_search'], batch_size= self.hparams.params['embed_batch_size'], probe_mult_factor= self.hparams.params['probe_mult_factor'])
        
        # Number of entities and mentions
        self.n_entities = len(self.data_module.entity_dictionary)
        self.n_mentions = len(self.data_module.train_processed_data)
        
        # Store golden MST links
        self.gold_links = {}
        # Calculate the number of negative entities and mentions to fetch # Divides the k-nn evenly between entities and mentions
        self.knn_dict = self.hparams["knn"]//2
        self.knn_men = self.hparams["knn"] - self.knn_dict
        
        'IV.3) knn search : indice and distance of k closest mentions and entities'
        logger.info("Starting KNN search...")
        # INFO: Fetching all sorted mentions to be able to filter to within-doc later=
        n_men_to_fetch = len(self.train_men_embeddings) if self.hparams.params["use_types"] else self.knn_men + self.data_module.max_gold_cluster_len
        n_ent_to_fetch = self.knn_dict + 1 # +1 accounts for the possibility of self-reference
        if not self.hparams.params["use_types"]:
            _, self.dict_nns = self.train_dict_index.search(self.train_men_embeddings, n_ent_to_fetch)
            _, self.men_nns = self.train_men_index.search(self.train_men_embeddings, n_men_to_fetch)
        else:
            self.dict_nns = -1 * np.ones((len(self.train_men_embeddings), n_ent_to_fetch))
            self.men_nns = -1 * np.ones((len(self.train_men_embeddings), n_men_to_fetch))
            for entity_type in self.train_men_indexes:
                self.men_embeds_by_type = self.train_men_embeddings[self.men_idxs_by_type[entity_type]]
                _, self.dict_nns_by_type = self.train_dict_indexes[entity_type].search(self.men_embeds_by_type, n_ent_to_fetch)
                _, self.men_nns_by_type = self.train_men_indexes[entity_type].search(self.men_embeds_by_type, min(n_men_to_fetch, len(self.men_embeds_by_type)))
                self.dict_nns_idxs = np.array(list(map(lambda x: self.dict_idxs_by_type[entity_type][x], self.dict_nns_by_type)))
                self.men_nns_idxs = np.array(list(map(lambda x: self.men_idxs_by_type[entity_type][x], self.men_nns_by_type)))
                for i, idx in enumerate(self.men_idxs_by_type[entity_type]):
                    self.dict_nns[idx] = self.dict_nns_idxs[i]
                    self.men_nns[idx][:len(self.men_nns_idxs[i])] = self.men_nns_idxs[i]
        logger.info("Search finished")
        
        self.total_skipped = self.total_knn_men_negs = 0
        
    
    
    def on_train_epoch_end(self):
        # To do at the end of each epoch
        # May not need it 
        pass
    
    def on_after_backward(self): #DD11
        # After .backward()
        if (self.trainer.global_step + 1) % self.trainer.accumulate_grad_batches == 0:
            torch.nn.utils.clip_grad_norm_(
                self.parameters(), self.trainer.grad_clip_norm
            )
            
    # RR2 : Most of the info is already in the Trainer's callback        
    # def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx = 0):
    #     # end of training batch
    #     'IV.4.E) Information about the training (step, epoch, average_loss)'
    #     n_print_iters = self.hparams.params["print_interval"] * self.trainer.accumulate_grad_batches #29
    #     if (batch_idx + 1) % n_print_iters == 0:
    #         # DD13
    #         self.log("train/average_loss", self.tr_loss / n_print_iters, on_step=False, on_epoch=True, prog_bar=True)
    #         if self.total_skipped > 0:
    #             self.log("train/queries_without_negs", self.total_skipped / n_print_iters, on_step=False, on_epoch=True)
    #             self.log("train/negative_mentions_per_query", self.total_knn_men_negs / n_print_iters, on_step=False, on_epoch=True)
            
    #         # Reset your tracking variables for the next interval
    #         self.total_skipped = 0
    #         self.total_knn_men_negs = 0
    #         self.tr_loss = 0

    #     # DD14
    #     '''
    #     # Regular checks on model performance against a validation dataset without interrupting the training more often than desired
    #     if self.hparams.params["eval_interval"] != -1: #A31
    #         if (batch_idx + 1) % (self.hparams.params["eval_interval"] * self.hparams.params["gradient_accumulation_steps"]) == 0:
    #             logger.info("Evaluation on the development dataset")
    #             evaluate(
    #                 self.reranker, self.data_module.entity_dict_vecs, self.data_module.valid_men_vecs, device=self.device, logger=logger, knn=self.hparams["knn"], n_gpu=self.n_gpu,
    #                 entity_data=self.data_module.entity_dictionary, query_data=self.valid_processed_data, silent=self.hparams.params["silent"],
    #                 use_types=self.hparams.params['use_types'] or self.hparams.params["use_types_for_eval"], embed_batch_size=self.hparams.params["embed_batch_size"],
    #                 force_exact_search=self.hparams.params['use_types'] or self.hparams.params["use_types_for_eval"] or self.hparams.params["force_exact_search"],
    #                 probe_mult_factor=self.hparams.params['probe_mult_factor'], within_doc=self.hparams.params['within_doc'],
    #                 context_doc_ids=self.data_module.valid_context_doc_ids
    #             )
    #             self.model.train()
    #             logger.info("\n")
    #     '''
        
    #     pass
    
    
    def validation_step(self, batch, batch_idx): #DD14

        max_acc, dict_acc, embed_and_index_dict = evaluate(
            self.reranker, 
            self.data_module.entity_dict_vecs, 
            self.data_module.valid_men_vecs, 
            device=self.device, 
            logger=logger, 
            # knn=self.hparams.params["knn"], 
            n_gpu=self.n_gpu,
            entity_data=self.data_module.entity_dictionary, 
            query_data=self.data_module.valid_processed_data, 
            # silent=self.hparams.params["silent"],
            use_types=self.hparams.params['use_types'], #use_types=self.hparams.params['use_types'] or self.hparams.params["use_types_for_eval"] 
            embed_batch_size=self.hparams.params["embed_batch_size"],
            force_exact_search=self.hparams.params['use_types'] or self.hparams.params["force_exact_search"], #force_exact_search= self.hparams.params['use_types'] or use_types or self.hparams.params["use_types_for_eval"] or self.hparams.params["force_exact_search"]
            probe_mult_factor=self.hparams.params['probe_mult_factor'], 
            within_doc=self.hparams.params['within_doc'],
            context_doc_ids=self.data_module.valid_context_doc_ids
        )
        self.log("max_acc", max_acc, on_epoch=True)
        self.log("dict_acc", dict_acc, on_epoch=True)
        self.log("embedding_and_indexing", embed_and_index_dict, on_epoch=True)
        
        
    def test_step(self, batch, batch_idx): #DD14
        max_acc, dict_acc, embed_and_index_dict = evaluate(
            self.reranker, 
            self.data_module.entity_dict_vecs, 
            self.data_module.test_men_vecs, 
            device=self.device, 
            logger=logger, 
            # knn=self.hparams.params["knn"], 
            n_gpu=self.n_gpu,
            entity_data=self.data_module.entity_dictionary, 
            query_data=self.data_module.test_processed_data,
            # silent=self.hparams.params["silent"],
            use_types=self.hparams.params['use_types'], #use_types=self.hparams.params['use_types'] or self.hparams.params["use_types_for_eval"] 
            embed_batch_size=self.hparams.params["embed_batch_size"],
            force_exact_search=self.hparams.params['use_types'] or self.hparams.params["force_exact_search"], #force_exact_search= self.hparams.params['use_types'] or use_types or self.hparams.params["use_types_for_eval"] or self.hparams.params["force_exact_search"]
            probe_mult_factor=self.hparams.params['probe_mult_factor'], 
            within_doc=self.hparams.params['within_doc'],
            context_doc_ids=self.data_module.test_context_doc_ids
        )
        self.log("max_acc", max_acc, on_epoch=True)
        self.log("dict_acc", dict_acc, on_epoch=True)
        self.log("embedding_and_indexing", embed_and_index_dict, on_epoch=True)
        

    def on_train_end(self):
        execution_time = (time.time() - self.start_time) / 60
        self.logger.info(f"The training took {execution_time} minutes")

    


def main(params):
    
    # Model output path
    root_output_dir = params["model_output_path"]  # Root directory for all outputs
    experiment_name = f"experiment_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
    model_output_path = os.path.join(root_output_dir, experiment_name)

    # Ensuring the output directory exists
    os.makedirs(model_output_path, exist_ok=True)
    
    # Initialize the data module
    data_module = ArboelDataModule(params)

    # Initialize the model
    model = LitArboel(params,data_module)
    
    # retrieve info during training
    model_checkpoint = ModelCheckpoint(
    monitor='max_acc',  # Metric to monitor
    dirpath=model_output_path,  # Directory to save the model
    filename='{epoch}-{max_acc:.2f}',  # Saves the model with epoch and val_loss in the filename
    save_top_k= 1,  # Number of best models to save; -1 means save all of them
    mode='max',  # 'max' means the highest max_acc will be considered as the best model
    verbose=True,  # Logs a message whenever a model checkpoint is saved
    )

    # Initialize PyTorch Lightning trainer
    trainer = L.Trainer(
        max_epochs=params["num_train_epochs"],
        devices=1,
        accelerator="gpu",
        strategy="ddp",
        val_check_interval=1, 
        enable_progress_bar=True,
        callbacks=[model_checkpoint]
    )

    # Train the model
    trainer.fit(model, datamodule=data_module)
    
    best_model_path = model_checkpoint.best_model_path
    best_model_score = model_checkpoint.best_model_score
    logger.info(f"Best model saved at {best_model_path} with score {best_model_score}")
    

if __name__ == "__main__":
    parser = BlinkParser(add_model_args=True)
    parser.add_training_args()
    args = parser.parse_args()
    print(args)
    main(args.__dict__)


In [None]:
params = {
"model_output_path": "(str) Path where the model outputs and logs should be saved",
"data_path": "(str) Path to dictionary.pickle file",
"knn" : "(int) Number of nearest neighbors to consider" ,
"use_types" : "(bool) Whether to use type-specific indexes" ,
"max_context_length" : "(int) Maximum length of context tokens" ,
"max_cand_length" : "(int) Maximum length of candidate entity tokens" ,
"context_key" : "(str) Key for context in the processed data. (= “context”)" ,
"debug" : "(bool) If set to True, run in debug mode (test only on 200 first samples)" ,
"gold_arbo_knn" : "(int) Number of gold nearest neighbors to consider" ,
"within_doc" : "(bool) If True, constrain evaluation within documents" ,
"within_doc_skip_strategy" : "specific strategy used with within_doc" ,
"filter_unlabeled" : "(bool) If True, filter out samples without labels" ,
"type_optimization" : "(string) Type of optimization to use (e.g., AdamW)", 
"learning_rate" : "(float) Learning rate for the optimizer" ,
"warmup_proportion" : "(float) Proportion of warmup steps in the total number of training steps" ,
"fp16" : "(bool) Whether to use mixed precision training" ,
"embed_batch_size" : "(int) Batch size for embedding during evaluation" ,
"force_exact_search" : "(bool) If True, use exact search methods during evaluation" ,
"probe_mult_factor" : "(int) Multiplier factor used in index building for probing in approximate search" ,
"pos_neg_loss" : "Specific positive + negative type of loss" ,
"pickle_src_path" : "(string) Path to the directory containing preprocessed data in pickle format" , 
"use_types_for_eval" : "(bool) Whether to use type-specific indexes for evaluation" ,
"drop_entities" : "(bool) If True, drop entities from training data based on some criterion" ,
"drop_set" : "(string) Path to the mention data from which entities are to be dropped" ,
"farthest_neighbor" : "(bool) If True, consider the farthest neighbor instead of the nearest for positive linkage" ,
"rand_gold_arbo" : "(bool) If True, use random permutation for gold MST approximation",
"bert_model" : "bert-base-uncased / bert-large-uncased / bert-base-cased",
"out_dim": "(int = 768) Output dimension = length of the encoded mention/entity",
"pull_from_layer" : "(int=11 or 23) From which layer shall we pull the encoded embedding",
"add_linear" : "(bool : True) Whether an additional linear transformation is applied to the output"
}