## Basic examples with PlantCaduceus

### Setup environment

In [1]:
#!pip install --quiet mamba-ssm==1.1.3

In [2]:
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer
import torch
import numpy as np
import pandas as pd

### Load the model

In [3]:
model_path = 'kuleshov-group/PlantCaduceus_l20'
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AutoModelForMaskedLM.from_pretrained(model_path, trust_remote_code=True, device_map=device)
model.eval()

CaduceusForMaskedLM(
  (caduceus): Caduceus(
    (backbone): CaduceusMixerModel(
      (embeddings): CaduceusEmbeddings(
        (word_embeddings): RCPSEmbedding(
          (embedding): Embedding(8, 384)
        )
      )
      (layers): ModuleList(
        (0-19): 20 x RCPSMambaBlock(
          (mixer): RCPSWrapper(
            (submodule): BiMambaWrapper(
              (mamba_fwd): Mamba(
                (in_proj): Linear(in_features=384, out_features=1536, bias=False)
                (conv1d): Conv1d(768, 768, kernel_size=(4,), stride=(1,), padding=(3,), groups=768)
                (act): SiLU()
                (x_proj): Linear(in_features=768, out_features=56, bias=False)
                (dt_proj): Linear(in_features=24, out_features=768, bias=True)
                (out_proj): Linear(in_features=768, out_features=384, bias=False)
              )
              (mamba_rev): Mamba(
                (in_proj): Linear(in_features=384, out_features=1536, bias=False)
                (conv1

### Tokenize the sequence

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
sequence = "CTTAATTAATATTGCCTTTGTAATAACGCGCGAAACACAAATCTTCTCTGCCTAATGCAGTAGTCATGTGTTGACTCCTTCAAAATTTCCAAGAAGTTAGTGGCTGGTGTGTCATTGTCTTCATCTTTTTTTTTTTTTTTTTAAAAATTGAATGCGACATGTACTCCTCAACGTATAAGCTCAATGCTTGTTACTGAAACATCTCTTGTCTGATTTTTTCAGGCTAAGTCTTACAGAAAGTGATTGGGCACTTCAATGGCTTTCACAAATGAAAAAGATGGATCTAAGGGATTTGTGAAGAGAGTGGCTTCATCTTTCTCCATGAGGAAGAAGAAGAATGCAACAAGTGAACCCAAGTTGCTTCCAAGATCGAAATCAACAGGTTCTGCTAACTTTGAATCCATGAGGCTACCTGCAACGAAGAAGATTTCAGATGTCACAAACAAAACAAGGATCAAACCATTAGGTGGTGTAGCACCAGCACAACCAAGAAGGGAAAAGATCGATGATCG"
encoding = tokenizer.encode_plus(
            sequence,
            return_tensors="pt",
            return_attention_mask=False,
            return_token_type_ids=False
        )
input_ids = encoding["input_ids"].to(device)
input_ids.shape

torch.Size([1, 512])

### Embedding

In [5]:
with torch.inference_mode():
    outputs = model(input_ids=input_ids, output_hidden_states=True)
emeddings = outputs.hidden_states[-1]

In [6]:
print(emeddings.shape)

torch.Size([1, 512, 768])


#### Averaging forward and reverse embeddings

In [7]:
emeddings = emeddings.to(torch.float32).cpu().numpy()

In [8]:
hidden_size = emeddings.shape[-1] // 2
forward = emeddings[..., 0:hidden_size]
reverse = emeddings[..., hidden_size:]
reverse = reverse[..., ::-1]
averaged_embeddings = (forward + reverse) / 2
print(averaged_embeddings.shape)

(1, 512, 384)


### Masked token prediction

In [9]:
pos = 255
sequence[pos]

'A'

In [10]:
input_ids[0, pos] = tokenizer.mask_token_id
with torch.inference_mode():
    outputs = model(input_ids=input_ids)

In [11]:
nucleotides = list('acgt')
logits = outputs.logits
logits = logits[:, pos, [tokenizer.get_vocab()[nc] for nc in nucleotides]]
probs = torch.nn.functional.softmax(logits.cpu(), dim=1).numpy()

In [12]:
probs

array([[0.96960527, 0.00782286, 0.01123959, 0.01133224]], dtype=float32)

In [13]:
df = pd.DataFrame(dict(nucleotides = nucleotides, probs = probs[0]))

In [14]:
df

Unnamed: 0,nucleotides,probs
0,a,0.969605
1,c,0.007823
2,g,0.01124
3,t,0.011332
