In [1]:
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
from rdkit import Chem
import requests
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
import umap
from sklearn.cluster import DBSCAN

from tkgdti.data.GraphBuilder import GraphBuilder
import os 

from tkgdti.data.utils import get_protein_sequence_uniprot
from tkgdti.embed.AA2EMB import AA2EMB


import re
import pandas as pd

# set seed 
torch.manual_seed(0)
np.random.seed(0)



  from .autonotebook import tqdm as notebook_tqdm


In [None]:
MODEL_NAME = "Rostlab/prot_bert"
ROOT = '../../extdata/relations/'

In [3]:

relnames = os.listdir(ROOT)
GB = GraphBuilder(root=ROOT, relnames=relnames, val_idxs=None, test_idxs=None)
print('building...')
GB.build() 

Node types: ['dbgap_subject' 'disease' 'drug' 'gene' 'pathway']
building...


In [4]:
genespace = np.unique(GB.relations[lambda x: x.src_type == 'gene'].src.values.tolist() + GB.relations[lambda x: x.dst_type == 'gene'].dst.values.tolist())


In [5]:

def parse_fasta_to_dataframe(path):
    """
    Reads a FASTA file from the given path, parses out the organism, gene name,
    and sequence for each entry, and returns a pandas DataFrame.
    """

    # List to hold all records as dictionaries
    records = []
    current_record = None

    # Regex patterns to extract organism and gene name
    organism_pattern = re.compile(r'OS=(.+?)\s+OX=')
    gene_pattern = re.compile(r'GN=(\S+)')

    with open(path, 'r') as fasta_file:
        for line in fasta_file:
            line = line.strip()
            if not line:
                continue  # skip empty lines

            # Check if this line is a new FASTA header
            if line.startswith('>'):
                # If we already have a record built, store it
                if current_record:
                    records.append(current_record)

                # Initialize a new record
                header_line = line[1:].strip()  # remove '>'
                current_record = {
                    'id': header_line.split()[0],  # first token as ID
                    'organism': None,
                    'gene_name': None,
                    'sequence': []
                }

                # Extract organism
                organism_match = organism_pattern.search(header_line)
                if organism_match:
                    current_record['organism'] = organism_match.group(1)

                # Extract gene name
                gene_match = gene_pattern.search(header_line)
                if gene_match:
                    current_record['gene_name'] = gene_match.group(1)

            else:
                # Lines not starting with '>' are part of the sequence
                if current_record is not None:
                    current_record['sequence'].append(line)

        # Append the last record if it exists
        if current_record:
            records.append(current_record)

    # Convert the list of sequence lines into a single string for each record
    for record in records:
        record['sequence'] = ''.join(record['sequence'])

    # Create a DataFrame from the list of dictionaries
    df = pd.DataFrame(records, columns=['id', 'organism', 'gene_name', 'sequence'])
    return df

In [6]:
gene2aa = parse_fasta_to_dataframe('../../data/tkg_raw/UP000005640_9606.fasta')


In [7]:
gene2aa = gene2aa[lambda x: x.gene_name.isin(genespace)]
gene2aa = gene2aa.groupby('gene_name').first().reset_index()

In [8]:
gene2aa.to_csv('../../extdata/meta/gene2aa.csv', index=False)

In [9]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
model = AutoModel.from_pretrained(MODEL_NAME)

In [10]:
aas = gene2aa.sequence.values

In [11]:
AA2E = AA2EMB()
outputs = AA2E.embed(aas).cpu().numpy()

Progress: 1116/13053

You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset


Progress: 13020/13053

In [12]:
outputs.shape

(13053, 1024)

In [13]:
gene2aa.head()

Unnamed: 0,gene_name,id,organism,sequence
0,A1BG,sp|P04217|A1BG_HUMAN,Homo sapiens,MSMLVVFLLLWGVTWGPVTEAAIFYETQPSLWAESESLLKPLANVT...
1,A1CF,sp|Q9NQ94|A1CF_HUMAN,Homo sapiens,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...
2,A2M,sp|P01023|A2MG_HUMAN,Homo sapiens,MGKNKLLHPSLVLLLLVLLPTDASVSGKPQYMVLVPSLLHTETTEK...
3,A3GALT2,sp|U3KPV4|A3LT2_HUMAN,Homo sapiens,MALKEGLRAWKRIFWRQILLTLGLLGLFLYGLPKFRHLEALIPMGV...
4,A4GALT,sp|Q9NPC4|A4GAT_HUMAN,Homo sapiens,MSKPPDLLLRLLRGAPRQRVCTLFIIGFKFTFFVSIMIYWHVVGEP...


In [14]:
aas_dict = {'amino_acids':aas, 'embeddings':outputs, 'meta_df':gene2aa}
torch.save(aas_dict, '../../extdata/meta/aas_dict.pt')

In [17]:
gene2aa[lambda x: x.gene_name.str.startswith('MIR')]

Unnamed: 0,gene_name,id,organism,sequence
6582,MIR1-1HG,sp|Q9H1L0|MI1HG_HUMAN,Homo sapiens,MPSCSCALMAPCGPAAGPAAVERTQQVARGEPGSARGQLQVSPEMS...
6583,MIR17HG,sp|Q75NE6|MIRH1_HUMAN,Homo sapiens,MFCHVDVKISSKRYTWTKLPLNVPKLVLIYLQSHFVLFFFSMCQSI...
6584,MIR22HG,sp|Q0VDD5|CQ091_HUMAN,Homo sapiens,MGWEGPNSRVDDTFWASWRAFAQIGPARSGFRLETLAGLRSRRLKQ...
