`Viru`, `Temp`, `Bact` 

### IMPORTS

In [1]:
import sys
print(sys.executable)

/home/ec2-user/SageMaker/project-gamma/evo_env/bin/python


In [None]:
!{sys.executable} -m pip install torchinfo

In [1]:
from Bio import SeqIO
from evo import Evo
from transformers import AutoTokenizer, AutoModel
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import random
from sklearn.metrics import accuracy_score
from torch.optim import AdamW
from torchinfo import summary
import gc
from torch.cuda.amp import autocast, GradScaler


device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
def parse_labeled_fasta(fasta_path):
    data = []
    i=0
    for record in SeqIO.parse(fasta_path, "fasta"):
        if i >= 1000:
            break
        seq = str(record.seq)
        if record.id.startswith("Viru"):
            label = "Viru"
        elif record.id.startswith("Temp"):
            label = "Temp"
        elif record.id.startswith("Bact"):
            label = "Bact"
        else:
            continue  # Skip unknown
        data.append({"sequence": seq, "label": label})
        i += 1
    return data

In [3]:
class GenomicDataset(Dataset):
    def __init__(self, data, tokenizer, label2id, max_length=4096):
        self.data = data
        self.tokenizer = tokenizer
        self.label2id = label2id
        self.max_length = max_length
        self.pad_token = tokenizer.pad_token if hasattr(tokenizer, 'pad_token') else 'N'  # fallback if undefined

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]["sequence"]
        seq = seq[:self.max_length]
        label = self.label2id[self.data[idx]["label"]]

        # Truncate or pad sequence BEFORE tokenization
        # if len(seq) < self.max_length:
        #     seq += self.pad_token * (self.max_length - len(seq))
        # else:
        #     seq = seq[:self.max_length]

        # Now tokenize the sequence
        tokens = self.tokenizer.tokenize(seq)

        input_ids = torch.tensor(tokens, dtype=torch.long)
        attention_mask = (input_ids != self.tokenizer.pad_id).long()

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "label": torch.tensor(label, dtype=torch.long)
        }

In [4]:
class EvoClassifier(nn.Module):
    def __init__(self, encoder, num_classes=3):
        super().__init__()
        self.encoder = encoder

        # Grab the true hidden dim
        cfg = encoder.config
        hidden_size = getattr(cfg, "d_model", None) or getattr(cfg, "vocab_size", None)
        assert hidden_size == 512
        if hidden_size is None:
            raise ValueError(f"No d_model or hidden_size in config: {cfg}")
        self.classifier = nn.Linear(hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            x_out, _ = self.encoder(input_ids, None, attention_mask)
            pooled = x_out.mean(1).to(torch.float32)
            #print(pooled.shape)
            #print(pooled)
            # outputs = self.encoder(input_ids, attention_mask)
            # pooled = outputs.last_hidden_state[:, 0]  # CLS token
        return self.classifier(pooled)

### Load model and tokenizer

In [5]:
evo_model = Evo('evo-1.5-8k-base')
model, tokenizer = evo_model.model, evo_model.tokenizer
model.to(device)
model.eval()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

StripedHyena(
  (embedding_layer): VocabParallelEmbedding(512, 4096)
  (norm): RMSNorm()
  (unembed): VocabParallelEmbedding(512, 4096)
  (blocks): ModuleList(
    (0-7): 8 x ParallelGatedConvBlock(
      (pre_norm): RMSNorm()
      (post_norm): RMSNorm()
      (filter): ParallelHyenaFilter()
      (projections): Linear(in_features=4096, out_features=12288, bias=True)
      (out_filter_dense): Linear(in_features=4096, out_features=4096, bias=True)
      (mlp): ParallelGatedMLP(
        (l1): Linear(in_features=4096, out_features=10928, bias=False)
        (l2): Linear(in_features=4096, out_features=10928, bias=False)
        (l3): Linear(in_features=10928, out_features=4096, bias=False)
      )
    )
    (8): AttentionBlock(
      (pre_norm): RMSNorm()
      (post_norm): RMSNorm()
      (inner_mha_cls): MHA(
        (rotary_emb): RotaryEmbedding()
        (Wqkv): Linear(in_features=4096, out_features=12288, bias=True)
        (inner_attn): FlashSelfAttention(
          (drop): Dropout(

### Load dataset and split train/test data

In [6]:
label2id = {"Viru": 0, "Temp": 1, "Bact": 2}
all_data = parse_labeled_fasta("datasets/fusion_sequences_shuffled.fasta")
random.shuffle(all_data)
split_idx = int(0.8 * len(all_data))
train_data = all_data[:split_idx]
val_data = all_data[split_idx:]

train_dataset = GenomicDataset(train_data, tokenizer, label2id, max_length=4096)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)

val_dataset = GenomicDataset(val_data, tokenizer, label2id, max_length=4096)
val_loader = DataLoader(val_dataset, batch_size=1)

In [7]:
print(train_dataset[0])
x = train_dataset[0]['input_ids']
with torch.no_grad():
    y = model(x.unsqueeze(0).to(device))
print(x.shape, y[0].size())

{'input_ids': tensor([65, 71, 67,  ..., 67, 84, 84]), 'attention_mask': tensor([1, 1, 1,  ..., 1, 1, 1]), 'label': tensor(1)}
torch.Size([4096]) torch.Size([1, 4096, 512])


### Setup hyperparameters

### Training loop 

In [10]:
m = EvoClassifier(model, num_classes=3).to(device)

# Optimizer and loss
optimizer = AdamW(m.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

In [11]:
summary(m)

Layer (type:depth-idx)                                  Param #
EvoClassifier                                           --
├─StripedHyena: 1-1                                     --
│    └─VocabParallelEmbedding: 2-1                      2,097,152
│    └─RMSNorm: 2-2                                     4,096
│    └─VocabParallelEmbedding: 2-3                      (recursive)
│    └─ModuleList: 2-4                                  --
│    │    └─ParallelGatedConvBlock: 3-1                 201,601,024
│    │    └─ParallelGatedConvBlock: 3-2                 201,601,024
│    │    └─ParallelGatedConvBlock: 3-3                 201,601,024
│    │    └─ParallelGatedConvBlock: 3-4                 201,601,024
│    │    └─ParallelGatedConvBlock: 3-5                 201,601,024
│    │    └─ParallelGatedConvBlock: 3-6                 201,601,024
│    │    └─ParallelGatedConvBlock: 3-7                 201,601,024
│    │    └─ParallelGatedConvBlock: 3-8                 201,601,024
│    │    └─Attenti

In [12]:
# del model
# gc.collect()
torch.cuda.empty_cache()

# for param in m.encoder.parameters():
#    param.requires_grad = False

num_epochs = 1
for epoch in range(num_epochs):
    m.train()
    train_loss = 0.0
    print(f"\nEpoch {epoch + 1}/{num_epochs} — Training")

    for batch in tqdm(train_loader, desc="Train", leave=False):
        input_ids = batch["input_ids"].to(device)
        padding_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()
        #with autocast():
        logits = m(input_ids, attention_mask=padding_mask)
        loss = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_loss += loss.item()
        torch.cuda.empty_cache()

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch + 1} | Avg Train Loss: {avg_train_loss:.4f}")

    # Phase validation
    m.eval()
    all_preds, all_labels = [], []
    val_loss = 0.0
    print("Validating...")

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Val", leave=False):
            input_ids = batch["input_ids"].to(device)
            padding_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            #with autocast():
            logits = m(input_ids, attention_mask=padding_mask)
            loss = criterion(logits, labels)

            val_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            torch.cuda.empty_cache()

    avg_val_loss = val_loss / len(val_loader)
    acc = accuracy_score(all_labels, all_preds)
    print(f"Validation Loss: {avg_val_loss:.4f} | Accuracy: {acc:.4f}")



Epoch 1/1 — Training


                                                        

Epoch 1 | Avg Train Loss: 0.7620
Validating...


                                                      

Validation Loss: 0.4722 | Accuracy: 0.8800




In [None]:
!pip install matplotlib

In [15]:
import pandas as pd
import matplotlib.pyplot as plt

# Build DataFrames
train_df = pd.DataFrame(train_data)
val_df = pd.DataFrame(val_data)


def label_distribution(df, split_name):
    counts = df['label'].value_counts().sort_index()
    percents = (counts / len(df) * 100).round(2)
    dist = pd.DataFrame({
        'count': counts,
        'percent': percents
    })
    dist.index.name = 'label'
    dist.reset_index(inplace=True)
    dist['split'] = split_name
    return dist[['split', 'label', 'count', 'percent']]


# Compute
train_dist = label_distribution(train_df, 'train')
val_dist = label_distribution(val_df,   'val')

# Combine and display
dist = pd.concat([train_dist, val_dist], ignore_index=True)
print(dist)


ModuleNotFoundError: No module named 'matplotlib'

In [14]:
train_counts = train_dist.set_index('label')['count']
val_counts   = val_dist.set_index('label')['count']

labels = train_counts.index.tolist()
x = np.arange(len(labels))
width = 0.35

fig, ax = plt.subplots()
# plot train and val bars side by side
ax.bar(x - width/2, train_counts.values, width, label='train')
ax.bar(x + width/2, val_counts.values, width, label='val')

ax.set_xlabel('Label')
ax.set_ylabel('Number of examples')
ax.set_title('Label Distribution in Train vs Validation')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
plt.tight_layout()
plt.show()

NameError: name 'np' is not defined