In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader

In [2]:
import sys
import os

module_path = os.path.abspath("/dcai/users/chache/smrt-foundation")

if module_path not in sys.path:
    sys.path.append(module_path)
    
import copy
import torch
from tqdm import tqdm
from smrt_foundation.probe import SingleIdxProbe
from smrt_foundation.model import Smrt2Vec, SmrtEncoder


device=torch.device('cuda')


In [3]:
encoder = SmrtEncoder()
encoder.cnn.r0

25

In [4]:
from smrt_foundation.dataset import LegacyMethylDataset, compute_log_normalization_stats
import polars as pl
TRAIN_PATH = '../data/01_processed/val_sets/pacbio_standard_train.parquet'
VAL_PATH =  '../data/01_processed/val_sets/pacbio_standard_test.parquet'
KINETICS_FEATURES = ['fi', 'fp', 'ri', 'rp']

df = pl.read_parquet(TRAIN_PATH).head(1_000_000)
train_means, train_stds = compute_log_normalization_stats(df, KINETICS_FEATURES)

methyl_train_ds = LegacyMethylDataset(TRAIN_PATH, train_means, train_stds, context=32, restrict_row_groups=5)
methyl_train_dl = DataLoader(methyl_train_ds,
                             # num_workers=8,
                             batch_size=256,
                             drop_last=True,
                             persistent_workers=False,
                             # prefetch_factor=5
                            )

methyl_val_ds = LegacyMethylDataset(VAL_PATH,
                                      means=train_means,
                                      stds=train_stds,
                                      context=32,
                                      restrict_row_groups=5)
methyl_val_dl = DataLoader(methyl_val_ds,
                        batch_size=256,
                        drop_last=True,
                        persistent_workers=False,
                        prefetch_factor=None)



In [5]:


EPOCHS = 30
DEVICE = torch.device('cuda')


encoder = SmrtEncoder(max_len=32)
probe = SingleIdxProbe(encoder, freeze_encoder=False).to(DEVICE)
print(encoder.cnn.r0)
optimizer = torch.optim.AdamW([
    {'params': probe.encoder.parameters(), 'lr': 2e-4},
    {'params': probe.head.parameters(), 'lr': 2e-4}
])

criterion = torch.nn.BCEWithLogitsLoss()
loss_history = []

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}")
    probe.train()
    running_loss = 0.0
    for i, batch in enumerate(tqdm(methyl_train_dl)):
        inputs = batch['data'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        optimizer.zero_grad()
        logits = probe(inputs)
        loss = criterion(logits, labels.unsqueeze(1).to(torch.float32))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i + 1) % 100 == 0:
            loss_history.append(running_loss / 100)
            running_loss = 0.0

    probe.eval()
    sample_count = 0
    sample_correct = 0
    # Added torch.no_grad() for validation efficiency
    with torch.no_grad():
        for batch in tqdm(methyl_val_dl):
            inputs = batch['data'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            logits = probe(inputs)
            preds = logits > 0
            correct = labels == preds.squeeze(-1)
            sample_count += correct.shape[0]
            sample_correct += correct.sum()
    
    print(f"epoch val top1_acc: {sample_correct/sample_count}")

25
Epoch 1


100%|██████████| 5129/5129 [00:41<00:00, 123.27it/s]
100%|██████████| 5166/5166 [00:16<00:00, 304.00it/s]


epoch val top1_acc: 0.7868257761001587
Epoch 2


100%|██████████| 5129/5129 [00:41<00:00, 123.04it/s]
100%|██████████| 5166/5166 [00:17<00:00, 300.19it/s]


epoch val top1_acc: 0.7983411550521851
Epoch 3


100%|██████████| 5129/5129 [00:41<00:00, 124.62it/s]
100%|██████████| 5166/5166 [00:17<00:00, 293.36it/s]


epoch val top1_acc: 0.8017566800117493
Epoch 4


100%|██████████| 5129/5129 [00:42<00:00, 121.77it/s]
100%|██████████| 5166/5166 [00:17<00:00, 296.34it/s]


epoch val top1_acc: 0.8032681941986084
Epoch 5


100%|██████████| 5129/5129 [00:41<00:00, 123.32it/s]
100%|██████████| 5166/5166 [00:17<00:00, 295.93it/s]


epoch val top1_acc: 0.8039865493774414
Epoch 6


100%|██████████| 5129/5129 [00:41<00:00, 124.30it/s]
100%|██████████| 5166/5166 [00:17<00:00, 299.71it/s]


epoch val top1_acc: 0.803377091884613
Epoch 7


100%|██████████| 5129/5129 [00:41<00:00, 122.60it/s]
100%|██████████| 5166/5166 [00:17<00:00, 294.79it/s]


epoch val top1_acc: 0.8025445342063904
Epoch 8


100%|██████████| 5129/5129 [00:40<00:00, 125.19it/s]
100%|██████████| 5166/5166 [00:17<00:00, 292.71it/s]


epoch val top1_acc: 0.8007698655128479
Epoch 9


100%|██████████| 5129/5129 [00:41<00:00, 123.85it/s]
100%|██████████| 5166/5166 [00:17<00:00, 293.97it/s]


epoch val top1_acc: 0.7993891835212708
Epoch 10


100%|██████████| 5129/5129 [00:43<00:00, 117.95it/s]
100%|██████████| 5166/5166 [00:17<00:00, 293.03it/s]


epoch val top1_acc: 0.7963456511497498
Epoch 11


100%|██████████| 5129/5129 [00:41<00:00, 122.97it/s]
100%|██████████| 5166/5166 [00:17<00:00, 294.67it/s]


epoch val top1_acc: 0.7924885749816895
Epoch 12


100%|██████████| 5129/5129 [00:41<00:00, 123.31it/s]
100%|██████████| 5166/5166 [00:17<00:00, 296.30it/s]


epoch val top1_acc: 0.7912008762359619
Epoch 13


100%|██████████| 5129/5129 [00:41<00:00, 124.74it/s]
100%|██████████| 5166/5166 [00:17<00:00, 297.62it/s]


epoch val top1_acc: 0.7875237464904785
Epoch 14


100%|██████████| 5129/5129 [00:41<00:00, 123.79it/s]
100%|██████████| 5166/5166 [00:22<00:00, 231.95it/s]


epoch val top1_acc: 0.7876273393630981
Epoch 15


100%|██████████| 5129/5129 [00:46<00:00, 111.19it/s]
100%|██████████| 5166/5166 [00:17<00:00, 295.42it/s]


epoch val top1_acc: 0.7860424518585205
Epoch 16


100%|██████████| 5129/5129 [00:41<00:00, 123.78it/s]
100%|██████████| 5166/5166 [00:17<00:00, 298.72it/s]


epoch val top1_acc: 0.7839940190315247
Epoch 17


100%|██████████| 5129/5129 [00:41<00:00, 124.28it/s]
100%|██████████| 5166/5166 [00:17<00:00, 300.07it/s]


epoch val top1_acc: 0.7808356285095215
Epoch 18


100%|██████████| 5129/5129 [00:41<00:00, 123.38it/s]
100%|██████████| 5166/5166 [00:17<00:00, 299.48it/s]


epoch val top1_acc: 0.7819842100143433
Epoch 19


100%|██████████| 5129/5129 [00:41<00:00, 123.97it/s]
100%|██████████| 5166/5166 [00:17<00:00, 290.46it/s]


epoch val top1_acc: 0.7797982096672058
Epoch 20


100%|██████████| 5129/5129 [00:46<00:00, 109.92it/s]
100%|██████████| 5166/5166 [00:17<00:00, 295.71it/s]


epoch val top1_acc: 0.778468906879425
Epoch 21


100%|██████████| 5129/5129 [00:41<00:00, 123.62it/s]
100%|██████████| 5166/5166 [00:17<00:00, 298.20it/s]


epoch val top1_acc: 0.7770941853523254
Epoch 22


100%|██████████| 5129/5129 [00:41<00:00, 123.32it/s]
100%|██████████| 5166/5166 [00:17<00:00, 294.13it/s]


epoch val top1_acc: 0.7748923301696777
Epoch 23


100%|██████████| 5129/5129 [00:41<00:00, 122.84it/s]
100%|██████████| 5166/5166 [00:17<00:00, 300.73it/s]


epoch val top1_acc: 0.7763592004776001
Epoch 24


100%|██████████| 5129/5129 [00:41<00:00, 124.19it/s]
100%|██████████| 5166/5166 [00:18<00:00, 286.37it/s]


epoch val top1_acc: 0.7747834324836731
Epoch 25


100%|██████████| 5129/5129 [00:40<00:00, 125.71it/s]
100%|██████████| 5166/5166 [00:17<00:00, 297.74it/s]


epoch val top1_acc: 0.7726851105690002
Epoch 26


100%|██████████| 5129/5129 [00:41<00:00, 123.09it/s]
100%|██████████| 5166/5166 [00:17<00:00, 295.72it/s]


epoch val top1_acc: 0.7717671394348145
Epoch 27


100%|██████████| 5129/5129 [00:41<00:00, 124.12it/s]
100%|██████████| 5166/5166 [00:17<00:00, 296.89it/s]


epoch val top1_acc: 0.7699138522148132
Epoch 28


100%|██████████| 5129/5129 [00:41<00:00, 124.02it/s]
100%|██████████| 5166/5166 [00:17<00:00, 297.87it/s]


epoch val top1_acc: 0.7716348171234131
Epoch 29


100%|██████████| 5129/5129 [00:40<00:00, 125.41it/s]
100%|██████████| 5166/5166 [00:17<00:00, 297.11it/s]


epoch val top1_acc: 0.7711841464042664
Epoch 30


100%|██████████| 5129/5129 [00:41<00:00, 123.34it/s]
100%|██████████| 5166/5166 [00:17<00:00, 293.58it/s]

epoch val top1_acc: 0.7702699899673462





In [6]:
probe.eval()
sample_count = 0
sample_correct = 0
with torch.no_grad():
    for batch in tqdm(methyl_train_dl):
        inputs = batch['data'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        logits = probe(inputs)
        preds = logits > 0
        correct = labels == preds.squeeze(-1)
        sample_count += correct.shape[0]
        sample_correct += correct.sum()
print(f"epoch train top1_acc: {sample_correct/sample_count}")

100%|██████████| 5129/5129 [00:16<00:00, 301.93it/s]

epoch train top1_acc: 0.8433158993721008



