In [None]:
!pip install git+https://github.com/Profluent-AI/E1.git

In [None]:
import torch

if torch.cuda.is_available():
    cuda_capabilities = torch.cuda.get_device_capability(0)
    if cuda_capabilities[0] >= 8:
        print("CUDA 8.0 or higher detected; installing flash-attention")
        !pip install flash-attn --no-build-isolation
    else:
        print("CUDA capability lower than 8.0; will not be using flash attention")
else:
    print("CUDA not available")

### Scoring substitution variants of a protein

In this notebook, we will use a DMS Assay (ID: AMIE_PSEAE_Wrenbeck_2017) from Protein Gym (https://proteingym.org/) containing substitution variants of a protein (Uniprot entry https://www.uniprot.org/uniprotkb/P11436/entry) to show zero-shot fitness prediction. We willcompute score of each variant using masked marginal method using the `E1` model directly in both single sequence and retrieval augmented mode and measure correlation with experimental fitness values. For an explanation of the masked marginal method, please refer to this [paper](https://proceedings.neurips.cc/paper/2021/file/f51338d736f95dd42427296047067694-Supplemental.pdf). In short, we replace each mutated position with mask token and compute the log probability of the actual residue at that position. We then compute the score of the single substitution mutant as the difference in log probability between the mutant and the wildtype. The score for multiple substitutions is computed as the sum of the scores of the individual single substitutions.

While we do similar scoring as basic notebook, we will use the `E1Scorer` utility to compute scores, which makes things much easier and faster by auto-batching and KV caching.

In [None]:
import polars as pl

wildtype_sequence = (
    "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAI"
    "PGEETEIFSRACRKANVWGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKG"
    "MKISLIICDDGNYPEIWRDCAMKGAELIVRCQGYMYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIG"
    "FDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQASGDGDRGLAECPFEFYRTWVTDAEKARE"
    "NVERLTRSTTGVAQCPVGRLPYEGLEKEA"
)

assay_data = pl.read_csv("https://huggingface.co/datasets/Profluent-Bio/AMIE_PSEAE_Wrenbeck_2017_example/resolve/main/AMIE_PSEAE_Wrenbeck_2017.csv")
print(assay_data.head())

mutated_sequences = assay_data["mutated_sequence"].to_list()
mutated_sequence_ids = assay_data["mutant"].to_list()

In [None]:
import torch

from E1 import dist
from E1.modeling import E1ForMaskedLM
from E1.scorer import E1Scorer, EncoderScoreMethod

In [None]:
model_name = "Profluent-Bio/E1-300m"
max_batch_tokens = 4096
# Also available: EncoderScoreMethod.WILDTYPE_MARGINAL which is generally faster but less accurate
scoring_method = EncoderScoreMethod.MASKED_MARGINAL

In [None]:
model = E1ForMaskedLM.from_pretrained(model_name, dtype=torch.float).to(dist.get_device()).eval()

We initialize the scorer with the model and the scoring method (either WILDTYPE_MARGINAL or MASKED_MARGINAL) and a max batch tokens limit (set this to lower if you encounter CUDA OOM errors). Then, we call the `score` method with the wildtype sequence and the list of mutated sequences. Note, that the `score` method returns a list of dictionaries with the score for each mutated sequence (but it may not be the same order as the sequences list passed in the arguments). Use the `id` field in the returned dictionaries to match the score with the corresponding mutated sequence.

In [None]:
scorer = E1Scorer(model, method=scoring_method, max_batch_tokens=max_batch_tokens)
scores = scorer.score(
    parent_sequence=wildtype_sequence,  # parent sequence
    sequences=mutated_sequences,  # list of mutated sequences we want to score (substitutions only)
    sequence_ids=mutated_sequence_ids,  # list of sequence ids for each mutated sequence
)
print(scores[:10])
scores = pl.from_dicts(scores)
scores.head()

In [None]:
from scipy.stats import spearmanr

scores = scores.join(assay_data, left_on="id", right_on="mutant", how="left")
print(spearmanr(scores["score"], scores["DMS_score"]))

### Scoring using Retrieval Augmented Mode

In this section, we will an MSA to sample homolog sequences for the wildtype sequence used above and pass them to the model as part of the context. 

Note: You will very likely encounter CUDA OOM error if using T4 gpu on colab. We recommend using A100 or L40 gpu when working in retrieval augmented mode.

In [None]:
!wget https://huggingface.co/datasets/Profluent-Bio/AMIE_PSEAE_Wrenbeck_2017_example/resolve/main/msa.a3m

We generally find that ensembling scores over multiple context sequences sampled from the same protein family helps improve the performance of the model.

In [None]:
from E1.msa_sampling import ContextSpecification, sample_multiple_contexts

msa_path = "msa.a3m"
context_token_lengths = [7168, 14336, 21504]
max_query_similarities = [1.0, 0.95, 0.9, 0.7, 0.5]

context_specs = []
for max_num_tokens in context_token_lengths:
    for max_query_similarity in max_query_similarities:
        context_specs.append(
            ContextSpecification(
                # maximum number of sequences that can be sampled from the MSA (should be <= 511)
                max_num_samples=511,
                # maximum number of concatenated tokens in the context sequences
                max_token_length=max_num_tokens,
                # maximum similarity between the query and the context sequences
                max_query_similarity=max_query_similarity,
                # minimum similarity between the query and the context sequences
                min_query_similarity=0.0,
                # minimum similarity between the two sequences for them to be considered neighbors during MSA Sampling
                neighbor_similarity_lower_bound=0.8,
            )
        )

context_seqs, _ = sample_multiple_contexts(msa_path=msa_path, context_specifications=context_specs, seed=0)

context_seqs_dict = {f"context_{i}": seq for i, seq in enumerate(context_seqs)}

In [None]:
scorer = E1Scorer(model, method=EncoderScoreMethod.MASKED_MARGINAL, max_batch_tokens=max_batch_tokens)
scores = scorer.score(
    parent_sequence=wildtype_sequence,  # parent sequence
    sequences=mutated_sequences,  # list of mutated sequences we want to score (substitutions only)
    sequence_ids=mutated_sequence_ids,  # list of sequence ids for each mutated sequence
    context_seqs=context_seqs_dict,  # dictionary of context sequences
    # we ensemble scores over multiple context sequences by taking mean; set to "none" to return
    # scores with respect to each context sequence individually
    context_reduction="mean",
)
print(scores[:10])
scores = pl.from_dicts(scores)
scores.head()

In [None]:
from scipy.stats import spearmanr

scores = scores.join(assay_data, left_on="id", right_on="mutant", how="left")
print(spearmanr(scores["score"], scores["DMS_score"]))