This notebook takes in a table of GO term, LLM name, and human curated name and obtains a the semantic sentence similarity between each pair


Run separately for NeST and MSigDB

In [1]:
from transformers import AutoTokenizer, AutoModel
from semanticSimFunctions import getSentenceEmbedding, getSentenceSimilarity, getNameSimilarities

In [2]:
SapBERT_tokenizer = AutoTokenizer.from_pretrained('cambridgeltl/SapBERT-from-PubMedBERT-fulltext')
SapBERT_model = AutoModel.from_pretrained('cambridgeltl/SapBERT-from-PubMedBERT-fulltext')

In [3]:
import pandas as pd

In [4]:
# Edit these parameters
dataType = "MSigDB"
runVersion = "test"

In [5]:
if runVersion == "test":
    infix = '_subset'
else: 
    infix = ''

In [12]:
if dataType == "MSigDB":
    geneSetID = 'Name'
    humanNameCol = 'Name'
    
elif dataType == "NeST":
    geneSetID = 'NEST ID'
    humanNameCol = 'name_new'
else:
    print("Data type not implemented yet")

In [13]:
inputFile = 'data/' + dataType + '_table' + infix + '_LLM_Enrichr_DF.tsv' 

In [15]:
expanded_LLM_genes_geneSetDF = pd.read_csv(inputFile, sep = "\t") 

In [16]:
reduced_LLM_genes_DF = expanded_LLM_genes_geneSetDF.loc[expanded_LLM_genes_geneSetDF.reset_index().groupby(geneSetID)['Adjusted P-value'].idxmin()]

In [17]:
reduced_LLM_genes_DF = reduced_LLM_genes_DF.reset_index()

In [18]:
reduced_LLM_genes_DF.columns

Index(['index', 'Unnamed: 0', 'Name', 'Genes', 'Genes.1', 'LLM Name',
       'LLM Analysis', 'Rank', 'Overlap', 'P-value', 'Adjusted P-value',
       'Genes.2', 'Genes.3', 'GO term', 'GO ID'],
      dtype='object')

In [19]:
reduced_LLM_genes_DF.shape

(3, 15)

In [20]:
names_DF = getNameSimilarities(reduced_LLM_genes_DF, 'LLM Name', 'GO term', humanNameCol, SapBERT_tokenizer, SapBERT_model, "cosine_similarity")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


0
1
2


In [21]:
outputFile = 'data/' + dataType + '_table' + infix + '_LLM_Enrichr_simVals_DF.tsv' 

In [22]:
names_DF.to_csv(outputFile, sep = "\t")

## Get performance measure

In [23]:
# What percentage of systems have LLM_name_human_name_sim > GO_term_human_name_sim
names_DF.winner.value_counts()

Tied    1
LLM     1
GO      1
Name: winner, dtype: int64

In [24]:
names_DF.shape

(3, 18)