#### In this notebook we will try to use the **pyhmmer** package to realign the sequences generated by esm to the one of the msa. 

In [33]:
import torch
import numpy as np
import scipy
import os
import sys
import biotite.structure
from biotite.structure.io import pdbx, pdb
from biotite.structure.residues import get_residues
from biotite.structure import filter_backbone
from biotite.structure import get_chains
from biotite.sequence import ProteinSequence
from typing import Sequence, Tuple, List
from Bio import SeqIO


git_folder = '/home/luchinoprince/Dropbox/Old_OneDrive/Phd/Second_year/research/Feinauer/Inverse_Folding'
esm_folder = '/home/luchinoprince/Dropbox/Old_OneDrive/Phd/Second_year/research/Feinauer/esm/'
sys.path.insert(1, os.path.join(git_folder, 'model'))
sys.path.insert(1, os.path.join(git_folder, 'util'))

sys.path.insert(1, esm_folder)
import esm
#from esm.inverse_folding import util
import esm.pretrained as pretrained

## I import this to try to get deeper on the sampling perplexity
from esm.inverse_folding.features import DihedralFeatures
from esm.inverse_folding.gvp_encoder import GVPEncoder
from esm.inverse_folding.gvp_utils import unflatten_graph
from esm.inverse_folding.gvp_transformer_encoder import GVPTransformerEncoder
from esm.inverse_folding.transformer_decoder import TransformerDecoder
from esm.inverse_folding.util import rotate, CoordBatchConverter 


#### Code for model with PLL ########
#from potts_decoder import PottsDecoder
#### Code for model with NCE ##########
from potts_decoder import PottsDecoder
from ioutils import read_fasta, read_encodings

from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.nn.functional import one_hot

from collections import defaultdict

import matplotlib.pyplot as plt
import pyhmmer



In [34]:
def load_structure(fpath, chain=None):
    """
    Args:
        fpath: filepath to either pdb or cif file
        chain: the chain id or list of chain ids to load
    Returns:
        biotite.structure.AtomArray
    """
    with open(fpath) as fin:
        pdbf = pdb.PDBFile.read(fin)
    structure = pdb.get_structure(pdbf, model=1)
    bbmask = filter_backbone(structure)
    structure = structure[bbmask]
    all_chains = get_chains(structure)
    if len(all_chains) == 0:
        raise ValueError('No chains found in the input file.')
    if chain is None:
        chain_ids = all_chains
    elif isinstance(chain, list):
        chain_ids = chain
    else:
        chain_ids = [chain] 
    for chain in chain_ids:
        if chain not in all_chains:
            raise ValueError(f'Chain {chain} not found in input file')
    chain_filter = [a.chain_id in chain_ids for a in structure]
    structure = structure[chain_filter]
    return structure

def extract_coords_from_structure(structure: biotite.structure.AtomArray):
    """
    Args:
        structure: An instance of biotite AtomArray
    Returns:
        Tuple (coords, seq)
            - coords is an L x 3 x 3 array for N, CA, C coordinates
            - seq is the extracted sequence
    """
    coords = get_atom_coords_residuewise(["N", "CA", "C"], structure)
    residue_identities = get_residues(structure)[1]
    seq = ''.join([ProteinSequence.convert_letter_3to1(r) for r in residue_identities])
    return coords, seq

def get_atom_coords_residuewise(atoms: List[str], struct: biotite.structure.AtomArray):
    """
    Example for atoms argument: ["N", "CA", "C"]
    """
    def filterfn(s, axis=None):
        filters = np.stack([s.atom_name == name for name in atoms], axis=1)
        sum = filters.sum(0)
        if not np.all(sum <= np.ones(filters.shape[1])):
            raise RuntimeError("structure has multiple atoms with same name")
        index = filters.argmax(0)
        coords = s[index].coord
        coords[sum == 0] = float("nan")
        return coords

    return biotite.structure.apply_residue_wise(struct, struct, filterfn)

def get_loss_new(decoder, inputs, eta):
    """eta is the multiplicative term in front of the penalized negative pseudo-log-likelihood"""
    msas, encodings, padding_mask  = [input.to(device) for input in inputs]
    B, M, N = msas.shape
    #print(f"encodings' shape{encodings.shape}, padding mask:{padding_mask.shape}")
    param_embeddings, fields = decoder.forward_new(encodings, padding_mask)
    msas_embedded = embedding(msas)

    # get npll
    npll = get_npll2(msas_embedded, param_embeddings, fields, N, q)
    padding_mask_inv = (~padding_mask)
    # multiply with the padding mask to filter non-existing residues (this is probably not necessary)       
    npll = npll * padding_mask_inv.unsqueeze(1)
    npll_mean = torch.sum(npll) / (M * torch.sum(padding_mask_inv))
    
    Q = torch.einsum('bkuia, buhia->bkhia', 
                param_embeddings.unsqueeze(2), param_embeddings.unsqueeze(1)).sum(axis=-1)
    penalty = eta*(torch.sum(torch.sum(Q,axis=-1)**2) - torch.sum(Q**2) + torch.sum(fields**2))/B
    loss_penalty = npll_mean + penalty
    return loss_penalty, npll_mean.item() 

def get_loss(decoder, inputs, eta):
    """eta is the multiplicative term in front of the penalized negative pseudo-log-likelihood"""
    msas, encodings, padding_mask  = [input.to(device) for input in inputs]
    B, M, N = msas.shape
    #print(f"encodings' shape{encodings.shape}, padding mask:{padding_mask.shape}")
    couplings, fields = decoder(encodings, padding_mask)

    # embed and reshape to (B, M, N*q)
    msas_embedded = embedding(msas).view(B, M, -1)

    # get npll
    npll = get_npll(msas_embedded, couplings, fields, N, q)
    padding_mask_inv = (~padding_mask)

    # multiply with the padding mask to filter non-existing residues (this is probably not necessary)       
    npll = npll * padding_mask_inv.unsqueeze(1)
    penalty = eta*(torch.sum(couplings**2) + torch.sum(fields**2))/B

    # the padding mask does not contain the msa dimension so we need to multiply by M
    npll_mean = torch.sum(npll) / (M * torch.sum(padding_mask_inv))
    loss_penalty = npll_mean + penalty
    return loss_penalty, npll_mean.item() 
    #return loss_penalty

def get_loss_loader(decoder, loader, eta):

    decoder.eval()
    losses = 0
    iterator = 0
    with torch.no_grad():
        for inputs in loader:
            iterator+=1
            _, npll = get_loss_new(decoder, inputs, eta) 
            losses+=npll
    
    return losses/iterator

def compute_covariance(msa, q):
    """
    Compute covariance matrix of a given MSA having q different amino acids
    """
    M, N = msa.shape

    # One hot encode classes and reshape to create data matrix
    D = torch.flatten(one_hot(msa, num_classes=q), start_dim=1).to(torch.float32)

    # Remove one amino acid
    D = D.view(M, N, q)[:, :, :q-1].flatten(1)

    # Compute bivariate frequencies
    bivariate_freqs = D.T @ D / M
    
    # Compute product of univariate frequencies
    univariate_freqs = torch.diagonal(bivariate_freqs).view(N*(q-1), 1) @ torch.diagonal(bivariate_freqs).view(1, N*(q-1))

    return bivariate_freqs - univariate_freqs

In [35]:
device='cpu'
mutational_dir = '/media/luchinoprince/b1715ef3-045d-4bdf-b216-c211472fb5a2/Data/InverseFolding/Mutational_Data'
msas_folder = '/media/luchinoprince/b1715ef3-045d-4bdf-b216-c211472fb5a2/Data/InverseFolding/Mutational_Data/alphafold_results_wildtype'

protein_original_DMS = 'YAP1_HUMAN_1_b0.5.a2m.wildtype.fasta'
structure_name = 'YAP1_HUMAN_1_b0.5.a2m_unrelaxed_rank_1_model_5.pdb'

folder_fasta = os.path.join(mutational_dir, 'alignments')
native_path = os.path.join(folder_fasta, protein_original_DMS)
structure_folder = os.path.join(mutational_dir, 'alphafold_results_wildtype')
structure_path = os.path.join(structure_folder, structure_name)

num_seq = read_fasta(native_path, mutated_exp=True)
structure =  load_structure(structure_path)
coords, native_seq = extract_coords_from_structure(structure)
coords=torch.tensor(coords).to(device)

In [36]:
alphabet='ACDEFGHIKLMNPQRSTVWY-'
default_index = alphabet.index('-')
aa_index = defaultdict(lambda: default_index, {alphabet[i]: i for i in range(len(alphabet))})
aa_index_inv = dict(map(reversed, aa_index.items()))

In [5]:
import re
fastapath = "/media/luchinoprince/b1715ef3-045d-4bdf-b216-c211472fb5a2/Data/InverseFolding/Mutational_Data/alphafold_results_wildtype/MSAS_new/YAP1_HUMAN_1_b0.5.a2m.a3m"
with open(fastapath, mode="r") as f:
    lines = f.readlines()
lines = lines[1:]

msa_true = []
for line in range(len(lines)):
    if line%2 == 0:
        ## Take the end of sequence file
        seq_str = lines[line]#[0:-1]
        seq_num = []
        for char in seq_str:
            if char != '\n':
                seq_num.append(aa_index[char])
        if len(seq_num) == 36:
            msa_true.append(seq_num)
        else:
            print("failure")

msa_true = torch.tensor(msa_true)
cov_true = compute_covariance(msa_true[::,::], q=21)


In [6]:
pyhmmer.__version__

'0.8.1'

In [7]:
alphabet_hmm = pyhmmer.easel.Alphabet.amino()
alphabet_hmm

pyhmmer.easel.Alphabet.amino()

In [8]:
with pyhmmer.easel.MSAFile(fastapath, digital=True, alphabet=alphabet_hmm) as msa_file:
    msa = msa_file.read()

In [9]:
len(msa.sequences)

14484

In [10]:
msa.name = b"YAP-HUMAN"

In [11]:
?pyhmmer.plan7.Builder

[0;31mInit signature:[0m
[0mpyhmmer[0m[0;34m.[0m[0mplan7[0m[0;34m.[0m[0mBuilder[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mself[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0malphabet[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m*[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0marchitecture[0m[0;34m=[0m[0;34m'fast'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mweighting[0m[0;34m=[0m[0;34m'pb'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meffective_number[0m[0;34m=[0m[0;34m'entropy'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mprior_scheme[0m[0;34m=[0m[0;34m'alpha'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msymfrac[0m[0;34m=[0m[0;36m0.5[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mfragthresh[0m[0;34m=[0m[0;36m0.5[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mwid[0m[0;34m=[0m[0;36m0.62[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mesigma[0m[0;34m=[0m[0;36m45.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0meid[0m[0;34m=[0m[0;36m0.62

In [12]:
builder = pyhmmer.plan7.Builder(alphabet_hmm, symfrac=0.0)
background = pyhmmer.plan7.Background(alphabet_hmm)
hmm, _, _ = builder.build_msa(msa, background)

In [13]:
hmm.consensus

'daPLPpGWeeavdpdGrvYYyNheTgettWedPreA'

In [37]:
model, alphabet_esm = pretrained.esm_if1_gvp4_t16_142M_UR50()
model.eval();
model.to(device)
rep = esm.inverse_folding.util.get_encoder_output(model, alphabet_esm, coords)

  F.pad(torch.tensor(cd), (0, 0, 0, 0, 1, 1), value=np.inf)


In [40]:
model.sample.__globals__['__file__']

'/home/luchinoprince/Dropbox/Old_OneDrive/Phd/Second_year/research/Feinauer/esm/esm/inverse_folding/gvp_transformer.py'

In [15]:
samples_esm = []
samples_str = []
samples_hmm = []
for attempt in range(50):
    print(f"We are at sample {attempt} out of {50}", end="\r")
    sample = model.sample(coords, temperature=1.0)
    name = f"sample{attempt}".encode('ASCII')
    sample_dig = pyhmmer.easel.TextSequence(name = name, sequence=sample).digitize(alphabet_hmm)
    samples_hmm.append(sample_dig)
    samples_str.append(sample)
    seq_num = []
    for char in sample:
        seq_num.append(aa_index[char])
    samples_esm.append(seq_num)

We are at sample 49 out of 50

In [98]:
#msa_hmm  = pyhmmer.easel.TextMSA(name=b"msa", sequences=samples_hmm)
#msa_d = msa_hmm.digitize(alphabet_hmm)#

In [24]:
msa_aligned = pyhmmer.hmmer.hmmalign(hmm, samples_hmm, trim=True)

In [32]:
len(msa_aligned.alignment[1])

36

In [27]:
samples_str[0]

'MKSMPEGYLAISDNEGNRQYYNTTTDQISIADPRQE'

In [28]:
for name, aligned in zip(msa_aligned.names, msa_aligned.alignment):
    print(name, " ", aligned[:48], "...")

b'sample0'   ---MPEGYLAISDNEGNRQYYNTTTDQISIADPRQ- ...
b'sample1'   ERPLPEGYTAVSTAEGKTLFIDNNTKQATGIDPR-- ...
b'sample2'   --ALPKGWKKATTASGKQVYYDSKKATVTSKDPR-- ...
b'sample3'   --PLPDGYVEQYTKHGTKIYFDTETQTVTYTDPREA ...
b'sample4'   ---LPDGYVEITTLRGRLLYFDSSRRKVSLVDPR-- ...
b'sample5'   ---NPIGWIQTNTDDGTVVFYNSERDMVTRSDPR-- ...
b'sample6'   ----PIGWVEESDEEGVQFYWNTVQNTRSHEDPR-- ...
b'sample7'   ---LPAGWVAVKNDSGETFFFDSKTNTQSWEDPRQ- ...
b'sample8'   ---MPEGWRAHDNGNGTKFYFDGNNNTSSWFDPR-- ...
b'sample9'   ----PFGWTVVYTKTGKSLYVDKNQNTISGVDPR-- ...
b'sample10'   ---MPAGWLRLFTDQGDQIYFDMNTKTTTWQDPRQ- ...
b'sample11'   -TPLPEGYVEIYDGAGRKHYFDDNTKTATKDDPRD- ...
b'sample12'   --ELPDGFYQWHNSEGETWYYDTTTETSTKEDPR-- ...
b'sample13'   ----PPGWVDRVAPTGEKFFYDSRWGRETWTDPRQ- ...
b'sample14'   ----PFGWTEIYTDTGTLLYYNGVTHKASSVDPR-- ...
b'sample15'   ----PCGYQSRKSSSGQRFYYDANTQTSTWIDPRD- ...
b'sample16'   ----PTGWRILHTADGTAVYFDQSAFIVSRDDPRQ- ...
b'sample17'   --PPPSGWKRVYDKSGKRHWYNSNTNTTSWYDPRE- ...
b'sample18'   -IPMPA

In [68]:
?pyhmmer.easel.TextMSA.alignment

[0;31mType:[0m        getset_descriptor
[0;31mString form:[0m <attribute 'alignment' of 'pyhmmer.easel.TextMSA' objects>
[0;31mDocstring:[0m  
`tuple` of `str`: A view of the aligned sequences as strings.

This property gives access to the aligned sequences, including gap
characters, so that they can be displayed or processed column by
column.

Examples:
    Use `TextMSA.alignment` to display an alignment in text
    format::

        >>> for name, aligned in zip(luxc.names, luxc.alignment):
        ...     print(name, " ", aligned[:40], "...")
        b'Q9KV99.1'   LANQPLEAILGLINEARKSWSST------------PELDP ...
        b'Q2WLE3.1'   IYSYPSEAMIEIINEYSKILCSD------------RKFLS ...
        b'Q97GS8.1'   VHDIKTEETIDLLDRCAKLWLDDNYSKK--HIETLAQITN ...
        b'Q3WCI9.1'   LLNVPLKEIIDFLVETGERIRDPRNTFMQDCIDRMAGTHV ...
        b'P08639.1'   LNDLNINNIINFLYTTGQRWKSEEYSRRRAYIRSLITYLG ...
        ...

    Use the splat operator (*) in combination with the `zip`
    builtin to iterate over the co

In [29]:
## THis does not make a lot of sense
msa_aligned.sequences[0].sequence

'MPEGYLAISDNEGNRQYYNTTTDQISIADPRQ'

In [21]:
samples_str[1]

'ERPLPEGYTAVSTAEGKTLFIDNNTKQATGIDPRAK'

In [23]:
len("---MPEGYLAISDNEGNRQYYNTTTDQISIADPRQ-")

36

In [68]:
alphabet_hmm.symbols

'ACDEFGHIKLMNPQRSTVWY-BJZOUX*~'

#### WORKED OUT HOW TO USE THE LIBRARY, THE PROBLEM IS THAT IT DOES NOT GIVE BACK GAPS WHEN ACCESSING SEQUENCES, AND WE WANT ALSO THOSE SINCE WE USE THEM TO CONSTRUCT THE PCA