In [116]:
import os
import pandas as pd
import numpy as np

import torch
import esm

import re
import gc

from utils.common import load_tab, save_np, load_np, read_json2list
from utils.alignmentParser import read_msa, greedy_select
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x714c417a1490>

In [2]:
import params as p

from Bio import SeqIO
from utils.common import dump_list2json
path_input_dataset_json = p.path_input_dataset_json
path_output_features_msaTrans = p.path_output_features_msaTrans
path_hmm = p.path_hmm

In [3]:
# get FASTA file
fasta_sequences = SeqIO.parse(open(p.path_input),'fasta')
list_entity = []
for entity in fasta_sequences:
    dict_e = {}
    dict_e['id'], dict_e['sequence'] = entity.id, str(entity.seq)
    list_entity.append(dict_e)

# save JSON file
dump_list2json(list_entity, path_input_dataset_json)
print(f'input JSON file is created under: {path_input_dataset_json}')

input JSON file is created under: /home/dimeng/caid3/embedding/data/dataset.json


In [52]:
name = 'DP02585'
print(name)
# This is where the data is actually read in
inputs = read_msa(os.path.join(path_hmm, f'{name}.a3m'))

inputs = greedy_select(inputs, num_seqs=128) # can change this to pass more/fewer sequences
msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = msa_transformer(msa_transformer_batch_tokens, repr_layers=[12], return_contacts=False)
token_representations = results["representations"][12]
seq_representation = token_representations[:, :, 1: ].mean(1)

DP02585


In [53]:
seq_representation.size()

torch.Size([1, 591, 768])

In [60]:
seq_rep = seq_representation.detach().cpu().numpy()
seq_rep = np.concatenate([seq_rep, seq_rep], axis=1)

In [65]:
n_empty = np.empty([1, length, 768])

In [119]:
(seq_rep[:, -2:, :5] + seq_rep[:, -2:, :5])/2

array([[[-0.20227243, -0.11333056,  0.2251099 ,  0.24856019,
         -1.0699196 ],
        [-0.5465641 , -0.15934734,  0.3027207 ,  0.21070966,
         -1.1074893 ]]], dtype=float32)

In [125]:
seq_rep[:, 3*2-2:3*5-2, 0]

array([[ 0.18389681,  0.13333188, -0.43770298, -0.78563946,  0.17937823,
        -0.6691764 ,  0.24237913,  0.26934707, -0.84384835]],
      dtype=float32)

In [126]:
seq_rep[:, (3*2-2):(3*5-2), 0]

array([[ 0.18389681,  0.13333188, -0.43770298, -0.78563946,  0.17937823,
        -0.6691764 ,  0.24237913,  0.26934707, -0.84384835]],
      dtype=float32)

In [112]:
def _msaTrans_long(length: int, name: str, path_hmm: str):
    '''
    params:
        length - int, sequence length
        name - protein ID, for reading NAME.a3m file
        path_hmm - folder dir to .a3m files

    return:
        embed_seq - numpy array, shape: (1,length ,768), where 768 is number of features
    '''
    embed_seq = np.empty([1, length, 768])
    
    # load the long sequence
    inputs = read_msa(os.path.join(path_hmm, f'{name}.a3m'))
    inputs = greedy_select(inputs, num_seqs=128)
    _, _, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
    msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
    # separate the long sequence
    for i in range(length//500):
        # tail
        if (i+2)*500>=length:
            batch_token = msa_transformer_batch_tokens[:, :, [0]+list(range(i*500+1, length+1))]
        else:
            batch_token = msa_transformer_batch_tokens[:, :, [0]+list(range(i*500+1, i*500+1000+1))]
        with torch.no_grad():
            results = msa_transformer(batch_token, repr_layers=[12], return_contacts=False)
        token_representations = results["representations"][12]
        seq_representation = token_representations[:, :, 1: ].mean(1)
        seq_rep = seq_representation.detach().cpu().numpy()
        # head
        if i==0:
            embed_seq = seq_rep
        else:
            seq_rep[:, 50:450, :] = (embed_seq[:, -450:-50, :] + seq_rep[:, 50:450, :])/2
            # Here!!!!!!
            embed_seq = np.concatenate((embed_seq[:, :-450, :], seq_rep[:, 50:, :]), axis=1) 
    return embed_seq

In [113]:
df_dataset = pd.DataFrame(read_json2list(path_input_dataset_json))
df_dataset['length'] = [len(seq) for seq in df_dataset['sequence']]
df_dataset = df_dataset[df_dataset['length']>20000]

In [114]:
# for sequence length greater than 1022(1022 + start/end tokens)
for idx, row in df_dataset.iterrows():
    length = row['length']
    name = row['id']
    embed_seq = _msaTrans_long(length, name, path_hmm)
    # save embedd sequences
    save_np(embed_seq, os.path.join(path_output_features_msaTrans, f'{name}.npy'))

In [115]:
df_dataset.shape

(0, 3)

In [None]:
def msaTrans(path_input_dataset_json, path_output_features_msaTrans, path_hmm):
    # 1. mdoel & tokenizer
    msa_transformer, msa_transformer_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()
    msa_transformer = msa_transformer.eval()
    msa_transformer_batch_converter = msa_transformer_alphabet.get_batch_converter()
    
    # 2. predict & save embedded sequences
    df_dataset = pd.DataFrame(read_json2list(path_input_dataset_json))
    df_dataset['length'] = [len(seq) for seq in df_dataset['sequence']]
    df_dataset_long = df_dataset[df_dataset['length']>1022]
    df_dataset = df_dataset[df_dataset['length']<=1022]

    # 2.1. short sequences
    # entyID_entityID
    seq_IDS = list(set(df_dataset['id'].tolist()))
    for name in seq_IDS:
        print(name)
        # This is where the data is actually read in
        inputs = read_msa(os.path.join(path_hmm, f'{name}.a3m'))
        
        inputs = greedy_select(inputs, num_seqs=128) # can change this to pass more/fewer sequences
        msa_transformer_batch_labels, msa_transformer_batch_strs, msa_transformer_batch_tokens = msa_transformer_batch_converter([inputs])
        msa_transformer_batch_tokens = msa_transformer_batch_tokens.to(next(msa_transformer.parameters()).device)
        # Extract per-residue representations (on CPU)
        with torch.no_grad():
            results = msa_transformer(msa_transformer_batch_tokens, repr_layers=[12], return_contacts=False)
        token_representations = results["representations"][12]
        seq_representation = token_representations[:, :, 1: ].mean(1)
        # save embedd sequences
        save_np(seq_representation, os.path.join(path_output_features_msaTrans, f'{name}.npy'))

    # 2.2. long sequences
    # for sequence length greater than 1022(1022 + start/end tokens)
    for idx, row in df_dataset_long.iterrows():
        length = row['length']
        name = row['id']
        embed_seq = _msaTrans_long(length, name, path_hmm)
        # save embedd sequences
        save_np(embed_seq, os.path.join(path_output_features_msaTrans, f'{name}.npy'))
    print('Done!!!')