## In this file I will test the correlation between the samples generated from the esm, the Potts decoder and the true MSA

In [1]:
import numpy as np
import scipy
import matplotlib.pyplot as plt
from encoded_protein_dataset import EncodedProteinDataset#, collate_fn, get_embedding
from encoded_protein_dataset_new import get_embedding, EncodedProteinDataset_new, collate_fn_new
from pseudolikelihood import get_npll, get_npll_new, get_npll2
import torch, torchvision
from torch.nn.functional import one_hot
from potts_decoder import PottsDecoder
from torch.utils.data import DataLoader, RandomSampler
from functools import partial
import biotite.structure
from biotite.structure.io import pdbx, pdb
from biotite.structure.residues import get_residues
from biotite.structure import filter_backbone
from biotite.structure import get_chains
from biotite.sequence import ProteinSequence
from typing import Sequence, Tuple, List
import scipy
from tqdm import tqdm
import pandas as pd
import csv
import time

#import pytorch_warmup as warmup
from collections import defaultdict
import os
import sys
##TURIN HPC
#sys.path.insert(1, "/Data/silva/esm/")

## EUROPA
sys.path.insert(1, "/home/lucasilva/esm/")

##Lucas computer
sys.path.insert(1, "D:/esm/")
import esm
#from esm.inverse_folding import util
import esm.pretrained as pretrained
from ioutils import read_fasta, read_encodings
from torch.nn.utils.rnn import pad_sequence
from collections import defaultdict


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
max_msas = 100
msa_dir = "D:/Data/InverseFolding/msas/"
encoding_dir ="D:/Data/InverseFolding/structure_encodings/"


train_dataset = EncodedProteinDataset_new(os.path.join(msa_dir, 'train'), encoding_dir, noise=0.02, max_msas=max_msas)          ## Default value of noise used
sequence_test_dataset = EncodedProteinDataset_new(os.path.join(msa_dir, 'test/sequence'), encoding_dir, noise=0.0, max_msas=max_msas)
structure_test_dataset = EncodedProteinDataset_new(os.path.join(msa_dir, 'test/structure'), encoding_dir, noise=0.0, max_msas=max_msas)
superfamily_test_dataset = EncodedProteinDataset_new(os.path.join(msa_dir, 'test/superfamily'), encoding_dir, noise=0.0, max_msas=max_msas)


Counter is:36 length data:36

  encodings = torch.tensor(read_encodings(encoding_path, trim=False))


Counter is:100 length data:99

In [3]:
batch_structure_size = 1   ### To generate models fix batchsize to one for the moment
perc_subset_test = 1.0     ## During the training, for every dataset available we select a random 10% of its samples. Now moved to 100% due to noise in evaluation metrics
batch_msa_size = 16
q = 21                      ##isn't always 21??
collate_fn = partial(collate_fn_new, q=q, batch_msa_size=batch_msa_size)

train_loader = DataLoader(train_dataset, batch_size=batch_structure_size, collate_fn=collate_fn, shuffle=True)#, num_workers=3)

In [4]:
def get_loss_new(decoder, inputs, eta):
    """eta is the multiplicative term in front of the penalized negative pseudo-log-likelihood"""
    msas, encodings, padding_mask  = [input.to(device) for input in inputs]
    B, M, N = msas.shape
    #print(f"encodings' shape{encodings.shape}, padding mask:{padding_mask.shape}")
    param_embeddings, fields = decoder.forward_new(encodings, padding_mask)
    msas_embedded = embedding(msas)

    # get npll
    npll = get_npll2(msas_embedded, param_embeddings, fields, N, q)
    padding_mask_inv = (~padding_mask)
    # multiply with the padding mask to filter non-existing residues (this is probably not necessary)       
    npll = npll * padding_mask_inv.unsqueeze(1)
    npll_mean = torch.sum(npll) / (M * torch.sum(padding_mask_inv))
    
    Q = torch.einsum('bkuia, buhia->bkhia', 
                param_embeddings.unsqueeze(2), param_embeddings.unsqueeze(1)).sum(axis=-1)
    penalty = eta*(torch.sum(torch.sum(Q,axis=-1)**2) - torch.sum(Q**2) + torch.sum(fields**2))/B
    loss_penalty = npll_mean + penalty
    return loss_penalty, npll_mean.item() 

def get_loss(decoder, inputs, eta):
    """eta is the multiplicative term in front of the penalized negative pseudo-log-likelihood"""
    msas, encodings, padding_mask  = [input.to(device) for input in inputs]
    B, M, N = msas.shape
    #print(f"encodings' shape{encodings.shape}, padding mask:{padding_mask.shape}")
    couplings, fields = decoder(encodings, padding_mask)

    # embed and reshape to (B, M, N*q)
    msas_embedded = embedding(msas).view(B, M, -1)

    # get npll
    npll = get_npll(msas_embedded, couplings, fields, N, q)
    padding_mask_inv = (~padding_mask)

    # multiply with the padding mask to filter non-existing residues (this is probably not necessary)       
    npll = npll * padding_mask_inv.unsqueeze(1)
    penalty = eta*(torch.sum(couplings**2) + torch.sum(fields**2))/B

    # the padding mask does not contain the msa dimension so we need to multiply by M
    npll_mean = torch.sum(npll) / (M * torch.sum(padding_mask_inv))
    loss_penalty = npll_mean + penalty
    return loss_penalty, npll_mean.item() 
    #return loss_penalty

def get_loss_loader(decoder, loader, eta):

    decoder.eval()
    losses = 0
    iterator = 0
    with torch.no_grad():
        for inputs in loader:
            iterator+=1
            _, npll = get_loss_new(decoder, inputs, eta) 
            losses+=npll
    
    return losses/iterator

def compute_covariance(msa, q):
    """
    Compute covariance matrix of a given MSA having q different amino acids
    """
    M, N = msa.shape

    # One hot encode classes and reshape to create data matrix
    D = torch.flatten(one_hot(msa, num_classes=q), start_dim=1).to(torch.float32)

    # Remove one amino acid
    D = D.view(M, N, q)[:, :, :q-1].flatten(1)

    # Compute bivariate frequencies
    bivariate_freqs = D.T @ D / M
    
    # Compute product of univariate frequencies
    univariate_freqs = torch.diagonal(bivariate_freqs).view(N*(q-1), 1) @ torch.diagonal(bivariate_freqs).view(1, N*(q-1))

    return bivariate_freqs - univariate_freqs

## Let us load a model

In [5]:
device=0
bk_dir= 'D:/Data/InverseFolding/Intermediate_Models/'
fname_par = 'model_17_01_2023.pt'
checkpoint = torch.load(os.path.join(bk_dir, fname_par))

## Load parameters of the mode,
q=21
args = checkpoint['args']
n_layers = args['n_layers']
param_embed_dim = d_model = args['param_embed_dim']
input_encoding_dim = args['input_encoding_dim']
n_heads=args['n_heads']
n_param_heads=args['n_param_heads']
dropout=args['dropout']

decoder = PottsDecoder(q, n_layers, d_model, input_encoding_dim, param_embed_dim, n_heads, n_param_heads, dropout=dropout);
decoder.to(device);

decoder.load_state_dict(checkpoint['model_state_dict']);
decoder.eval();   ##to generate data we need just the forward pass of the model!

512
512


Here we actually do the old forward pass as it outputs directly the copuling matrix and the fields

In [6]:
inputs = next(iter(train_loader))
msas, encodings, padding_mask  = [input.to(device) for input in inputs]

## We have to normalize these!!!! Correct
couplings, fields = decoder(encodings, padding_mask)

  encodings = torch.tensor(read_encodings(encoding_path, trim=False))


In [7]:
msas.shape

torch.Size([1, 16, 348])

Now we have to save them in a _txt_ file in the format that the _C++_ library can read

In [51]:
B,N,_ = encodings.shape
with open("couplings_fields.txt", "w") as f:
    ## write J
    for i in range(N):
        for j in range(i+1, N):
            for aa1 in range(q):
                for aa2 in range(q):
                    J_el = couplings[0, i*q+aa1, j*q+aa2].detach().to('cpu').item()
                    line = "J " + str(i) + " " + str(j) + " "+ str(aa1) + " " + str(aa2) + " " + str(J_el) +"\n"
                    f.write(line)
    
    ## write h
    for i in range(N):
        for aa in range(q):
            h_el = fields[0, i*q+aa1].detach().to('cpu').item()
            line = "h " + str(i) + " " + str(aa) + " " + str(h_el) + "\n"
            f.write(line)

## Find faster way to do this, it takes roughly 4mins hour

## Now I will load the sequences generated from _bmDCA_ and generate their covariances.

In [84]:
samples_dir = "../Samples_Potts/YAP1_HUMAN"
file='samples_numerical.txt'
with open(os.path.join(samples_dir,file), mode='r') as f:
    lines=f.readlines()

len(lines)

10001

In [85]:
alphabet='ACDEFGHIKLMNPQRSTVWY-'
default_index = alphabet.index('-')
aa_index = defaultdict(lambda: default_index, {alphabet[i]: i for i in range(len(alphabet))})
aa_index_inv = dict(map(reversed, aa_index.items()))

char_seq = [] ##36 is the lenght of YAP

for i in range(1, len(lines)):
    line = lines[i][0:-1].split(" ") ## I take out the end of file
    line_char = [aa_index_inv[int(idx)] for idx in line]
    char_seq.append(line_char)

In [86]:
## Now re-translate
for prot_idx in range(len(char_seq)):
    for aa in range(len(char_seq[prot_idx])):
        char_seq[prot_idx][aa] = aa_index[char_seq[prot_idx][aa]]

In [87]:
msa = np.array(char_seq)

In [88]:
msa.shape

(10000, 36)

In [89]:
msa_t = torch.tensor(msa, dtype=torch.long)

In [90]:
msa_t

tensor([[20, 20, 20,  ..., 20, 20, 20],
        [20, 20, 20,  ..., 20, 20, 20],
        [20, 20, 20,  ...,  3, 20, 20],
        ...,
        [20, 20,  2,  ..., 14, 20, 20],
        [20, 20, 20,  ...,  9, 20, 20],
        [20, 20, 20,  ...,  0, 20, 20]])

In [91]:
cov = compute_covariance(msa_t, 21)

In [92]:
cov_4d = cov.reshape(36,20,36, 20)

In [93]:
cov_4d

tensor([[[[ 1.6971e-03, -4.7600e-06, -1.7170e-05,  ..., -2.8900e-06,
           -3.4000e-06, -2.5500e-06],
          [-1.3090e-05, -6.1200e-06, -9.1800e-06,  ..., -8.1600e-06,
           -7.1400e-06,  9.5580e-05],
          [ 8.3510e-05,  1.9082e-04,  1.3081e-04,  ..., -1.1900e-05,
           -9.0100e-06, -1.1390e-05],
          ...,
          [ 1.5975e-04,  8.5040e-05, -2.9920e-05,  ..., -3.4850e-05,
           -9.0100e-06, -8.5000e-06],
          [-1.2240e-05, -6.2900e-06,  3.3700e-05,  ..., -8.3300e-06,
           -7.6500e-06, -6.6300e-06],
          [-3.5700e-06, -3.2300e-06, -2.2100e-06,  ..., -3.0600e-06,
            9.6940e-05, -4.2500e-06]],

         [[-4.7600e-06,  2.7922e-03, -2.8280e-05,  ..., -4.7600e-06,
           -5.6000e-06, -4.2000e-06],
          [-2.1560e-05,  1.8992e-04,  1.8488e-04,  ...,  1.8656e-04,
           -1.1760e-05,  1.9272e-04],
          [ 7.2840e-05, -1.5120e-05, -1.1396e-04,  ...,  1.8040e-04,
            8.5160e-05, -1.8760e-05],
          ...,
     

## Now let us load the MSA of YAP and see if we are able to recover the covariance of the MSA

In [19]:
msa_YAP_path = "D:\Data\InverseFolding\Mutational_Data\msa_full_YAP"
with open(msa_YAP_path, mode="r") as f:
    lines = f.readlines()
lines = lines[1:]


In [20]:
lines[1][0:-1]

'>UniRef50_UPI000D50330B WW domain with PPxY motif n=1 Tax=Homo sapiens TaxID=9606 RepID=UPI000D50330B'

In [21]:
msa = []
for line in range(len(lines)):
    if line%2 == 0:
        seq_str = 'DVP'+lines[line][0:-1]+'RKA'
        seq_num = []
        for char in seq_str:
            seq_num.append(aa_index[char])
        msa.append(seq_num)

In [22]:
msa_true = torch.tensor(msa)
msa_true.shape

torch.Size([9456, 36])

## Let us see if the other file has similar data

In [52]:
fpath = "D:/Data/InverseFolding/Mutational_Data/alphafold_results_wildtype/YAP1_HUMAN_1_b0.5.a2m.a3m"
with open(fpath, mode="r") as f:
    lines = f.readlines()

In [53]:
lines = lines[2:]

In [54]:
msa_new = []
for line in range(len(lines)):
    if line%2 == 0:
        ## Take the end of sequence file
        seq_str = lines[line][0:-1]
        seq_num = []
        for char in seq_str:
            seq_num.append(aa_index[char])
        if len(seq_num) == 36:
            msa_new.append(seq_num)

In [55]:
msa_new = torch.tensor(msa_new)

In [56]:
msa_new

tensor([[ 2, 17, 12,  ..., 14,  8,  0],
        [ 2,  7, 12,  ..., 14, 14,  0],
        [11, 16, 12,  ..., 14,  8, 15],
        ...,
        [20, 20, 20,  ..., 14, 20, 20],
        [ 2,  8,  3,  ..., 14, 20, 20],
        [20, 20, 20,  ..., 14, 20, 20]])

In [57]:
msa_true

tensor([[ 2, 17, 12,  ..., 14,  8,  0],
        [ 2, 17, 12,  ..., 14,  8,  0],
        [ 2, 17, 12,  ..., 14,  8,  0],
        ...,
        [ 2, 17, 12,  ..., 14,  8,  0],
        [ 2, 17, 12,  ..., 14,  8,  0],
        [ 2, 17, 12,  ..., 14,  8,  0]])

In [58]:
msa_true

tensor([[ 2, 17, 12,  ..., 14,  8,  0],
        [ 2, 17, 12,  ..., 14,  8,  0],
        [ 2, 17, 12,  ..., 14,  8,  0],
        ...,
        [ 2, 17, 12,  ..., 14,  8,  0],
        [ 2, 17, 12,  ..., 14,  8,  0],
        [ 2, 17, 12,  ..., 14,  8,  0]])

In [75]:
cov_true = compute_covariance(msa_true[::,::], q=21)
cov_true_new = compute_covariance(msa_new[::,::], q=21)

In [76]:
corr = torch.sum(cov_true * cov_true_new)/torch.sqrt(torch.sum(cov_true**2)*torch.sum(cov_true_new**2))
corr

tensor(0.7970)

In [82]:
cov_true = compute_covariance(msa_true, q=21)

In [94]:
corr = torch.sum(cov_true * cov)/torch.sqrt(torch.sum(cov_true**2)*torch.sum(cov**2))

In [95]:
corr

tensor(0.5333)

In [97]:
corr_new = torch.sum(cov_true_new * cov)/torch.sqrt(torch.sum(cov_true_new**2)*torch.sum(cov**2))
corr_new

tensor(0.5323)

In [98]:
with open("C:/Users/Luca/OneDrive/Phd/Second_year/research/Feinauer/Samples_esm/output/sampled_seqs.fasta", mode="r") as f:
    lines=f.readlines()

In [99]:
msa_esm = []
for line in range(len(lines)):
    if line%2 != 0:
        seq_str = lines[line][0:-1]
        seq_num = []
        for char in seq_str:
            seq_num.append(aa_index[char])
        msa_esm.append(seq_num)

In [100]:
msa_esm = torch.tensor(msa_esm)
msa_esm.shape

torch.Size([1000, 36])

In [101]:
lines[1]

'GMEPPEGWEKRKTKEGDEVWFHKGTNTWTYTDPRTQ\n'

In [106]:
cov_esm = compute_covariance(msa_esm, q=21)

In [107]:
corr_new = torch.sum(cov_true_new * cov_esm)/torch.sqrt(torch.sum(cov_true_new**2)*torch.sum(cov_esm**2))
corr_new


tensor(0.4115)

In [104]:
cov_esm.shape

torch.Size([720, 720])

In [105]:
cov_true.shape

torch.Size([720, 720])