In [6]:
from collections import defaultdict
from dataclasses import dataclass
from functools import cache
import random

import torch
import torch.nn.functional as F

from evo import Evo
from evo.generation import Generator

from Bio import Seq, SeqIO
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.metrics import roc_auc_score, average_precision_score

In [2]:
# run models on GPU if GPU node is used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# download the model
evo_model = Evo("evo-1.5-8k-base")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
# extract torch modules
tokenizer = evo_model.tokenizer
model = evo_model.model.to(device)

In [5]:
# create a generator object
generator = Generator(model, tokenizer, temperature=0.5)

In [None]:
@dataclass
class SeqStats:
    """Dataclass for storing generation parameters and metrics."""
    biotype: str
    index: int
    mock: bool
    k_fit: int
    k_test: int
    correctness: float
    auroc: float
    auprc: float

In [None]:
@cache
def get_forward_stats(seq_fit, seq_test, generator, *, biotype="unknown", index=0, mock=False, seed=None) -> SeqStats:
    """Generate a sequence and calculate metrics."""
    k_fit = len(seq_fit)
    k_test = len(seq_test)

    if seed is not None:
        torch.manual_seed(seed)
    test_tokens = torch.tensor(tokenizer.tokenize(seq_test))
    gen_tokens, scores, _ = generator.generate("cuda", input_string=seq_fit, num_tokens=k_test,
                                               cached_generation=True)

    test_seq = Seq.Seq(seq_test)
    test_aa = test_seq.translate()
    gen_seq = Seq.Seq(tokenizer.detokenize(gen_tokens))
    gen_aa = gen_seq.translate()
    correct = [a == b for a,b in zip(list(test_aa.seq), list(gen_aa.seq))]
    correctness = correct.to(torch.float).mean().item()

    test_probs = F.one_hot(test_tokens.to(torch.int64))[:,(65,67,71,84)]
    gen_acgt_logits = scores[0,:,(65,67,71,84)]
    gen_acgt_probs = torch.softmax(gen_acgt_logits, dim=-1).to("cpu")
    auroc = roc_auc_score(test_probs, gen_acgt_probs, average="weighted")

    auprc = average_precision_score(test_probs, gen_acgt_probs, average="weighted")

    return SeqStats(biotype=biotype,
                    index=index,
                    mock=mock,
                    k_fit=k_fit,
                    k_test=k_test,
                    correctness=correctness,
                    auroc=auroc,
                    auprc=auprc)

In [7]:
with open("seqs/cds_1.fa") as f:
    seqs = SeqIO.parse(f, "fasta")
    seq = next(seqs)

In [10]:
fit_nuc = seq.seq[:900]
fit_aa = Seq.Seq(fit_nuc).translate()
fit_aa

Seq('MSAPSEEEEYARLVMEAQPEWLRAEVKRLSHELAETTREKIQAAEYGLAVLEEK...ALV')