In [None]:
import pandas as pd
import numpy as np
import torch
import os 
import sys
from Bio import SeqIO
from esm.models.esm3 import ESM3
from esm.sdk.api import ESMProtein, SamplingConfig
from esm.utils.constants.models import ESM3_OPEN_SMALL
from torch_geometric.data import Dataset
from tqdm import tqdm
from huggingface_hub import login
import torch, gc

login()

class StringDB_Dataset(Dataset):
    # esm_model must end in .pt for local model.
    def __init__(self,data_path,fasta_path,esm_model='ESM3_OPEN_SMALL'):
        self.data_path = data_path
        self.fasta_path = fasta_path
        self.esm_model = esm_model
        self.data_cols = self.get_column_names()
        '''
        self.data = self.data.dropna()
        self.data = self.data.reset_index(drop=True)
        self.data = self.data.drop_duplicates()
        self.data = self.data.reset_index(drop=True)
        self.data = self.data.drop(columns=['Unnamed: 0'])
        '''
    
    def get_column_names(self):
        if os.path.exists(self.data_path):
            with open(self.data_path) as f:
                return f.readline().strip().split('\t')
        else:
            Exception('File not found, check filepath')

    
    def get_esm_embeddings(self,out_dir):
        client = ESM3.from_pretrained(self.esm_model)
        with open(self.fasta_path) as handle:
            iter = 0
            for record in tqdm(SeqIO.parse(handle, "fasta")):
                label = record.id
                if not os.path.exists(out_dir):
                    os.makedirs(out_dir)
                output_file = os.path.join(out_dir,f"{label}.pt")
                protein = ESMProtein(sequence=(str(record.seq)))
                protein_tensor = client.encode(protein)
                output = client.forward_and_sample(protein_tensor, SamplingConfig(return_mean_embedding=True))
                result = {"label": record.id, "embeddings": output.mean_embedding}
                if iter % 50 ==0:
                    gc.collect()
                    torch.cuda.empty_cache()
                iter+=1

                torch.save(result,output_file)


    def load_data(self,data_path):
        if os.path.exists(data_path):
            data = pd.read_csv(data_path)
            return data

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
path = '/mnt/d/data/stitchdb.txt'
fasta_path = '/mnt/d/data/stringdb_seq.fa'
local_model = '/mnt/d/data/esm3_sm_open_v1.pt'
out_dir = '/mnt/d/data/stringdb_embeddings'
data = StringDB_Dataset(path,fasta_path,esm_model='esm3_sm_open_v1')
data.get_esm_embeddings(out_dir)

Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]

  state_dict = torch.load(
2495it [04:02, 12.42it/s]

In [None]:
with open(path, 'r') as file:
    for _ in range(10):
        print(file.readline().strip())

chemical	protein	experimental_direct	experimental_transferred	prediction_direct	prediction_transferred	database_direct	database_transferred	textmining_direct	textmining_transferred	combined_score
CIDm91758680	190486.XAC0787	0	0	0	0	0	0	0	161	161
CIDm91758680	190486.XAC0788	0	0	0	0	0	0	0	187	187
CIDm91758680	190486.XAC1728	0	0	0	0	0	0	0	161	161
CIDm91758680	190486.XAC1855	0	0	0	0	0	0	0	210	210
CIDm91758680	190486.XAC2361	0	0	0	0	0	0	0	161	161
CIDm91758680	190486.XAC2462	0	0	0	0	0	0	0	173	173
CIDm91758680	190486.XAC2928	0	0	0	0	0	0	0	161	161
CIDm91758680	190486.XAC3041	0	0	0	0	0	0	0	161	161
CIDm91758680	190486.XAC3368	0	0	0	0	0	0	0	161	161


In [None]:
path = '/mnt/d/data/stringdb.txt'
with open(path, 'r') as file:
    for _ in range(10):
        print(file.readline().strip())

protein1 protein2 homology experiments experiments_transferred database database_transferred textmining textmining_transferred combined_score
23.BEL05_00025 23.BEL05_06890 0 0 738 0 194 0 0 779
23.BEL05_00025 23.BEL05_19855 0 0 264 0 0 0 0 264
23.BEL05_00025 23.BEL05_17340 0 0 0 0 134 0 64 154
23.BEL05_00025 23.BEL05_06420 0 0 597 0 194 0 66 670
23.BEL05_00030 23.BEL05_09555 0 0 208 0 0 0 0 208
23.BEL05_00030 23.BEL05_04075 0 0 317 0 0 0 0 317
23.BEL05_00030 23.BEL05_05440 0 0 270 0 0 0 0 270
23.BEL05_00035 23.BEL05_00165 0 0 49 0 229 0 0 235
23.BEL05_00035 23.BEL05_09055 0 0 47 0 339 0 0 343


In [None]:
with open(fasta_path) as handle:
    fa = next(SeqIO.parse(handle, "fasta"))
    print(fa.seq)
    print(type(str(fa.seq)))
    print(next(SeqIO.parse(handle, "fasta")))

MSLPRCNSYYNATINQETDFDQLQGEVDVDVVIIGGGFTGVATAVELSEQGYRVAIVEANKIGWGATGRNGGQVTGSLSGDGAMTKQLRNQIGSEAEAFVWNLRWRGHDIIKNRVAKYGIDCDLKFGHLHTAYKFAHMGEMQKTFDEGVNRGMGDELILLSKADIPQYLDTPLYHGGLLNKRNMHLHSVNLCIGEARAAVGNGAQIFEHSSVLDIIEGDRPVVKTAKGQITANSVVLAGNAYHKLARKKLSGLLFPASLGNCATVKLDSALAKQLNPHDVAVYDSRFVLDYYRMTADHRLMFGGGTNYSGRDSKDVAAELRPALERTFPQLKGVEIEFDWTGMAGIVVNRIPQLGKVSPNVFYCQGYSGHGVATSHIMGEIMAAAVVGQHKEFDLFANMKQIRLPVGEWLGNQGMAIGMLYYRMMENFR
<class 'str'>
ID: 23.BEL05_00035
Name: 23.BEL05_00035
Description: 23.BEL05_00035
Number of features: 0
Seq('MNLTIIVGVIAVLYVSLLFLLAWGAERWFGGITKKIQTWIYGLSLAVYCSSWSF...LVK')
