In [None]:
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
from ipywidgets import interact, Dropdown
import pandas as pd
import numpy as np
import tarfile
import trec
import pprint
import json
import copy
import pickle
import re
from bertviz import model_view, head_view
from transformers import *
from sklearn.decomposition import PCA
import torch


%matplotlib notebook
%matplotlib inline
plt.style.use('ggplot')
    
pp = pprint.PrettyPrinter(width=120, compact=True)

# Load Patient Case Descriptions

In [None]:
Queries = "topics-2014_2015-summary.topics"
Qrels = "qrels-clinical_trials.txt"
with open(Queries, 'r') as queries_reader:
    txt = queries_reader.read()

root = ET.fromstring(txt)
cases = {}
for query in root.iter('TOP'):
    q_num = query.find('NUM').text
    q_title = query.find('TITLE').text
    cases[q_num] = q_title

eval = trec.TrecEvaluation(cases, Qrels)

In [None]:
pp.pprint(cases)

# Loading stop words file

In [None]:
gist_file = open("stopWords.txt", "r")
try:
    content = gist_file.read()
    stop_words = content.split(",")
finally:
    gist_file.close()


# Define Clinical Trial Document Structure

In [None]:
class Trial:
    _nct_id : str
    _brief_title : str
    _detailed_description : str
    _brief_summary : str
    _criteria : str
    _phase : str
    _study_type : str
    _study_design : str
    _condition : str
    _intervention : {}
    _gender : str
    _min_age : int
    _max_age : int
    _healthy_volunteers : str
    _mesh_terms : []

    def __init__(self):
        self._nct_id = ""
        self._intervention = {}
        self._mesh_terms = []

    def show(self):
        print(json.dumps(self.__dict__, indent=4))

def cleanstr(txt):
    return re.sub(' +', ' ', txt.strip().replace('\n',''))


# Load the clinical trials

Load from bin files

In [None]:
# Load ids and documents
ids = pickle.load( open( "doc_ids.bin", "rb" ) )
full_docs = pickle.load( open( "full_documents.bin", "rb" ) )
doc_dict = dict(zip(ids, full_docs))

Load from tar file

In [None]:
tar = tarfile.open("clinicaltrials.gov-16_dec_2015.tgz", "r:gz")
i = 0
ids = []
full_docs = []
for tarinfo in tar:
    if tarinfo.size > 500:
        txt = tar.extractfile(tarinfo).read().decode("utf-8", "strict")
        root = ET.fromstring(txt)

        judged = False
        for doc_id in root.iter('nct_id'):
            if doc_id.text in eval.judged_docs:
                judged = True

        if judged is False:
            continue

        i = i + 1

        trial = Trial()
        
        for brief_title in root.iter('brief_title'):
            ids.append(doc_id.text)
            trial._nct_id = cleanstr(doc_id.text)
            trial._brief_title = cleanstr(brief_title.text)

        trial._detailed_description = trial._brief_title
        for detailed_description in root.iter('detailed_description'):
            for child in detailed_description:
                trial._detailed_description = cleanstr(child.text)

        trial._brief_summary = trial._brief_title
        for brief_summary in root.iter('brief_summary'):
            for child in brief_summary:
                trial._brief_summary = cleanstr(child.text)

        trial._criteria = trial._brief_title
        for criteria in root.iter('criteria'):
            for child in criteria:
                trial._criteria = cleanstr(child.text)
                
        trial._phase = trial._brief_title
        for phase in root.iter('phase'):
            trial._phase = cleanstr(phase.text)

        for study_type in root.iter('study_type'):
            trial._study_type = cleanstr(study_type.text)
            
        for study_design in root.iter('study_design'):
            trial._study_design = cleanstr(study_design.text)
            
        trial._condition = trial._brief_title
        for condition in root.iter('condition'):
            trial._condition = cleanstr(condition.text)

        for interventions in root.iter('intervention'):
            for child in interventions:
                trial._intervention[cleanstr(child.tag)] = cleanstr(child.text)

        trial._gender = "both"
        for gender in root.iter('gender'):
            trial._gender = cleanstr(gender.text)
            
        trial._minimum_age = 0
        for minimum_age in root.iter('minimum_age'):
            age = re.findall('[0-9]+', cleanstr(minimum_age.text))
            if age:
                trial._minimum_age = int(age[0])
            else:
                trial._minimum_age = 0
            
        trial._maximum_age = 150
        for maximum_age in root.iter('maximum_age'):
            age = re.findall('[0-9]+', cleanstr(maximum_age.text))
            if age:
                trial._maximum_age = int(age[0])
            else:
                trial._maximum_age = 150
               
            
        trial._healthy_volunteers = trial._brief_title
        for healthy_volunteers in root.iter('healthy_volunteers'):
            trial._healthy_volunteers = cleanstr(healthy_volunteers.text)
            
        for mesh_term in root.iter('mesh_term'):
            trial._mesh_terms.append(cleanstr(mesh_term.text))
        
        full_docs.append(trial)
        
tar.close()

print("Total of clinical trials: ", i)

pickle.dump(ids, open( "doc_ids.bin", "wb" ) )
pickle.dump(full_docs, open( "full_documents.bin", "wb" ) )
doc_dict = dict(zip(ids, full_docs))

In [None]:
# Example of a document
pp.pprint(vars(full_docs[0]))

# Retrieval Models

In [None]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import pairwise_distances
from scipy.sparse import csr_matrix

class VSMindex:
    def __init__(self, corpus, _ngram_range=(1, 1), _analyzer='word', _stop_words=None):
        self.vectorizer = TfidfVectorizer(ngram_range=_ngram_range, analyzer=_analyzer, stop_words=_stop_words)
        self.count_matrix = self.vectorizer.fit_transform(corpus)

    def search(self, query):
        query_vector = self.vectorizer.transform([query])
        doc_scores = (1 - pairwise_distances(query_vector, self.count_matrix, metric='cosine')).flatten()
    
        return doc_scores


class LMJMindex:
    
    def __init__(self, corpus, _ngram_range=(1,1), _analyzer='word', _stop_words=None):
        self.vectorizer = CountVectorizer(ngram_range=_ngram_range, analyzer=_analyzer, stop_words=_stop_words)
        self.count_matrix = self.vectorizer.fit_transform(corpus)
        
        term_freq = np.sum(self.count_matrix, axis=0)

        doc_len = np.sum(self.count_matrix, axis=1) 

        self.prob_term_col = term_freq / np.sum(term_freq)

        self.prob_term_doc = self.count_matrix / doc_len
        
        # Set initial lambda value     
        params = {'lambda' : 0.3}
        self.set_params(params)

        
    def set_params(self, params):
        if 'lambda' in params:
            self.lbd = params['lambda']
            self._log_lmjm = np.log(self.lbd * self.prob_term_doc + (1 - self.lbd) * self.prob_term_col)
            print("LMJM lambda ", self.lbd)

            
    def search(self, query):
        query_vector = self.vectorizer.transform([query])

        #prob_term_query = query_vector / np.sum(query_vector)

        # np.array is used with flatten so instead of the operation resulting in a nx1 matrix it results in an array
        doc_scores = np.array(np.sum(query_vector.multiply(self._log_lmjm), axis=1)).flatten()

        return doc_scores

# LETOR Model

In [None]:
model_path = 'dmis-lab/biobert-v1.1'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path,  output_hidden_states=True, output_attentions=True)  
model = AutoModel.from_pretrained(model_path, config=config).to(device)

In [None]:
def extract_cls(query_pairs, embeddings, batch_size=32):

    # Iterate over all documents, in batches of size <batch_size>
    for batch_idx in range(0, len(query_pairs), batch_size):
        # Print how many batches have been processed and how many are left
        #print(f'Processing batch {1 + batch_idx/batch_size} of {round(len(query_pairs)/batch_size)}')

        # Get the current batch of samples
        batch_data = query_pairs[batch_idx:batch_idx + batch_size]

        inputs = tokenizer.batch_encode_plus(batch_data, 
                                       return_tensors='pt',  # pytorch tensors
                                       add_special_tokens=True,  # Add CLS and SEP tokens
                                       max_length = 512, # Max sequence length
                                       truncation = True, # Truncate if sequences exceed the Max Sequence length
                                       padding = True) # Add padding to forward sequences with different lengths
        
        # Forward the batch of (query, doc) sequences
        with torch.no_grad():
            inputs.to(device)
            outputs = model(**inputs)

        # Get the CLS embeddings for each pair query, document
        batch_cls = outputs['hidden_states'][-1][:,0,:]
        
        # L2-Normalize CLS embeddings. Embeddings norm will be 1.
        batch_cls = torch.nn.functional.normalize(batch_cls, p=2, dim=1)
        
        # Store the extracted CLS embeddings from the batch on the memory-mapped ndarray
        embeddings[batch_idx:batch_idx + batch_size] = batch_cls.cpu()
        
    return embeddings

In [None]:
from sklearn.preprocessing import StandardScaler

class LETORindex:
    
    def __init__(self, model, dataset):
        self.model = model
        self.dataset = dataset

        #self.scaler = StandardScaler()
        
        #test_queries = ["20155", "201514", "20152", "201512", "201524", "20154", "201423", "201429", "201413", "20144"]
        """
        training_dataset, test_dataset = self.split_dataset_on_queries(self.dataset, test_queries)
            
        X_train, y_train = self.get_X_scaled_and_y(training_dataset)
        self.X_test, self.y_test = self.get_X_scaled_and_y(test_dataset)
            
        self.y_test = self.y_test.replace(2,1)
        
        
        y_train = y_train.replace(2,1)
        self.model.fit(X_train, y_train)            
        """
        X = self.dataset.drop(columns=['case_id', 'doc_id', 'y'])
        y = self.dataset['y']
        

        w2 = len(y[y==0]) / len(y[y==2])
        w1 = len(y[y==0]) / len(y[y==1])
        w0 = 1

        weights = {0.0: w0, 1.0: w1, 2.0: w2}

        sample_weight = y.map(weights)
        y = y.replace(2.0,1.0)

        #X_scaled = self.scaler.fit_transform(X)

        self.model.fit(X, y, sample_weight=sample_weight)

        
    def split_dataset_on_queries(self, dataset, queries):
        # Training dataset, filters rows that do not have test_queries
        train_dataset = dataset[~dataset['case_id'].isin(queries)]

        train_dataset.to_csv("csv/train.csv", index=False)

        # Test dataset, filters rows with test_queries
        test_dataset = dataset[dataset['case_id'].isin(queries)]

        test_dataset.to_csv("csv/test.csv", index=False)
        
        return train_dataset, test_dataset
    

            
    def search(self, query):
        # Reverse the dictionary: create a new dictionary mapping values to keys
        reverse_dict = {v: k for k, v in cases.items()}

        # Get query id from query
        query_id = reverse_dict[query]

        query_embeddings = self.dataset[self.dataset['case_id'] == query_id].drop(columns=['case_id', 'doc_id', 'y'])
        

        #scaled_query_embeddings = self.scaler.transform(query_embeddings)

        doc_scores = query_embeddings.dot(self.model.coef_[0].T)
        
        return doc_scores

Load dataset from csv

In [None]:
# Load dataframe if it already exists
dataset = pd.read_csv('csv/dataset.csv')
# Change type of case_id to string
dataset['case_id'] = dataset['case_id'].astype(str)

# Load embeddings if they already exist
#embeddings_df = pd.read_csv('csv/embeddings.csv')
# Change type of case_id to string
#embeddings_df['case_id'] = embeddings_df['case_id'].astype(str)

Create dataset if it does not exist

In [None]:
from itertools import product

query_document_pairs_df = pd.DataFrame(product(cases.keys(), ids), columns=['case_id', 'doc_id'])

query_document_pairs = []
for case_id in cases:
    query = cases[case_id]
    for doc_id in doc_dict:
        doc = doc_dict[doc_id]._detailed_description
        query_document_pairs.append((query, doc))
        

# Create dataset from text file with y values (relevant or non relevant) for each query/document pair
relevancy_df = pd.read_csv('qrels-clinical_trials.txt', 
                           delim_whitespace=True, 
                           names=['case_id', 'x', 'doc_id','y'], 
                           dtype={'case_id': str}).drop('x', axis=1)

relevancy_df.to_csv("csv/relevancy.csv", index=False)

# Merge dataset with text file for y values (relevant or non relevant) for each query/document pair
dataset_without_embeddings = query_document_pairs_df.merge(relevancy_df, on=['case_id', 'doc_id'], how='left')

In [None]:
# WARNING, EXTREMELY LONG EXECUTION TIME
# Numpy ndarray that will store (in RAM) the CLS embeddings of each (query, doc) pair
embeddings = np.zeros((len(query_document_pairs), 768))

# Extract the embedding of the CLS token of the last layer for each (query, doc) pair
embeddings = extract_cls(query_document_pairs, embeddings=embeddings, batch_size=8)

embeddings_df = pd.DataFrame(embeddings)
embeddings_df = pd.concat([query_document_pairs_df, embeddings_df], axis=1)
embeddings_df.to_csv("csv/embeddings.csv", index=False)

In [None]:
# Create a new dataset with the CLS embeddings and the y values
dataset = dataset_without_embeddings.merge(embeddings_df, on=['case_id', 'doc_id'], how='left').fillna(0)
dataset.to_csv("csv/dataset.csv", index=False)

# Indexing

VSM and LMJM

In [None]:
import pickle
ids = pickle.load( open( "doc_ids.bin", "rb" ) )
full_docs = pickle.load( open( "full_documents.bin", "rb" ) )

corpus_brief_title = []
corpus_brief_summary = []
corpus_detailed_description = []
corpus_criteria = []
corpus_full = []
for trial in full_docs:
    corpus_brief_title.append(trial._brief_title)
    corpus_brief_summary.append(trial._brief_summary)
    corpus_detailed_description.append(trial._detailed_description)
    corpus_criteria.append(trial._criteria)

    full = trial._brief_title + trial._brief_summary + trial._detailed_description + trial._criteria
    corpus_full.append(full)

indexes_list = {}

indexes_list['vsm_brief_title'] = VSMindex(corpus_brief_title, _ngram_range=(1,1), _analyzer='word', _stop_words = stop_words)
indexes_list['vsm_brief_summary'] = VSMindex(corpus_brief_summary, _ngram_range=(1,1), _analyzer='word', _stop_words = stop_words)
indexes_list['vsm_detailed_description'] = VSMindex(corpus_detailed_description, _ngram_range=(1,1), _analyzer='word', _stop_words = stop_words)
indexes_list['vsm_criteria'] = VSMindex(corpus_criteria, _ngram_range=(1,1), _analyzer='word', _stop_words = stop_words)
indexes_list['vsm_full'] = VSMindex(corpus_full, _ngram_range=(1,1), _analyzer='word', _stop_words = stop_words)

indexes_list['lmjm_brief_title'] = LMJMindex(corpus_brief_title, _ngram_range=(1,1), _analyzer='word', _stop_words = stop_words)
indexes_list['lmjm_brief_summary'] = LMJMindex(corpus_brief_summary, _ngram_range=(1,1), _analyzer='word', _stop_words = stop_words)
indexes_list['lmjm_detailed_description'] = LMJMindex(corpus_detailed_description, _ngram_range=(1,1), _analyzer='word', _stop_words = stop_words)
indexes_list['lmjm_criteria'] = LMJMindex(corpus_criteria, _ngram_range=(1,1), _analyzer='word', _stop_words = stop_words)
indexes_list['lmjm_full'] = LMJMindex(corpus_full, _ngram_range=(1,1), _analyzer='word', _stop_words = stop_words)

#pickle.dump(indexes_list, open( "indexes_list.bin", "wb" ) )

LETOR

In [None]:
from sklearn.linear_model import LogisticRegression
C = 0.5

indexes_list['LETOR_model'] = LETORindex(LogisticRegression(C=C, max_iter=1000), dataset)

# Retrieval

In [None]:
for index_name in indexes_list:
    print()
    print(index_name)
    index = indexes_list[index_name]
    index.doc_scores = {}
    index.results_ord = {}
    for caseid in cases:
        query = cases[caseid]

        doc_scores = index.search(query)

        index.doc_scores[caseid] = doc_scores
        results = pd.DataFrame(list(zip(ids, doc_scores)), columns = ['_id', 'score'])
        index.results_ord[caseid] = results.sort_values(by=['score'], ascending = False)

    indexes_list[index_name] = index

#pickle.dump(indexes_list, open( "indexes_list.bin", "wb" ) )

LETOR only

In [None]:
for index_name in indexes_list:
    if index_name != 'LETOR_model':
        continue
    print()
    print(index_name)
    index = indexes_list[index_name]
    index.doc_scores = {}
    index.results_ord = {}
    i = 1
    for caseid in cases:
        query = cases[caseid]

        doc_scores = index.search(query)

        index.doc_scores[caseid] = doc_scores
        results = pd.DataFrame(list(zip(ids, doc_scores)), columns = ['_id', 'score'])
        index.results_ord[caseid] = results.sort_values(by=['score'], ascending = False)
        
        i = i + 1

    indexes_list[index_name] = index

#pickle.dump(indexes_list, open( "indexes_list.bin", "wb" ) )

## Filters

In [None]:
import re

def check_words_in_query(query, target_words):
    for target_word in target_words:
        pattern = re.compile(r'\b{}\b|\b{}-\b'.format(target_word, target_word), re.IGNORECASE)

        match = re.search(pattern, query)
        
        if bool(match):
            return True

    return False


def extract_age_and_gender(query):
    age_pattern = re.compile(r'\b(\d{1,3})\b(?: ?(?:-|years?-?)? ?(?:old|yo))?\b')
    gender_patterns = re.compile(r'\b(?:man|boy|male|woman|women|girl|female)\b', re.IGNORECASE)

    # Extract age
    default_age = 150
    age_match = age_pattern.search(query)
    # Check if the query contains relevant terms indicating age
    age_terms_present = check_words_in_query(query ,["year", "month", "yo"])
    
    if age_match and age_terms_present:        
        age_value = int(age_match.group(1))
        age = age_value / 12 if 'month' in query.lower() and 'year' not in query.lower() else age_value
        age = age if age_terms_present else default_age
        
    else:
        young_adult = check_words_in_query(query, ["young"])
        default_age = 21 if young_adult else default_age
        age = default_age

    # Extract gender
    default_gender = 'Unknown'
    gender_match = gender_patterns.search(query)
    gender = gender_match.group(0).lower() if gender_match else default_gender

    # Standardize gender
    if gender in ['man', 'boy', 'male']:
        gender = 'Male'
    elif gender in ['woman', 'women', 'girl', 'female']:
        gender = 'Female'    
    

    return age, gender



def filter_by_age_and_gender(doc_scores, query):
    
    age, gender = extract_age_and_gender(query)
    
    doc_scores = list(doc_scores)
    
    for i in range(len(full_docs)):
        trial = vars(full_docs[i])
        
        age_check = trial['_minimum_age'] <= age and age <= trial['_maximum_age']
        
        
        gender_check = trial['_gender'] == 'Both' or trial['_gender'] == gender

        
        # Filter unwanted documents
        # Filter document i (CHECK IF IT IS WORKING)
        if not (age_check and gender_check):
            doc_scores[i] = -1000000
    

    #return filtered_ids
    return doc_scores

------------------------------
# Compute evaluation metric results
For each patient, search each index and rank clinical trials by their similarity to the patient case description

In [None]:
for index_name in indexes_list:
    index = indexes_list[index_name]
    print()
    print(index_name)

    m_ap = 0
    m_p10 = 0
    m_mrr = 0
    m_ndcg5 = 0
    m_recall = 0
    mean_precision_11point = np.zeros(11)
    
    index.p10_per_query = {}
    index.ndcg5_per_query = {}
    index.recall_per_query = {}
    index.mrr_per_query = {}
    index.ap_per_query = {}
    index.precision_11point_per_query = {}

    for caseid in cases:
        query = cases[caseid]

        results_ord = index.results_ord[caseid]
        #p10 = eval.fast_p10(results_ord, caseid)

        doc_scores = copy.deepcopy(index.doc_scores[caseid])
        doc_scores = filter_by_age_and_gender(doc_scores, query)
        results = pd.DataFrame(list(zip(ids, doc_scores)), columns = ['_id', 'score'])
        results_ord = results.sort_values(by=['score'], ascending = False)
        #p10 = eval.fast_p10(results_ord, caseid)

        [p10, recall, ap, ndcg5, mrr] = eval.eval(results_ord, caseid)
        [precision_11point, recall_11point, total_relv_ret] = eval.evalPR(results_ord, caseid)

        if (np.shape(recall_11point) != (0,)):
            mean_precision_11point = mean_precision_11point + precision_11point
        
        index.p10_per_query[caseid] = p10
        index.ndcg5_per_query[caseid] = ndcg5
        index.recall_per_query[caseid] = recall
        index.mrr_per_query[caseid] = mrr
        index.ap_per_query[caseid] = ap
        index.precision_11point_per_query[caseid] = precision_11point

        m_ap = m_ap + ap
        m_p10 = m_p10 + p10
        m_mrr = m_mrr + mrr
        m_ndcg5 = m_ndcg5 + ndcg5
        m_recall = m_recall + recall

    index.m_ap = m_ap / len(cases)
    index.m_p10 = m_p10 / len(cases)
    index.m_mrr = m_mrr / len(cases)
    index.m_ndcg5 = m_ndcg5 / len(cases)
    index.m_recall = m_recall / len(cases)
    index.mean_precision_11point = mean_precision_11point/len(cases)

    print("   P10    ", index.m_p10)
    print("   NDCG@5 ", index.m_ndcg5)
    print("   MRR    ", index.m_mrr)
    print("   MAP    ", index.m_ap)
    print("   Recall ", index.m_recall)

    indexes_list[index_name] = index

pickle.dump(indexes_list, open( "indexes_results.bin", "wb" ) )

LETOR only

In [None]:
for index_name in indexes_list:
    if index_name != 'LETOR_model':
        continue
    index = indexes_list[index_name]
    print()
    print(index_name)

    m_ap = 0
    m_p10 = 0
    m_mrr = 0
    m_ndcg5 = 0
    m_recall = 0
    mean_precision_11point = np.zeros(11)
    
    index.p10_per_query = {}
    index.ndcg5_per_query = {}
    index.recall_per_query = {}
    index.mrr_per_query = {}
    index.ap_per_query = {}
    index.precision_11point_per_query = {}

    for caseid in cases:
        query = cases[caseid]

        results_ord = index.results_ord[caseid]
        #p10 = eval.fast_p10(results_ord, caseid)

        doc_scores = copy.deepcopy(index.doc_scores[caseid])
        doc_scores = filter_by_age_and_gender(doc_scores, query)
        results = pd.DataFrame(list(zip(ids, doc_scores)), columns = ['_id', 'score'])
        results_ord = results.sort_values(by=['score'], ascending = False)
        #p10 = eval.fast_p10(results_ord, caseid)

        [p10, recall, ap, ndcg5, mrr] = eval.eval(results_ord, caseid)
        [precision_11point, recall_11point, total_relv_ret] = eval.evalPR(results_ord, caseid)

        if (np.shape(recall_11point) != (0,)):
            mean_precision_11point = mean_precision_11point + precision_11point
        
        index.p10_per_query[caseid] = p10
        index.ndcg5_per_query[caseid] = ndcg5
        index.recall_per_query[caseid] = recall
        index.mrr_per_query[caseid] = mrr
        index.ap_per_query[caseid] = ap
        index.precision_11point_per_query[caseid] = precision_11point

        m_ap = m_ap + ap
        m_p10 = m_p10 + p10
        m_mrr = m_mrr + mrr
        m_ndcg5 = m_ndcg5 + ndcg5
        m_recall = m_recall + recall

    index.m_ap = m_ap / len(cases)
    index.m_p10 = m_p10 / len(cases)
    index.m_mrr = m_mrr / len(cases)
    index.m_ndcg5 = m_ndcg5 / len(cases)
    index.m_recall = m_recall / len(cases)
    index.mean_precision_11point = mean_precision_11point/len(cases)

    print("   P10    ", index.m_p10)
    print("   NDCG@5 ", index.m_ndcg5)
    print("   MRR    ", index.m_mrr)
    print("   MAP    ", index.m_ap)
    print("   Recall ", index.m_recall)

    indexes_list[index_name] = index

#pickle.dump(indexes_list, open( "indexes_results.bin", "wb" ) )

# Results and discussion

In [None]:
# Load previously computed results
indexes_list = pickle.load( open( "indexes_results.bin", "rb" ) )

All

In [None]:
recall_11point = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
for index_name in indexes_list:
    index = indexes_list[index_name]
    plt.plot(recall_11point, index.mean_precision_11point, label = index_name)
    plt.legend()

plt.show()

In [None]:
results_table = pd.DataFrame(columns =['model', 'p10', 'ndcg5', 'mrr', 'map', 'recall'])
for index_name in indexes_list:
    index = indexes_list[index_name]
    aa = pd.DataFrame({'model':[index_name], 'p10':[index.m_p10], 'ndcg5':[index.m_ndcg5], 'mrr': [index.m_mrr], 'map':[index.m_ap], 'recall':[index.m_recall]})
    results_table = results_table._append(aa, ignore_index=False)

results_table

LETOR

In [None]:
recall_11point = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
for index_name in indexes_list:
    if index_name == 'LETOR_model':
        index = indexes_list[index_name]
        plt.plot(recall_11point, index.mean_precision_11point, label = index_name)
        plt.legend()

plt.show()

In [None]:
results_table = pd.DataFrame(columns =['model', 'p10', 'ndcg5', 'mrr', 'map', 'recall'])

for index_name in indexes_list:
    if index_name == 'LETOR_model':
        index = indexes_list[index_name]
        aa = pd.DataFrame({'model':[index_name], 'p10':[index.m_p10], 'ndcg5':[index.m_ndcg5], 'mrr': [index.m_mrr], 'map':[index.m_ap], 'recall':[index.m_recall]})
        results_table = results_table._append(aa, ignore_index=False)

results_table

------------------------------

In [None]:
# ORIGINAL CELL MADE BY THE TEACHER
index = indexes_list['lmjm_full']
#index = indexes_list['lmjm_brief_title']

query_result = np.sort([index.p10_per_query[caseid] for caseid in cases])
query_text = np.sort([cases[caseid][0:30] for caseid in cases])

figure(figsize=(8, 19), dpi=80)

plt.barh(query_text, query_result)

In [None]:
metrics = ['p10', 'ndcg5', 'recall', 'mrr', 'ap']

def get_query_result(selected_index, selected_metric):
    query_result = None
    
    if selected_metric == metrics[0]:
        query_result = np.sort([indexes_list[selected_index].p10_per_query[caseid] for caseid in cases])
        
    elif selected_metric == metrics[1]:
        query_result = np.sort([indexes_list[selected_index].ndcg5_per_query[caseid] for caseid in cases])
        
    elif selected_metric == metrics[2]:
        query_result = np.sort([indexes_list[selected_index].recall_per_query[caseid] for caseid in cases])
        
    elif selected_metric == metrics[3]:
        query_result = np.sort([indexes_list[selected_index].mrr_per_query[caseid] for caseid in cases])
        
    elif selected_metric == metrics[4]:
        query_result = np.sort([indexes_list[selected_index].ap_per_query[caseid] for caseid in cases])
        
    return query_result
    

In [None]:
# CELL CREATED TO VISUALIZE THE METRICS OF EACH MODEL INDIVIDUALLY WITH EASE

query_texts = np.sort([cases[caseid][0:30] for caseid in cases])

def plot_barh(selected_index, selected_metric):
    plt.figure(figsize=(8, 19), dpi=80)
    
    query_result = get_query_result(selected_index, selected_metric)
        
    y = np.arange(len(query_result))
    
    figure(figsize=(8, 19), dpi=80)

    plt.barh(query_texts, query_result)


index_selector = Dropdown(options=list(indexes_list.keys()))
metric_selector = Dropdown(options=metrics)

interact(plot_barh, selected_index=index_selector, selected_metric = metric_selector)

In [None]:
# CELL CREATED TO VISUALIZE THE METRICS OF ALL MODELS AT ONCE (MIGHT NEED CHANGING SO MARKERS IN THE CHART DONT OVERLAP)
import plotly.express as px
query_texts = np.sort([cases[caseid][0:30] for caseid in cases])

# Define a function to update the scatter plot
def update_scatter(selected_metric):
    data = []

    for key in indexes_list:
        query_result = get_query_result(key, selected_metric)
        data.extend([(query_result[i], query_texts[i], key) for i in range(len(query_result))])

    df = pd.DataFrame(data, columns=[selected_metric, "Cases", "Model"])

    fig = px.scatter(
        df,
        x=selected_metric,
        y="Cases",
        color="Model",
        title=f"Scatter Plot of {selected_metric}"
    )
    
    fig.update_layout(
        height=1000,
        width=1000,  
    )
    

    fig.show()

# Create a dropdown menu for selecting metrics
metric_selector = Dropdown(options=metrics)

# Use the interact function to update the scatter plot based on the selected metric
interact(update_scatter, selected_metric=metric_selector)

Load model and tokenizer

In [None]:
model_path = 'dmis-lab/biobert-v1.1'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(model_path)
config = AutoConfig.from_pretrained(model_path,  output_hidden_states=True, output_attentions=True)  
model = AutoModel.from_pretrained(model_path, config=config).to(device)

Layer Embeddings Visualization

In [None]:
import nltk
from nltk.corpus import stopwords
import torch
import string

# Make sure nltk's stopwords are downloaded
nltk.download('stopwords')

def get_tokens_and_outputs(query):

    # Tokenize the query
    inputs = tokenizer.encode_plus(query, return_tensors='pt', add_special_tokens=True, max_length=512, truncation=True)
    input_ids = inputs['input_ids']
    input_id_list = input_ids[0].tolist()  # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)

    with torch.no_grad():
        inputs.to(device)
        outputs = model(**inputs)  # Ensure 'model' is a callable PyTorch model

    return tokens, outputs


def extract_top_words_from_query(query, top_k, remove_stop_words=True):
    # Define stop words if removal is enabled
    stop_words = set(stopwords.words('english')) if remove_stop_words else set()

    tokens, outputs = get_tokens_and_outputs(query)

    attention = outputs['attentions']

    # Calculate mean attention accors all heads in the last layer
    attention_scores = torch.sum(attention[-1][0], dim=(0,1))

    # Remove special tokens, punctuation, and stop words and duplicate tokens
    token_with_scores = [(idx, token, score) for idx, (token, score) in enumerate(zip(tokens, attention_scores)) 
                         if token not in ['[CLS]', '[SEP]'] and token.lower() not in stop_words and token not in string.punctuation]


    # Sort by score and select top X tokens' positions without duplicates
    top_token_positions = []
    added_tokens = set()
    for idx, _, _ in sorted(token_with_scores, key=lambda x: x[2], reverse=True):
        token = tokens[idx]
        if token not in added_tokens:
            top_token_positions.append(idx)
            added_tokens.add(token)
            if len(top_token_positions) == top_k:
                break

    return top_token_positions


def plot_embeddings(embeddings1, embeddings2, tokens):
    transformed_embeddings1 = PCA().fit_transform(embeddings1)[:,:2]
    transformed_embeddings2 = PCA().fit_transform(embeddings2)[:,:2]
    
    plt.figure(figsize=(12, 12))
    
    plt.scatter(transformed_embeddings1[:,0], transformed_embeddings1[:,1], edgecolors='k', c='r')
    plt.scatter(transformed_embeddings2[:,0], transformed_embeddings2[:,1], edgecolors='k', c='b')

    for token, (x1, y1), (x2, y2) in zip(tokens, transformed_embeddings1, transformed_embeddings2):
        plt.text(x1+0.05, y1+0.05, token)
        plt.text(x2+0.05, y2+0.05, token)
    


In [None]:
query = cases["20141"]
doc = doc_dict["NCT00000492"]._detailed_description
query_doc_pair = (query, doc)

query_tokens, outputs = get_tokens_and_outputs(query_doc_pair)
top_positions = extract_top_words_from_query(query_doc_pair, top_k=10, remove_stop_words=True)
sorted_top_positions = sorted(top_positions)

top_tokens = [query_tokens[i] for i in sorted_top_positions]

# Get the embeddings of the first and last layers for the top tokens
embeddings_first_layer = outputs['hidden_states'][0][:, top_positions, :].cpu().numpy().squeeze()
embeddings_last_layer = outputs['hidden_states'][-1][:, top_positions, :].cpu().numpy().squeeze()
#embeddings_first_layer = outputs['hidden_states'][0][:, :, :].cpu().numpy().squeeze()
#embeddings_last_layer = outputs['hidden_states'][-1][:, :, :].cpu().numpy().squeeze()

pp.pprint(query_doc_pair)
pp.pprint(top_tokens)

plot_embeddings(embeddings_first_layer, embeddings_last_layer, top_tokens)
#plot_embeddings(embeddings_first_layer, embeddings_last_layer, query_tokens)

Layer Embeddings Similarity Visualization

In [None]:
attention = outputs['attentions']

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_matrix(layer, tokens, attention, specific_positions=None):
    multiplier = len(tokens) / 3
    rows = 3
    cols = 4
    fig, ax_full = plt.subplots(rows, cols)
    fig.set_figheight(rows*multiplier)
    fig.set_figwidth(cols*multiplier+3)
    plt.rcParams.update({'font.size': 10})

    j = 0
    for r in range(rows):
        for c in range(cols):
        
            ax = ax_full[r,c]
            
            # Get the attention scores for the j-th head in the i-th layer
            sattention = attention[layer][0][j].cpu().numpy()

            # Filter attention scores for specific positions if provided
            if specific_positions is not None:
                # Sort positions in ascending order
                sattention = sattention[specific_positions, :][:, specific_positions]
          
            sattention = np.flip(sattention, 0)  
            
            plt.rcParams.update({'font.size': 10})

            im = ax.pcolormesh(sattention, cmap='gnuplot')

            # Show all ticks and label them with the respective list entries
            ax.set_title("Head " + str(j))
            ax.set_yticks(np.arange(len(tokens)))
            if c == 0:
                ax.set_yticklabels(reversed(tokens))
                ax.set_ylabel("Queries")
            else:
                ax.set_yticks([])

            ax.set_xticks(np.arange(len(tokens)))
            if r == rows-1:
                ax.set_xticklabels(tokens)
                ax.set_xlabel("Keys")
                
                # Rotate the tick labels and set their alignment.
                plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
                        rotation_mode="anchor")
            else:
                ax.set_xticks([])

                
            # Loop over data dimensions and create text annotations.
            j = j + 1

    fig.suptitle("Layer" + str(layer) + " Multi-head Self-attentions")
    cbar = fig.colorbar(im, ax=ax_full, location='right', shrink=0.5)
    cbar.ax.set_ylabel("Selt-attention", rotation=-90, va="bottom")
    plt.show()

In [None]:
layer = -1
plot_matrix(layer, top_tokens, attention, specific_positions=sorted_top_positions)

Self-attention head visualization

In [None]:
def call_html():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/5.7.0/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))

call_html()

In [None]:
head_view(attention, query_tokens)

In [None]:
model_view(attention, query_tokens)

# Bonus Question

In [None]:
# TODO (maybe)