In [None]:
"""Adapted from: https://github.com/facebookresearch/esm/blob/main/examples/contact_prediction.ipynb"""

## Imports

In [None]:
# standard lib
from typing import Tuple, Optional, Dict, Union, Callable
import string
from pathlib import Path

# scientific libs
import numpy as np
import pandas as pd
from scipy.spatial.distance import squareform, pdist
from Bio import SeqIO
import biotite.structure as bs
from biotite.structure.info import residue
from biotite.structure.io.pdbx import PDBxFile, get_structure
from biotite.database import rcsb
from biotite.sequence.io.fasta import convert


# DL libs
import torch
import torch.nn.functional as F
import esm

# graph libs
import matplotlib.pyplot as plt
import matplotlib as mpl
from tqdm import tqdm

torch.set_grad_enabled(False)

## Define Functions

### Parsing alignments

In [None]:
# This is an efficient way to delete lowercase characters from a string
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None

def read_sequence(filename: str) -> Tuple[str, str]:
    """ Reads the first (reference) sequences from a fasta file."""
    record = next(SeqIO.parse(filename, "fasta"))
    return record.description, str(record.seq)

### Converting structures to contacts

There are many ways to define a protein contact. Here we're using the definition of 15 angstroms between carbon beta atoms. Note that the position of the carbon beta is imputed from the position of the N, CA, and C atoms for each residue.

In [None]:
def extend(a, b, c, L, A, D):
    """
    input:  3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
    output: 4th coord
    """

    def normalize(x):
        return x / np.linalg.norm(x, ord=2, axis=-1, keepdims=True)

    bc = normalize(b - c)
    n = normalize(np.cross(b - a, bc))
    m = [bc, np.cross(n, bc), n]
    d = [L * np.cos(A), L * np.sin(A) * np.cos(D), -L * np.sin(A) * np.sin(D)]
    return c + sum([m * d for m, d in zip(m, d)])


def contacts_from_pdb(
    structure: bs.AtomArray,

    # Increase distances for flexible IDPs
    distance_threshold: float = 15.0,
    chain: Optional[str] = None,
) -> np.ndarray:
    mask = ~structure.hetero
    if chain is not None:
        mask &= structure.chain_id == chain

    N = structure.coord[mask & (structure.atom_name == "N")]
    CA = structure.coord[mask & (structure.atom_name == "CA")]
    C = structure.coord[mask & (structure.atom_name == "C")]

    Cbeta = extend(C, N, CA, 1.522, 1.927, -2.143)
    dist = squareform(pdist(Cbeta))
    contacts = dist < distance_threshold
    contacts = contacts.astype(np.int64)
    contacts[np.isnan(dist)] = -1
    return contacts

### Compute contact precisions

In [None]:
def compute_precisions(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    src_lengths: Optional[torch.Tensor] = None,
    minsep: int = 12, # increased to only consider medium-long range contacts
    maxsep: Optional[int] = None,
    override_length: Optional[int] = None,  # for CASP
):
    if isinstance(predictions, np.ndarray):
        predictions = torch.from_numpy(predictions)
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    if predictions.dim() == 2:
        predictions = predictions.unsqueeze(0)
    if targets.dim() == 2:
        targets = targets.unsqueeze(0)
    override_length = (targets[0, 0] >= 0).sum()

    if predictions.size() != targets.size():
        raise ValueError(
            f"Size mismatch. Received predictions of size {predictions.size()}, "
            f"targets of size {targets.size()}"
        )
    device = predictions.device

    batch_size, seqlen, _ = predictions.size()
    seqlen_range = torch.arange(seqlen, device=device)

    sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1)
    sep = sep.unsqueeze(0)
    valid_mask = sep >= minsep
    valid_mask = valid_mask & (targets >= 0)

    if maxsep is not None:
        valid_mask &= sep < maxsep

    if src_lengths is not None:
        valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1)
        valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2)
    else:
        src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long)

    predictions = predictions.masked_fill(~valid_mask, float("-inf"))

    # minsep masks the short range contacts < 12 residues
    x_ind, y_ind = np.triu_indices(seqlen, minsep)
    predictions_upper = predictions[:, x_ind, y_ind]
    targets_upper = targets[:, x_ind, y_ind]

    # Increase top K by three times
    topk_multiplier = 3
    topk = topk_multiplier * seqlen if override_length is None else max(seqlen, override_length)

    indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk]
    topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices]
    if topk_targets.size(1) < topk:
        topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)])

    cumulative_dist = topk_targets.type_as(predictions).cumsum(-1)

    gather_lengths = src_lengths.unsqueeze(1)
    if override_length is not None:
        gather_lengths = override_length * torch.ones_like(
            gather_lengths, device=device
        )

    gather_indices = (
        torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths
    ).type(torch.long) - 1

    binned_cumulative_dist = cumulative_dist.gather(1, gather_indices)
    binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as(
        binned_cumulative_dist
    )

    pl5 = binned_precisions[:, 1]
    pl2 = binned_precisions[:, 4]
    pl = binned_precisions[:, 9]
    auc = binned_precisions.mean(-1)

    return {"AUC": auc, "P@L": pl, "P@L2": pl2, "P@L5": pl5}


def evaluate_prediction(
    predictions: torch.Tensor,
    targets: torch.Tensor,
) -> Dict[str, float]:
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)

    contact_ranges = [
        ("M", 12, 24),
        ("L", 24, 48),
        ("XL", 48, 96),
        ("XXL", 96, None),
    ]
    metrics = {}
    targets = targets.to(predictions.device)
    for name, minsep, maxsep in contact_ranges:
        rangemetrics = compute_precisions(
            predictions,
            targets,
            minsep=minsep,
            maxsep=maxsep,
        )
        for key, val in rangemetrics.items():
            metrics[f"{name}_{key}"] = val.item()
    return metrics

### Getting contact pairs from prediction

In [None]:
def get_contact_pairs_from_prediction(prediction, sequence):
    contact_pairs = [(
            f"{convert.ProteinSequence.convert_letter_1to3(sequence[x]).capitalize()} {x+1}",
            f"{convert.ProteinSequence.convert_letter_1to3(sequence[y]).capitalize()} {y+1}",
        )
        for x, y in zip(*np.where(prediction))
    ]
    return contact_pairs

def filter_pairs(contact_pairs, residue_1, residue_2):
    return [
        res for res in contact_pairs if res[0].startswith(residue_1) and res[1].startswith(residue_2)
    ]

### Plotting Results

In [None]:
def plot_contacts_and_predictions(
    predictions: Union[torch.Tensor, np.ndarray],
    contacts: Union[torch.Tensor, np.ndarray],
    ax: Optional[mpl.axes.Axes] = None,
    cmap: str = "Blues",
    ms: float = 1,
    title: Union[bool, str, Callable[[float], str]] = True,
    animated: bool = False,
    prediction_threshold: float = 0.5,
) -> None:

    if isinstance(predictions, torch.Tensor):
        predictions = predictions.detach().cpu().numpy()
    if isinstance(contacts, torch.Tensor):
        contacts = contacts.detach().cpu().numpy()
    if ax is None:
        ax = plt.gca()

    # might need to edit
    seqlen = contacts.shape[0]
    relative_distance = np.add.outer(-np.arange(seqlen), np.arange(seqlen))

    bottom_mask = relative_distance < 6

    masked_image = np.ma.masked_where(bottom_mask, predictions)
    invalid_mask = np.abs(np.add.outer(np.arange(seqlen), -np.arange(seqlen))) < 6
    predictions = predictions.copy()
    predictions[invalid_mask] = float("-inf")

    # Select only predictions over the threshold that can be resonably predicted
    predictions = predictions > prediction_threshold
    topl_val = np.sort(predictions.reshape(-1))[-seqlen]
    pred_contacts = predictions >= topl_val
    true_positives = contacts & pred_contacts & ~bottom_mask
    false_positives = ~contacts & pred_contacts & ~bottom_mask
    other_contacts = contacts & ~pred_contacts & ~bottom_mask

    if isinstance(title, str):
        title_text: Optional[str] = title
    elif title:
        long_range_pl = compute_precisions(predictions, contacts, minsep=12)[  
            "P@L"
        ].item()
        if callable(title):
            title_text = title(long_range_pl)
        else:
            title_text = f"Long Range P@L: {100 * long_range_pl:0.1f}"
    else:
        title_text = None

    img = ax.imshow(masked_image, cmap=cmap, animated=animated)
    oc = ax.plot(*np.where(other_contacts), "o", c="lightgrey", ms=ms)[0]
    fn = ax.plot(*np.where(false_positives), "o", c="red", ms=ms)[0]
    tp = ax.plot(*np.where(true_positives), "o", c="midnightblue", ms=ms)[0]
    ti = ax.set_title(title_text) if title_text is not None else None

    ax.axis("square")
    ax.set_xlim([0, seqlen])
    ax.set_ylim([0, seqlen])

In [None]:
def get_sequence_from_structure(structure, chain_id):

    mask = ~structure.hetero
    mask &= structure.chain_id == chain_id

    residues = [ atom.res_name for atom in structure[mask] if atom.atom_name == 'CA' ]
    return convert.ProteinSequence(residues)

## Predict and Visualize

### Read Data

In [None]:
# Read in data

PDB_IDS = ["6shx"] # ["4e4w"] # ["1a3a", "5ahw", "1xcr"]
chain = 'A'

structures = {}
sequences = {}
contacts = {}

# For MMR examples
for name in PDB_IDS:
    structures[name] = get_structure(PDBxFile.read(rcsb.fetch(name, "cif")))[0]
    sequences[name] = (name, str(get_sequence_from_structure(structures[name], chain)))
    contacts[name] = contacts_from_pdb(structures[name], chain=chain)

### ESM-2 Predictions

In [None]:
esm2, esm2_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
esm2_batch_converter = esm2_alphabet.get_batch_converter()

# esm2 = torch.load('trained_model_6ep.pth') # V1

# esm2 = torch.load("trained_model_tristan_esm2_t33_650M_UR50D_v2.pth") # V2

esm2 = esm2.eval()

# Examples

# Model O

In [None]:
esm2_predictions = {}
esm2_results = []

for name, inputs in sequences.items():
    esm2_batch_labels, esm2_batch_strs, esm2_batch_tokens = esm2_batch_converter([('some random string',str(b))])

    esm2_batch_tokens = esm2_batch_tokens.to(next(esm2.parameters()).device)
    esm2_predictions[name] = esm2.predict_contacts(esm2_batch_tokens)[0].cpu()

    metrics = {"id": name, "model": "ESM-2 (Unsupervised)"}

    predictions = evaluate_prediction(esm2_predictions[name], contacts[name])

    metrics.update(predictions)
    esm2_results.append(metrics)

esm2_results = pd.DataFrame(esm2_results)
display(esm2_results)

In [None]:
fig, axes = plt.subplots(figsize=(18, 6), ncols=3)
for ax, name in zip(axes, PDB_IDS):
    prediction = esm2_predictions[name] > 0.5
    target = contacts[name]
    plot_contacts_and_predictions(
        prediction, target, ax=ax, title = lambda prec: f"{name}: Long Range P@L: {100 * prec:0.1f}"
    )
plt.show()

In [None]:
esm2_predictions = {}
esm2_results = []
for name, inputs in sequences.items():
    esm2_batch_labels, esm2_batch_strs, esm2_batch_tokens = esm2_batch_converter([inputs])
    esm2_batch_tokens = esm2_batch_tokens.to(next(esm2.parameters()).device)
    esm2_predictions[name] = esm2.predict_contacts(esm2_batch_tokens)[0].cpu()
    metrics = {"id": name, "model": "ESM-2 (Unsupervised)"}
    metrics.update(evaluate_prediction(esm2_predictions[name], contacts[name]))
    esm2_results.append(metrics)
esm2_results = pd.DataFrame(esm2_results)
display(esm2_results)

In [None]:
fig, axes = plt.subplots(figsize=(7, 5), ncols=1)
prediction = esm2_predictions[name] > 0.3
target = contacts[name]
plot_contacts_and_predictions(
    prediction, target, title = lambda prec: f"{name}: Long Range P@L: {100 * prec:0.1f}"
)
plt.show()

In [None]:
contact_pairs = get_contact_pairs_from_prediction(prediction, sequences['6shx'][1])

In [None]:
contact_pairs

In [None]:
filter_pairs(contact_pairs, 'Glu', 'Leu')

In [None]:
esm2_predictions = {}
esm2_results = []
for name, inputs in sequences.items():
    esm2_batch_labels, esm2_batch_strs, esm2_batch_tokens = esm2_batch_converter([inputs])
    esm2_batch_tokens = esm2_batch_tokens.to(next(esm2.parameters()).device)
    esm2_predictions[name] = esm2.predict_contacts(esm2_batch_tokens)[0].cpu()
    metrics = {"id": name, "model": "ESM-2 (Unsupervised)"}
    metrics.update(evaluate_prediction(esm2_predictions[name], contacts[name]))
    esm2_results.append(metrics)
esm2_results = pd.DataFrame(esm2_results)
display(esm2_results)

In [None]:
fig, axes = plt.subplots(figsize=(7, 5), ncols=1)
prediction = esm2_predictions[name] > 0.5
target = contacts[name]
plot_contacts_and_predictions(
    prediction, target, title = lambda prec: f"{name}: Long Range P@L: {100 * prec:0.1f}"
)
plt.show()

# Model V1

In [None]:
esm2_predictions = {}
esm2_results = []
for name, inputs in sequences.items():
    esm2_batch_labels, esm2_batch_strs, esm2_batch_tokens = esm2_batch_converter([inputs])
    esm2_batch_tokens = esm2_batch_tokens.to(next(esm2.parameters()).device)
    esm2_predictions[name] = esm2.predict_contacts(esm2_batch_tokens)[0].cpu()
    metrics = {"id": name, "model": "ESM-2 (Unsupervised)"}
    metrics.update(evaluate_prediction(esm2_predictions[name], contacts[name]))
    esm2_results.append(metrics)
esm2_results = pd.DataFrame(esm2_results)
display(esm2_results)

In [None]:
fig, axes = plt.subplots(figsize=(18, 6), ncols=3)
for ax, name in zip(axes, PDB_IDS):
    prediction = esm2_predictions[name] > 0.5
    target = contacts[name]
    plot_contacts_and_predictions(
        prediction, target, ax=ax, title = lambda prec: f"{name}: Long Range P@L: {100 * prec:0.1f}"
    )
plt.show()

In [None]:
esm2_predictions = {}
esm2_results = []
for name, inputs in sequences.items():
    esm2_batch_labels, esm2_batch_strs, esm2_batch_tokens = esm2_batch_converter([inputs])
    esm2_batch_tokens = esm2_batch_tokens.to(next(esm2.parameters()).device)
    esm2_predictions[name] = esm2.predict_contacts(esm2_batch_tokens)[0].cpu()
    metrics = {"id": name, "model": "ESM-2 (Unsupervised)"}
    metrics.update(evaluate_prediction(esm2_predictions[name], contacts[name]))
    esm2_results.append(metrics)
esm2_results = pd.DataFrame(esm2_results)
display(esm2_results)

In [None]:
fig, axes = plt.subplots(figsize=(7, 5), ncols=1)
prediction = esm2_predictions[name] > 0.5
target = contacts[name]d
plot_contacts_and_predictions(
    prediction, target, title = lambda prec: f"{name}: Long Range P@L: {100 * prec:0.1f}"
)
plt.show()

In [None]:
esm2_predictions = {}
esm2_results = []
for name, inputs in sequences.items():
    esm2_batch_labels, esm2_batch_strs, esm2_batch_tokens = esm2_batch_converter([inputs])
    esm2_batch_tokens = esm2_batch_tokens.to(next(esm2.parameters()).device)
    esm2_predictions[name] = esm2.predict_contacts(esm2_batch_tokens)[0].cpu()
    metrics = {"id": name, "model": "ESM-2 (Unsupervised)"}
    metrics.update(evaluate_prediction(esm2_predictions[name], contacts[name]))
    esm2_results.append(metrics)
esm2_results = pd.DataFrame(esm2_results)
display(esm2_results)

# Model V2

In [None]:
esm2_predictions = {}
esm2_results = []
for name, inputs in sequences.items():
    esm2_batch_labels, esm2_batch_strs, esm2_batch_tokens = esm2_batch_converter([inputs])
    esm2_batch_tokens = esm2_batch_tokens.to(next(esm2.parameters()).device)
    esm2_predictions[name] = esm2.predict_contacts(esm2_batch_tokens)[0].cpu()
    metrics = {"id": name, "model": "ESM-2 (Unsupervised)"}
    metrics.update(evaluate_prediction(esm2_predictions[name], contacts[name]))
    esm2_results.append(metrics)
esm2_results = pd.DataFrame(esm2_results)
display(esm2_results)

In [None]:
fig, axes = plt.subplots(figsize=(18, 6), ncols=3)
for ax, name in zip(axes, PDB_IDS):
    prediction = esm2_predictions[name] > 0.5
    target = contacts[name]
    plot_contacts_and_predictions(
        prediction, target, ax=ax, title = lambda prec: f"{name}: Long Range P@L: {100 * prec:0.1f}"
    )
plt.show()