In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../../..')
import ujson
import json
import os
import random
import time
import pickle
import logger
import numpy as np
import pandas as pd
import torch
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_transformers.tokenization_bert import BertTokenizer
from transformers import AutoTokenizer, AutoModel
from datetime import datetime
from typing import Optional, Union

from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler) 
from pytorch_transformers.optimization import WarmupLinearSchedule 
from scipy.sparse.csgraph import minimum_spanning_tree 
# csgraph = compressed sparse graph
from scipy.sparse import csr_matrix
# csr_matrix = compressed sparse row matrices
from collections import Counter, defaultdict

import blink.biencoder.data_process_mult as data_process
# import blink.biencoder.eval_cluster_linking as eval_cluster_linking
import blink.candidate_ranking.utils as utils
from blink.biencoder.biencoder import BiEncoderRanker
from blink.common.optimizer import get_bert_optimizer
from blink.common.params import BlinkParser

from IPython import embed 

  warn(f"Failed to load image Python extension: {e}")


[27/Mar/2024 23:06:45] INFO - Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .
[27/Mar/2024 23:06:45] INFO - Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .
[27/Mar/2024 23:06:45] INFO - Loading faiss with AVX2 support.
[27/Mar/2024 23:06:45] INFO - Successfully loaded faiss with AVX2 support.


In [3]:

sys.path.append('../../../..')
sys.path.append('..')
from DataModule import process_mention_dataset

from bigbio.dataloader import BigBioConfigHelpers
from umls_utils import UmlsMappings
from bigbio_utils import CUIS_TO_REMAP, CUIS_TO_EXCLUDE, DATASET_NAMES, VALIDATION_DOCUMENT_IDS
from bigbio_utils import dataset_to_documents, dataset_to_df, resolve_abbreviation


conhelps = BigBioConfigHelpers()

[27/Mar/2024 23:06:45] INFO - PyTorch version 2.2.0 available.


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `d

In [4]:
import logging

# Configure the logger
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [5]:
def read_data(split, params, logger):
    '''
    Description 
    -----------
    Loads dataset samples from a specified path
    Optionally filters out samples without labels
    Checks if the dataset supports multiple labels per sample
    "has_mult_labels" : bool
    
    Parameters 
    ----------
    split : str
        Indicates the portion of the dataset to load ("train", "test", "valid"), used by utils.read_dataset to determine which data to read.
    params : dict(str)
        Contains configuration options
    logger : 
        An object used for logging messages about the process, such as the number of samples read.
    '''
    samples = utils.read_dataset(split, params["data_path"]) #DD21
    
    # Check if dataset has multiple ground-truth labels
    has_mult_labels = "labels" in samples[0].keys()
    if params["filter_unlabeled"]:
        # Filter samples without gold entities
        samples = list(
            filter(lambda sample: (len(sample["labels"]) > 0) if has_mult_labels else (sample["label"] is not None),
                   samples))
    logger.info(f"Read %d {split} samples." % len(samples))
    return samples, has_mult_labels


# Utility function
def filter_by_context_doc_id(mention_idxs, doc_id, doc_id_list, return_numpy=False):
    '''
    Description 
    -----------
    Filters and returns mention indices that belong to a specific document identified by "doc_id".
    Ensures that the analysis are constrained within the context of that particular document.
    
    Parameters 
    ----------
    - mention_idxs : ndarray(int) of dim = (number of mentions)
    Represents the indices of mentions
    - doc_id : int 
    Indice of the target document
    - doc_id_list : ndarray(int) of dim = (number of mentions)
    Array of integers, where each element is a document ID associated with the corresponding mention in mention_idxs. 
    The length of doc_id_list should match the total number of mentions referenced in mention_idxs.
    - return_numpy : bool
    A flag indicating whether to return the filtered list of mention indices as a NumPy array. 
    If True, the function returns a NumPy array; otherwise, it returns a list
    -------
    Outputs: 
    - mask : ndarray(bool) of dim = (number of mentions)
    Mask indicating where each mention's document ID (from doc_id_list) matches the target doc_id
    - mention_idxs : 
    Only contains mention indices that belong to the target document (=doc_id).
    '''
    mask = [doc_id_list[i] == doc_id for i in mention_idxs]
    if isinstance(mention_idxs, list): # Test if mention_idxs = list. Return a bool
        mention_idxs = np.array(mention_idxs) 
    mention_idxs = mention_idxs[mask] # possible only if mention_idxs is an array, not a list
    if not return_numpy:
        mention_idxs = list(mention_idxs)
    return mention_idxs, mask

In [6]:
"Data module"
class ArboelDataModule2(L.LightningDataModule):
    '''
    Attributes
    ----------
    
    - entity_dictionary : list of dict
    Stores the initial and raw entity dictionary
    - train_tensor_data : TensorDataset(context_vecs, label_idxs, n_labels, mention_idx) with :
        - “context_vecs” : tensor containing IDs of (mention + surrounding context) tokens 
        - “label_idxs” : tensor with indices pointing to the entities in the entity dictionary that are considered correct labels for the mention.
        - “n_labels” : Number of labels (=entities) associated with the mention
        - “mention_idx” : tensor containing a sequence of integers from 0 to N-1 (N = number of mentions in the dataset) serving as a unique identifier for each mention.
    - train_processed_data : list of dict
    Contains information about mentions (mention_id, mention_name, context, etc…)
    - valid_tensor_data : TensorDataset
    Same as "train_tensor_dataset" but for validation set
    - max_gold_cluster_len : int
    Maximum length of clusters inside gold_cluster
    - train_context_doc_ids : list
    # Store the context_doc_id (=context document indice) for every mention in the train set
    '''
    def __init__(self, params):
        '''
        Parameters 
        ----------
        - params : dict(str)
        Contains configuration options
        - dataset : str
        Name of the dataset
        - ontology : str (only umls for now)
        Ontology associated with the dataset
        - model : 
        model used : arboel / krissbert / sapbert etc...
        - ontology_type : str
        'obo' or 'umls' and possibly others
        - ontology_dir : str
        Path to ontology
        '''
        super().__init__()
        self.save_hyperparameters(params)
        # # First try to load the tokens from a local file. If local file not found, uses a pre-trained model specified by params["bert_model"]
        # vocab_path = os.path.join(self.hparams["bert_model"], 'vocab.txt') #DD3
        # if os.path.isfile(vocab_path): 
        #     print(f"Found tokenizer vocabulary at {vocab_path}")
        # self.tokenizer = BertTokenizer.from_pretrained(
        #     vocab_path if os.path.isfile(vocab_path) else self.hparams["bert_model"], do_lower_case=self.hparams["lowercase"]
        # )
        
        self.dataset = self.hparams["dataset"]
        self.ontology = self.hparams["ontology"]
        self.data_path = self.hparams["data_path"]
        self.ontology_dir = self.hparams["ontology_dir"]
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.hparams["bert_model"])
            
        self.batch_size = self.hparams.get("train_batch_size", "scoring_batch_size")
        
        self.train_processed_data = None
        self.valid_processed_data = None
        self.test_processed_data = None
        self.train_tensor_data = None
        self.valid_tensor_data = None
        self.test_tensor_data = None
        self.train_samples = None
        self.valid_samples = None
        self.test_samples = None
        self.entity_dict_vecs = None


    def prepare_data(self):
        'Use this to download and prepare data.'
        
        # Create the entity data files: dictionary.pickle
        # Create the mentions data files:  train.jsonl, valid.jsonl, test.jsonl
        
        # path to a file where the training data is stored
        self.train_data = os.path.join(
            self.data_path, "train.jsonl"
        )

        # if the full path to file exist, no need to prepare them, they are already ready
        if not os.path.isfile(self.train_data):
            process_mention_dataset(
                ontology=self.ontology,
                dataset=self.dataset,
                data_path=self.data_path,
                ontology_type=self.ontology_type,
                ontology_dir=self.ontology_dir,
                mention_id=True,
                context_doc_id=True,
            )
        


    def setup(self, stage=None):
        '''
        For processing and splitting. Called at the beginning of fit (train + validate), validate, test, or predict.
        '''
        self.entity_dictionary_pkl_path = os.path.join(self.data_path, 'entity_dictionary.pickle')
        self.train_tensor_data_pkl_path = os.path.join(self.data_path, 'train_tensor_data.pickle')
        self.train_processed_data_pkl_path = os.path.join(self.data_path, 'train_processed_data.pickle')
        self.valid_tensor_data_pkl_path = os.path.join(self.data_path, 'valid_tensor_data.pickle')
        self.valid_processed_data_pkl_path = os.path.join(self.data_path, 'valid_processed_data.pickle')
        self.test_tensor_data_pkl_path = os.path.join(self.data_path, 'test_tensor_data.pickle')
        self.test_processed_data_pkl_path = os.path.join(self.data_path, 'test_processed_data.pickle')
        
        'entity dictionary'
        # if entity dictionary already tokenized, load it
        # self.entity_dictionary_pkl_path = os.path.join(self.data_path, 'entity_dictionary.pickle')
        self.entity_dictionary_loaded = False
        if os.path.isfile(self.entity_dictionary_pkl_path): 
            print("Loading stored processed entity dictionary...")
            with open(self.entity_dictionary_pkl_path, 'rb') as read_handle:
                self.entity_dictionary = pickle.load(read_handle) # DD12B
            self.entity_dictionary_loaded = True
        
        else : # else load the not processed one
            with open(os.path.join(self.data_path, 'dictionary.pickle'), 'rb') as read_handle: #A11
                self.entity_dictionary = pickle.load(read_handle)
        
        'training mention data'
        # # path to a file where the training data, already processed into tensors is saved
        # self.train_tensor_data_pkl_path = os.path.join(self.data_path, 'train_tensor_data.pickle')
        # # path to a file where metadata / additional information about the training data is stored
        # self.train_processed_data_pkl_path = os.path.join(self.data_path, 'train_processed_data.pickle')

            
        # if the full path to file exist, load the file
        if os.path.isfile(self.train_tensor_data_pkl_path) and os.path.isfile(self.train_processed_data_pkl_path):
            print("Loading stored processed train data...")
            with open(self.train_tensor_data_pkl_path, 'rb') as read_handle:
                self.train_tensor_data = pickle.load(read_handle)
            with open(self.train_processed_data_pkl_path, 'rb') as read_handle:
                self.train_processed_data = pickle.load(read_handle)
                
                
        'validation mention data'
        # self.valid_tensor_data_pkl_path = os.path.join(self.data_path, 'valid_tensor_data.pickle')
        # self.valid_processed_data_pkl_path = os.path.join(self.data_path, 'valid_processed_data.pickle')
        
        # Same as training data : 
        # if the full path to file exist, load the file
        if os.path.isfile(self.valid_tensor_data_pkl_path) and os.path.isfile(self.valid_processed_data_pkl_path):
            print("Loading stored processed valid data...")
            with open(self.valid_tensor_data_pkl_path, 'rb') as read_handle:
                self.valid_tensor_data = pickle.load(read_handle)
            with open(self.valid_processed_data_pkl_path, 'rb') as read_handle:
                self.valid_processed_data = pickle.load(read_handle)
                
        'test mention data'
        # self.test_tensor_data_pkl_path = os.path.join(self.data_path, 'test_tensor_data.pickle')
        # self.test_processed_data_pkl_path = os.path.join(self.data_path, 'test_processed_data.pickle')
        
        # Same as training data : 
        # if the full path to file exist, load the file
        if os.path.isfile(self.test_tensor_data_pkl_path) and os.path.isfile(self.test_processed_data_pkl_path):
            print("Loading stored processed test data...")
            with open(self.test_tensor_data_pkl_path, 'rb') as read_handle: #CC7 'rb' = binary read mode
                self.test_tensor_data = pickle.load(read_handle)
            with open(self.test_processed_data_pkl_path, 'rb') as read_handle:
                self.test_processed_data = pickle.load(read_handle)
        
        
        
        'Entity dict : drop entity for discovery'
        # For discovery experiment: Drop entities used in training that were dropped randomly from dev/test set
        if self.hparams["drop_entities"]: #A12
            assert self.entity_dictionary 
            drop_set_path = self.hparams["drop_set"] if self.hparams["drop_set"] is not None else os.path.join(self.data_path, 'drop_set_mention_data.pickle') #A12
            if not os.path.isfile(drop_set_path):
                raise ValueError("Invalid or no --drop_set path provided to dev/test mention data")
            with open(drop_set_path, 'rb') as read_handle:
                drop_set_data = pickle.load(read_handle)
            # gold cuis indices for each mention in drop_set_data
            drop_set_mention_gold_cui_idxs = list(map(lambda x: x['label_idxs'][0], drop_set_data))
            # Make the set unique
            ents_in_data = np.unique(drop_set_mention_gold_cui_idxs)
            # % of drop
            ent_drop_prop = 0.1
            logger.info(f"Dropping {ent_drop_prop*100}% of {len(ents_in_data)} entities found in drop set")
            # Number of entity indices to drop
            n_ents_dropped = int(ent_drop_prop*len(ents_in_data))
            # Random selection drop
            rng = np.random.default_rng(seed=17)
            # Indices of all entities that are dropped
            dropped_ent_idxs = rng.choice(ents_in_data, size=n_ents_dropped, replace=False)

            # Drop entities from dictionary (subsequent processing will automatically drop corresponding mentions)
            keep_mask = np.ones(len(self.entity_dictionary), dtype='bool')
            keep_mask[dropped_ent_idxs] = False
            self.entity_dictionary = np.array(self.entity_dictionary)[keep_mask]
        
        
        'Train mention data'
        
        if not os.path.isfile(self.train_tensor_data_pkl_path) : # Load and Process train data if not done yet
            # train_samples = list of dict. Each dict contains information about a mention (id, name, context, etc…). 
            # Each key can have a dictionary itself. Ex : mention["context"]["tokens"] or mention["context"]["ids"]
            self.train_samples, self.train_mult_labels = read_data("train", self.hparams, logger)
        
            # train_processed_data = (mention + surrounding context) tokens
            # entity_dictionary = tokenized entities
            # tensor_train_dataset = Dataset containing several tensors (IDs of mention + context / indices of correct entities etc..) # Go check "process_mention_data" for more info
            self.train_processed_data, self.entity_dictionary, self.train_tensor_data = data_process.process_mention_data(
                self.train_samples,
                self.entity_dictionary,
                self.tokenizer,
                self.hparams["max_context_length"],
                self.hparams["max_cand_length"],
                context_key=self.hparams["context_key"],
                multi_label_key="labels" if self.train_mult_labels else None,
                # silent=self.hparams["silent"], 
                logger=logger,
                debug=self.hparams["debug"], 
                knn=self.hparams['knn'],
                dictionary_processed=self.entity_dictionary_loaded
            )
            
            # Save the entity dictionary if not already done
            if not self.entity_dictionary_loaded:
                print("Saving entity dictionary...")
                with open(self.entity_dictionary_pkl_path, 'wb') as write_handle:
                    pickle.dump(self.entity_dictionary, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
                    
            self.entity_dictionary_loaded = True
            
            print("Saving processed train data...")
            with open(self.train_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.train_tensor_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(self.train_processed_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.train_processed_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        
        # Prepare tensor containing only ID of (mention + surrounding context) tokens of training set'
        self.train_men_vecs = self.train_tensor_data[:][0] 

        # Store the IDs of the entity in entity_dictionary # It's the equivalent of train_men_vecs for entities
        # (Done here because data_process.process_mention_data will tokenize the entities in entity_dict)
        self.entity_dict_vecs = torch.tensor(list(map(lambda x: x['ids'], self.entity_dictionary)), dtype=torch.long)


        'Validation mention data'
        if not os.path.isfile(self.valid_tensor_data_pkl_path) : 
            # Load and Process validation data if not done yet
            self.valid_samples, self.valid_mult_labels = read_data("valid", self.hparams, logger)
            self.valid_processed_data, _, self.valid_tensor_data = data_process.process_mention_data(
                self.valid_samples,
                self.entity_dictionary,
                self.tokenizer,
                self.hparams["max_context_length"],
                self.hparams["max_cand_length"],
                context_key=self.hparams["context_key"],
                multi_label_key="labels" if self.valid_mult_labels else None,
                # silent=self.hparams["silent"],
                logger=logger,
                debug=self.hparams["debug"],
                knn=self.hparams["knn"],
                dictionary_processed=self.entity_dictionary_loaded
            )
            
            print("Saving processed valid data...")
            with open(self.valid_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.valid_tensor_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(self.valid_processed_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.valid_processed_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        # Prepare tensor containing only ID of (mention + surrounding context) tokens of validation data'
        self.valid_men_vecs = self.valid_tensor_data[:][0]
        
        
        'Test mention data'
        if not os.path.isfile(self.test_tensor_data_pkl_path) :
            # Load and Process test data if not done yet
            self.test_samples, self.test_mult_labels = read_data("test", self.hparams, logger)
            self.test_processed_data, _, self.test_tensor_data = data_process.process_mention_data(
                self.test_samples,
                self.entity_dictionary,
                self.tokenizer,
                self.hparams["max_context_length"],
                self.hparams["max_cand_length"],
                context_key=self.hparams["context_key"],
                multi_label_key="labels" if self.test_mult_labels else None,
                # silent=self.hparams["silent"],
                logger=logger,
                debug=self.hparams["debug"],
                knn=self.hparams["knn"],
                dictionary_processed=self.entity_dictionary_loaded
            )
            
            print("Saving processed test data...")
            with open(self.test_tensor_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.test_tensor_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
            with open(self.test_processed_data_pkl_path, 'wb') as write_handle:
                pickle.dump(self.test_processed_data, write_handle,
                            protocol=pickle.HIGHEST_PROTOCOL)
        # Prepare tensor containing only ID of (mention + surrounding context) tokens of validation data'
        self.test_men_vecs = self.test_tensor_data[:][0]
        
        
        'Within_doc search'
        # Consider if it’s within_doc (=search only within the document)'
        self.train_context_doc_ids = self.valid_context_doc_ids = self.test_context_doc_ids = None
        if self.hparams["within_doc"]: 
            # RR9 : If path exist, train_samples, valid_samples, test_samples haven't been defined yet
            # Store the context_doc_id for every mention in the train and valid sets
            if self.train_samples is None:
                self.train_samples, _ = read_data("train", self.hparams, logger)
            self.train_context_doc_ids = [s['context_doc_id'] for s in self.train_samples]
            if self.valid_samples is None:
                self.valid_samples, _ = read_data("valid", self.hparams, logger)
            self.valid_context_doc_ids = [s['context_doc_id'] for s in self.valid_samples]
            if self.test_samples is None:
                self.test_samples, _ = read_data("test", self.hparams, logger)
            self.test_context_doc_ids = [s['context_doc_id'] for s in self.test_samples]
        
        # Get clusters of mentions that map to a gold entity
        self.train_gold_clusters = data_process.compute_gold_clusters(self.train_processed_data)
        # Maximum length of clusters inside gold_cluster
        self.max_gold_cluster_len = 0
        for ent in self.train_gold_clusters:
            if len(self.train_gold_clusters[ent]) > self.max_gold_cluster_len:
                self.max_gold_cluster_len = len(self.train_gold_clusters[ent])
        
        # print(
        #     f"entity_dictionary : {self.entity_dictionary[0]}, size : {len(self.entity_dictionary)}, type : {type(self.entity_dictionary)}"
        # )

        # print(
        #     f"train_samples : {self.train_samples[0]}, size : {len(self.train_samples)}, type : {type(self.train_processed_data)}"
        # )
        # print(
        #     f"valid_samples :{self.valid_samples[0]}, size : {len(self.valid_samples)}, type : {type(self.train_processed_data)}"
        # )
        # print(
        #     f"test_samples : {self.test_samples[0]}, size : {len(self.test_samples)}, type : {type(self.train_processed_data)}"
        # )

        # print(
        #     f"train_processed_data : {self.train_processed_data[0]} , size : {len(self.train_processed_data)}, type : {type(self.train_processed_data)} "
        # )
        # print(
        #     f"valid_processed_data : {self.valid_processed_data[0]} , size : {len(self.valid_processed_data)}, type : {type(self.valid_processed_data)}"
        # )
        # print(
        #     f"test_processed_data : {self.test_processed_data[0]} , size : {len(self.test_processed_data)}, type : {type(self.test_processed_data)}"
        # )

        # print(
        #     f"train_tensor_data : {self.train_tensor_data[0]} , size : {len(self.train_tensor_data)}, type : {type(self.train_tensor_data)}"
        # )
        # print(
        #     f"valid_tensor_data : {self.valid_tensor_data[0]} , size : {len(self.valid_tensor_data)}, type : {type(self.valid_tensor_data)}"
        # )
        # print(
        #     f"test_tensor_data : {self.test_tensor_data[0]} , size : {len(self.test_tensor_data)}, type : {type(self.test_tensor_data)}"
        # )

        # print(
        #     f"train_men_vecs : {self.train_men_vecs[0]} , size : {self.train_men_vecs.size}, type : {type(self.train_men_vecs)}"
        # )
        # print(
        #     f"entity_dict_vecs : {self.train_men_vecs[0]} , size :{self.entity_dict_vecs.size}, type : {type(self.entity_dict_vecs)}"
        # )
        
        # print(
        #     f"train_context_doc_ids : {self.train_context_doc_ids[0]}, size : {len(self.train_context_doc_ids)}, type : {type(self.train_context_doc_ids)}"
        # )
        
        # print(
        #     f"train_gold_clusters : {self.train_gold_clusters}, size : {len(self.train_gold_clusters)}, type : {type(self.train_gold_clusters)}"
        # )


    def train_dataloader(self): #RR5
        # Return the training DataLoader
        # train_sampler = RandomSampler(self.train_tensor_data) if self.params["shuffle"] else SequentialSampler(self.train_tensor_data)
        # return DataLoader(self.train_tensor_data, sampler=train_sampler, batch_size=self.batch_size) #DD4
        return DataLoader(dataset = self.train_tensor_data, batch_size=self.batch_size, shuffle=True,
            drop_last=True,
            )
    
    def val_dataloader(self):
        # Return the validation DataLoader
        return DataLoader(dataset = self.valid_tensor_data, batch_size=self.batch_size)
    
    def test_dataloader(self):
        # Return the validation DataLoader
        return DataLoader(dataset = self.test_tensor_data, batch_size=self.batch_size)

In [23]:
# ontology = "MeSH"
# dataset = "bc5cdr"
ontology = "medic"
dataset = "ncbi_disease"
model = "arboel"
abs_path = "/home2/cye73/data_test"
data_path = os.path.join(abs_path, model, dataset)
print(data_path)


# ontology = "MeSH"
# model = "arboel"
# dataset = "bc5cdr"
# abs_path = "/home2/cye73/data"
# data_path = os.path.join(abs_path, model, dataset)
# print(data_path)
# abs_path2 = "/home2/cye73/results"
# model_output_path = os.path.join(abs_path2, model, dataset)
# ontology_type = "umls"
# ontology_dir = "/mitchell/entity-linking/2017AA/META/"

params_test = {"data_path" : data_path, 
               "train_batch_size" : 64,
               "max_context_length": 128 ,
               "max_cand_length" : 128 ,
               "context_key" : "context",
               "debug" : False,
               "knn" : 4,
               # "bert_model": 'michiyasunaga/BioLinkBERT-base',
               "bert_model": "dmis-lab/biobert-base-cased-v1.1",
               "out_dim": 768 ,
               "pull_from_layer":11,
               "add_linear":True,
               "use_types" : True,
               "force_exact_search" : True,
               "probe_mult_factor" : 1,
               "embed_batch_size" : 768,
               "drop_entities" : False,
               "within_doc" : True,
               "filter_unlabeled" : True,
               "dataset" : "ncbi_disease",
               "model" : "arboel",
               "ontology_dir" : '/mitchell/entity-linking/kbs/medic.tsv',
               "ontology" : "medic"
               }


/home2/cye73/data_test/arboel/ncbi_disease


In [24]:
from LightningModule import LitArboel
from LightningDataModule import ArboelDataModule

In [25]:
data_module = ArboelDataModule(params = params_test)

In [26]:
data_module.prepare_data()


prepare_data is being executed


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


Loading stored processed entity dictionary...


100%|██████████| 13189/13189 [00:00<00:00, 1882099.74it/s]


Max labels on one doc: 5


Creating correct mention format for train dataset: 100%|██████████| 5065/5065 [00:00<00:00, 216188.04it/s]
Creating correct mention format for validation dataset: 100%|██████████| 780/780 [00:00<00:00, 278311.96it/s]
Creating correct mention format for test dataset: 100%|██████████| 960/960 [00:00<00:00, 360897.36it/s]


In [27]:
data_module.setup()

setup() is being executed
Read 4782 samples.
[25/Mar/2024 18:37:08] INFO - Read 4782 train samples..


Tokenizing dictionary: 100%|██████████| 13189/13189 [00:04<00:00, 3037.52it/s]


[25/Mar/2024 18:37:18] INFO - ====Processed samples: ====
[25/Mar/2024 18:37:18] INFO - Context tokens : [CLS] identification of a ##p ##c ##2 , a ho ##mo ##logue of the [unused1] ad ##eno ##mat ##ous p ##oly ##po ##sis co ##li t ##umour [unused2] suppress ##or . the ad ##eno ##mat ##ous p ##oly ##po ##sis co ##li ( a ##p ##c ) t ##umour - suppress ##or protein controls the w ##nt signalling pathway by forming a complex with g ##ly ##co ##gen s ##ynth ##ase kinase 3 ##bet ##a ( g ##sk - 3 ##bet ##a ) , a ##xin / conduct ##in and beta ##cate ##nin . complex formation induce ##s the rapid degradation of beta ##cate ##nin . in co ##lon car ##cin ##oma cells , loss of a ##p ##c leads to the accumulation of beta ##cate ##nin [SEP]
[25/Mar/2024 18:37:18] INFO - Context ids : 101 9117 1104 170 1643 1665 1477 117 170 16358 3702 12733 1104 1103 1 8050 26601 21943 2285 185 23415 5674 4863 1884 2646 189 27226 2 17203 1766 119 1103 8050 26601 21943 2285 185 23415 5674 4863 1884 2646 113 170 1643 1

Tokenizing dictionary: 100%|██████████| 13189/13189 [00:00<00:00, 1673331.78it/s]


[25/Mar/2024 18:37:19] INFO - ====Processed samples: ====
[25/Mar/2024 18:37:19] INFO - Context tokens : [CLS] br ##ca ##1 is secret ##ed and exhibits properties of a g ##rani ##n . g ##er ##m ##line mutations in br ##ca ##1 are responsible for most cases of [unused1] inherited breast and o ##var ##ian cancer [unused2] . however , the function of the br ##ca ##1 protein has remained el ##usive . we now show that br ##ca ##1 en ##codes a 190 - k ##d protein with sequence ho ##mology and bio ##chemical analogy to the g ##rani ##n protein family . interesting ##ly , br ##ca ##2 also includes a motif similar to the g ##rani ##n consensus at the c terminus of the protein . both br ##ca ##1 and the g ##rani ##ns local ##ize to secret ##ory ve [SEP]
[25/Mar/2024 18:37:19] INFO - Context ids : 101 9304 2599 1475 1110 3318 1174 1105 10877 4625 1104 170 176 23851 1179 119 176 1200 1306 2568 17157 1107 9304 2599 1475 1132 2784 1111 1211 2740 1104 1 7459 7209 1105 184 8997 1811 4182 2 119 1649 117

Tokenizing dictionary: 100%|██████████| 13189/13189 [00:00<00:00, 1851671.14it/s]


[25/Mar/2024 18:37:20] INFO - ====Processed samples: ====
[25/Mar/2024 18:37:20] INFO - Context tokens : [CLS] cluster ##ing of miss ##ense mutations in the [unused1] at ##ax ##ia - te ##lang ##ie ##ct ##asi ##a [unused2] gene in a s ##poradic t - cell le ##uka ##emia . at ##ax ##ia - te ##lang ##ie ##ct ##asi ##a ( a - t ) is a re ##cess ##ive multi - system disorder caused by mutations in the at ##m gene at 11 ##q ##22 - q ##23 ( re ##f . 3 ) . the risk of cancer , especially l ##ymph ##oid neo ##p ##lasia ##s , is substantially elevated in a - t patients and has long been associated with ch ##rom ##oso ##mal instability . by anal ##ys ##ing t ##umour d ##na from patients with s ##poradic t [SEP]
[25/Mar/2024 18:37:20] INFO - Context ids : 101 10005 1158 1104 5529 22615 17157 1107 1103 1 1120 7897 1465 118 21359 19514 1663 5822 17506 1161 2 5565 1107 170 188 27695 189 118 2765 5837 12658 20504 119 1120 7897 1465 118 21359 19514 1663 5822 17506 1161 113 170 118 189 114 1110 170 1231 2

In [28]:
data_module.train_tensor_data
len(data_module.train_tensor_data)



4782

In [29]:
# Load the train DataLoader
train_dataloader = data_module.train_dataloader()
assert next(iter(train_dataloader)) is not None, "Training DataLoader is empty!"

train_dataloader() is being executed


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [30]:
mentions_id = []
for batch in train_dataloader:
    # batch_context_inputs, candidate_idxs, n_gold, mention_idxs = batch
    # mentions_id.append(max(mention_idxs))
    first_batch = batch
    # print(len(batch[0]))
    break  # Exit the loop after the first iteration


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [31]:
# print(max(mentions_id))

tensor(4781)


In [None]:
batch_context_inputs, candidate_idxs, n_gold, mention_idxs = first_batch
print("batch_context_inputs: Indices of the 64 tokens of the first element (mention + context)\n", batch_context_inputs[0])
print("\ncandidate_idxs: Indices of the correct_entity\n", candidate_idxs[0])
print("\nn_gold: Number of correct entity\n", n_gold[0])
print("\nmention_idxs: Unique mention identifier id\n", mention_idxs[0])

In [120]:
len(data_module.valid_tensor_data)

722

In [56]:
# Load the valid DataLoader
val_dataloader = data_module.val_dataloader()
for batch in val_dataloader:
    print(len(batch[0]))
    # break  # Exit the loop after the first iteration

# batch_context_inputs, candidate_idxs, n_gold, mention_idxs = first_batch
# print("batch_context_inputs: Indices of the 64 tokens of the first element (mention + context)\n", batch_context_inputs[:2])
# print("\ncandidate_idxs: Indices of the correct_entity\n", candidate_idxs[:3])
# print("\nn_gold: Number of correct entity\n", n_gold)
# print("\nmention_idxs: Unique mention identifier id\n", mention_idxs[:10])


64
64
64
8


In [38]:
len(data_module.test_tensor_data)

878

In [57]:
# Load the test DataLoader
test_dataloader = data_module.test_dataloader()
for batch in test_dataloader:
    print(len(batch[0]))
    # break  # Exit the loop after the first iteration

# batch_context_inputs, candidate_idxs, n_gold, mention_idxs = first_batch
# print("batch_context_inputs: Indices of the 64 tokens of the first element (mention + context)\n", batch_context_inputs[:2])
# print("\ncandidate_idxs: Indices of the correct_entity\n", candidate_idxs[:3])
# print("\nn_gold: Number of correct entity\n", n_gold)
# print("\nmention_idxs: Unique mention identifier id\n", mention_idxs[:10])


64
64
64
64
64
64
64
64
64
64
64
64
64
46


# Tests my ncbi_disease vs David's one after preprocessing

In [7]:
# path_entity = '/home2/cye73/data_test2/arboel/ncbi_disease/dictionary.pickle'
path_entity = '/home2/cye73/data/arboel/ncbi_disease/dictionary.pickle'
path_entity2 = '/home2/cye73/arboEL2/data/arboel/ncbi_disease/dictionary.pickle'
with open(path_entity, 'rb') as read_handle:
    dict = pickle.load(read_handle)
with open(path_entity2, 'rb') as read_handle:
    dict2 = pickle.load(read_handle)

print("dict :\n", dict[3])
print("dict2 :\n", dict2[3])

dict :
 {'type': 'Disease', 'cui': 'MESH:C579850', 'title': '16p11.2 Deletion Syndrome', 'cuis': ['MESH:C579850'], 'description': '16p11.2 Deletion Syndrome ( Disease)'}
dict2 :
 {'type': 'Disease', 'cui': 'MESH:C579850', 'title': '16p11.2 Deletion Syndrome', 'cuis': ['MESH:C579850'], 'description': '16p11.2 Deletion Syndrome ( Disease :  )'}


In [9]:
path_entity = '/home2/cye73/data_test2/arboel/ncbi_disease/entity_dictionary.pickle'
# path_entity = '/home2/cye73/data/arboel/ncbi_disease/entity_dictionary.pickle'
path_entity2 = '/home2/cye73/arboEL2/data/arboel/ncbi_disease/entity_dictionary.pickle'
with open(path_entity, 'rb') as read_handle:
    entity_dict = pickle.load(read_handle)
with open(path_entity2, 'rb') as read_handle:
    entity_dict2 = pickle.load(read_handle)
    
entity_dict_vecs = torch.tensor(list(map(lambda x: x['ids'], entity_dict)), dtype=torch.long)
entity_dict_vecs2 = torch.tensor(list(map(lambda x: x['ids'], entity_dict2)), dtype=torch.long)

In [10]:
print("entity_dict :\n", entity_dict[0])
print("entity_dict2 :\n", entity_dict2[0])

entity_dict :
 {'type': 'Disease', 'cui': 'MESH:C538288', 'title': '10p Deletion Syndrome (Partial)', 'cuis': ['MESH:C538288'], 'description': '10p Deletion Syndrome (Partial) ( Disease : Chromosome 10, 10p- Partial|Chromosome 10, monosomy 10p|Chromosome 10, Partial Deletion (short arm)|Monosomy 10p )', 'tokens': ['[CLS]', '10', '##p', 'del', '##eti', '##on', 'syndrome', '(', 'partial', ')', '[unused3]', '(', 'disease', ':', 'chromosome', '10', ',', '10', '##p', '-', 'partial', '|', 'chromosome', '10', ',', 'mon', '##oso', '##my', '10', '##p', '|', 'chromosome', '10', ',', 'partial', 'del', '##eti', '##on', '(', 'short', 'arm', ')', '|', 'mon', '##oso', '##my', '10', '##p', ')', '[SEP]'], 'ids': [101, 1275, 1643, 3687, 26883, 1320, 9318, 113, 7597, 114, 3, 113, 3653, 131, 18697, 1275, 117, 1275, 1643, 118, 7597, 197, 18697, 1275, 117, 19863, 22354, 4527, 1275, 1643, 197, 18697, 1275, 117, 7597, 3687, 26883, 1320, 113, 1603, 1981, 114, 197, 19863, 22354, 4527, 1275, 1643, 114, 102, 0, 0

In [78]:
print("entity_dict_vecs :\n", entity_dict_vecs[0])
print("entity_dict_vecs2 :\n", entity_dict_vecs2[0])

entity_dict_vecs :
 tensor([  101,  1275,  1643,  3687, 26883,  1320,  9318,   113,  7597,   114,
            3,   113,  3653,   131, 18697,  1275,   117,  1275,  1643,   118,
         7597,   132, 18697,  1275,   117, 19863, 22354,  4527,  1275,  1643,
          132, 18697,  1275,   117,  7597,  3687, 26883,  1320,   113,  1603,
         1981,   114,   132, 19863, 22354,  4527,  1275,  1643,   114,   102,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,  

In [19]:
NULL_IDX=0
token_idx_cands, segment_idx_cands, mask_cands = to_bert_input(
            entity_dict_vecs, NULL_IDX
        )

In [20]:
print("token_idx_cands:\n",token_idx_cands[0])
print("segment_idx_cands:\n",segment_idx_cands[0])
print("mask_cands:\n",mask_cands[0])

token_idx_cands:
 tensor([    2,  2073,  1014,  6749,  3328,    11,  4782,    12,     1,    11,
         6414, 10281,    69,  3056,  2174,    11,  1682,  8328,    12,    69,
         6029,    11,  2377,    12,    29,  5206,  2073,    15,  2073,  1014,
           16,  4782,    69,  5206,  2073,    15, 16639,  2508,  2073,  1014,
           69,  5206,  2073,    15,  4782,  6749,    11,  3274,  5996,    12,
           69, 16639,  2508,  2073,  1014,    12,     3,     0,     0,     0,
            0,     0,     0,     0])
segment_idx_cands:
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
mask_cands:
 tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True, 

In [17]:
model = BiEncoderModule(params_test)

In [25]:
_, embedding_cands = model(
            None, None, None, token_idx_cands[:5000], segment_idx_cands[:5000], mask_cands[:5000])
# _, embedding_cands = model(
#             None, None, None, token_idx_cands, segment_idx_cands, mask_cands)

In [None]:
embedding_cands[0]

In [10]:
ontology = "MEDIC"
model = "arboel"
# dataset = "ncbi_disease"
dataset = "bc5cdr"
abs_path = "/home2/cye73/data_test2"
data_path = os.path.join(abs_path, model, dataset)

mentions = []

with open(os.path.join(data_path, "train.jsonl"), 'r')  as read_handle :
    for line in read_handle:
        mentions.append(json.loads(line))

for i in range(1) :
    print("------") 
    for key, value in mentions[i].items():
        print(f"{key}: {value}")

------
mention: Selegiline
mention_id: 10091617.1
context_left: 
context_right: -induced postural hypotension in Parkinson's disease: a longitudinal study on the effects of drug withdrawal.
OBJECTIVES: The United Kingdom Parkinson's Disease Research Group (UKPDRG) trial found an increased mortality in patients with Parkinson's disease (PD) randomized to receive 10 mg selegiline per day and L-dopa compared with those taking L-dopa alone. Recently, we found that therapy with selegiline and L-dopa was associated with selective systolic orthostatic hypotension which was abolished by withdrawal of selegiline. This unwanted effect on postural blood pressure was not the result of underlying autonomic failure. The aims of this study were to confirm our previous findings in a separate cohort of patients and to determine the time course of the cardiovascular consequences of stopping selegiline in the expectation that this might shed light on the mechanisms by which the drug causes orthostatic hy

In [12]:
ontology = "MEDIC"
model = "arboel"
dataset = "ncbi_disease"
# dataset = "bc5cdr"
abs_path = "/home2/cye73/arboEL2/data"
data_path = os.path.join(abs_path, model, dataset)

mentions = []

with open(os.path.join(data_path, "train.jsonl"), 'r')  as read_handle :
    for line in read_handle:
        mentions.append(json.loads(line))

for i in range(1) :
    print("------") 
    for key, value in mentions[i].items():
        print(f"{key}: {value}")

------
mention: adenomatous polyposis coli tumour
context_left: Identification of APC2, a homologue of the 
context_right:  suppressor.
The adenomatous polyposis coli (APC) tumour-suppressor protein controls the Wnt signalling pathway by forming a complex with glycogen synthase kinase 3beta (GSK-3beta), axin/conductin and betacatenin. Complex formation induces the rapid degradation of betacatenin. In colon carcinoma cells, loss of APC leads to the accumulation of betacatenin in the nucleus, where it binds to and activates the Tcf-4 transcription factor (reviewed in [1] [2]). Here, we report the identification and genomic structure of APC homologues. Mammalian APC2, which closely resembles APC in overall domain structure, was functionally analyzed and shown to contain two SAMP domains, both of which are required for binding to conductin. Like APC, APC2 regulates the formation of active betacatenin-Tcf complexes, as demonstrated using transient transcriptional activation assays in APC -/

In [21]:
path_train_processed_data = f'/home2/cye73/data_test2/arboel/{dataset}/train_processed_data.pickle'
path_train_processed_data2 = f'/home2/cye73/arboEL2/data/arboel/{dataset}/train_processed_data.pickle'
with open(path_train_processed_data, 'rb') as read_handle:
    train_processed_data = pickle.load(read_handle)
with open(path_train_processed_data2, 'rb') as read_handle:
    train_processed_data2 = pickle.load(read_handle)

print("dict :\n", train_processed_data[0])
print("dict2 :\n", train_processed_data2[0])

dict :
 {'mention_id': '10091617.1', 'mention_name': 'Selegiline', 'context': {'tokens': ['[CLS]', '[unused1]', 'se', '##leg', '##ili', '##ne', '[unused2]', '-', 'induced', 'post', '##ural', 'h', '##y', '##pot', '##ens', '##ion', 'in', 'park', '##ins', '##on', "'", 's', 'disease', ':', 'a', 'longitudinal', 'study', 'on', 'the', 'effects', 'of', 'drug', 'withdrawal', '.', 'objectives', ':', 'the', 'united', 'kingdom', 'park', '##ins', '##on', "'", 's', 'disease', 'research', 'group', '(', 'uk', '##p', '##dr', '##g', ')', 'trial', 'found', 'an', 'increased', 'mortality', 'in', 'patients', 'with', 'park', '##ins', '##on', "'", 's', 'disease', '(', 'p', '##d', ')', 'random', '##ized', 'to', 'receive', '10', 'mg', 'se', '##leg', '##ili', '##ne', 'per', 'day', 'and', 'l', '-', 'do', '##pa', 'compared', 'with', 'those', 'taking', 'l', '-', 'do', '##pa', 'alone', '.', 'recently', ',', 'we', 'found', 'that', 'therapy', 'with', 'se', '##leg', '##ili', '##ne', 'and', 'l', '-', 'do', '##pa', 'was'

In [22]:
path_train_tensor_data = f'/home2/cye73/data_test2/arboel/{dataset}/train_tensor_data.pickle'
path_train_tensor_data2 = f'/home2/cye73/arboEL2/data/arboel/{dataset}/train_tensor_data.pickle'
with open(path_train_tensor_data, 'rb') as read_handle:
    train_tensor_data = pickle.load(read_handle)
with open(path_train_tensor_data2, 'rb') as read_handle:
    train_tensor_data2 = pickle.load(read_handle)

print("dict :\n", train_tensor_data[0])
print("dict2 :\n", train_tensor_data2[0])

dict :
 (tensor([  101,     1, 14516, 27412, 18575,  1673,     2,   118, 10645,  2112,
        12602,   177,  1183, 11439,  5026,  1988,  1107,  2493,  4935,  1320,
          112,   188,  3653,   131,   170, 23191,  2025,  1113,  1103,  3154,
         1104,  3850, 10602,   119, 11350,   131,  1103, 10280,  6139,  2493,
         4935,  1320,   112,   188,  3653,  1844,  1372,   113, 26006,  1643,
        23632,  1403,   114,  3443,  1276,  1126,  2569, 14471,  1107,  4420,
         1114,  2493,  4935,  1320,   112,   188,  3653,   113,   185,  1181,
          114,  7091,  2200,  1106,  3531,  1275, 17713, 14516, 27412, 18575,
         1673,  1679,  1285,  1105,   181,   118,  1202,  4163,  3402,  1114,
         1343,  1781,   181,   118,  1202,  4163,  2041,   119,  3055,   117,
         1195,  1276,  1115,  7606,  1114, 14516, 27412, 18575,  1673,  1105,
          181,   118,  1202,  4163,  1108,  2628,  1114, 14930,   188,  6834,
         2430,  8031,  1137,  1582, 15540,  7698,   177

In [26]:
path_entity_dictionary= f'/home2/cye73/data_test2/arboel/{dataset}/entity_dictionary.pickle'
path_entity2_dictionary = f'/home2/cye73/arboEL2/data/arboel/{dataset}/entity_dictionary.pickle'
with open(path_entity_dictionary, 'rb') as read_handle:
    entity_dict = pickle.load(read_handle)
with open(path_entity2_dictionary, 'rb') as read_handle:
    entity_dict2 = pickle.load(read_handle)

print("entity_dict :\n", entity_dict[0])
print("entity_dict2 :\n", entity_dict2[0])

entity_dict :
 {'cui': 'MESH:C000002', 'title': 'bevonium', 'type': 'CHEM', 'description': 'bevonium ( CHEM : 2-(hydroxymethyl)-N,N-dimethylpiperidinium benzilate ; piribenzil methyl sulfate ; bevonium methylsulfate ; bevonium metilsulfate ; CG 201 ; Acabel ; bevonium sulfate (1:1) ; bevonium methyl sulfate )', 'tokens': ['[CLS]', 'be', '##von', '##ium', '[unused3]', '(', 'ch', '##em', ':', '2', '-', '(', 'h', '##ydro', '##xy', '##met', '##hyl', ')', '-', 'n', ',', 'n', '-', 'dim', '##eth', '##yl', '##pipe', '##rid', '##ini', '##um', 'ben', '##zi', '##late', ';', 'p', '##iri', '##ben', '##zi', '##l', 'met', '##hyl', 'su', '##lf', '##ate', ';', 'be', '##von', '##ium', 'met', '##hyl', '##sul', '##fa', '##te', ';', 'be', '##von', '##ium', 'met', '##ils', '##ulf', '##ate', ';', 'c', '##g', '201', ';', 'a', '##ca', '##bel', ';', 'be', '##von', '##ium', 'su', '##lf', '##ate', '(', '1', ':', '1', ')', ';', 'be', '##von', '##ium', 'met', '##hyl', 'su', '##lf', '##ate', ')', '[SEP]'], 'ids': [1

In [9]:
import sys
sys.path.append('/home/cye73/biomedical-entity-linking')
from dataclasses import dataclass, field
from typing import List, Optional, Union
from umls_utils import UmlsMappings

@dataclass
class BiomedicalEntity:
    """
    Class for keeping track of all relevant fields in an ontology
    """
    cui: str
    name: str
    types: List[str]
    aliases: List[str]
    definition: Optional[str]
    equivalant_cuis: Optional[List[str]] = None
    taxonomy: Optional[str] = None
    extra_data: Optional[dict] = None

@dataclass
class BiomedicalOntology:
    name: str
    abbrev: Optional[str] = None                                           # Abbreviated name of ontology if different than name
    types: List[str] = field(default_factory=list)                                          # List of all types in the ontology                                        
    entities: List[BiomedicalEntity] = field(default_factory=list)                          # List Containing all Biomedical Entity Objects
    mappings: dict = field(default_factory=dict)                                            # Dict mapping a cui to the index in entities

    def get_aliases(self, cui=None):
        '''
        Get aliases for a particular CUI.  If cui=None, provide a mapping of {cui: [aliases]}
        '''
        pass

    def get_entities_with_alias(self, alias=None):
        '''
        Get all entities sharing a particular alias.  If alias=None, return a mapping of {alias: [cuis]}
        '''
        pass

    def get_definitions(self, cui):
        pass

    def from_obo(self, filepath=None):
        pass

    def load_umls(self, path = None, api_key = ""):
        umls = UmlsMappings(umls_dir = path, umls_api_key=api_key)

        # Get the Canonial Names
        lowercase = False
        umls_to_name = umls.get_canonical_name(
            ontologies_to_include="all",
            use_umls_curies=True,
            lowercase=lowercase,
        )

        # Group by the canonical names to group the alias and types 
        all_umls_df = umls.umls.query('lang == "ENG"').groupby('cui').agg({'alias': lambda x: list(set(x)), 'tui':'first', 'group': 'first', 'def':'first'}).reset_index()
        all_umls_df['name'] = all_umls_df.cui.map(umls_to_name)
        all_umls_df['alias'] = all_umls_df[['name','alias']].apply(lambda x: list(set(x[1]) - set([x[0]])) , axis=1)
        all_umls_df['cui'] = all_umls_df['cui'].map(lambda x: 'UMLS' + x)
        all_umls_df['has_definition'] = all_umls_df['def'].map(lambda x: x is not None)
        all_umls_df['num_aliases'] = all_umls_df['alias'].map(lambda x: len(x))

        for index, row in all_umls_df.iterrows():
            entity = BiomedicalEntity(
                cui = row['cui'],
                name = row['name'],
                types = row['tui'],
                aliases = row['alias'],
                definition = row['def'],
                extra_data = {
                    'group': row['group'],
                }
            )
            self.entities.append(entity)
            self.mappings[row['cui']] = index 
            
    def load_medic(self, path):
        '''
        path : str
        Path to medic.tsv dataset
        '''

        key_dict = [
            "DiseaseName",
            "DiseaseID",
            "AltDiseaseIDs",
            "Definition",
            "ParentIDs",
            "TreeNumbers",
            "ParentTreeNumbers",
            "Synonyms",
            "SlimMappings",
        ]
        
        # Open the TSV file
        with open(path, newline="") as tsvfile:
            # Create a CSV reader specifying the delimiter as a tab character
            reader = csv.reader(tsvfile, delimiter="\t")

            # Initialize a counter
            counter = 0

            ontology = []
            # Iterate over the rows in the file
            for row in reader:
                dict = {}
                # Print the current row
                if counter > 28 :
                    for i, elements in enumerate(row) :
                        dict[key_dict[i]] = elements
                    
                    ontology.append(dict)
                # Increment the counter
                counter += 1
                
        for element in ontology : 
            equivalant_cuis = [element['DiseaseID']]
            alt_ids = element['AltDiseaseIDs'].split('|') if element['AltDiseaseIDs'] else []
            for alt_id in alt_ids:
                if alt_id not in equivalant_cuis and alt_id[:2] != "DO":
                    equivalant_cuis.append(alt_id)
            entity = BiomedicalEntity(
                cui = element['DiseaseID'],
                name = element['DiseaseName'],
                types = "Disease",
                aliases = element['Synonyms'],
                definition = element['Definition'],
                equivalant_cuis = equivalant_cuis
            )
            self.entities.append(entity)


    def from_mesh(self):
        pass

    def from_ncbi_taxon(self):
        pass

    def from_csv(self):
        pass

    def from_json(self):
        pass


# if __name__ == "__main__":
#     ontology = BiomedicalOntology(name="UMLS")
#     ontology.load_umls(path="/mitchell/entity-linking/2017AA/META/", api_key="")
#     print(ontology.entities[0])
#     print(ontology.entities[0].__dict__)
#     print(ontology.entities[0].cui)
#     print(ontology.mappings[ontology.entities[0].cui])

In [10]:
import csv
ontology = BiomedicalOntology(name="UMLS")
ontology.load_medic(path="/mitchell/entity-linking/kbs/medic.tsv")

In [11]:
print('ontology.entities[0] :', ontology.entities[0])
print('ontology.entities[0].cui :', ontology.entities[0].cui)

ontology.entities[0] : BiomedicalEntity(cui='MESH:C538288', name='10p Deletion Syndrome (Partial)', types='Disease', aliases='Chromosome 10, 10p- Partial|Chromosome 10, monosomy 10p|Chromosome 10, Partial Deletion (short arm)|Monosomy 10p', definition='', equivalant_cuis=['MESH:C538288'], taxonomy=None, extra_data=None)
ontology.entities[0].cui : MESH:C538288
