In [None]:
import pandas as pd
import numpy as np
import torch
import os 
import sys
from Bio import SeqIO
from esm import pretrained
from esm import data
from torch_geometric.data import Dataset

class StringDB_Dataset(Dataset):
    # esm_model must end in .pt for local model.
    def __init__(self,data_path,fasta_path,esm_model='ESM3_OPEN_SMALL'):
        self.data_path = data_path
        self.fasta_path = fasta_path
        self.esm_model = esm_model
        self.data_cols = self.get_column_names()
        '''
        self.data = self.data.dropna()
        self.data = self.data.reset_index(drop=True)
        self.data = self.data.drop_duplicates()
        self.data = self.data.reset_index(drop=True)
        self.data = self.data.drop(columns=['Unnamed: 0'])
        '''
    
    def get_column_names(self):
        if os.path.exists(self.data_path):
            with open(self.data_path) as f:
                return f.readline().strip().split('\t')
        else:
            Exception('File not found, check filepath')
    '''
    def get_esm_embeddings(self):
        # Adapted from https://github.com/facebookresearch/esm/blob/main/scripts/extract.py
        model, alphabet = pretrained.load_model_and_alphabet(self.esm_model,device='cuda')
        model.eval()
        dataset = data.FastaBatchedDataset.from_file(self.fasta_path)
        batches = dataset.get_batch_indices(args.toks_per_batch, extra_toks_per_seq=1)
        data_loader = torch.utils.data.DataLoader(
            dataset, collate_fn=alphabet.get_batch_converter(args.truncation_seq_length), batch_sampler=batches)
        print(f"Read {args.fasta_file} with {len(dataset)} sequences")

        args.output_dir.mkdir(parents=True, exist_ok=True)
        return_contacts = "contacts" in args.include

        assert all(-(model.num_layers + 1) <= i <= model.num_layers for i in args.repr_layers)
        repr_layers = [(i + model.num_layers + 1) % (model.num_layers + 1) for i in args.repr_layers]

        with torch.no_grad():
            for batch_idx, (labels, strs, toks) in enumerate(data_loader):
                print(
                    f"Processing {batch_idx + 1} of {len(batches)} batches ({toks.size(0)} sequences)"
                )
                if torch.cuda.is_available() and not args.nogpu:
                    toks = toks.to(device="cuda", non_blocking=True)

                out = model(toks, repr_layers=repr_layers, return_contacts=return_contacts)

                logits = out["logits"].to(device="cpu")
                representations = {
                    layer: t.to(device="cpu") for layer, t in out["representations"].items()
                }
                if return_contacts:
                    contacts = out["contacts"].to(device="cpu")

                for i, label in enumerate(labels):
                    args.output_file = args.output_dir / f"{label}.pt"
                    args.output_file.parent.mkdir(parents=True, exist_ok=True)
                    result = {"label": label}
                    truncate_len = min(args.truncation_seq_length, len(strs[i]))
                    # Call clone on tensors to ensure tensors are not views into a larger representation
                    # See https://github.com/pytorch/pytorch/issues/1995
                    if "per_tok" in args.include:
                        result["representations"] = {
                            layer: t[i, 1 : truncate_len + 1].clone()
                            for layer, t in representations.items()
                        }
                    if "mean" in args.include:
                        result["mean_representations"] = {
                            layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                            for layer, t in representations.items()
                        }
                    if "bos" in args.include:
                        result["bos_representations"] = {
                            layer: t[i, 0].clone() for layer, t in representations.items()
                        }
                    if return_contacts:
                        result["contacts"] = contacts[i, : truncate_len, : truncate_len].clone()

                    torch.save(
                        result,
                        args.output_file,
                    )
    '''
    
    def load_data(self,data_path):
        if os.path.exists(data_path):
            data = pd.read_csv(data_path)
            return data

In [26]:
path = '/mnt/d/data/stitchdb.txt'
local_model = '/mnt/d/data/esm3_sm_open_v1.pt'
data = StringDB_Dataset(path,local_model)
data.data_cols

['chemical',
 'protein',
 'experimental_direct',
 'experimental_transferred',
 'prediction_direct',
 'prediction_transferred',
 'database_direct',
 'database_transferred',
 'textmining_direct',
 'textmining_transferred',
 'combined_score']

In [6]:
with open(path, 'r') as file:
    for _ in range(10):
        print(file.readline().strip())

chemical	protein	experimental_direct	experimental_transferred	prediction_direct	prediction_transferred	database_direct	database_transferred	textmining_direct	textmining_transferred	combined_score
CIDm91758680	190486.XAC0787	0	0	0	0	0	0	0	161	161
CIDm91758680	190486.XAC0788	0	0	0	0	0	0	0	187	187
CIDm91758680	190486.XAC1728	0	0	0	0	0	0	0	161	161
CIDm91758680	190486.XAC1855	0	0	0	0	0	0	0	210	210
CIDm91758680	190486.XAC2361	0	0	0	0	0	0	0	161	161
CIDm91758680	190486.XAC2462	0	0	0	0	0	0	0	173	173
CIDm91758680	190486.XAC2928	0	0	0	0	0	0	0	161	161
CIDm91758680	190486.XAC3041	0	0	0	0	0	0	0	161	161
CIDm91758680	190486.XAC3368	0	0	0	0	0	0	0	161	161


In [7]:
path = '/mnt/d/data/stringdb.txt'
with open(path, 'r') as file:
    for _ in range(10):
        print(file.readline().strip())

protein1 protein2 homology experiments experiments_transferred database database_transferred textmining textmining_transferred combined_score
23.BEL05_00025 23.BEL05_06890 0 0 738 0 194 0 0 779
23.BEL05_00025 23.BEL05_19855 0 0 264 0 0 0 0 264
23.BEL05_00025 23.BEL05_17340 0 0 0 0 134 0 64 154
23.BEL05_00025 23.BEL05_06420 0 0 597 0 194 0 66 670
23.BEL05_00030 23.BEL05_09555 0 0 208 0 0 0 0 208
23.BEL05_00030 23.BEL05_04075 0 0 317 0 0 0 0 317
23.BEL05_00030 23.BEL05_05440 0 0 270 0 0 0 0 270
23.BEL05_00035 23.BEL05_00165 0 0 49 0 229 0 0 235
23.BEL05_00035 23.BEL05_09055 0 0 47 0 339 0 0 343
