In [10]:

import torch
import random
import numpy as np
from tqdm import tqdm
from time import time

from utils import load_from_hf, load_from_mila

In [11]:
# Model
source = "mila"
model_name = "AMPLIFY350M"
model_path = "../outputs/MILA_PLM_350M_UR100P/checkpoint/pytorch_model.pt"
tokenizer_path = None 
config_path = "../outputs/MILA_PLM_350M_UR100P/checkpoint/config.yaml"
batch_size = 32
device = "cuda"
compile = False
fp16 = True

# Dataset
n_samples = 1000
seq_length = 512
seed = 0

In [12]:
# Get model and tokenizer
if source == "hf":
    model, tokenizer = load_from_hf(model_path, tokenizer_path, fp16=fp16)
elif source == "mila":
    model, tokenizer = load_from_mila(model_path, config_path)
else:
    raise Exception("Only 'hf' and 'mila' sources are supported, not {source}.")
model.to(device)
torch.compile(model, disable=~compile)

AMPLIFY(
  (encoder): Embedding(27, 960, padding_idx=0)
  (transformer_encoder): ModuleList(
    (0-31): 32 x EncoderBlock(
      (q): Linear(in_features=960, out_features=960, bias=False)
      (k): Linear(in_features=960, out_features=960, bias=False)
      (v): Linear(in_features=960, out_features=960, bias=False)
      (wo): Linear(in_features=960, out_features=960, bias=False)
      (resid_dropout): Dropout(p=0, inplace=False)
      (ffn): SwiGLU(
        (w12): Linear(in_features=960, out_features=5120, bias=False)
        (w3): Linear(in_features=2560, out_features=960, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
      (ffn_dropout): Dropout(p=0, inplace=False)
    )
  )
  (layer_norm_2): RMSNorm()
  (decoder): Linear(in_features=960, out_features=27, bias=True)
)

In [13]:
vocab = ["L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C"]
random.seed(seed)

with torch.no_grad(), torch.autocast(device_type=device, dtype=torch.float16, enabled=fp16):
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    
    times = []
    for i in tqdm(range(n_samples + 100)):
        x = torch.stack([tokenizer.encode(random.choices(vocab, k=seq_length), return_tensors="pt").squeeze() for _ in range(batch_size)])
        x = x.to(device)
        
        # Time the forward pass (inference)
        start = time()
        y = model(x)
        stop = time()
        
        # Burn-in period of 100 samples (benefit compiled ESM)
        if i >= 100:
            times.append(stop - start)
   
times = np.array(times)
print(f"{np.mean(times * 1000 / batch_size):.2f} ± {np.std(times * 1000 / batch_size):.2f} ms/protein")
print(f"{np.mean(batch_size / times):.2f} ± {np.std(batch_size / times):.2f} protein/s")
print(f"{np.mean(times * 1000 / (seq_length * batch_size)):.2f} ± {np.std(times * 1000 / (seq_length * batch_size)):.2f} ms/token")
print(f"{np.mean(seq_length / times):.2f} ± {np.std(seq_length / times):.2f} token/s")
print(f"{torch.cuda.max_memory_allocated(device=device)/1e6:.0f} MB")

100%|██████████| 1100/1100 [02:10<00:00,  8.42it/s]

0.81 ± 0.01 ms/protein
1239.09 ± 13.74 protein/s
0.00 ± 0.00 ms/token
19825.47 ± 219.76 token/s
2388 MB



