In [4]:
%load_ext autoreload
%autoreload 2

In [5]:
import sys
sys.path.append('../../..')
sys.path.append('..')

import os
import random
import time
import pickle
import logger
import numpy as np
import torch
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_transformers.tokenization_bert import BertTokenizer
from transformers import AutoTokenizer, AutoModel
from datetime import datetime
from typing import Optional, Union

from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler) 
from pytorch_transformers.optimization import WarmupLinearSchedule 
from scipy.sparse.csgraph import minimum_spanning_tree 
# csgraph = compressed sparse graph
from scipy.sparse import csr_matrix
# csr_matrix = compressed sparse row matrices
from collections import Counter 

import blink.biencoder.data_process_mult as data_process
# import blink.biencoder.eval_cluster_linking as eval_cluster_linking
import blink.candidate_ranking.utils as utils
from blink.biencoder.biencoder import BiEncoderRanker
from blink.common.optimizer import get_bert_optimizer
from blink.common.params import BlinkParser

from IPython import embed 
from tqdm import tqdm

import logging

# Configure the logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Now you can use logger.info()
logger.info("This is an info message.")

[17/Mar/2024 11:25:29] INFO - This is an info message.


In [6]:
import blink.biencoder.eval_cluster_linking as eval_cluster_linking
from special_partition.special_partition import cluster_linking_partition

In [7]:
from LightningDataModule import ArboelDataModule

### Evaluate / filter_by_context_doc_id / loss functions

In [73]:
def evaluate(
    reranker,
    valid_dict_vecs, 
    valid_men_vecs, 
    # device, #no longer used
    logger, 
    # knn,  #not even used
    # n_gpu, 
    entity_data, 
    query_data,
    silent=False, 
    use_types=False, 
    embed_batch_size=768, 
    force_exact_search=False, 
    probe_mult_factor=1,
    within_doc=False, 
    context_doc_ids=None ):
    '''
    Description 
    -----------
    1) Computes embeddings and indexes for entities and mentions. 
    2) Performs k-nearest neighbors (k-NN) search to establish relationships between them.
    3) Constructs graphs based on these relationships.
    4) Evaluates the model's accuracy by analyzing how effectively the model can link mentions to the correct entities.
    
    Parameters 
    ----------
    reranker : BiEncoderRanker
        NN-based ranking model
    valid_dict_vec : list or ndarray
        Ground truth dataset containing the entities
    valid_men_vecs : list or ndarray
        Dataset containing mentions
    device : str
        cpu or gpu
    logger : 'Logger' object
        Logging object used to record messages
    knn : int
        Number of neighbors
    n_gpu : int
        Number of gpu
    entity_data : list or dict
        Entities from the data
    query_data : list or dict
        Queries / mentions against which the entities are evaluated
    silent=False : bool 
        When set to "True", likely suppresses the output or logging of progress updates to keep the console output clean.
    use_types=False : bool
        A boolean flag that indicates whether or not to use type-specific indexes for entities and mentions
    embed_batch_size=128 : int
        The batch size to use when processing embeddings.
    force_exact_search=False : bool
        force the embedding process to use exact search methods rather than approximate methods.
    probe_mult_factor=1 : int
        A multiplier factor used in index building for probing in case of approximate search (bigger = better but slower)
    within_doc=False : bool
        Boolean flag that indicates whether the evaluation should be constrained to within-document contexts
    context_doc_ids=None : bool
        This would be used in conjunction with within_doc to limit evaluations within the same document.
    '''
    torch.cuda.empty_cache() # Empty the CUDA cache to free up GPU memory
    
    n_entities = len(valid_dict_vecs) # total number of entities
    n_mentions = len(valid_men_vecs) # total number of mentions
    max_knn = 8 # max number of neighbors
    
    joint_graphs = {} # Store results of the NN search and distance between entities and mentions
    
    for k in [0, 1, 2, 4, 8]:
        joint_graphs[k] = { #DD3
            "rows": np.array([]),
            "cols": np.array([]),
            "data": np.array([]),
            "shape": (n_entities + n_mentions, n_entities + n_mentions),
        }
        
    
    '1) Computes embeddings and indexes for entities and mentions. '
    '''
    This block is preparing the data for evaluation by transforming raw vectors into a format that can be efficiently used for retrieval and comparison operations
    '''
    if use_types: # corpus = entity data
                  # corpus is a collection of entities, which is used to build type-specific search indexes if provided.
        '''
        With a Corpus : Multiple type-specific indexes are created, allowing for more targeted and efficient searches within specific categories of entities.
        'dict_embeds' and 'men_embeds': The resulting entity and mention embeddings.
        'dict_indexes' and 'men_indexes': Dictionary that will store search indexes (!= indices)for each unique entity type found in the corpus
        'dict_idxs_by_type' and 'men_idxs_by_type': Dictionary to store indices of the corpus elements, grouped by their entity type.
        !!! idxs = indices / indexes = indexes !!!
        '''
        logger.info("Eval: Dictionary: Embedding and building index") # For entities
        dict_embeds, dict_indexes, dict_idxs_by_type = data_process.embed_and_index(
            reranker,
            valid_dict_vecs,
            encoder_type="candidate",
            # n_gpu=n_gpu,
            corpus=entity_data, 
            force_exact_search=force_exact_search, 
            batch_size=embed_batch_size, 
            probe_mult_factor=probe_mult_factor, 
            )
        logger.info("Eval: Queries: Embedding and building index") # For mentions
        men_embeds, men_indexes, men_idxs_by_type = data_process.embed_and_index(
            reranker,
            valid_men_vecs,
            encoder_type="context",
            # n_gpu=n_gpu,
            corpus=query_data,
            force_exact_search=force_exact_search,
            batch_size=embed_batch_size,
            probe_mult_factor=probe_mult_factor,
        )
    else: # corpus = None
        '''
        Without a Corpus: A single, general index is created for all embeddings, suitable for broad searches across the entire dataset.
        'dict_embeds' and 'men_embeds': The resulting entity and mention embeddings.
        'dict_index' and 'men_index': Dictionary that will store search index
        '''
        logger.info("Eval: Dictionary: Embedding and building index")
        dict_embeds, dict_index = data_process.embed_and_index(
            reranker,
            valid_dict_vecs,
            "candidate",
            # n_gpu=n_gpu,
            force_exact_search=force_exact_search,
            batch_size=embed_batch_size,
            probe_mult_factor=probe_mult_factor,
        )
        logger.info("Eval: Queries: Embedding and building index")
        men_embeds, men_index = data_process.embed_and_index(
            reranker,
            valid_men_vecs,
            "context",
            # n_gpu=n_gpu,
            force_exact_search=force_exact_search,
            batch_size=embed_batch_size,
            probe_mult_factor=probe_mult_factor,
        )


    '2) Performs k-nearest neighbors (k-NN) search to establish relationships between mentions and entities.'
    logger.info("Eval: Starting KNN search...") # An informational message is logged to indicate that the k-NN search is starting.
    # Fetch recall_k (default 16) knn entities for all mentions
    # Fetch (k+1) NN mention candidates; fetching all mentions for within_doc to filter down later
    n_men_to_fetch = len(men_embeds) if within_doc else max_knn + 1 # Number of mentions to fetch
    if not use_types: # Only one index so only need one search
        nn_ent_dists, nn_ent_idxs = dict_index.search(men_embeds, 1) #DD4/DD5 #return the distance and the indice of the closest entity for all mentions in men_embeds
        nn_men_dists, nn_men_idxs = men_index.search(men_embeds, n_men_to_fetch) # return the distances and the indices of the k closest mentions for all mentions in men_embeds
    else: #C Several indexes corresponding to the different entities in entity_data so we can use the specific search index
        # DD6
        # DD7
        nn_ent_idxs = -1 * np.ones((len(men_embeds), 1), dtype=int) # Indice of the closest entity for all mentions in men_embeds
        nn_ent_dists = -1 * np.ones((len(men_embeds), 1), dtype="float64") # Distance of the closest entity for all mentions in men_embeds
        nn_men_idxs = -1 * np.ones((len(men_embeds), n_men_to_fetch), dtype=int) # Indice of k closest mentions for all mentions in men_embeds
        nn_men_dists = -1 * np.ones((len(men_embeds), n_men_to_fetch), dtype="float64") # Distance of the k closest mentions for all mentions in men_embeds
        for entity_type in men_indexes:
            #CC3 Creates a new list only containing the mentions for which type = entity_types
            men_embeds_by_type = men_embeds[men_idxs_by_type[entity_type]] # Only want to search the mentions that belongs to a specific type of entity.
            # Returns the distance and the indice of the closest entity for all mentions in men_embeds by entity type
            nn_ent_dists_by_type, nn_ent_idxs_by_type = dict_indexes[entity_type].search(men_embeds_by_type, 1) 
            nn_ent_idxs_by_type = np.array( #CC4 DD8
                list( #DD9
                    map( # lambda x : acts as a function
                        lambda x: dict_idxs_by_type[entity_type][x], nn_ent_idxs_by_type
                    ) # nn_ent_idxs_by_type is the iterable being processed by the map function
                    # Each element within nn_ent_idxs_by_type is passed to the lambda function as x.
                ) # map alone would return an object, that's why need a list
            )
            # Returns the distance and the indice of the k closest mentions for all mention in men_embeds by entity type
            # Note that here we may not necessarily have k mentions in each entity type which is why we use min(k,len(men_embeds_by_type))
            nn_men_dists_by_type, nn_men_idxs_by_type = men_indexes[entity_type].search(
                men_embeds_by_type, min(n_men_to_fetch, len(men_embeds_by_type))
            )
            nn_men_idxs_by_type = np.array(
                list(
                    map(lambda x: men_idxs_by_type[entity_type][x], nn_men_idxs_by_type)
                )
            )
            for i, idx in enumerate(men_idxs_by_type[entity_type]): #CC5
                nn_ent_idxs[idx] = nn_ent_idxs_by_type[i]
                nn_ent_dists[idx] = nn_ent_dists_by_type[i]
                nn_men_idxs[idx][: len(nn_men_idxs_by_type[i])] = nn_men_idxs_by_type[i]
                nn_men_dists[idx][: len(nn_men_dists_by_type[i])] = nn_men_dists_by_type[i]
    logger.info("Eval: Search finished") # An informational message is logged to indicate that the k-NN search is finished

    '3) Constructs graphs based on these relationships.'
    '''
    nn_ent_dists contain information about distance of the closest entity
    nn_ent_idxs contain information about indice of the closest entity
    nn_men_dists contain information about distance of the k nearest mentions
    nn_men_idxs contain information about indice of the k nearest mentions
    - We can fill in the "rows" part (=start nodes) of the graph in the order of the mentions
    - We can fill in the "cols" part (=end nodes) of the graph with nn_ent_idxs and nn_men_idxs
    - We can fill in the "data" part (=weights) of the graph with nn_ent_dists and nn_men_dists
    '''
    logger.info("Eval: Building graphs")
    for men_query_idx, men_embed in enumerate(
        tqdm(men_embeds, total=len(men_embeds), desc="Eval: Building graphs")
    ):
        # Get nearest entity candidate
        dict_cand_idx = nn_ent_idxs[men_query_idx][0] # Use of [0] to retrieve a scalar and not an 1D array
        dict_cand_score = nn_ent_dists[men_query_idx][0]

        # Filter candidates to remove -1s, mention query, within doc (if reqd.), and keep only the top k candidates
        filter_mask_neg1 = nn_men_idxs[men_query_idx] != -1 # bool ndarray. Ex : np.array([True, False, True, False])
        men_cand_idxs = nn_men_idxs[men_query_idx][filter_mask_neg1] # Only keep the elements != -1
        men_cand_scores = nn_men_dists[men_query_idx][filter_mask_neg1]

        if within_doc:
            men_cand_idxs, wd_mask = filter_by_context_doc_id(
                men_cand_idxs,
                context_doc_ids[men_query_idx],
                context_doc_ids,
                return_numpy=True,
            )
            men_cand_scores = men_cand_scores[wd_mask]
        
        # Filter self-reference + limits the number of candidate to 'max_knn'
        filter_mask = men_cand_idxs != men_query_idx
        men_cand_idxs, men_cand_scores = (
            men_cand_idxs[filter_mask][:max_knn],
            men_cand_scores[filter_mask][:max_knn],
        )

        # Add edges to the graphs
        for k in joint_graphs:
            joint_graph = joint_graphs[k] # There is no "s" in "joint_graph", it's not the same ! 
            # Add mention-entity edge
            joint_graph["rows"] = np.append( # Mentions are offset by the total number of entities to differentiate mention nodes from entity nodes
                joint_graph["rows"], [n_entities + men_query_idx]
            )  
            joint_graph["cols"] = np.append(joint_graph["cols"], dict_cand_idx)
            joint_graph["data"] = np.append(joint_graph["data"], dict_cand_score)
            if k > 0:
                # Add mention-mention edges
                joint_graph["rows"] = np.append(
                    joint_graph["rows"],
                    [n_entities + men_query_idx] * len(men_cand_idxs[:k]), # creates an array where the starting node (current mention) is repeated len(men_cand_idxs[:k]) times
                ) 
                joint_graph["cols"] = np.append(
                    joint_graph["cols"], n_entities + men_cand_idxs[:k]
                )
                joint_graph["data"] = np.append(
                    joint_graph["data"], men_cand_scores[:k]
                )
    
    "4) Evaluates the model's accuracy by analyzing how effectively the model can link mentions to the correct entities."    
    
    best_result = {'accuracy': 0}
    
    dict_acc = {}
    max_eval_acc = -1.
    for k in joint_graphs:
        logger.info(f"\nEval: Graph (k={k}):")
        # Partition graph based on cluster-linking constraints (inference procedure)
        partitioned_graph, clusters = eval_cluster_linking.partition_graph(
            joint_graphs[k], n_entities, directed=True, return_clusters=True)
        # Infer predictions from clusters
        result = eval_cluster_linking.analyzeClusters(clusters, entity_data, query_data, k)
        acc = float(result['accuracy'].split(' ')[0])
        best_result['accuracy'] = acc if acc >= best_result['accuracy'] else best_result['accuracy']
        dict_acc[f'k{k}'] = acc
        max_eval_acc = max(acc, max_eval_acc)
        logger.info(f"Eval: accuracy for graph@k={k}: {acc}%")
    logger.info(f"Eval: Best accuracy: {max_eval_acc}%")
    embed_and_index_dict = {'dict_embeds': dict_embeds, 'dict_indexes': dict_indexes, 'dict_idxs_by_type': dict_idxs_by_type} if use_types else {'dict_embeds': dict_embeds, 'dict_index': dict_index}
    return max_eval_acc, dict_acc, embed_and_index_dict



def filter_by_context_doc_id(mention_idxs, doc_id, doc_id_list, return_numpy=False):
    '''
    Description 
    -----------
    Filters and returns mention indices that belong to a specific document identified by "doc_id".
    Ensures that the analysis are constrained within the context of that particular document.
    
    Parameters 
    ----------
    - mention_idxs : ndarray(int) of dim = (number of mentions)
    Represents the indices of mentions
    - doc_id : int 
    Indice of the target document
    - doc_id_list : ndarray(int) of dim = (number of mentions)
    Array of integers, where each element is a document ID associated with the corresponding mention in mention_idxs. 
    The length of doc_id_list should match the total number of mentions referenced in mention_idxs.
    - return_numpy : bool
    A flag indicating whether to return the filtered list of mention indices as a NumPy array. 
    If True, the function returns a NumPy array; otherwise, it returns a list
    -------
    Outputs: 
    - mask : ndarray(bool) of dim = (number of mentions)
    Mask indicating where each mention's document ID (from doc_id_list) matches the target doc_id
    - mention_idxs : 
    Only contains mention indices that belong to the target document (=doc_id).
    '''
    mask = [doc_id_list[i] == doc_id for i in mention_idxs]
    if isinstance(mention_idxs, list): # Test if mention_idxs = list. Return a bool
        mention_idxs = np.array(mention_idxs) 
    mention_idxs = mention_idxs[mask] # possible only if mention_idxs is an array, not a list
    if not return_numpy:
        mention_idxs = list(mention_idxs)
    return mention_idxs, mask



def loss_function(reranker, 
    params, 
    forward_output, 
    data_module, 
    n_entities, 
    knn_dict, 
    batch_context_inputs, 
    accumulate_grad_batches
    ):
    '''
    Compute the loss function during the training.
    
    Parameters
    ----------
    - reranker : BiEncoderRanker
    NN-based ranking model
    - params : dict
    Contains most of the relevant keys for training (embed_batch_size, train_batch_size, n_gpu, force_exact_search etc...)
    - forward_output : dict
    Output of the forward() method
    - data_module : Instance of ArboelDataModule class
    - n_entities : int
    Total number of entities
    - knn_dict : int (self.knn_dict = self.hparams["knn"]//2)
    number of negative entities to fetch. It divides the k-nn evenly between entities and mentions 
    - accumulate_grad_batches : int
    Number of steps to accumulate gradients
    '''
    
    # Compute the loss
    loss_dual_negs = loss_ent_negs = 0
    # loss of a batch includes both negative mention and entity inputs (alongside positive examples ofc)
    loss_dual_negs, _ = reranker(forward_output['context_inputs'], label_input=forward_output['label_inputs'], mst_data={
        'positive_embeds': forward_output['positive_embeds'],
        'negative_dict_inputs': forward_output['negative_dict_inputs'],
        'negative_men_inputs': forward_output['negative_men_inputs']
    }, pos_neg_loss=params["pos_neg_loss"]) #A27
    skipped_context_inputs = []
    if forward_output['skipped'] > 0 and not params["within_doc_skip_strategy"]: #A28
        skipped_negative_dict_inputs = torch.tensor(
            list(map(lambda x: data_module.entity_dict_vecs[x].numpy(), skipped_negative_dict_inputs)))
        skipped_positive_embeds = []
        for pos_idx in forward_output['skipped_positive_idxs']:
            if pos_idx < n_entities:
                pos_embed = reranker.encode_candidate(data_module.entity_dict_vecs[pos_idx:pos_idx + 1],
                                                        requires_grad=True)
            else:
                pos_embed = reranker.encode_context(
                    data_module.train_men_vecs[pos_idx - n_entities:pos_idx - n_entities + 1], requires_grad=True)
            skipped_positive_embeds.append(pos_embed)
        skipped_positive_embeds = torch.cat(skipped_positive_embeds)
        skipped_context_inputs = batch_context_inputs[~np.array(forward_output['context_inputs_mask'])]
        skipped_context_inputs = skipped_context_inputs
        skipped_label_inputs = torch.tensor([[1] + [0] * (knn_dict)] * len(skipped_context_inputs),
                                    dtype=torch.float32)
        #DD18 loss of a batch that only includes negative entity inputs.
        loss_ent_negs, _ = reranker(skipped_context_inputs, label_input=skipped_label_inputs, mst_data={
            'positive_embeds': skipped_positive_embeds,
            'negative_dict_inputs': skipped_negative_dict_inputs,
            'negative_men_inputs': None
        }, pos_neg_loss=params["pos_neg_loss"])
            
    # len(context_input) = Number of mentions in the batch that successfully found negative entities and mentions.
    # len(skipped_context_inputs): Number of mentions in the batch that only found negative entities.
    loss = ((loss_dual_negs * len(forward_output['context_inputs']) + loss_ent_negs * len(skipped_context_inputs)) / (len(forward_output['context_inputs']) + len(skipped_context_inputs))) / accumulate_grad_batches
    return loss

In [74]:
'Model Training'
class LitArboel(L.LightningModule):
    def __init__(
        self, 
        params
        ):
        '''
        - params : dict
        Contains most of the relevant keys for training (embed_batch_size, train_batch_size, n_gpu, force_exact_search etc...)
        - data_module : Instance of ArboelDataModule class
        '''
        super(LitArboel, self).__init__()
        self.save_hyperparameters(params) #DD1
        
        self.reranker = BiEncoderRanker(params)
        # self.tokenizer = self.reranker.tokenizer
        self.model = self.reranker.model
        
    # def setup(self, stage: Optional[str] = None):
    #     self.entity_dict_vecs = self.trainer.datamodule.entity_dict_vecs.to(self.device)
    #     self.train_men_vecs = self.trainer.datamodule.train_men_vecs.to(self.device)
        
        
    def forward(self, batch_context_inputs, candidate_idxs, n_gold, mention_idxs):
        """
        Description
        -----------
        Processes a batch of input data to generate embeddings, and identifies positive and negative examples for training. 
        It handles the construction of mention-entity graphs, computes nearest neighbors, and organizes the data for subsequent loss calculation.
        
        Parameters
        ----------
        - “batch_context_inputs” : Tensor
            Tensor containing IDs of (mention + surrounding context) tokens. Shape: (batch_size, context_length) 
        - “candidate_idxs” : Tensor
            Tensor with indices pointing to the entities in the entity dictionary that are considered correct labels for the mention. Shape: (batch_size, candidate_count)
        - “n_gold” : Tensor
            Number of labels (=entities) associated with the mention. Shape: (batch_size,)
        - “mention_idx” : Tensor
            Tensor containing a sequence of integers from 0 to N-1 (N = number of mentions in the dataset) serbing as a unique identifier for each mention.
        
        Return
        ------
        - label_inputs : Tensor
            Tensor of binary labels indicating the correct candidates. Shape: (batch_size, 1 + knn_dict + knn_men), where 1 represents the positive example and the rest are negative examples.
        - context_inputs : Tensor
            Processed batch context inputs, filtered to remove mentions with no negative examples. Shape: (filtered_batch_size, context_length).
        - negative_men_inputs : Tensor
            Tensor of negative mention inputs. Shape: (filtered_batch_size * knn_men,)
        - negative_dict_inputs : Tensor
            Tensor of negative dictionary (entity) inputs. Shape: (filtered_batch_size * knn_dict,)
        - positive_embeds : Tensor
            Tensor of embeddings for the positive examples. Shape: (filtered_batch_size, embedding_dim)
        - skipped : int 
            The number of mentions skipped due to lack of valid negative examples.
        - skipped_positive_idxs : list(int)
            List of indices for positive examples that were skipped.
        - skipped_negative_dict_inputs :
            Tensor of negative dictionary inputs for skipped examples. Shape may vary based on the number of skipped examples and available negative dictionary entries.
        - context_inputs_mask : list(bool)
            Mask indicating which entries in batch_context_inputs were retained after filtering out mentions with no negative examples.
        """
        
        # mentions within the batch
        mention_embeddings = self.train_men_embeddings[mention_idxs.cpu()]
        if len(mention_embeddings.shape) == 1:
            mention_embeddings = np.expand_dims(mention_embeddings, axis=0)
        # Convert Back to Tensor and Move to GPU
        mention_embeddings = torch.from_numpy(mention_embeddings).to(self.device)

        positive_idxs = []
        negative_dict_inputs = []
        negative_men_inputs = []

        skipped_positive_idxs = []
        skipped_negative_dict_inputs = []

        min_neg_mens = float('inf')
        skipped = 0
        context_inputs_mask = [True]*len(batch_context_inputs)
        
        'IV.4.B) For each mention within the batch'
        # For each mention within the batch
        for m_embed_idx, m_embed in enumerate(mention_embeddings):
            mention_idx = int(mention_idxs[m_embed_idx])
            #CC11 ground truth entities of the mention "mention_idx"
            gold_idxs = set(self.trainer.datamodule.train_processed_data[mention_idx]['label_idxs'][:n_gold[m_embed_idx]])
            
            # TEMPORARY: Assuming that there is only 1 gold label, TODO: Incorporate multiple case
            assert n_gold[m_embed_idx] == 1

            if mention_idx in self.gold_links:
                gold_link_idx = self.gold_links[mention_idx]
            else:
                'IV.4.B.a) Create the graph with positive edges'
                # This block creates all the positive edges of the mention in this iteration
                # Run MST on mention clusters of all the gold entities of the current query mention to find its positive edge
                rows, cols, data, shape = [], [], [], (self.n_entities+self.n_mentions, self.n_entities+self.n_mentions)
                seen = set()

                # Set whether the gold edge should be the nearest or the farthest neighbor
                sim_order = 1 if self.hparams["farthest_neighbor"] else -1 #A26

                for cluster_ent in gold_idxs:
                    #CC12 IDs of all the mentions inside the gold cluster with entity id = "cluster_ent"
                    cluster_mens = self.trainer.datamodule.train_gold_clusters[cluster_ent]

                    if self.hparams["within_doc"]:
                        # Filter the gold cluster to within-doc
                        cluster_mens, _ = filter_by_context_doc_id(cluster_mens,
                                                                    self.trainer.datamodule.train_context_doc_ids[mention_idx],
                                                                    self.trainer.datamodule.train_context_doc_ids)
                    
                    # weights for all the mention-entity links inside the cluster of the current mention
                    to_ent_data = self.train_men_embeddings[cluster_mens] @ self.train_dict_embeddings[cluster_ent].T

                    # weights for all the mention-mention links inside the cluster of the current mention
                    to_men_data = self.train_men_embeddings[cluster_mens] @ self.train_men_embeddings[cluster_mens].T

                    if self.hparams['gold_arbo_knn'] is not None:
                        # Descending order of similarity if nearest-neighbor, else ascending order
                        sorti = np.argsort(sim_order * to_men_data, axis=1)
                        sortv = np.take_along_axis(to_men_data, sorti, axis=1)
                        if self.hparams["rand_gold_arbo"]:
                            randperm = np.random.permutation(sorti.shape[1])
                            sortv, sorti = sortv[:, randperm], sorti[:, randperm]

                    for i in range(len(cluster_mens)):
                        from_node = self.n_entities + cluster_mens[i]
                        to_node = cluster_ent
                        # Add mention-entity link
                        rows.append(from_node)
                        cols.append(to_node)
                        data.append(-1 * to_ent_data[i])
                        if self.hparams['gold_arbo_knn'] is None:
                            # Add forward and reverse mention-mention links over the entire MST
                            for j in range(i+1, len(cluster_mens)):
                                to_node = self.n_entities + cluster_mens[j]
                                if (from_node, to_node) not in seen:
                                    score = to_men_data[i,j]
                                    rows.append(from_node)
                                    cols.append(to_node)
                                    data.append(-1 * score) # Negatives needed for SciPy's Minimum Spanning Tree computation
                                    seen.add((from_node, to_node))
                                    seen.add((to_node, from_node))
                        else:
                            # Approximate the MST using <gold_arbo_knn> nearest mentions from the gold cluster
                            added = 0
                            approx_k = min(self.hparams['gold_arbo_knn']+1, len(cluster_mens))
                            for j in range(approx_k):
                                if added == approx_k - 1:
                                    break
                                to_node = self.n_entities + cluster_mens[sorti[i, j]]
                                if to_node == from_node:
                                    continue
                                added += 1
                                if (from_node, to_node) not in seen:
                                    score = sortv[i, j]
                                    rows.append(from_node)
                                    cols.append(to_node)
                                    data.append(
                                        -1 * score)  # Negatives needed for SciPy's Minimum Spanning Tree computation
                                    seen.add((from_node, to_node))

                'IV.4.B.b) Fine tuning with inference procedure to get a mst'
                # Creates MST with entity constraint (inference procedure)
                csr = csr_matrix((-sim_order * np.array(data), (rows, cols)), shape=shape)
                # Note: minimum_spanning_tree expects distances as edge weights
                mst = minimum_spanning_tree(csr).tocoo()
                # Note: cluster_linking_partition expects similarities as edge weights # Convert directed to undirected graph
                rows, cols, data = cluster_linking_partition(np.concatenate((mst.row, mst.col)), # cluster_linking_partition is imported from eval_cluster_linking
                                                                np.concatenate((mst.col, mst.row)),
                                                                np.concatenate((sim_order * mst.data, sim_order * mst.data)),
                                                                self.n_entities,
                                                                directed=True,
                                                                silent=True)
                assert np.array_equal(rows - self.n_entities, cluster_mens)
                
                for i in range(len(rows)):
                    men_idx = rows[i] - self.n_entities
                    if men_idx in self.gold_links:
                        continue
                    assert men_idx >= 0
                    add_link = True
                    # Store the computed positive edges for the mentions in the clusters only if they have the same gold entities as the query mention
                    for l in self.trainer.datamodule.train_processed_data[men_idx]['label_idxs'][:self.trainer.datamodule.train_processed_data[men_idx]['n_labels']]:
                        if l not in gold_idxs:
                            add_link = False
                            break
                    if add_link:
                        self.gold_links[men_idx] = cols[i]
                gold_link_idx = self.gold_links[mention_idx]
                
            'IV.4.B.c) Retrieve the pre-computed nearest neighbors'
            knn_dict_idxs = self.dict_nns[mention_idx]
            knn_dict_idxs = knn_dict_idxs.astype(np.int64).flatten()
            knn_men_idxs = self.men_nns[mention_idx][self.men_nns[mention_idx] != -1]
            knn_men_idxs = knn_men_idxs.astype(np.int64).flatten()
            if self.hparams['within_doc']:
                knn_men_idxs, _ = filter_by_context_doc_id(knn_men_idxs,
                                                        self.trainer.datamodule.train_context_doc_ids[mention_idx],
                                                        self.trainer.datamodule.train_context_doc_ids, return_numpy=True)
            'IV.4.B.d) Add negative examples'
            neg_mens = list(knn_men_idxs[~np.isin(knn_men_idxs, np.concatenate([self.trainer.datamodule.train_gold_clusters[gi] for gi in gold_idxs]))][:self.knn_men])
            # Track queries with no valid mention negatives
            if len(neg_mens) == 0:
                context_inputs_mask[m_embed_idx] = False
                skipped_negative_dict_inputs += list(knn_dict_idxs[~np.isin(knn_dict_idxs, list(gold_idxs))][:self.knn_dict])
                skipped_positive_idxs.append(gold_link_idx)
                skipped += 1
                continue
            else:
                min_neg_mens = min(min_neg_mens, len(neg_mens))
            negative_men_inputs.append(knn_men_idxs[~np.isin(knn_men_idxs, np.concatenate([self.trainer.datamodule.train_gold_clusters[gi] for gi in gold_idxs]))][:self.knn_men])
            negative_dict_inputs += list(knn_dict_idxs[~np.isin(knn_dict_idxs, list(gold_idxs))][:self.knn_dict])
            # Add the positive example
            positive_idxs.append(gold_link_idx)

        
        'IV.4.C) Skip this iteration if no suitable negative examples found'
        if len(negative_men_inputs) == 0 :
            return None #DD8 instead of continue
        
        # Sets the minimum number of negative mentions found across all processed mentions in the current batch
        self.knn_men = min_neg_mens
        
        # This step ensures that each mention is compared against a uniform number of negative mentions
        filtered_negative_men_inputs = []
        for row in negative_men_inputs:
            filtered_negative_men_inputs += list(row[:self.knn_men])
        negative_men_inputs = filtered_negative_men_inputs

        # Assertions for Data Integrity
        assert len(negative_dict_inputs) == (len(mention_embeddings) - skipped) * self.knn_dict
        assert len(negative_men_inputs) == (len(mention_embeddings) - skipped) * self.knn_men

        self.total_skipped += skipped
        self.total_knn_men_negs += self.knn_men

        # Convert to tensors
        negative_dict_inputs = torch.tensor(list(map(lambda x: self.trainer.datamodule.entity_dict_vecs[x].numpy(), negative_dict_inputs)))
        negative_men_inputs = torch.tensor(list(map(lambda x: self.trainer.datamodule.train_men_vecs[x].numpy(), negative_men_inputs)))
        
        # Labels indicating the correct candidates. Used for computing loss.
        positive_embeds = []
        for pos_idx in positive_idxs:
            if pos_idx < self.n_entities:
                # print("Device of input tensors: entity_dict_vecs / train_men_vecs", self.entity_dict_vecs.device, self.train_men_vecs.device)
                # pos_embed = self.reranker.encode_candidate(self.trainer.datamodule.entity_dict_vecs.to(self.device)[pos_idx:pos_idx + 1], requires_grad=True)
                pos_embed = self.reranker.encode_candidate(self.trainer.datamodule.entity_dict_vecs[pos_idx:pos_idx + 1], requires_grad=True)
            else:
                # print("Device of input tensors: entity_dict_vecs / train_men_vecs", self.entity_dict_vecs.device, self.train_men_vecs.device)
                pos_embed = self.reranker.encode_context(self.trainer.datamodule.train_men_vecs[pos_idx - self.n_entities:pos_idx - self.n_entities + 1], requires_grad=True)
            positive_embeds.append(pos_embed)
        positive_embeds = torch.cat(positive_embeds)
        
        # Remove mentions with no negative examples
        context_inputs = batch_context_inputs[context_inputs_mask]
        context_inputs = context_inputs
        
        # Tensor containing binary values that act as indicator variables in the paper:
        # Contains Indicator variable such that I_{u,m_i} = 1 if(u,mi) ∈ E'_{m_i} and I{u,m_i} = 0 otherwise.
        label_inputs = torch.tensor([[1]+[0]*(self.knn_dict+self.knn_men)]*len(context_inputs), dtype=torch.float32)
        
        return {'label_inputs':label_inputs, 'context_inputs' : context_inputs, "negative_men_inputs" : negative_men_inputs,
                'negative_dict_inputs' : negative_dict_inputs, 'positive_embeds' : positive_embeds, 'skipped' : skipped,
                'skipped_positive_idxs' : skipped_positive_idxs, 'skipped_negative_dict_inputs' : skipped_negative_dict_inputs,
                'context_inputs_mask' : context_inputs_mask
                }



    def training_step(self, batch, batch_idx):
        
        # batch = tuple(t.to(device) for t in batch) : automated in pytorch #DD5
        
        # Initialize the parameters
        # batch is a subsample from tensor_dataset
        batch_context_inputs, candidate_idxs, n_gold, mention_idxs = batch
        
        f = self.forward(batch_context_inputs, candidate_idxs, n_gold, mention_idxs)
        
        # Compute the loss
        loss = loss_function(self.reranker, 
            self.hparams, 
            f, 
            self.trainer.datamodule, 
            self.n_entities, 
            self.knn_dict, 
            batch_context_inputs, 
            self.trainer.accumulate_grad_batches
        )

        return loss


    def configure_optimizers(self): #DD2
        
        # Define optimizer
        optimizer = get_bert_optimizer(
        [self.model],
        self.hparams["type_optimization"],
        self.hparams["learning_rate"],
        fp16=self.hparams.get("fp16"),
        )
        
        # Define scheduler
        num_train_steps = int(len(self.trainer.datamodule.train_tensor_data) / self.hparams["train_batch_size"] / self.trainer.accumulate_grad_batches) * self.trainer.max_epochs
        num_warmup_steps = int(num_train_steps * self.hparams["warmup_proportion"])

        scheduler = WarmupLinearSchedule(
            optimizer, warmup_steps=num_warmup_steps, t_total=num_train_steps,
        )
        logger.info(" Num optimization steps = %d" % num_train_steps)
        logger.info(" Num warmup steps = %d", num_warmup_steps)
        return [optimizer], [{'scheduler': scheduler, 'interval': 'step'}]



    def on_train_start(self):
        self.start_time = time.time()
    
    
    # Don't need this for now
    # def on_fit_start(self):
    #     # Compute n_gpu once, at the start of training, and store it as an instance attribute
    #     if self.trainer.devices is None:
    #         self.n_gpu = 0
    #     elif isinstance(self.trainer.devices, list):
    #         self.n_gpu = len(self.trainer.devices)
    #     elif isinstance(self.trainer.devices, int):
    #         self.n_gpu = self.trainer.devices
    #     else:
    #         # For other configurations, such as auto-select or when specific GPUs are selected as a string
    #         # It's safer to rely on the actual allocated devices by the trainer
    #         self.n_gpu = len(self.trainer.accelerator_connector.parallel_devices)
            
            
    def on_train_epoch_start(self):
        # To do at the start of each epoch
        self.tr_loss = 0
        
        'IV.1) Compute mention and entity embeddings and indexes at the start of each epoch'
        # Compute mention and entity embeddings and indexes at the start of each epoch
        if self.hparams['use_types']: # type-specific indexes 
            self.train_dict_embeddings, self.train_dict_indexes, self.dict_idxs_by_type = data_process.embed_and_index(
                self.reranker, 
                self.trainer.datamodule.entity_dict_vecs, 
                encoder_type="candidate", 
                # n_gpu=self.n_gpu, 
                corpus= self.trainer.datamodule.entity_dictionary, 
                force_exact_search= self.hparams['force_exact_search'], 
                batch_size= self.hparams['embed_batch_size'], 
                probe_mult_factor= self.hparams['probe_mult_factor']) #D11
            self.train_men_embeddings, self.train_men_indexes, self.men_idxs_by_type = data_process.embed_and_index(
                self.reranker, 
                self.trainer.datamodule.train_men_vecs, 
                encoder_type="context", 
                #n_gpu=self.n_gpu, 
                corpus= self.trainer.datamodule.train_processed_data, 
                force_exact_search= self.hparams['force_exact_search'], 
                batch_size= self.hparams['embed_batch_size'], 
                probe_mult_factor= self.hparams['probe_mult_factor'])
        
        else: # general indexes
            self.train_dict_embeddings, self.train_dict_index = data_process.embed_and_index(
                self.reranker, 
                self.trainer.datamodule.entity_dict_vecs, 
                encoder_type="candidate", 
                # n_gpu=self.n_gpu, 
                force_exact_search= 
                self.hparams['force_exact_search'], 
                batch_size= self.hparams['embed_batch_size'], 
                probe_mult_factor= self.hparams['probe_mult_factor'])
            self.train_men_embeddings, self.train_men_index = data_process.embed_and_index(
                self.reranker, 
                self.trainer.datamodule.train_men_vecs, 
                encoder_type="context", 
                # n_gpu=self.n_gpu, 
                force_exact_search = self.hparams['force_exact_search'], 
                batch_size= self.hparams['embed_batch_size'], 
                probe_mult_factor= self.hparams['probe_mult_factor'])
        
        # Number of entities and mentions
        self.n_entities = len(self.trainer.datamodule.entity_dictionary)
        self.n_mentions = len(self.trainer.datamodule.train_processed_data)
        
        # Store golden MST links
        self.gold_links = {}
        # Calculate the number of negative entities and mentions to fetch # Divides the k-nn evenly between entities and mentions
        self.knn_dict = self.hparams["knn"]//2
        self.knn_men = self.hparams["knn"] - self.knn_dict
        
        'IV.3) knn search : indice and distance of k closest mentions and entities'
        logger.info("Starting KNN search...")
        # INFO: Fetching all sorted mentions to be able to filter to within-doc later=
        n_men_to_fetch = len(self.train_men_embeddings) if self.hparams["use_types"] else self.knn_men + self.trainer.datamodule.max_gold_cluster_len
        n_ent_to_fetch = self.knn_dict + 1 # +1 accounts for the possibility of self-reference
        if not self.hparams["use_types"]:
            _, self.dict_nns = self.train_dict_index.search(self.train_men_embeddings, n_ent_to_fetch)
            _, self.men_nns = self.train_men_index.search(self.train_men_embeddings, n_men_to_fetch)
        else:
            self.dict_nns = -1 * np.ones((len(self.train_men_embeddings), n_ent_to_fetch))
            self.men_nns = -1 * np.ones((len(self.train_men_embeddings), n_men_to_fetch))
            for entity_type in self.train_men_indexes:
                self.men_embeds_by_type = self.train_men_embeddings[self.men_idxs_by_type[entity_type]]
                _, self.dict_nns_by_type = self.train_dict_indexes[entity_type].search(self.men_embeds_by_type, n_ent_to_fetch)
                _, self.men_nns_by_type = self.train_men_indexes[entity_type].search(self.men_embeds_by_type, min(n_men_to_fetch, len(self.men_embeds_by_type)))
                self.dict_nns_idxs = np.array(list(map(lambda x: self.dict_idxs_by_type[entity_type][x], self.dict_nns_by_type)))
                self.men_nns_idxs = np.array(list(map(lambda x: self.men_idxs_by_type[entity_type][x], self.men_nns_by_type)))
                for i, idx in enumerate(self.men_idxs_by_type[entity_type]):
                    self.dict_nns[idx] = self.dict_nns_idxs[i]
                    self.men_nns[idx][:len(self.men_nns_idxs[i])] = self.men_nns_idxs[i]
        logger.info("Search finished")
        
        self.total_skipped = self.total_knn_men_negs = 0
        
        pass
    
    def on_train_epoch_end(self):
        # To do at the end of each epoch
        # May not need it 
        pass
    
    def on_after_backward(self): #DD11
        # After .backward()
        if (self.trainer.global_step + 1) % self.trainer.accumulate_grad_batches == 0:
            torch.nn.utils.clip_grad_norm_(
                self.parameters(), self.trainer.grad_clip_norm
            )
            
    # RR2 : Most of the info is already in the Trainer's callback        
    # def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx = 0):
    #     # end of training batch
    #     'IV.4.E) Information about the training (step, epoch, average_loss)'
    #     n_print_iters = self.hparams["print_interval"] * self.trainer.accumulate_grad_batches #29
    #     if (batch_idx + 1) % n_print_iters == 0:
    #         # DD13
    #         self.log("train/average_loss", self.tr_loss / n_print_iters, on_step=False, on_epoch=True, prog_bar=True)
    #         if self.total_skipped > 0:
    #             self.log("train/queries_without_negs", self.total_skipped / n_print_iters, on_step=False, on_epoch=True)
    #             self.log("train/negative_mentions_per_query", self.total_knn_men_negs / n_print_iters, on_step=False, on_epoch=True)
            
    #         # Reset your tracking variables for the next interval
    #         self.total_skipped = 0
    #         self.total_knn_men_negs = 0
    #         self.tr_loss = 0

    #     # DD14
    #     '''
    #     # Regular checks on model performance against a validation dataset without interrupting the training more often than desired
    #     if self.hparams["eval_interval"] != -1: #A31
    #         if (batch_idx + 1) % (self.hparams["eval_interval"] * self.hparams["gradient_accumulation_steps"]) == 0:
    #             logger.info("Evaluation on the development dataset")
    #             evaluate(
    #                 self.reranker, self.trainer.datamodule.entity_dict_vecs, self.trainer.datamodule.valid_men_vecs, device=self.device, logger=logger, knn=self.hparams["knn"], n_gpu=self.n_gpu,
    #                 entity_data=self.trainer.datamodule.entity_dictionary, query_data=self.valid_processed_data, silent=self.hparams["silent"],
    #                 use_types=self.hparams['use_types'] or self.hparams["use_types_for_eval"], embed_batch_size=self.hparams["embed_batch_size"],
    #                 force_exact_search=self.hparams['use_types'] or self.hparams["use_types_for_eval"] or self.hparams["force_exact_search"],
    #                 probe_mult_factor=self.hparams['probe_mult_factor'], within_doc=self.hparams['within_doc'],
    #                 context_doc_ids=self.trainer.datamodule.valid_context_doc_ids
    #             )
    #             self.model.train()
    #             logger.info("\n")
    #     '''
        
    #     pass
    
    
    def validation_step(self, batch, batch_idx): #DD14

        max_acc, dict_acc, embed_and_index_dict = evaluate(
            self.reranker, 
            self.trainer.datamodule.entity_dict_vecs, 
            self.trainer.datamodule.valid_men_vecs, 
            # device=self.device, 
            logger=logger, 
            # knn=self.hparams["knn"], 
            # n_gpu=self.n_gpu,
            entity_data=self.trainer.datamodule.entity_dictionary, 
            query_data=self.trainer.datamodule.valid_processed_data, 
            # silent=self.hparams["silent"],
            use_types=self.hparams['use_types'], #use_types=self.hparams['use_types'] or self.hparams["use_types_for_eval"] 
            embed_batch_size=self.hparams["embed_batch_size"],
            force_exact_search=self.hparams['use_types'] or self.hparams["force_exact_search"], #force_exact_search= self.hparams['use_types'] or use_types or self.hparams["use_types_for_eval"] or self.hparams["force_exact_search"]
            probe_mult_factor=self.hparams['probe_mult_factor'], 
            within_doc=self.hparams['within_doc'],
            context_doc_ids=self.trainer.datamodule.valid_context_doc_ids
        )
        self.log("max_acc", max_acc, on_epoch=True)
        for key, value in dict_acc.items():
            self.log(f"dict_acc_{key}", value, on_epoch=True)
        for key, value in dict_acc.items():
            self.log(f"embed_and_index_dict{key}", value, on_epoch=True)
        
        
    def test_step(self, batch, batch_idx): #DD14
        max_acc, dict_acc, embed_and_index_dict = evaluate(
            self.reranker, 
            self.trainer.datamodule.entity_dict_vecs, 
            self.trainer.datamodule.test_men_vecs, 
            # device=self.device, 
            logger=logger, 
            # knn=self.hparams["knn"], 
            # n_gpu=self.n_gpu,
            entity_data=self.trainer.datamodule.entity_dictionary, 
            query_data=self.trainer.datamodule.test_processed_data,
            # silent=self.hparams["silent"],
            use_types=self.hparams['use_types'], #use_types=self.hparams['use_types'] or self.hparams["use_types_for_eval"] 
            embed_batch_size=self.hparams["embed_batch_size"],
            force_exact_search=self.hparams['use_types'] or self.hparams["force_exact_search"], #force_exact_search= self.hparams['use_types'] or use_types or self.hparams["use_types_for_eval"] or self.hparams["force_exact_search"]
            probe_mult_factor=self.hparams['probe_mult_factor'], 
            within_doc=self.hparams['within_doc'],
            context_doc_ids=self.trainer.datamodule.test_context_doc_ids
        )
        self.log("max_acc", max_acc, on_epoch=True)
        self.log("dict_acc", dict_acc, on_epoch=True)
        self.log("embedding_and_indexing", embed_and_index_dict, on_epoch=True)
        

    def on_train_end(self):
        execution_time = (time.time() - self.start_time) / 60
        self.logger.info(f"The training took {execution_time} minutes")

### Experiments

In [75]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.trainer import Trainer
torch.set_float32_matmul_precision('medium')

In [76]:
ontology = "medic"
model = "arboel"
dataset = "ncbi_disease"
abs_path = "/home2/cye73/data"
data_path = os.path.join(abs_path, model, dataset)
print(data_path)
abs_path2 = "/home2/cye73/results"
model_output_path = os.path.join(abs_path2, model, dataset)

ontology_type = "umls"
umls_dir="/mitchell/entity-linking/2017AA/META/"

params_test = {"model_output_path" : model_output_path,
               "data_path" : data_path,  
               "knn" : 4,
               "use_types" : False,
               "max_context_length": 64 ,
               "max_cand_length" : 64 ,
               "context_key" : "context", # to specify context_left or context_right
               "debug" : True,
               "gold_arbo_knn": 4,
               "within_doc" : True,
               "within_doc_skip_strategy" : False,
               "batch_size" : 128,#batch_size = embed_batch_size
               "train_batch_size" : 128,
               "filter_unlabeled" : False,
               "type_optimization" : "all",
               # 'additional_layers', 'top_layer', 'top4_layers', 'all_encoder_layers', 'all'
               "learning_rate" : 3e-5,
               "warmup_proportion" : 464,
               "fp16" : False,
               "embed_batch_size" : 128,
               "force_exact_search" : True,
               "probe_mult_factor" : 1,
               "pos_neg_loss" : True,
               "use_types_for_eval" : True,
               "drop_entities" : False,
               "drop_set" : False,
               "farthest_neighbor" : True,
               "rand_gold_arbo" : True,
               "bert_model": 'michiyasunaga/BioLinkBERT-base',
               "out_dim": 768 ,
               "pull_from_layer":11, #11 for base and 23 for large
               "add_linear":True,
               }

/home2/cye73/data/arboel/ncbi_disease


### Data Module

In [77]:
data_module = ArboelDataModule(params = params_test,
                                ontology = ontology,
                                dataset = dataset,
                                ontology_type = ontology_type,
                                umls_dir = umls_dir)

### Model

In [78]:
model = LitArboel(params = params_test)

###Model checkpoints

In [79]:
model_checkpoint = ModelCheckpoint(
    monitor='max_acc',  # Metric to monitor
    dirpath=params_test["model_output_path"],  # Directory to save the model
    filename='{epoch}-{max_acc:.2f}',  # Saves the model with epoch and val_loss in the filename
    save_top_k= 1,  # Number of best models to save; -1 means save all of them
    mode='max',  # 'max' means the highest max_acc will be considered as the best model
    verbose=True,  # Logs a message whenever a model checkpoint is saved
    )

###Trainer

In [80]:
trainer = L.Trainer(
    max_epochs=2,
    devices=[0, 1, 2, 3],
    accelerator="gpu",
    strategy="ddp_notebook",
    enable_progress_bar=True,
    # callbacks=[model_checkpoint],
    log_every_n_steps=10
)

[15/Mar/2024 16:20:00] INFO - GPU available: True (cuda), used: True
[15/Mar/2024 16:20:00] INFO - TPU available: False, using: 0 TPU cores
[15/Mar/2024 16:20:00] INFO - IPU available: False, using: 0 IPUs
[15/Mar/2024 16:20:00] INFO - HPU available: False, using: 0 HPUs


### Training

In [81]:
trainer.fit(model, datamodule=data_module )

INFO: Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4


[15/Mar/2024 16:20:02] INFO - Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4


INFO: Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4


[15/Mar/2024 16:20:02] INFO - Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4


INFO: Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4


[15/Mar/2024 16:20:02] INFO - Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4


INFO: Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4


[15/Mar/2024 16:20:02] INFO - Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
[15/Mar/2024 16:20:02] INFO - ----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------



You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Loading stored processed entity dictionary...


100%|██████████| 13189/13189 [00:00<00:00, 2553012.53it/s]


Max labels on one doc: 5


Creating correct mention format for train dataset: 100%|██████████| 5065/5065 [00:00<00:00, 10306.00it/s]
Creating correct mention format for validation dataset: 100%|██████████| 780/780 [00:00<00:00, 302027.06it/s]
Creating correct mention format for test dataset: 100%|██████████| 960/960 [00:00<00:00, 375469.21it/s]


Loading stored processed entity dictionary...
Loading stored processed entity dictionary...Loading stored processed entity dictionary...Loading stored processed entity dictionary...


Loading stored processed train data...
Loading stored processed valid data...Loading stored processed train data...

Loading stored processed train data...
Loading stored processed valid data...
Loading stored processed valid data...
Loading stored processed train data...
Loading stored processed valid data...
[15/Mar/2024 16:20:06] INFO - within_doc
[15/Mar/2024 16:20:06] INFO - within_doc
[15/Mar/2024 16:20:06] INFO - within_doc
[15/Mar/2024 16:20:06] INFO - within_doc
[15/Mar/2024 16:20:06] INFO - Read 4783 train samples..
[15/Mar/2024 16:20:06] INFO - Read 722 valid samples..
[15/Mar/2024 16:20:06] INFO - Read 4783 train samples..
[15/Mar/2024 16:20:06] INFO - Read 877 test samples..
[15/Mar/2024 16:20:06] INFO - Read 4783 train samples..
[15/Mar/2024 16:20:06] INFO - Read 722 valid samples..
[15/Mar/

INFO: LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
INFO: LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
INFO: LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]


[15/Mar/2024 16:20:06] INFO - LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
[15/Mar/2024 16:20:06] INFO - LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
[15/Mar/2024 16:20:06] INFO - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
[15/Mar/2024 16:20:06] INFO - LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
The following parameters will be optimized WITH decay:The following parameters will be optimized WITH decay:The following parameters will be optimized WITH decay:


context_encoder.bert_model.embeddings.word_embeddings.weight , context_encoder.bert_model.embeddings.position_embeddings.weight , context_encoder.bert_model.embeddings.token_type_embeddings.weight , context_encoder.bert_model.embeddings.LayerNorm.weight , context_encoder.bert_model.encoder.layer.0.attention.self.query.weight , ...and 197 morecontext_encoder.bert_model.embeddings.word_embeddings.weight , context_encoder.bert_model.embeddings.position_embeddings.weight , context_encoder.bert_model.embeddings.token_type_

INFO: 
  | Name     | Type            | Params
---------------------------------------------
0 | reranker | BiEncoderRanker | 217 M 
1 | model    | BiEncoderModule | 217 M 
---------------------------------------------
217 M     Trainable params
0         Non-trainable params
217 M     Total params
870.586   Total estimated model params size (MB)


[15/Mar/2024 16:20:06] INFO - 
  | Name     | Type            | Params
---------------------------------------------
0 | reranker | BiEncoderRanker | 217 M 
1 | model    | BiEncoderModule | 217 M 
---------------------------------------------
217 M     Trainable params
0         Non-trainable params
217 M     Total params
870.586   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


[15/Mar/2024 16:20:07] INFO - Eval: Dictionary: Embedding and building index
[15/Mar/2024 16:20:07] INFO - Eval: Dictionary: Embedding and building index
[15/Mar/2024 16:20:07] INFO - Eval: Dictionary: Embedding and building index
[15/Mar/2024 16:20:07] INFO - Eval: Dictionary: Embedding and building index


Embedding in batches: 100%|██████████| 104/104 [00:06<00:00, 15.13it/s]
Embedding in batches:  99%|█████████▉| 103/104 [00:06<00:00, 15.10it/s]

[15/Mar/2024 16:20:14] INFO - Eval: Queries: Embedding and building index


Embedding in batches:   0%|          | 0/2 [00:00<?, ?it/s], 15.05it/s]


[15/Mar/2024 16:20:14] INFO - Eval: Queries: Embedding and building index


Embedding in batches: 100%|██████████| 104/104 [00:06<00:00, 15.00it/s]
Embedding in batches:  42%|████▏     | 44/104 [00:06<00:08,  7.12it/s]

[15/Mar/2024 16:20:14] INFO - Eval: Queries: Embedding and building index


Embedding in batches: 100%|██████████| 2/2 [00:00<00:00, 25.42it/s]


[15/Mar/2024 16:20:14] INFO - Eval: Starting KNN search...
[15/Mar/2024 16:20:14] INFO - Eval: Search finished
[15/Mar/2024 16:20:14] INFO - Eval: Building graphs


Embedding in batches: 100%|██████████| 2/2 [00:00<00:00, 25.15it/s]


[15/Mar/2024 16:20:14] INFO - Eval: Starting KNN search...
[15/Mar/2024 16:20:14] INFO - Eval: Search finished
[15/Mar/2024 16:20:14] INFO - Eval: Building graphs


Embedding in batches: 100%|██████████| 2/2 [00:00<00:00, 25.21it/s]


[15/Mar/2024 16:20:14] INFO - Eval: Starting KNN search...


Eval: Building graphs: 100%|██████████| 200/200 [00:00<00:00, 3867.43it/s]


[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=0):


Paritioning joint graph: 100%|██████████| 200/200 [00:00<00:00, 102512.62it/s]

[15/Mar/2024 16:20:14] INFO - Eval: Search finished


Embedding in batches:  43%|████▎     | 45/104 [00:07<00:08,  7.11it/s]


[15/Mar/2024 16:20:14] INFO - Eval: Building graphs


Eval: Building graphs: 100%|██████████| 200/200 [00:00<00:00, 4274.75it/s]


[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=0):


Paritioning joint graph: 100%|██████████| 200/200 [00:00<00:00, 105358.05it/s]
Eval: Building graphs: 100%|██████████| 200/200 [00:00<00:00, 4187.16it/s]


[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=0):


Paritioning joint graph:   0%|          | 0/200 [00:00<?, ?it/s]

Analyzing clusters...


Paritioning joint graph: 100%|██████████| 200/200 [00:00<00:00, 68775.99it/s]


Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=0: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=1):


Paritioning joint graph:   0%|          | 0/400 [00:00<?, ?it/s]

Analyzing clusters...

Paritioning joint graph: 100%|██████████| 400/400 [00:00<00:00, 91091.41it/s]







Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=0: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=1):


Paritioning joint graph: 100%|██████████| 400/400 [00:00<00:00, 100013.21it/s]


Analyzing clusters...Analyzing clusters...

Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=1: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=2):
Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=0: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=1):


Paritioning joint graph: 100%|██████████| 600/600 [00:00<00:00, 55475.32it/s]
Paritioning joint graph:   0%|          | 0/400 [00:00<?, ?it/s]

Analyzing clusters...


Paritioning joint graph: 100%|██████████| 400/400 [00:00<00:00, 78142.60it/s]

Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=1: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=2):



Paritioning joint graph: 100%|██████████| 600/600 [00:00<00:00, 71842.83it/s]


Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=2: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=4):
Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=1: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=2):
Analyzing clusters...


Paritioning joint graph:   0%|          | 0/1000 [00:00<?, ?it/s]

Accuracy = 0.0 %


Paritioning joint graph:   0%|          | 0/600 [00:00<?, ?it/s]

[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=2: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=4):



Paritioning joint graph: 100%|██████████| 1000/1000 [00:00<00:00, 57866.84it/s]
Paritioning joint graph: 100%|██████████| 1000/1000 [00:00<00:00, 97632.77it/s]


Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=2: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=4):


Paritioning joint graph:   0%|          | 0/1000 [00:00<?, ?it/s]

Analyzing clusters...
Accuracy = 0.0 %

Embedding in batches:  45%|████▌     | 47/104 [00:07<00:08,  7.03it/s]


[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=4: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=8):
Analyzing clusters...

Paritioning joint graph: 100%|██████████| 1000/1000 [00:00<00:00, 93310.43it/s]







Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=4: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=8):




Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=4: 0.0%
[15/Mar/2024 16:20:14] INFO - 
Eval: Graph (k=8):


Paritioning joint graph: 100%|██████████| 1800/1800 [00:00<00:00, 11912.72it/s]
Paritioning joint graph: 100%|██████████| 1800/1800 [00:00<00:00, 11565.41it/s]


Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=8: 0.0%
[15/Mar/2024 16:20:14] INFO - Eval: Best accuracy: 0.0%
Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=8: 0.0%
[15/Mar/2024 16:20:14] INFO - Eval: Best accuracy: 0.0%


Paritioning joint graph: 100%|██████████| 1800/1800 [00:00<00:00, 10382.83it/s]
/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:441: It is recommended to use `self.log('max_acc', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:441: It is recommended to use `self.log('dict_acc_k0', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py:441: It is recommended to use `self.log('dict_acc_k1', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
/nethome/cye73/conda_envs/arboel_2/l

Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:14] INFO - Eval: accuracy for graph@k=8: 0.0%
[15/Mar/2024 16:20:14] INFO - Eval: Best accuracy: 0.0%


Embedding in batches: 100%|██████████| 104/104 [00:15<00:00,  6.69it/s]


[15/Mar/2024 16:20:22] INFO - Eval: Queries: Embedding and building index


Embedding in batches: 100%|██████████| 2/2 [00:00<00:00,  9.40it/s]


[15/Mar/2024 16:20:22] INFO - Eval: Starting KNN search...
[15/Mar/2024 16:20:22] INFO - Eval: Search finished
[15/Mar/2024 16:20:22] INFO - Eval: Building graphs


Eval: Building graphs: 100%|██████████| 200/200 [00:00<00:00, 4289.18it/s]


[15/Mar/2024 16:20:22] INFO - 
Eval: Graph (k=0):


Paritioning joint graph: 100%|██████████| 200/200 [00:00<00:00, 93175.70it/s]


Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:23] INFO - Eval: accuracy for graph@k=0: 0.0%
[15/Mar/2024 16:20:23] INFO - 
Eval: Graph (k=1):


Paritioning joint graph: 100%|██████████| 400/400 [00:00<00:00, 79093.04it/s]


Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:23] INFO - Eval: accuracy for graph@k=1: 0.0%
[15/Mar/2024 16:20:23] INFO - 
Eval: Graph (k=2):


Paritioning joint graph: 100%|██████████| 600/600 [00:00<00:00, 99152.22it/s]


Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:23] INFO - Eval: accuracy for graph@k=2: 0.0%
[15/Mar/2024 16:20:23] INFO - 
Eval: Graph (k=4):


Paritioning joint graph: 100%|██████████| 1000/1000 [00:00<00:00, 48660.08it/s]


Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:23] INFO - Eval: accuracy for graph@k=4: 0.0%
[15/Mar/2024 16:20:23] INFO - 
Eval: Graph (k=8):


Paritioning joint graph: 100%|██████████| 1800/1800 [00:00<00:00, 11425.14it/s]


Analyzing clusters...
Accuracy = 0.0 %
[15/Mar/2024 16:20:23] INFO - Eval: accuracy for graph@k=8: 0.0%
[15/Mar/2024 16:20:23] INFO - Eval: Best accuracy: 0.0%


/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Embedding in batches:   0%|          | 0/104 [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Embedding in batches: 100%|██████████| 104/104 [00:05<00:00, 18.35it/s]
Embedding in batches: 100%|██████████| 104/104 [00:05<00:00, 18.35it/s]
Embedding in batches:  40%|████      | 42/104 [00:05<00:08,  7.28it/s]]
Embedding in batches: 100%|██████████| 2/2 [00:00<00:00, 25.31it/s]


[15/Mar/2024 16:20:29] INFO - Starting KNN search...
[15/Mar/2024 16:20:29] INFO - Search finished


Embedding in batches: 100%|██████████| 2/2 [00:00<00:00, 25.05it/s]


[15/Mar/2024 16:20:29] INFO - Starting KNN search...
[15/Mar/2024 16:20:29] INFO - Search finished


Embedding in batches: 100%|██████████| 2/2 [00:00<00:00, 25.20it/s]


[15/Mar/2024 16:20:29] INFO - Starting KNN search...
[15/Mar/2024 16:20:29] INFO - Search finished


  negative_dict_inputs = torch.tensor(list(map(lambda x: self.trainer.datamodule.entity_dict_vecs[x].numpy(), negative_dict_inputs)))
  negative_dict_inputs = torch.tensor(list(map(lambda x: self.trainer.datamodule.entity_dict_vecs[x].numpy(), negative_dict_inputs)))
  negative_dict_inputs = torch.tensor(list(map(lambda x: self.trainer.datamodule.entity_dict_vecs[x].numpy(), negative_dict_inputs)))
Embedding in batches:  44%|████▍     | 46/104 [00:06<00:08,  7.25it/s]INFO: [rank: 3] Received SIGTERM: 15


[15/Mar/2024 16:20:29] INFO - [rank: 3] Received SIGTERM: 15


Embedding in batches: 100%|██████████| 104/104 [00:14<00:00,  7.21it/s]
Embedding in batches: 100%|██████████| 2/2 [00:00<00:00,  9.52it/s]


[15/Mar/2024 16:20:38] INFO - Starting KNN search...
[15/Mar/2024 16:20:38] INFO - Search finished


  negative_dict_inputs = torch.tensor(list(map(lambda x: self.trainer.datamodule.entity_dict_vecs[x].numpy(), negative_dict_inputs)))


ProcessRaisedException: 

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
    fn(i, *args)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py", line 173, in _wrapping_function
    results = function(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 986, in _run
    results = self._run_stage()
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py", line 1032, in _run_stage
    self.fit_loop.run()
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 138, in run
    self.advance(data_fetcher)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 242, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 191, in run
    self._optimizer_step(batch_idx, closure)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 269, in _optimizer_step
    call._call_lightning_module_hook(
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py", line 157, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/core/module.py", line 1303, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/core/optimizer.py", line 152, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/strategies/ddp.py", line 270, in optimizer_step
    optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/optim/lr_scheduler.py", line 75, in wrapper
    return wrapped(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/optim/optimizer.py", line 385, in wrapper
    out = func(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/pytorch_transformers/optimization.py", line 139, in step
    loss = closure()
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/plugins/precision/precision.py", line 108, in _wrap_closure
    closure_result = closure()
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 129, in closure
    step_output = self._step_fn()
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
    return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py", line 642, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py", line 635, in wrapped_forward
    out = method(*_args, **_kwargs)
  File "/tmp/ipykernel_715265/1537706544.py", line 281, in training_step
    f = self.forward(batch_context_inputs, candidate_idxs, n_gold, mention_idxs)
  File "/tmp/ipykernel_715265/1537706544.py", line 250, in forward
    pos_embed = self.reranker.encode_candidate(self.trainer.datamodule.entity_dict_vecs[pos_idx:pos_idx + 1], requires_grad=True)
  File "/home2/cye73/arboEL/blink/biencoder/../../blink/biencoder/biencoder.py", line 238, in encode_candidate
    _, embedding_cands = self.model(
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home2/cye73/arboEL/blink/biencoder/../../blink/biencoder/biencoder.py", line 90, in forward
    embedding_cands = self.cand_encoder(
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home2/cye73/arboEL/blink/biencoder/../../blink/common/ranker_base.py", line 43, in forward
    output_bert, output_pooler = self.bert_model(
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 1006, in forward
    embedding_output = self.embeddings(
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 232, in forward
    inputs_embeds = self.word_embeddings(input_ids)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
  File "/nethome/cye73/conda_envs/arboel_2/lib/python3.9/site-packages/torch/nn/functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)
