In [None]:
!pip install git+https://github.com/Profluent-AI/E1.git

In [None]:
import torch

if torch.cuda.is_available():
    cuda_capabilities = torch.cuda.get_device_capability(0)
    if cuda_capabilities[0] >= 8:
        print("CUDA 8.0 or higher detected; installing flash-attention")
        !pip install flash-attn --no-build-isolation
    else:
        print("CUDA capability lower than 8.0; will not be using flash attention")
else:
    print("CUDA not available")

### In-silico scanning single site saturation mutagenesis

In this notebook, we use E1 model to score every single site mutant of the wildtype protein from the AMIE_PSEAE_Wrenbeck_2017 dataset (Uniprot entry https://www.uniprot.org/uniprotkb/P11436/entry) and return the score for each mutation ordered by descending score.

In [None]:
import pandas as pd

from E1 import dist
from E1.modeling import E1ForMaskedLM
from E1.msa_sampling import ContextSpecification, sample_multiple_contexts
from E1.scorer import E1Scorer, EncoderScoreMethod


def ssm(
    parent_sequence: str,
    msa_path: str | None = None,
    model_name: str = "Profluent-Bio/E1-300m",
    scoring_method: EncoderScoreMethod = EncoderScoreMethod.MASKED_MARGINAL,
    max_batch_tokens: int = 4096,
    singleseq_mode: bool = False,
):
    model = E1ForMaskedLM.from_pretrained(model_name, dtype=torch.float).to(dist.get_device()).eval()
    scorer = E1Scorer(model, method=scoring_method, max_batch_tokens=max_batch_tokens)

    if not singleseq_mode:
        assert msa_path is not None, "MSA path is required for retrieval augmented mode"
        context_token_lengths = [7168, 15360, 23552]
        max_query_similarities = [1.0, 0.95, 0.9, 0.7, 0.5]

        context_specs = []
        for max_num_tokens in context_token_lengths:
            for max_query_similarity in max_query_similarities:
                context_specs.append(
                    ContextSpecification(
                        max_num_samples=512,
                        max_token_length=max_num_tokens,
                        max_query_similarity=max_query_similarity,
                        min_query_similarity=0.0,
                        neighbor_similarity_lower_bound=0.8,
                    )
                )

        context_seqs, _ = sample_multiple_contexts(msa_path=msa_path, context_specifications=context_specs, seed=42)

        context_seqs_dict = {f"context_{i}": seq for i, seq in enumerate(context_seqs)}
    else:
        context_seqs_dict = None

    positions_masked = list(range(0, len(parent_sequence)))
    position_scores, _ = scorer.get_position_scores(
        parent_sequence=parent_sequence,
        mutation_positions=positions_masked,
        context_seqs=context_seqs_dict,
        context_reduction="mean",
    )
    assert position_scores.shape == (1, len(parent_sequence), len(scorer.vocab))

    aa_string = "ACDEFGHIKLMNPQRSTVWY"
    aa_ids = [scorer.vocab[aa] for aa in aa_string]
    scores_heatmap = position_scores[0, :, aa_ids].cpu().numpy()

    mutant_scores = []
    for i in range(len(parent_sequence)):
        for j, aa in enumerate(aa_string):
            if aa == parent_sequence[i]:
                continue
            mutant_scores.append((i, aa, scores_heatmap[i, j].item()))

    mutant_scores = pd.DataFrame(mutant_scores, columns=["position", "mutant", "score"]).sort_values(
        by="score", ascending=False
    )

    return mutant_scores

In [None]:
parent_sequence = (
    "MRHGDISSSNDTVGVAVVNYKMPRLHTAAEVLDNARKIAEMIVGMKQGLPGMDLVVFPEYSLQGIMYDPAEMMETAVAIPGEETEIFSRACRKANV"
    "WGVFSLTGERHEEHPRKAPYNTLVLIDNNGEIVQKYRKIIPWCPIEGWYPGGQTYVSEGPKGMKISLIICDDGNYPEIWRDCAMKGAELIVRCQGY"
    "MYPAKDQQVMMAKAMAWANNCYVAVANAAGFDGVYSYFGHSAIIGFDGRTLGECGEEEMGIQYAQLSLSQIRDARANDQSQNHLFKILHRGYSGLQA"
    "SGDGDRGLAECPFEFYRTWVTDAEKARENVERLTRSTTGVAQCPVGRLPYEGLEKEA"
)

#### Scoring in Single Sequence Mode

In [None]:
# Return a pandas dataframe containing all single mutant scores with columns position, mutant, and score
mutant_scores = ssm(parent_sequence, model_name="Profluent-Bio/E1-300m", singleseq_mode=True)

In [None]:
# List out top-96 mutants. The position is 0-indexed.
mutant_scores.head(96)

In [None]:
# Find score for specific mutant like R2H (1-indexed) or R1H (0-indexed)
mutant_scores[(mutant_scores["position"] == 1) & (mutant_scores["mutant"] == "H")]

#### Scoring in Retrieval Augmented Mode

Note: You will very likely encounter CUDA OOM error if using T4 gpu on colab. We recommend using A100 or L40 gpu when working in retrieval augmented mode.

In [None]:
!wget https://huggingface.co/datasets/Profluent-Bio/AMIE_PSEAE_Wrenbeck_2017_example/resolve/main/msa.a3m

In [None]:
# Return a pandas dataframe containing all single mutant scores with columns position, mutant, and score
# Use MSA to provide context sequences. Takes about 2-3 min on A100 gpu.
mutant_scores = ssm(
    parent_sequence, 
    model_name="Profluent-Bio/E1-300m", 
    singleseq_mode=False, 
    msa_path="msa.a3m",
    scoring_method=EncoderScoreMethod.WILDTYPE_MARGINAL,
)

In [None]:
# List out top-96 mutants. The position is 0-indexed.
mutant_scores.head(96)

In [None]:
# Find score for specific mutant like R2H (1-indexed) or R1H (0-indexed)
mutant_scores[(mutant_scores["position"] == 1) & (mutant_scores["mutant"] == "H")]