<a href="https://colab.research.google.com/github/kangjunseo/DeepLearning/blob/main/S_Pred/S_pred_ASA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# S_pred_ASA
공부용 코드라 주석이 필요 이상으로 많을 수 있습니다.
난잡해 보일 수는 있지만, 주석만으로도 논문의 주요 흐름을 파악할 수 있을 정도로 자세히 적어놓았습니다.

## Installing Requirements

In [None]:
!pip install einops fair-esm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


## Imports

In [None]:
import os
import torch
from torch import nn
import torch.nn.functional as F
import esm
import json
import numpy as np
import argparse
from einops import rearrange
import string

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Defining models and functions

In [None]:
PROTEIN_PROPERTY = "accessible surface area"
ASA_NORM_SCALE = 200  # at ASA dataset, divide 200 to make values smaller

MAX_MSA_ROW_NUM = 256  # if larger than 256 sequences, downscale to 256
MAX_MSA_COL_NUM = 1023  # skip proteins bigger than 1023 residues

torch.set_grad_enabled(False)  # autograd off -> memory usage down, calculate speed up

<torch.autograd.grad_mode.set_grad_enabled at 0x7f55d71947c0>

In [None]:
class lstm_net(nn.Module):
    def __init__(self, input_feature_size=768, hidden_node=256, dropout=0.25, need_row_attention=False, class_num=8):
        super().__init__()
        self.need_row_attention = need_row_attention
        self.linear_proj = nn.Sequential(  # MSA features MLP
            nn.Linear(input_feature_size, input_feature_size // 2),  # 768 -> 384
            nn.InstanceNorm1d(input_feature_size // 2), 
            nn.ReLU(),
            nn.Linear(input_feature_size // 2, input_feature_size // 4),  # 384 -> 192
            nn.InstanceNorm1d(input_feature_size // 4), 
            nn.ReLU(),
            nn.Linear(input_feature_size // 4, input_feature_size // 4),  # 192 -> 192
        )

        if self.need_row_attention:
            lstm_input_feature_size = input_feature_size // 4 + 144*2
        else:
            lstm_input_feature_size = input_feature_size // 4

        self.lstm = nn.LSTM(  # LSTM layer
            input_size=lstm_input_feature_size,
            hidden_size=hidden_node,  # 256 hidden unit
            num_layers=2,  # 2 LSTM layers
            bidirectional=True,
            dropout=dropout,
        )

        self.to_property = nn.Sequential(  # final classification layer
            nn.Linear(hidden_node * 2, hidden_node * 2),
            nn.InstanceNorm1d(hidden_node * 2),
            nn.ReLU(),
            nn.Linear(hidden_node * 2, class_num),
        )

        def forward(self, msa_query_embeddings, msa_row_attentions):
            msa_query_embeddings = self.linear_proj(msa_query_embeddings)  # input : last attention block(r x c x 768) output tensor

            if self.need_row_attention:  # input : row attention maps
                msa_row_attentions = rearrange(msa_row_attentions, 'b l h i j -> b (l h) i j')  # rearrange dimension
                msa_attention_features = torch.cat((torch.mean(msa_row_attentions, dim=2), torch.mean(msa_row_attentions, dim=3)), dim=1)  # mean pooling
                msa_attention_features = msa_attention_features.permute((0, 2, 1))  # transpose

                lstm_input = torch.cat([msa_query_embeddings, msa_attention_features], dim=2)  # concat first and second feature

            else:
                lstm_input = msa_query_embeddings

            lstm_input = lstm_input.permute((1,0,2))  # transpose
            lstm_output, lstm_hidden = self.lstm(lstm_input)
            lstm_output = lstm_output.permute((1,0,2))  # transpose
            label_output = self.to_property(lstm_output)

            return label_output

In [None]:
def read_msa_file(filepath, msa_row_num):

    seqs = []
    table = str.maketrans(dict.fromkeys(string.ascii_lowercase))
    with open(filepath,"r") as f:
        lines = f.readlines()
    # read file line by line
    for i in range(0,len(lines),2):

        seq = []
        seq.append(lines[i])
        seq.append(lines[i+1].rstrip().translate(table))
        seqs.append(seq)

    if msa_row_num > MAX_MSA_ROW_NUM:
        msa_row_num = MAX_MSA_ROW_NUM
        print(f"The MSA row num is larger than {MAX_MSA_ROW_NUM}. This program force the msa row to under {MAX_MSA_ROW_NUM}")

    seqs = seqs[: msa_row_num]
    return seqs, seqs[0]

In [None]:
def extract_msa_transformer_features(msa_seq, msa_transformer, msa_batch_converter, device=torch.device("cpu")):
    msa_seq_label, msa_seq_str, msa_seq_token = msa_batch_converter([msa_seq])
    msa_seq_token = msa_seq_token.to(device)
    msa_row, msa_col = msa_seq_token.shape[1], msa_seq_token.shape[2]
    print(f"{msa_seq_label[0][0]}, msa_row: {msa_row}, msa_col: {msa_col}")

    if msa_col > MAX_MSA_COL_NUM:
        print(f"msa col num should less than {MAX_MSA_COL_NUM}. This program force the msa col to under {MAX_MSA_COL_NUM}")
    msa_seq_token = msa_seq_token[:, :, :MAX_MSA_COL_NUM]

    ### keys: ['logits', 'representations', 'col_attentions', 'row_attentions', 'contacts']
    msa_transformer_outputs = msa_transformer(
        msa_seq_token, repr_layers=[12],
        need_head_weights=True, return_contacts=True)
    msa_row_attentions = msa_transformer_outputs['row_attentions']
    msa_representations = msa_transformer_outputs['representations'][12]
    msa_query_representation = msa_representations[:, 0, 1:, :]  # remove start token
    msa_row_attentions = msa_row_attentions[..., 1:, 1:]  # remove start token

    return msa_query_representation, msa_row_attentions


In [None]:
def save_property_to_json(out_property_json, output_property, query_seq):
    output_property_list = output_property.tolist()
    output_property_list = [round(x) for x in output_property_list]

    json_dict = {
        "asa_data": output_property_list,
        "query_seq": query_seq,
        "metadata": {
            "precision": 4,
            "title": "accessible surface area",
            "data-min": 0,
            "data-max": 200,
        }
    }

    with open(out_property_json, "w") as f:
        json.dump(json_dict, f, indent=4)

## Main code

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    print("gpu is not available, run on cpu")
    device = torch.device("cpu")


In [None]:
input_path = '/content/drive/MyDrive/S_pred/examples/s_pred_asa.a3m'
output_path = '/content/drive/MyDrive/S_pred/s_pred_asa.out'
conv_model_path = '/content/drive/MyDrive/S_pred/s_pred_asa_weights.pth'

In [None]:
msa_transformer, msa_alphabet = esm.pretrained.esm_msa1_t12_100M_UR50S()
msa_batch_converter = msa_alphabet.get_batch_converter()

msa_transformer.to(device)
msa_transformer.eval()

In [None]:
conv_model = lstm_net(input_feature_size=768, hidden_node=256, dropout=0.25, need_row_attention=True, class_num=1)  # class num=1 because ASA
conv_model = conv_model.to(device)

In [None]:
if device.type == 'cpu':
    ch = torch.load(conv_model_path, map_location=torch.device('cpu'))
else:
    ch = torch.load(conv_model_path)

In [None]:
conv_model.load_state_dict(ch['conv_model'])
conv_model.to(device)
conv_model.eval()

for param in msa_transformer.parameters():
    param.requires_grad = False
for param in conv_model.parameters():
    param.requires_grad = False

In [None]:
msa_seq, query_seq = read_msa_file(input_path, 256)
msa_row_num = len(msa_seq)
msa_col_num = len(query_seq)

print(f"msa row number: {msa_row_num}")
print(f"msa column number: {msa_col_num}")

msa row number: 256
msa column number: 2


In [None]:
msa_query_representation, msa_row_attentions = extract_msa_transformer_features(msa_seq,
                                                                                msa_transformer,
                                                                                msa_batch_converter,
                                                                                device=device)
msa_query_representation.to(device)
msa_row_attentions.to(device)

>5G6UA
, msa_row: 256, msa_col: 262


OutOfMemoryError: ignored

In [None]:
msa_query_representation.to(device)
msa_row_attentions.to(device)

output_property = conv_model(msa_query_representation, msa_row_attentions)
output_property = output_property.squeeze(dim=2)


output_property_np = output_property.data.cpu().numpy().squeeze()
output_property_np[output_property_np < 0] = 0
output_property_np = output_property_np*ASA_NORM_SCALE

output_property_json_path = args.output_path + '.asa.json'
save_property_to_json(output_property_json_path, output_property_np, query_seq)

NameError: ignored