In [None]:
!pip install git+https://github.com/Profluent-AI/E1.git

In [None]:
import torch

if torch.cuda.is_available():
    cuda_capabilities = torch.cuda.get_device_capability(0)
    if cuda_capabilities[0] >= 8:
        print("CUDA 8.0 or higher detected; installing flash-attention")
        !pip install flash-attn --no-build-isolation
    else:
        print("CUDA capability lower than 8.0; will not be using flash attention")
else:
    print("CUDA not available")

### Scoring substitution variants of a protein

In this notebook, we will use a DMS Assay (ID: AMIE_PSEAE_Wrenbeck_2017) from Protein Gym (https://proteingym.org/) containing substitution variants of a protein (Uniprot entry https://www.uniprot.org/uniprotkb/P11436/entry) to show zero-shot fitness prediction. We willcompute score of each variant using masked marginal method using the `E1` model directly in both single sequence and retrieval augmented mode and measure correlation with experimental fitness values. For an explanation of the masked marginal method, please refer to this [paper](https://proceedings.neurips.cc/paper/2021/file/f51338d736f95dd42427296047067694-Supplemental.pdf). In short, we replace each mutated position with mask token and compute the log probability of the actual residue at that position. We then compute the score of the single substitution mutant as the difference in log probability between the mutant and the wildtype. The score for multiple substitutions is computed as the sum of the scores of the individual single substitutions.

In [None]:
import torch
import numpy as np

from E1.batch_preparer import E1BatchPreparer
from E1.modeling import E1ForMaskedLM

device = torch.device("cuda", torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu")
model = E1ForMaskedLM.from_pretrained("Profluent-Bio/E1-300m").to(device).eval()
batch_preparer = E1BatchPreparer()

### Working in Single Sequence Mode

In this section, we will use the model in single sequence mode (i.e we will not pass any homolog sequences as context) to compute the scores.

In [None]:
import polars as pl
wildtype_sequence = (
    "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETE"
    "IFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNY"
    "PEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQL"
    "SLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA"
)

# Read the DMS assay data
mutated_sequences = pl.read_csv("https://huggingface.co/datasets/Profluent-Bio/AMIE_PSEAE_Wrenbeck_2017_example/resolve/main/AMIE_PSEAE_Wrenbeck_2017.csv")
print(mutated_sequences.head())

# Figure out the positions where substitutions are made in the wildtype sequence (0-indexed)
mutated_positions = sorted(set([int(x[1:-1]) - 1 for x in mutated_sequences["mutant"]]))
len(mutated_positions)

In [None]:
# For each mutated position, create a masked sequence by replacing the wildtype residue with mask token
masked_sequences = [wildtype_sequence[:mp] + "?" + wildtype_sequence[mp + 1 :] for mp in mutated_positions]

In [None]:
# Here we show an example of how the input sequences are converted to relevant information for the model's forward pass.
batch = batch_preparer.get_batch_kwargs(masked_sequences[:4], device=device)
input_ids = batch["input_ids"]

for i in range(min(4, input_ids.shape[0])):
    print(batch_preparer.tokenizer.decode(input_ids[i].tolist(), skip_special_tokens=False))

Since we have large number of masked sequences, we will process them in batches of 4 so as not to overwhelm the GPU memory.
For each sequence, we will extract the log-probabilities of all the residues of the masked sequences.

In [None]:
from tqdm import tqdm
log_probs_for_all = []

for batch_idx in tqdm(range(0, len(masked_sequences), 4)):
    batch = batch_preparer.get_batch_kwargs(masked_sequences[batch_idx : batch_idx + 4], device=device)
    input_ids = batch["input_ids"]
    with torch.no_grad():
        with torch.autocast(device.type, dtype=torch.bfloat16, enabled=device.type == "cuda"):
            outputs = model(
                input_ids=batch["input_ids"],
                within_seq_position_ids=batch["within_seq_position_ids"],
                global_position_ids=batch["global_position_ids"],
                sequence_ids=batch["sequence_ids"],
                past_key_values=None,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=False,
            )

    logits: torch.Tensor = outputs.logits  # (B, L, V)
    embeddings: torch.Tensor = outputs.embeddings  # (B, L, E)
    log_probs = torch.log_softmax(logits, dim=-1)

    # Boolean Selectors of shape (B, L) to get relevant tokens from logits/embeddings
    # last_sequence_selector: True for tokens that are part of the last sequence (including boundary tokens) in case of multi-sequence input.
    last_sequence_selector = batch["sequence_ids"] == batch["sequence_ids"].max(dim=1)[0][:, None]
    # residue_selector: True for tokens that are part of the input sequence i.e not boundary tokens like 1, 2, <bos>, <eos>, <pad>, etc.
    residue_selector = ~(batch_preparer.get_boundary_token_mask(batch["input_ids"]))
    # last_sequence_residue_selector: True for residues that are part of the last sequence (excluding boundary tokens)
    last_sequence_residue_selector = last_sequence_selector & residue_selector

    log_probs = [log_probs[i, last_sequence_residue_selector[i]].cpu().numpy()[None, :] for i in range(input_ids.shape[0])]
    log_probs_for_all.extend(log_probs)

log_probs_for_all = np.concatenate(log_probs_for_all, axis=0)
log_probs_for_all.shape

In [None]:
from scipy.stats import spearmanr
from tqdm import tqdm

vocab = batch_preparer.tokenizer.get_vocab()

predicted_scores, fitness_scores = [], []
for row in mutated_sequences.iter_rows(named=True):
    mutated_position, wildtype_aa, mutant_aa = int(row["mutant"][1:-1]) - 1, row["mutant"][0], row["mutant"][-1]
    position_log_probs = log_probs_for_all[mutated_positions.index(mutated_position)]
    mutant_score = (position_log_probs[mutated_position, vocab[mutant_aa]] - position_log_probs[mutated_position, vocab[wildtype_aa]]).item()
    predicted_scores.append(mutant_score)
    fitness_scores.append(row["DMS_score"])

print(spearmanr(predicted_scores, fitness_scores))

### Working in Retrieval Augmented Mode

In this section, we will an MSA to sample homolog sequences for the wildtype sequence used above and pass them to the model as part of the context. The context sequence returned by sample context function is simply a string of homolog protein sequences separated by commas. It is concatenated with each of our masked sequences from above and passed to the model as part of the input.

We use PoET style strategy to sample the homologs from the MSA. Homologs are sampled with weights inversely proportional to the number of their neighbors (sequences in the MSA that are at least 80% identical to them) and are additionally constrained to satisfy a specified maximum similarity to the wildtype sequence.

But it is not necessary to use PoET style strategy. You can use any other strategy to sample the homologs (for example, experimentally derived high fitness homologs).

Note: You will very likely encounter CUDA OOM error if using T4 gpu on colab. We recommend using A100 or L40 gpu when working in retrieval augmented mode.

In [None]:
!wget https://huggingface.co/datasets/Profluent-Bio/AMIE_PSEAE_Wrenbeck_2017_example/resolve/main/msa.a3m

In [None]:
from E1.msa_sampling import sample_context

context, _ = sample_context(
    msa_path="msa.a3m",
    # Maximum number of sequences that can be in context (hard limit of 511)
    max_num_samples=511,
    # Total number of residues in the context
    max_token_length=14784,
    # Maximum similarity of any context sequence to the query sequence
    max_query_similarity=0.95,
    # Minimum similarity of any context sequence to the query sequence
    min_query_similarity=0.0,
    # Minimum similarity between any two context sequences for them to be considered neighbors
    neighbor_similarity_lower_bound=0.8,
    seed=0,
)

In [None]:
# Concatenate the context sequence to each of the masked sequence we used previously to compute the scores using
# masked marginal method.

masked_sequences = [context + "," + wildtype_sequence[:mp] + "?" + wildtype_sequence[mp + 1 :] for mp in mutated_positions]

In [None]:
# Here we show an example of how the input sequences are converted to relevant information for the model's forward pass.
batch = batch_preparer.get_batch_kwargs(masked_sequences[:4], device=device)
input_ids = batch["input_ids"]

for i in range(min(1, input_ids.shape[0])):
    print(batch_preparer.tokenizer.decode(input_ids[i].tolist(), skip_special_tokens=False))

In [None]:
from tqdm import tqdm
log_probs_for_all = []

for batch_idx in tqdm(range(0, len(masked_sequences), 4)):
    batch = batch_preparer.get_batch_kwargs(masked_sequences[batch_idx : batch_idx + 4], device=device)
    input_ids = batch["input_ids"]
    with torch.no_grad():
        with torch.autocast(device.type, dtype=torch.bfloat16, enabled=device.type == "cuda"):
            outputs = model(
                input_ids=batch["input_ids"],
                within_seq_position_ids=batch["within_seq_position_ids"],
                global_position_ids=batch["global_position_ids"],
                sequence_ids=batch["sequence_ids"],
                past_key_values=None,
                use_cache=False,
                output_attentions=False,
                output_hidden_states=False,
            )

    logits: torch.Tensor = outputs.logits  # (B, L, V)
    embeddings: torch.Tensor = outputs.embeddings  # (B, L, E)
    log_probs = torch.log_softmax(logits, dim=-1)

    # Boolean Selectors of shape (B, L) to get relevant tokens from logits/embeddings
    # last_sequence_selector: True for tokens that are part of the last sequence (including boundary tokens) in case of multi-sequence input.
    last_sequence_selector = batch["sequence_ids"] == batch["sequence_ids"].max(dim=1)[0][:, None]
    # residue_selector: True for tokens that are part of the input sequence i.e not boundary tokens like 1, 2, <bos>, <eos>, <pad>, etc.
    residue_selector = ~(batch_preparer.get_boundary_token_mask(batch["input_ids"]))
    # last_sequence_residue_selector: True for residues that are part of the last sequence (excluding boundary tokens)
    last_sequence_residue_selector = last_sequence_selector & residue_selector

    log_probs = [log_probs[i, last_sequence_residue_selector[i]].cpu().numpy()[None, :] for i in range(input_ids.shape[0])]
    log_probs_for_all.extend(log_probs)

log_probs_for_all = np.concatenate(log_probs_for_all, axis=0)
log_probs_for_all.shape

In [None]:
from scipy.stats import spearmanr

vocab = batch_preparer.tokenizer.get_vocab()

predicted_scores, fitness_scores = [], []
for row in mutated_sequences.iter_rows(named=True):
    mutated_position, wildtype_aa, mutant_aa = int(row["mutant"][1:-1]) - 1, row["mutant"][0], row["mutant"][-1]
    position_log_probs = log_probs_for_all[mutated_positions.index(mutated_position)]
    mutant_score = (position_log_probs[mutated_position, vocab[mutant_aa]] - position_log_probs[mutated_position, vocab[wildtype_aa]]).item()
    predicted_scores.append(mutant_score)
    fitness_scores.append(row["DMS_score"])

print(spearmanr(predicted_scores, fitness_scores))