In [2]:
from functools import wraps
from time import perf_counter_ns

import torch
from evo import Evo, generate
from stripedhyena.sample import sample

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
def timeit(f):
    @wraps(f)
    def wrapper(*args, **kwargs):
        start = perf_counter_ns()
        result = f(*args, **kwargs)
        print(f"{f.__name__}: {(perf_counter_ns() - start) / 10 ** 6:.3f}ms")
        return result
    return wrapper

# Setup

In [4]:
# run models on GPU if GPU node is used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
evo_model = Evo("evo-1.5-8k-base")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
torch.cuda.empty_cache()

In [7]:
evo_model.model

StripedHyena(
  (embedding_layer): VocabParallelEmbedding(512, 4096)
  (norm): RMSNorm()
  (unembed): VocabParallelEmbedding(512, 4096)
  (blocks): ModuleList(
    (0-7): 8 x ParallelGatedConvBlock(
      (pre_norm): RMSNorm()
      (post_norm): RMSNorm()
      (filter): ParallelHyenaFilter()
      (projections): Linear(in_features=4096, out_features=12288, bias=True)
      (out_filter_dense): Linear(in_features=4096, out_features=4096, bias=True)
      (mlp): ParallelGatedMLP(
        (l1): Linear(in_features=4096, out_features=10928, bias=False)
        (l2): Linear(in_features=4096, out_features=10928, bias=False)
        (l3): Linear(in_features=10928, out_features=4096, bias=False)
      )
    )
    (8): AttentionBlock(
      (pre_norm): RMSNorm()
      (post_norm): RMSNorm()
      (inner_mha_cls): MHA(
        (rotary_emb): RotaryEmbedding()
        (Wqkv): Linear(in_features=4096, out_features=12288, bias=True)
        (inner_attn): FlashSelfAttention(
          (drop): Dropout(

In [8]:
tokenizer = evo_model.tokenizer
model = evo_model.model.to(device)

# Test generation

In [9]:
seq = "ACGTACGT"
input_ids = torch.tensor(
        tokenizer.tokenize(seq),
        dtype=torch.int,
    ).to(device).unsqueeze(0)

In [10]:
input_ids

tensor([[65, 67, 71, 84, 65, 67, 71, 84]], device='cuda:0', dtype=torch.int32)

In [11]:
with torch.inference_mode():
    logits, _ = model(input_ids)

In [12]:
logits

tensor([[[-11.7500, -19.3750, -19.2500,  ..., -19.2500, -19.2500, -19.2500],
         [ -8.8125, -21.1250, -21.1250,  ..., -21.1250, -21.1250, -21.1250],
         [ -8.4375, -21.8750, -21.8750,  ..., -21.8750, -21.8750, -21.8750],
         ...,
         [ -9.0625, -22.6250, -22.5000,  ..., -22.5000, -22.5000, -22.5000],
         [ -8.7500, -22.5000, -22.3750,  ..., -22.3750, -22.3750, -22.3750],
         [ -8.6875, -22.5000, -22.3750,  ..., -22.3750, -22.5000, -22.3750]]],
       device='cuda:0', dtype=torch.bfloat16)

In [13]:
logits.shape

torch.Size([1, 8, 512])

In [14]:
(logits[:,-1] > -10.).nonzero()

tensor([[ 0,  0],
        [ 0, 65],
        [ 0, 67],
        [ 0, 71],
        [ 0, 84]], device='cuda:0')

In [15]:
next_id = sample(logits[:,-1], top_k=5, top_p=0)

In [16]:
next_id

tensor([67], device='cuda:0')

# Larger example with evo.generate

In [17]:
long_seq = """
ATGTCGGCGCCGTCGGAGGAGGAGGAGTACGCGCGGCTGGTGATGGAGGCGCAGCCGGAG
TGGCTGCGCGCCGAGGTGAAGCGGCTGTCCCACGAGCTGGCCGAGACCACGCGTGAGAAG
ATCCAGGCGGCCGAGTACGGGCTGGCGGTGCTCGAGGAGAAGCACCAGCTCAAGCTGCAG
TTCGAGGAGCTCGAGGTGGACTATGAGGCTATCCGCAGCGAGATGGAGCAGCTCAAGGAG
GCCTTTGGACAAGCACACACAAACCACAAGAAGGTGGCTGCTGACGGAGAGAGCCGGGAG
GAGAGCCTGATCCAGGAGTCGGCCTCCAAGGAGCAGTACTACGTGCGGAAGGTGCTAGAG
CTGCAGACGGAGCTGAAGCAGTTGCGCAATGTCCTCACCAACACGCAGTCGGAGAATGAG
CGCCTGGCCTCTGTGGCCCAGGAGCTGAAGGAGATCAACCAGAATGTGGAGATCCAGCGT
GGCCGCCTGCGGGATGACATCAAGGAGTACAAATTCCGGGAAGCTCGTCTGCTGCAGGAC
TACTCGGAACTGGAGGAGGAGAACATCAGCCTGCAGAAGCAAGTGTCTGTGCTCAGACAG
AACCAGGTGGAGTTTGAGGGCCTCAAGCATGAGATCAAGCGTCTGGAGGAGGAGACCGAG
TACCTCAACAGCCAGCTGGAGGATGCCATCCGCCTCAAGGAGATCTCAGAGCGGCAGCTG
GAGGAGGCGCTGGAGACCCTGAAGACGGAGCGCGAACAGAAGAACAGCCTGCGCAAGGAG
CTGTCACACTACATGAGCATCAATGACTCCTTCTACACCAGCCACCTGCATGTCTCGCTG
GATGGCCTCAAGTTCAGTGACGATGCTGCCGAGCCCAACAACGATGCCGAGGCCCTGGTC
AATGGCTTTGAGCACGGCGGCCTGGCCAAGCTGCCACTGGACAACAAGACCTCCACGCCC
AAGAAGGAGGGCCTCGCACCGCCCTCCCCCAGCCTCGTCTCCGACCTACTCAGTGAGCTC
AACATCTCTGAGATCCAGAAGCTGAAGCAGCAGCTGATGCAGATGGAGCGGGAAAAGGCG
GGCCTGCTGGCAACGCTGCAGGACACACAGAAGCAGCTGGAGCACACGCGGGGCTCCCTG
TCAGAACAGCAGGAGAAGGTGACCCGCCTCACAGAGAATCTGAGTGCCCTGCGGCGCCTG
CAGGCCAGCAAGGAGCGGCAGACAGCCCTGGACAACGAGAAGGACCGTGACAGCCATGAG
GATGGGGACTACTACGAGGTGGACATCAACGGGCCTGAGATCTTGGCCTGCAAGTACCAT
GTGGCTGTGGCTGAGGCTGGCGAGCTCCGCGAGCAGCTCAAGGCACTGCGCAGCACGCAC
GAGGCTCGTGAGGCCCAGCACGCCGAGGAGAAGGGCCGCTATGAGGCTGAGGGCCAGGCA
CTCACGGAGAAGGTCTCCCTGCTAGAGAAGGCCAGCCGCCAGGACCGCGAGCTGCTGGCC
CGGCTGGAGAAGGAGCTAAAGAAGGTGAGCGACGTCGCCGGCGAGACACAGGGCAGCCTG
AGTGTGGCCCAGGATGAGCTGGTGACCTTCAGTGAGGAGCTGGCCAATCTCTACCACCAC
GTGTGCATGTGCAACAATGAGACACCCAACCGTGTCATGCTGGACTACTACCGCGAGGGC
CAGGGCGGGGCCGGCCGCACCAGTCCCGGGGGCCGCACCAGCCCCGAGGCGCGTGGCCGG
CGCTCACCCATCCTCCTACCCAAGGGGCTGCTGGCTCCTGAGGCGGGCCGAGCAGATGGT
GGGACGGGGGACAGCAGCCCCTCGCCTGGCTCCTCACTGCCATCACCCCTGAGTGACCCA
CGCCGGGAGCCCATGAACATCTACAACCTGATCGCTATCATCCGTGACCAGATCAAGCAC
CTGCAGGCAGCCGTGGACCGCACCACGGAGCTGTCACGCCAGCGCATTGCCTCTCAGGAG
CTGGGCCCCGCCGTGGACAAGGACAAGGAAGCGCTTATGGAGGAGATCCTCAAGCTGAAG
TCGCTGCTCAGCACCAAGCGGGAGCAGATCACCACGCTGCGCACTGTGCTCAAGGCCAAC
AAGCAGACGGCCGAGGTGGCCCTTGCCAACCTGAAGAGCAAGTATGAGAATGAGAAGGCC
ATGGTTACCGAGACCATGATGAAGCTGCGCAATGAGCTCAAGGCCCTCAAGGAGGACGCA
GCCACCTTCTCCTCGCTGCGTGCTATGTTTGCCACCAGGTGTGACGAGTACATTACACAG
CTGGATGAGATGCAGCGGCAGCTGGCGGCTGCTGAGGACGAGAAGAAGACGCTGAACTCG
CTGCTGCGCATGGCCATCCAGCAGAAGCTGGCGCTGACCCAGCGGCTGGAGCTGCTCGAG
CTGGACCATGAGCAGACCCGGCGTGGCCGTGCCAAAGCCGCCCCGAAGACCAAGCCAGCC
ACACCGAGCCTGTAG
""".replace("\n", "")

In [18]:
len(long_seq)

2475

In [19]:
seed = long_seq[:1000]

In [20]:
@timeit
def generate_tokens(seed, n_seqs, n_tokens, model, tokenizer, **kwargs):
    return generate(
        [ seed ] * n_seqs, 
        model, 
        tokenizer, 
        n_tokens = n_tokens,
        **kwargs
    )

In [21]:
next_seqs, next_scores = generate_tokens(
        seed,
        1,
        100,
        model,
        tokenizer,
        cached_generation = True,
        top_p = 1.,
        temperature = 20.
    )

Prompt: "ATGTCGGCGCCGTCGGAGGAGGAGGAGTACGCGCGGCTGGTGATGGAGGCGCAGCCGGAGTGGCTGCGCGCCGAGGTGAAGCGGCTGTCCCACGAGCTGGCCGAGACCACGCGTGAGAAGATCCAGGCGGCCGAGTACGGGCTGGCGGTGCTCGAGGAGAAGCACCAGCTCAAGCTGCAGTTCGAGGAGCTCGAGGTGGACTATGAGGCTATCCGCAGCGAGATGGAGCAGCTCAAGGAGGCCTTTGGACAAGCACACACAAACCACAAGAAGGTGGCTGCTGACGGAGAGAGCCGGGAGGAGAGCCTGATCCAGGAGTCGGCCTCCAAGGAGCAGTACTACGTGCGGAAGGTGCTAGAGCTGCAGACGGAGCTGAAGCAGTTGCGCAATGTCCTCACCAACACGCAGTCGGAGAATGAGCGCCTGGCCTCTGTGGCCCAGGAGCTGAAGGAGATCAACCAGAATGTGGAGATCCAGCGTGGCCGCCTGCGGGATGACATCAAGGAGTACAAATTCCGGGAAGCTCGTCTGCTGCAGGACTACTCGGAACTGGAGGAGGAGAACATCAGCCTGCAGAAGCAAGTGTCTGTGCTCAGACAGAACCAGGTGGAGTTTGAGGGCCTCAAGCATGAGATCAAGCGTCTGGAGGAGGAGACCGAGTACCTCAACAGCCAGCTGGAGGATGCCATCCGCCTCAAGGAGATCTCAGAGCGGCAGCTGGAGGAGGCGCTGGAGACCCTGAAGACGGAGCGCGAACAGAAGAACAGCCTGCGCAAGGAGCTGTCACACTACATGAGCATCAATGACTCCTTCTACACCAGCCACCTGCATGTCTCGCTGGATGGCCTCAAGTTCAGTGACGATGCTGCCGAGCCCAACAACGATGCCGAGGCCCTGGTCAATGGCTTTGAGCACGGCGGCCTGGCCAAGCTGCCACTGGACAACAAGACCTCCACGCCCAAGAAGGAGGGCCTCGCACCGCCCTCCCCCA

Generation speed is roughly 6 tokens/s (150ms/token)

In [22]:
next_seqs[0]

'CCAAGGAGGAGGAGCTGGAGCGCCTGCGCAAGGAGCTGGAGGAGCTGCGCAAGCAGCTGGAGGAGCTGCGCAAGGAGCTGGAGGAGCTGCGCAAGGAGCT'

In [23]:
real_seq = long_seq[1000:1100]
real_seq

'CCGACCTACTCAGTGAGCTCAACATCTCTGAGATCCAGAAGCTGAAGCAGCAGCTGATGCAGATGGAGCGGGAAAAGGCGGGCCTGCTGGCAACGCTGCA'

In [24]:
sum([a == b for a,b in zip(list(next_seq[0]), list(real_seq))])

NameError: name 'next_seq' is not defined

# PWM?

In [None]:
long_tokens = torch.tensor(tokenizer.tokenize(seed), dtype=torch.int).unsqueeze(0).to(device)

In [None]:
long_tokens

In [None]:
with torch.inference_mode():
    logits, _ = model(long_tokens)

In [None]:
logits

In [None]:
logits.shape

In [None]:
acgt_logits = logits[0,:,(65,67,71,84)]

In [None]:
acgt_logits.shape

In [None]:
acgt_logits

In [None]:
acgt_probs = torch.softmax(acgt_logits, dim=-1)

In [None]:
acgt_probs

In [None]:
probs_t = acgt_probs.transpose(1,0).to("cpu").to(torch.float32)[:,500:530]

In [None]:
probs_dict = {
    'A': probs_t[0,:],
    'C': probs_t[1,:],
    'G': probs_t[2,:],
    'T': probs_t[3,:]
}

In [None]:
x = np.linspace(0,probs_t.shape[1],probs_t.shape[1],dtype=np.int32)
plt.bar(x, probs_dict['A'], label = 'A')
plt.bar(x, probs_dict['C'], bottom = probs_dict['A'], label = 'C')
plt.bar(x, probs_dict['G'], bottom = probs_dict['A'] + probs_dict['C'], label = 'G')
plt.bar(x, probs_dict['T'], bottom = probs_dict['A'] + probs_dict['C'] + probs_dict['G'], label = 'T')
plt.legend(loc = "center left", bbox_to_anchor = (1,0.5))
plt.show()

In [None]:
seed[1:51]