## Predict Tm using embeddings from ProteinMPNN
- Then see if combining esm2 embeddings with proteinMPNN will further improve things

### First download structures for each protein

In [1]:
import json, time, os, sys, glob, re
import pandas as pd
import numpy as np
from tqdm import tqdm
tqdm.pandas()

# git clone https://github.com/dauparas/ProteinMPNN.git"
# add to path
base_dir = "/projects/bpms/jlaw/tools/ProteinMPNN"
sys.path.append(base_dir)

In [2]:
# Setup Model
import matplotlib.pyplot as plt
import shutil
import warnings
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split, Subset
import copy
import torch.nn as nn
import torch.nn.functional as F
import random
import os.path
from protein_mpnn_utils import loss_nll, loss_smoothed, gather_edges, gather_nodes, gather_nodes_t, cat_neighbors_nodes, _scores, _S_to_seq, tied_featurize, parse_PDB
from protein_mpnn_utils import StructureDataset, StructureDatasetPDB, StructureLoader, ProteinMPNN

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
#v_48_010=version with 48 edges 0.10A noise
model_name = "v_48_010" #@param ["v_48_002", "v_48_010", "v_48_020", "v_48_030"]

# Standard deviation of Gaussian noise to add to backbone atoms
backbone_noise=0.0  

path_to_model_weights = f'{base_dir}/vanilla_model_weights'          
hidden_dim = 128
num_layers = 3 
model_folder_path = path_to_model_weights
if model_folder_path[-1] != '/':
    model_folder_path = model_folder_path + '/'
checkpoint_path = model_folder_path + f'{model_name}.pt'

checkpoint = torch.load(checkpoint_path, map_location=device) 
print('Number of edges:', checkpoint['num_edges'])
noise_level_print = checkpoint['noise_level']
print(f'Training noise level: {noise_level_print}A')
model = ProteinMPNN(num_letters=21, node_features=hidden_dim, edge_features=hidden_dim, hidden_dim=hidden_dim, num_encoder_layers=num_layers, num_decoder_layers=num_layers, augment_eps=backbone_noise, k_neighbors=checkpoint['num_edges'])
model.to(device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("Model loaded")

Number of edges: 48
Training noise level: 0.1A
Model loaded


In [3]:
model

ProteinMPNN(
  (features): ProteinFeatures(
    (embeddings): PositionalEncodings(
      (linear): Linear(in_features=66, out_features=16, bias=True)
    )
    (edge_embedding): Linear(in_features=416, out_features=128, bias=False)
    (norm_edges): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (W_e): Linear(in_features=128, out_features=128, bias=True)
  (W_s): Embedding(21, 128)
  (encoder_layers): ModuleList(
    (0): EncLayer(
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
      (dropout3): Dropout(p=0.1, inplace=False)
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (W1): Linear(in_features=384, out_features=128, bias=True)
      (W2): Linear(in_features=128, out_features=128, bias=True)
      (W3): Linear(in_features=128, out_features=128, bias=True)
     

In [3]:
import numpy as np
import pandas as pd
import os
from Bio import SeqIO
import itertools
from typing import List, Tuple
import string
from pathlib import Path
from tqdm.auto import tqdm, trange

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

In [4]:
inputs_dir = Path("/projects/robustmicrob/jlaw/inputs/")

In [5]:
data_file = "/projects/robustmicrob/jlaw/inputs/meltome/flip/github/full_dataset_sequences.csv.gz"
data = pd.read_csv(data_file)
print(len(data))
data.head(2)

201283


Unnamed: 0,uniprot,run_name,Tm,sequence
0,A0A023T4K3,Caenorhabditis_elegans_lysate,37.962947,MSGEEEKAADFYVRYYVGHKGKFGHEFLEFEFRPNGSLRYANNSNY...
1,A0A023T778,Mus_musculus_BMDC_lysate,54.425342,MSMGSDFYLRYYVGHKGKFGHEFLEFEFRPDGKLRYANNSNYKNDV...


In [6]:
# try using the same train/test splits that flip used
df_split = pd.read_csv(Path(inputs_dir, "meltome/flip/github/splits/mixed_split.csv"))
print(len(df_split))
df_split.head(2)

27951


Unnamed: 0,sequence,target,set,validation
0,MSGEEEKAADFYVRYYVGHKGKFGHEFLEFEFRPNGSLRYANNSNY...,37.962947,train,
1,MSMGSDFYLRYYVGHKGKFGHEFLEFEFRPDGKLRYANNSNYKNDV...,54.425342,train,


In [7]:
seq_to_uniprot = dict(zip(data.sequence, data.uniprot))
len(seq_to_uniprot)

34158

In [8]:
df_split['uniprot_id'] = df_split.sequence.apply(lambda seq: seq_to_uniprot[seq])

In [9]:
def read_embeddings(embed_file, sequence_idx_file):
    """ Read embeddings stored in an npz file
    Get the sequences at each index from the *sequence_idx_file
    """
    embeddings = np.load(embed_file, allow_pickle=True)['arr_0']
    sequences = pd.read_csv(sequence_idx_file)
    print(f"{len(embeddings) = } read from {embed_file}")
    print(f"{len(sequences) = } read from {sequence_idx_file}")
    return embeddings, sequences

In [10]:
embeddings, df_seq = read_embeddings(Path(inputs_dir, "meltome/embeddings/20230206_embeddings_esm2_t33_650M_UR50D.npz"),
                                     Path(inputs_dir, "meltome/embeddings/20230125_embeddings_seqs.csv"))
embeddings.shape

len(embeddings) = 32563 read from /projects/robustmicrob/jlaw/inputs/meltome/embeddings/20230206_embeddings_esm2_t33_650M_UR50D.npz
len(sequences) = 32563 read from /projects/robustmicrob/jlaw/inputs/meltome/embeddings/20230125_embeddings_seqs.csv


(32563, 1280)

In [11]:
df_split_w_embed = df_split[df_split.sequence.isin(df_seq.sequence)]
print(len(df_split_w_embed))

26082


In [12]:
prot_ids = df_split_w_embed.uniprot_id.unique()
len(prot_ids)

21721

In [13]:
import gzip
import tarfile

In [41]:
from importlib import reload  # Python 3.4+
import protein_mpnn_utils
protein_mpnn_utils = reload(protein_mpnn_utils)

In [44]:
# I downloaded the species-specific alphafold structures from here: 
# https://alphafold.ebi.ac.uk/download
# load the structures from the tar files

# prot_structure = {}
prot_with_structure = set()
# strc_tar_file = Path(inputs_dir, "structures/UP000000625_83333_ECOLI_v4.tar")
for strc_tar_file in glob.glob(f"{inputs_dir}/structures/*.tar"):
    prot_with_structure = set()
    print(strc_tar_file)
    with tarfile.open(strc_tar_file, 'r') as tar:
        for member in tqdm(tar.getmembers()):
            if 'cif' in member.name:
                continue
            u_id = member.name.split('-')[1]
            if u_id in prot_ids:
                prot_with_structure.add(u_id)
                pdb_file = tar.extractfile(member)
                file_ = gzip.decompress(pdb_file.read()).decode()
                pdb = protein_mpnn_utils.parse_PDB(file_handle=file_.split('\n'), alphafold=True)
                prot_structure[u_id] = pdb
    print(len(prot_with_structure))
    with open(strc_tar_file.replace('.tar','.p'), 'wb') as out:
        pickle.dump(prot_structures, out)

print(len(prot_with_structure))

/projects/robustmicrob/jlaw/inputs/structures/UP000000625_83333_ECOLI_v4.tar


  0%|          | 0/8726 [00:00<?, ?it/s]

/projects/robustmicrob/jlaw/inputs/structures/UP000000803_7227_DROME_v4.tar


  0%|          | 0/26916 [00:00<?, ?it/s]

/projects/robustmicrob/jlaw/inputs/structures/UP000002311_559292_YEAST_v4.tar


  0%|          | 0/12078 [00:00<?, ?it/s]

/projects/robustmicrob/jlaw/inputs/structures/UP000000589_10090_MOUSE_v4.tar


  0%|          | 0/43230 [00:00<?, ?it/s]

/projects/robustmicrob/jlaw/inputs/structures/UP000000437_7955_DANRE_v4.tar


  0%|          | 0/49328 [00:00<?, ?it/s]

/projects/robustmicrob/jlaw/inputs/structures/UP000006548_3702_ARATH_v4.tar


  0%|          | 0/54868 [00:00<?, ?it/s]

/projects/robustmicrob/jlaw/inputs/structures/UP000005640_9606_HUMAN_v4.tar


  0%|          | 0/46782 [00:00<?, ?it/s]

/projects/robustmicrob/jlaw/inputs/structures/UP000001940_6239_CAEEL_v4.tar


  0%|          | 0/39388 [00:00<?, ?it/s]

15879


In [46]:
prots_remaining = set(prot_ids) - prot_with_structure
len(prots_remaining)

5842

In [51]:
with open('prots_remaining.txt', 'w') as out:
    out.write('\n'.join(prots_remaining) + '\n')

In [None]:
import subprocess
out_dir = "/projects/robustmicrob/jlaw/inputs/structures/meltome"
# for uniprot-ids without a structure, try downloading the alphafold structure
failed = set()
for u_id in tqdm(prots_remaining):
    url = f"https://alphafold.ebi.ac.uk/files/AF-{u_id}-F1-model_v4.pdb"
    out_file = Path(out_dir, f"AF-{u_id}-F1-model_v4.pdb")
    if not out_file.is_file():
        command = f"wget -O {out_file} {url}"
        # print(command)
        try:
            subprocess.check_call(command, shell=True)
            subprocess.check_call(['gzip', out_file])
        except subprocess.CalledProcessError:
            failed.add(u_id)
print(f"{len(failed)} structures failed to download")

In [54]:
# now for each of these structures, extract the proteinMPNN embeddings
# example for one structure:
pdb_file = Path(inputs_dir, "structures/meltome/AF-Q72JN9-F1-model_v4.pdb.gz")


In [5]:
### Design Options
num_seqs = 1  #@param ["1", "2", "4", "8", "16", "32", "64"] {type:"raw"}
num_seq_per_target = num_seqs

# Sampling temperature for amino acids, T=0.0 means taking argmax, T>>1.0 means sample randomly.
sampling_temp = "0.2" #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5"]

save_score=0                      # 0 for False, 1 for True; save score=-log_prob to npy files
save_probs=0                      # 0 for False, 1 for True; save MPNN predicted probabilites per position
score_only=0                      # 0 for False, 1 for True; score input backbone-sequence pairs
conditional_probs_only=0          # 0 for False, 1 for True; output conditional probabilities p(s_i given the rest of the sequence and backbone)
conditional_probs_only_backbone=0 # 0 for False, 1 for True; if true output conditional probabilities p(s_i given backbone)
    
batch_size=1                      # Batch size; can set higher for titan, quadro GPUs, reduce this if running out of GPU memory
max_length=20000                  # Max sequence length
    
out_folder='results/mpnn_vanilla/'                    # Path to a folder to output sequences, e.g. /home/out/
jsonl_path=''                     # Path to a folder with parsed pdb into jsonl
omit_AAs='X'                      # Specify which amino acids should be omitted in the generated sequence, e.g. 'AC' would omit alanine and cystine.
   
pssm_multi=0.0                    # A value between [0.0, 1.0], 0.0 means do not use pssm, 1.0 ignore MPNN predictions
pssm_threshold=0.0                # A value between -inf + inf to restric per position AAs
pssm_log_odds_flag=0               # 0 for False, 1 for True
pssm_bias_flag=0                   # 0 for False, 1 for True

##############################################################
folder_for_outputs = out_folder

NUM_BATCHES = num_seq_per_target//batch_size
BATCH_COPIES = batch_size
temperatures = [float(item) for item in sampling_temp.split()]
omit_AAs_list = omit_AAs
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'

omit_AAs_np = np.array([AA in omit_AAs_list for AA in alphabet]).astype(np.float32)

chain_id_dict = None
fixed_positions_dict = None
pssm_dict = None
omit_AA_dict = None
bias_AA_dict = None
tied_positions_dict = None
bias_by_res_dict = None
bias_AAs_np = np.zeros(len(alphabet))

In [60]:
from importlib import reload  # Python 3.4+
import protein_mpnn_utils
protein_mpnn_utils = reload(protein_mpnn_utils)

In [63]:
homomer = False #@param {type:"boolean"}
designed_chain = "A" #@param {type:"string"}
fixed_chain = "" #@param {type:"string"}

designed_chain_list = ["A"]
fixed_chain_list = []
chain_list = list(set(designed_chain_list + fixed_chain_list))

###############################################################
pdb_dict_list = protein_mpnn_utils.parse_PDB(str(pdb_file), input_chain_list=chain_list)
dataset_valid = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)

chain_id_dict = {}
chain_id_dict[pdb_dict_list[0]['name']]= (designed_chain_list, fixed_chain_list)

print(chain_id_dict)
for chain in chain_list:
  l = len(pdb_dict_list[0][f"seq_chain_{chain}"])
  print(f"Length of chain {chain} is {l}")

tied_positions_dict = None

{'AF-Q72JN9-F1-model_v4.pd': (['A'], [])}
Length of chain A is 270


In [64]:
# X has the N, Ca, C, and O coordinates for each AA
# S has the sequence
batch_clones = [copy.deepcopy(protein) for ix, protein in enumerate(dataset_valid) for i in range(BATCH_COPIES)]
X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, \
    visible_list_list, masked_list_list, masked_chain_length_list_list, \
    chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, \
    tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, \
    bias_by_res_all, tied_beta = \
        tied_featurize(
    batch_clones, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, \
    tied_positions_dict, pssm_dict, bias_by_res_dict)

In [None]:
# I need to load the pdb_dict objects for each structure
pdb_dict_list = protein_mpnn_utils.parse_PDB(str(pdb_file))
# then create the dataset objects
dataset = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)
loader = StructureLoader(dataset, batch_size=args.batch_size)
for _, batch in enumerate(loader_train):
    start_batch = time.time()
    X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
    randn_1 = torch.randn(chain_M.shape, device=X.device)
    log_probs, h_V = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1, return_embedding=True)

In [65]:
randn_1 = torch.randn(chain_M.shape, device=X.device)
log_probs, h_V = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1, return_embedding=True)
log_probs

tensor([[[-2.0773, -5.1405, -3.3911,  ..., -5.2813, -4.7411, -5.4465],
         [-1.9160, -5.2629, -3.5025,  ..., -5.1482, -4.7408, -5.4474],
         [-2.5496, -5.3444, -3.9634,  ..., -5.3768, -5.1485, -5.5342],
         ...,
         [-2.7685, -4.5945, -4.3894,  ..., -4.5018, -2.6582, -5.2274],
         [-1.3811, -2.5970, -5.1208,  ..., -4.5383, -1.9913, -5.4775],
         [-1.8980, -5.1028, -3.4226,  ..., -5.3310, -5.2427, -5.4430]]],
       device='cuda:0', grad_fn=<LogSoftmaxBackward0>)

In [66]:
h_V.shape

torch.Size([1, 270, 128])

### Extracting ProteinMPNN Embeddings

In [6]:
# load the preprocessed structures
import pickle
danre_structures = pickle.load(open(Path(inputs_dir, "structures/UP000000437_7955_DANRE_v4.p"), 'rb'))
print(len(danre_structures))

55


In [7]:
pdb_dict_list = []
for u_id, dict_list in danre_structures.items():
    pdb_dict = dict_list[0]
    pdb_dict['uniprot_id'] = u_id
    pdb_dict['name'] = u_id
    pdb_dict_list.append(pdb_dict)

In [8]:
pdb_dict_list[0].keys()

dict_keys(['seq_chain_A', 'coords_chain_A', 'num_of_chains', 'seq', 'uniprot_id', 'name'])

In [9]:
set(key for pdb_dict in pdb_dict_list for key in pdb_dict.keys() if 'seq_chain' in key)

{'seq_chain_A'}

In [None]:
max_length

20000

In [20]:
# number of tokens for one batch
# batch_size = 10000
batch_size = 5000

In [12]:
def my_featurize(batch, device):
    homomer = False #@param {type:"boolean"}
    designed_chain = "A" #@param {type:"string"}
    fixed_chain = "" #@param {type:"string"}

    designed_chain_list = ["A"]
    fixed_chain_list = []
    chain_list = list(set(designed_chain_list + fixed_chain_list))

    chain_id_dict = {pdb_dict['name']: (designed_chain_list, fixed_chain_list) for pdb_dict in batch}
    tied_positions_dict = None

    X, S, mask, lengths, chain_M, chain_encoding_all, chain_list_list, \
    visible_list_list, masked_list_list, masked_chain_length_list_list, \
    chain_M_pos, omit_AA_mask, residue_idx, dihedral_mask, \
    tied_pos_list_of_lists_list, pssm_coef, pssm_bias, pssm_log_odds_all, \
    bias_by_res_all, tied_beta = \
        protein_mpnn_utils.tied_featurize(
    batch, device, chain_id_dict, fixed_positions_dict, omit_AA_dict, \
    tied_positions_dict, pssm_dict, bias_by_res_dict)
    
    return X, S, mask, lengths, chain_M, chain_M_pos, residue_idx, chain_encoding_all

In [24]:
from importlib import reload  # Python 3.4+
import protein_mpnn_utils
protein_mpnn_utils = reload(protein_mpnn_utils)
torch.cuda.empty_cache()

In [22]:
del X, S, mask, lengths, chain_M, chain_M_pos, residue_idx, chain_encoding_all, randn_1
torch.cuda.empty_cache()

In [23]:
del batch

In [21]:
# I need to load the pdb_dict objects for each structure
# pdb_dict_list = protein_mpnn_utils.parse_PDB(str(pdb_file))
# then create the dataset objects
dataset = StructureDatasetPDB(pdb_dict_list, truncate=None, max_length=max_length)
loader = StructureLoader(dataset, batch_size=batch_size)
for batch in loader:
    start_batch = time.time()
    X, S, mask, lengths, chain_M, chain_M_pos, residue_idx, chain_encoding_all = my_featurize(batch, device)
    randn_1 = torch.randn(chain_M.shape, device=X.device)
    log_probs, h_V = model(X, S, mask, chain_M*chain_M_pos, residue_idx, chain_encoding_all, randn_1, return_embedding=True)

OutOfMemoryError: CUDA out of memory. Tried to allocate 344.00 MiB (GPU 0; 31.75 GiB total capacity; 29.13 GiB already allocated; 124.94 MiB free; 30.42 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [112]:
len(batch)

10