In [None]:
from txgnn import TxData, TxGNN, TxEval
import requests

TxData = TxData(data_folder_path = '/Users/emmatysinger/Develop/meng/kg/')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Load the data
kg_df = pd.read_csv('../kg/kg.csv')
nodes_df = pd.read_csv('../kg/node.csv', sep='\t')
edges_df = pd.read_csv('../kg/edges.csv')


In [None]:
from Bio import SeqIO

def count_sequences(file_name):
    count = 0
    with open(file_name, "r") as file:
        for record in SeqIO.parse(file, "fasta"):
            count += 1
    return count

# Replace 'gene_sequences_1.fasta' with the path to your FASTA file
number_of_sequences = count_sequences('embeddings/gene_sequences_3.fasta')
print(f"Number of sequences: {number_of_sequences}")


## Node Exploration

In [None]:
print("Node Data:")
nodes_df.head()

In [None]:
handle = Entrez.esearch(db="protein", retmax=10, term='GPANK1')
record = Entrez.read(handle)
handle.close()
record["IdList"]

In [None]:
handle = Entrez.efetch(db="protein", id=record["IdList"][0], rettype="fasta", retmode="text")
gene_data = handle.read()
handle.close()
gene_data

In [None]:
handle = Entrez.efetch(db="protein", id=record["IdList"][1], rettype="fasta", retmode="text")
gene_data = handle.read()
handle.close()
gene_data

In [None]:
from Bio import Entrez
import time

# Always tell NCBI who you are (email)
Entrez.email = "tysinger@mit.edu"

def fetch_sequences(id_list, output_file, gene_id_dict):
    with open(output_file, 'w') as outfile:
        for i, gene_id in enumerate(id_list):
            # Fetch the sequence
            if i%10 == 0:
                print(i)
            handle = Entrez.esearch(db="protein", retmax=10, term=gene_id, idtype='acc')
            record = Entrez.read(handle)
            handle.close()
            id_list = record["IdList"]
            gene_id_dict[id_list[0]] = gene_id

            handle = Entrez.efetch(db="protein", id=id_list[0], rettype="fasta", retmode="text")
            gene_data = handle.read()
            handle.close()
            
            # Write the sequence to a file
            outfile.write(gene_data)
            
            # NCBI recommends not to send more than 3 requests per second to avoid overload
            # on their servers, so we wait for a third of a second before the next request
            time.sleep(1)
    return gene_id_dict

# Example usage
gene_ids = list(nodes_df[nodes_df.node_type == 'gene/protein']['node_name'])[:200]  
output_filename = "gene_sequences.fasta"  
gene_id_dict = {}
gene_id_dict = fetch_sequences(gene_ids, output_filename, gene_id_dict)


In [None]:
import pathlib
import torch

from esm import FastaBatchedDataset, pretrained

In [None]:
def extract_embeddings(model_name, fasta_file, output_dir, gene_dict, tokens_per_batch=4096, seq_length=1022,repr_layers=[33]):
    
    model, alphabet = pretrained.load_model_and_alphabet(model_name)
    model.eval()

    if torch.cuda.is_available():
        model = model.cuda()
        
    dataset = FastaBatchedDataset.from_file(fasta_file)
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)

    data_loader = torch.utils.data.DataLoader(
        dataset, 
        collate_fn=alphabet.get_batch_converter(seq_length), 
        batch_sampler=batches
    )

    output_dir.mkdir(parents=True, exist_ok=True)
    
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):

            print(f'Processing batch {batch_idx + 1} of {len(batches)}')

            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)

            logits = out["logits"].to(device="cpu")
            representations = {layer: t.to(device="cpu") for layer, t in out["representations"].items()}
            
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                try:
                    gene_id = gene_dict[entry_id]
                except:
                    entry_id = entry_id.split('|')[1]
                    gene_id = gene_dict[entry_id]
                
                filename = output_dir / f"{gene_id}.pt"
                truncate_len = min(seq_length, len(strs[i]))

                result = {"gene_id": gene_id, 
                          "entry_id": entry_id}
                result["mean_representations"] = {
                        layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                        for layer, t in representations.items()
                    }

                torch.save(result, filename)

In [None]:
gene_id_dict

In [None]:
model_name = 'esm2_t33_650M_UR50D'
fasta_file = pathlib.Path('./gene_sequences.fasta')
output_dir = pathlib.Path('train_embeddings')

extract_embeddings(model_name, fasta_file, output_dir, gene_id_dict)


In [None]:
import torch

embedding = torch.load('train_embeddings/XP_060544080.1.pt')
embedding = embedding['mean_representations'][33].numpy()

In [None]:
embedding.shape

In [None]:
len(set(nodes_df[nodes_df.node_type == 'gene/protein']['node_name']))

In [None]:
# getting unique node types
unique_node_types = nodes_df['node_type'].value_counts()
unique_node_types_df = unique_node_types.to_frame().reset_index()
unique_node_types_df.columns = ['node_type', 'count']
unique_node_types_df


In [None]:
# get protein nodes
protein_nodes = nodes_df[nodes_df['node_type'] == 'gene/protein']
protein_nodes.set_index('node_index', inplace=True)
protein_nodes

In [None]:
# getting unique sources
unique_sources = protein_nodes['node_source'].value_counts()
unique_sources_df = unique_sources.to_frame().reset_index()
unique_sources_df.columns = ['node_source', 'count']
unique_sources_df

## Edge Exploration

In [None]:
print("Edges Data:")
edges_df.head()

In [None]:
print('Unique relations: ')
unique_relations = edges_df['relation'].value_counts()
unique_relations_df = unique_relations.to_frame().reset_index()
unique_relations_df.columns = ['relation', 'count']
unique_relations_df

In [None]:
relation_sources = unique_kg_df_proteins[['relation', 'y_source']].drop_duplicates(subset=['relation'], keep='first').reset_index(drop=True).rename(columns={'y_source': 'data_source'})
relation_sources

In [None]:
# unique relations containing protein
protein_relations = unique_relations_df[unique_relations_df.relation.str.contains('protein')].reset_index(drop=True)
protein_relations = merged_df = pd.merge(protein_relations, relation_sources, on='relation', how='left')
protein_relations['relation'] = ['Anatomy Present', 'Protein', 'Biological Process', 'Cellular Component', 'Disease', 'Molecular Function',
                                 'Pathway', 'Drug', 'Anatomy Absent', 'Phenotype', 'Exposure']

plt.figure(figsize=(8, 4))  # Adjust the figure size as needed
sns.barplot(x='count', y='relation', data=protein_relations, hue='data_source',dodge=False)

# Set plot labels and title
plt.xlabel('Count')
plt.ylabel('')
plt.title('Relations containing Protein')
plt.legend(title='Data Source')

# Show the plot
plt.show()

In [None]:
protein_relations

In [None]:
print('Unique display relations: ')
unique_display_relations = edges_df['display_relation'].value_counts()
unique_display_relations_df = unique_display_relations.to_frame().reset_index()
unique_display_relations_df.columns = ['display relation', 'count']
unique_display_relations_df

In [None]:
unique_pairs_counts = edges_df.groupby(['relation', 'display_relation']).size().reset_index(name='count')
unique_pair_counts_df = unique_pairs_counts
unique_pair_counts_df.columns = ['relation','display relation', 'count']
unique_pair_counts_df[unique_pair_counts_df.relation.str.contains('protein')].reset_index(drop=True)

## Whole Knowledge Graph Exploration

In [None]:
print("Knowledge Graph Data:")
kg_df.head()

In [None]:
unique_kg_df = kg_df.groupby(['relation', 'display_relation','x_type', 'y_type', 'y_source']).size().reset_index(name='count')
unique_kg_df_proteins = unique_kg_df[(unique_kg_df['x_type'] == 'gene/protein') & (unique_kg_df['relation'].str.contains('protein'))].reset_index(drop=True)
unique_kg_df_proteins

In [None]:
unique_molecular_function_kg = kg_df[(kg_df['relation'] == 'molfunc_protein')&(kg_df['x_type'] == 'gene/protein')][['y_name']].drop_duplicates().reset_index(drop=True)
unique_molecular_function_kg

In [None]:
protein_molecular_func = kg_df[(kg_df['relation'] == 'molfunc_protein')&(kg_df['x_type'] == 'gene/protein')].groupby('x_name').size().reset_index(name='number of molfunc relations')
protein_pathway = kg_df[(kg_df['relation'] == 'pathway_protein')&(kg_df['x_type'] == 'gene/protein')].groupby('x_name').size().reset_index(name='number of pathway relations')
protein_disease = kg_df[(kg_df['relation'] == 'disease_protein')&(kg_df['x_type'] == 'gene/protein')].groupby('x_name').size().reset_index(name='number of disease relations')


In [None]:
merged_df = pd.merge(protein_molecular_func, protein_pathway, on='x_name', how='outer')
merged_df = pd.merge(merged_df, protein_disease, on='x_name', how='outer')
merged_df.fillna(0, inplace=True)
merged_df


In [None]:
grouped_molfunc_df = merged_df.groupby('number of molfunc relations').size().reset_index(name='molfunc')
grouped_molfunc_df = grouped_molfunc_df.rename(columns={'number of molfunc relations': 'number of relations'})

grouped_pathway_df = merged_df.groupby('number of pathway relations').size().reset_index(name='pathway')
grouped_pathway_df = grouped_pathway_df.rename(columns={'number of pathway relations': 'number of relations'})

grouped_disease_df = merged_df.groupby('number of disease relations').size().reset_index(name='disease')
grouped_disease_df = grouped_disease_df.rename(columns={'number of disease relations': 'number of relations'})

In [None]:
merged_grouped_df = pd.merge(grouped_molfunc_df, grouped_pathway_df, on='number of relations', how='outer')
merged_grouped_df = pd.merge(merged_grouped_df, grouped_disease_df, on='number of relations', how='outer')
merged_grouped_df.fillna(0, inplace=True)
merged_grouped_df

In [None]:
# Group the data by the number of molfunc relations and count the number of proteins
grouped_df = protein_molecular_func.groupby('number of molfunc relations').size().reset_index(name='number of proteins')

# Create the bar chart
plt.bar(grouped_df['number of molfunc relations'], grouped_df['number of proteins'])

# Customize the chart labels and title
plt.xlabel('Number of Molfunc Relations')
plt.ylabel('Number of Proteins')
plt.title('Protein Molecular Function')

# Display the chart
plt.show()


In [None]:
# Group the data by the number of molfunc relations and count the number of proteins
grouped_df = protein_pathway.groupby('number of pathway relations').size().reset_index(name='number of proteins')
grouped_df = grouped_df[grouped_df['number of pathway relations']<=60]

# Create the bar chart
plt.bar(grouped_df['number of pathway relations'], grouped_df['number of proteins'])
#plt.yscale('log')

# Customize the chart labels and title
plt.xlabel('Number of Pathway Relations')
plt.ylabel('Number of Proteins')
plt.title('Protein Pathway')

# Display the chart
plt.show()


In [None]:
# Group the data by the number of molfunc relations and count the number of proteins
grouped_df = protein_disease.groupby('number of disease relations').size().reset_index(name='number of proteins')
grouped_df = grouped_df[grouped_df['number of disease relations']<=60]

# Create the bar chart
plt.bar(grouped_df['number of disease relations'], grouped_df['number of proteins'])
#plt.yscale('log')

# Customize the chart labels and title
plt.xlabel('Number of Disease Relations')
plt.ylabel('Number of Proteins')
plt.title('Protein Disease')

# Display the chart
plt.show()

## Drug Disease Exploration

In [None]:
unique_kg_df = kg_df.groupby(['relation', 'display_relation','x_type', 'y_type']).size().reset_index(name='count')
unique_kg_df_drug_disease = unique_kg_df[((unique_kg_df['relation']=='contraindication')|(unique_kg_df['relation']=='indication'))&(unique_kg_df['x_type']=='drug')].reset_index(drop=True)
unique_kg_df_drug_disease

In [None]:
drug_disease_count = kg_df[((kg_df['relation']=='contraindication')|(kg_df['relation']=='indication'))&(kg_df['x_type'] == 'disease')].groupby('x_name').size().reset_index(name='number of relations')
grouped_drug_disease_count = drug_disease_count.groupby('number of relations').size().reset_index(name='number of diseases')

In [None]:
diseases_nonzero_drugdisease = grouped_drug_disease_count['number of diseases'].sum()
total_diseases = nodes_df[nodes_df['node_type'] == 'disease'].groupby('node_name').size().reset_index(name='number of diseases').sum()
diseases_nonzero_drugdisease

In [None]:
total_diseases

In [None]:
# Create the bar chart
grouped_drug_disease_count = grouped_drug_disease_count[grouped_drug_disease_count['number of relations']<=100]
plt.bar(grouped_drug_disease_count['number of relations'], grouped_drug_disease_count['number of diseases'])
#plt.yscale('log')

# Customize the chart labels and title
plt.xlabel('Number of Drug-Disease Relations')
plt.ylabel('Number of Diseases')
plt.title('Drug Disease')

# Display the chart
plt.show()

In [None]:
grouped_drug_disease_count

In [None]:
unique_kg_df

# Loading embeddings for proteins

In [1]:
from txgnn.model import HeteroRGCN
from IPython import get_ipython
from importlib import reload

# Import the desired libraries
import txgnn

# Reload the libraries
reload(txgnn)

from txgnn import TxData, TxGNN, TxEval
import torch.nn as nn
import torch
import os
import pandas as pd
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
TxData_inst = TxData(data_folder_path = '/om/user/tysinger/kg/')
TxData_inst.prepare_split(split = 'random', seed = 42, no_kg = False)

Found local copy...
Found local copy...
Found local copy...
Found saved processed KG... Loading...
Splits detected... Loading splits....
Creating DGL graph....
{'anatomy': 14032.0, 'biological_process': 28641.0, 'cellular_component': 4175.0, 'disease': 17079.0, 'drug': 7956.0, 'effect/phenotype': 15310.0, 'exposure': 817.0, 'gene/protein': 27609.0, 'molecular_function': 11168.0, 'pathway': 2515.0}
Done!


In [5]:
TxGNN_model = txgnn.TxGNN(data = TxData_inst, 
              weight_bias_track = False,
              proj_name = 'TxGNN',
              exp_name = 'TxGNN',
              device = 'cpu'
              )



TxGNN_model.model_initialize(n_hid = 1280, 
                      n_inp = 1280, 
                      n_out = 1280, 
                      proto = False, #made this False
                      proto_num = 3,
                      attention = False,
                      sim_measure = 'all_nodes_profile',
                      bert_measure = 'disease_name',
                      agg_measure = 'rarity',
                      num_walks = 200,
                      walk_mode = 'bit',
                      path_length = 2,
                      esm = False)

In [4]:
def initialize_node_embedding(g, n_inp, df, df_nodes, esm=False):
    # initialize embedding xavier uniform
    for ntype in g.ntypes:
        if not esm or ntype != 'gene/protein':
            emb = nn.Parameter(torch.Tensor(g.number_of_nodes(ntype), n_inp), requires_grad = False)
            nn.init.xavier_uniform_(emb)
            g.nodes[ntype].data['inp'] = emb
        elif ntype == 'gene/protein':
            idx2id = dict(zip(df[df.x_type == 'gene/protein']['x_idx'],df[df.x_type == 'gene/protein']['x_id']))
            idx2id.update(zip(df[df.y_type == 'gene/protein']['y_idx'],df[df.y_type == 'gene/protein']['y_id']))
            id2name = dict(zip(df_nodes[df_nodes.node_type == 'gene/protein'].node_id,df_nodes[df_nodes.node_type == 'gene/protein'].node_name))
            prot_embs = []
            #get mapping from idx to node name
            for i in tqdm(range(g.number_of_nodes(ntype))):
                try:
                    id = idx2id[i]
                    name = id2name[str(int(float(id)))]
                    emb_path = os.path.join('/om/user/tysinger/TxGNN/embeddings/esm_embeddings/', name+'.pt')
                except Exception as e: 
                    try:
                        id = idx2id[i]
                        name = id2name[id]
                        emb_path = os.path.join('/om/user/tysinger/TxGNN/embeddings/esm_embeddings/', name+'.pt')
                    except:
                        emb_path = 'None'


                if os.path.exists(emb_path):
                    esm_emb = torch.load(emb_path)
                    prot_embs.append(list(esm_emb['mean_representations'].values())[0])
                else:
                    # Create xavier embedding for those without esm embedding
                    xavier_emb = nn.init.xavier_uniform_(torch.Tensor(1, n_inp)).squeeze() #.tolist()[0]
                    prot_embs.append(xavier_emb)
            
            # convert list to tensor
            emb = nn.Parameter(torch.stack(prot_embs), requires_grad=False)
            g.nodes[ntype].data['inp'] = emb
    return g

In [None]:
len(nodes[nodes.node_type == 'gene/protein'])

In [5]:
initialize_node_embedding(TxData_inst.G, 1280, TxData_inst.df, nodes, esm=True)

100%|██████████| 27610/27610 [00:38<00:00, 721.48it/s]


ESM embeddings found:  20873
ESM embeddings not found:  6737
Name not found:  0


Graph(num_nodes={'anatomy': 14033, 'biological_process': 28642, 'cellular_component': 4176, 'disease': 17080, 'drug': 7957, 'effect/phenotype': 15311, 'exposure': 818, 'gene/protein': 27610, 'molecular_function': 11169, 'pathway': 2516},
      num_edges={('anatomy', 'anatomy_anatomy', 'anatomy'): 23328, ('anatomy', 'rev_anatomy_protein_absent', 'gene/protein'): 16531, ('anatomy', 'rev_anatomy_protein_present', 'gene/protein'): 1262006, ('biological_process', 'bioprocess_bioprocess', 'biological_process'): 87924, ('biological_process', 'rev_bioprocess_protein', 'gene/protein'): 120369, ('biological_process', 'rev_exposure_bioprocess', 'exposure'): 1351, ('cellular_component', 'cellcomp_cellcomp', 'cellular_component'): 8056, ('cellular_component', 'rev_cellcomp_protein', 'gene/protein'): 69328, ('cellular_component', 'rev_exposure_cellcomp', 'exposure'): 9, ('disease', 'disease_disease', 'disease'): 53522, ('disease', 'disease_phenotype_negative', 'effect/phenotype'): 991, ('disease', '

In [3]:
kg_path = os.path.join(TxData_inst.data_folder, TxData_inst.split + '_kg', 'kg.csv')
kg_path = os.path.join(TxData_inst.data_folder, 'kg.csv')
nodes = pd.read_csv(os.path.join(TxData_inst.data_folder, 'nodes.csv'))
#df = pd.read_csv(kg_path)
#split_data_path = os.path.join(TxData_inst.data_folder, TxData_inst.split + '_' + str(42))
#df_train = pd.read_csv(os.path.join(split_data_path, 'train.csv'))

In [6]:
#id2name = dict(zip(nodes[nodes.node_type == 'gene/protein'].node_id,nodes[nodes.node_type == 'gene/protein'].node_name))
idx2id = dict(zip(TxData_inst.df_train[TxData_inst.df_train.x_type == 'gene/protein']['x_idx'], TxData_inst.df_train[TxData_inst.df_train.x_type == 'gene/protein']['x_id']))
idx2id.update(dict(zip(TxData_inst.df_train[TxData_inst.df_train.x_type == 'gene/protein']['y_idx'], TxData_inst.df_train[TxData_inst.df_train.x_type == 'gene/protein']['y_id'])))

In [None]:
id = idx2id[8]
id

In [None]:
id_keys = list(int(k) for k in id2name.keys())
id_keys.sort()
print(id in id_keys)

In [None]:
for i in range(100):
    id = idx2id[i]
    try:
        name = id2name[str(int(id))]
        print(name)
        emb_path = os.path.join('/Users/emmatysinger/Develop/meng/TxGNN/embeddings/embeddings1/', name+'.pt')
        print(emb_path)
    except:
        print(id)

In [None]:
nodes[nodes.node_type == 'gene/protein']

In [4]:
def create_dgl_graph(df_train, df):
    unique_graph = df_train[['x_type', 'relation', 'y_type']].drop_duplicates()
    DGL_input = {}
    for i in unique_graph.values:
        o = df_train[df_train.relation == i[1]][['x_idx', 'y_idx']].values.T
        DGL_input[tuple(i)] = (o[0].astype(int), o[1].astype(int))

    temp = dict(df.groupby('x_type')['x_idx'].max())
    temp2 = dict(df.groupby('y_type')['y_idx'].max())
    temp['effect/phenotype'] = 0.0

    
    output = {}

    for d in (temp, temp2):
        for k, v in d.items():
            output.setdefault(k, float('-inf'))
            output[k] = max(output[k], v)
    
    print(output)

    g = dgl.heterograph(DGL_input, num_nodes_dict={i: int(output[i])+1 for i in output.keys()})
    
    # get node, edge dictionary mapping relation sent to index
    node_dict = {}
    edge_dict = {}
    for ntype in g.ntypes:
        node_dict[ntype] = len(node_dict)
    for etype in g.etypes:
        edge_dict[etype] = len(edge_dict)
        g.edges[etype].data['id'] = torch.ones(g.number_of_edges(etype), dtype=torch.long) * edge_dict[etype] 

    return g

In [5]:
create_dgl_graph(df_train, pd.read_csv(kg_path))