# Preliminaries

## Imports

In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader

## Load local module

In [2]:
import sys
import os

module_path = os.path.abspath("/dcai/users/chache/smrt-foundation")

if module_path not in sys.path:
    sys.path.append(module_path)

device=torch.device('cuda')


## Optional copy step

In [3]:

! du -h ../data/01_processed/ssl_sets/ob007_test.memmap/
! cd ../data/01_processed/ssl_sets/ob007_test.memmap/ && time find  -type f -name '*.npy' | xargs -P16 -IX cp -v X $TMPDIR/ 
! df -h ${TMPDIR:-/tmp}

16G	../data/01_processed/ssl_sets/ob007_test.memmap/
'./shard_00028.npy' -> '/tmp/shard_00028.npy'
'./shard_00021.npy' -> '/tmp/shard_00021.npy'
'./shard_00026.npy' -> '/tmp/shard_00026.npy'
'./shard_00013.npy' -> '/tmp/shard_00013.npy'
'./shard_00014.npy' -> '/tmp/shard_00014.npy'
'./shard_00005.npy' -> '/tmp/shard_00005.npy'
'./shard_00002.npy' -> '/tmp/shard_00002.npy'
'./shard_00030.npy' -> '/tmp/shard_00030.npy'
'./shard_00015.npy' -> '/tmp/shard_00015.npy'
'./shard_00012.npy' -> '/tmp/shard_00012.npy'
'./shard_00027.npy' -> '/tmp/shard_00027.npy'
'./shard_00020.npy' -> '/tmp/shard_00020.npy'
'./shard_00029.npy' -> '/tmp/shard_00029.npy'
'./shard_00031.npy' -> '/tmp/shard_00031.npy'
'./shard_00003.npy' -> '/tmp/shard_00003.npy'
'./shard_00004.npy' -> '/tmp/shard_00004.npy'
'./shard_00007.npy' -> '/tmp/shard_00007.npy'
'./shard_00000.npy' -> '/tmp/shard_00000.npy'
'./shard_00009.npy' -> '/tmp/shard_00009.npy'
'./shard_00018.npy' -> '/tmp/shard_00018.npy'
'./shard_00011.npy' -> '/tm

# Dataset

In [4]:
from smrt_foundation.dataset import ShardedMemmapDataset
B = 64
# ssl_ds = ShardedMemmapDataset("../data/01_processed/ssl_sets/ob007_test.memmap/")
ssl_ds = ShardedMemmapDataset("/tmp")
ssl_dl = DataLoader(ssl_ds, batch_size=B, num_workers=8, pin_memory=True, prefetch_factor=4, shuffle=2)

In [5]:
batch = next(iter(ssl_dl))
batch[0,0:10,:]

tensor([[ 3.0000, -0.0933,  1.5625,  0.0000],
        [ 1.0000,  1.0781,  0.2344,  0.0000],
        [ 3.0000, -0.1582,  0.9375,  0.0000],
        [ 1.0000,  0.2432,  0.7305,  0.0000],
        [ 1.0000,  0.3848,  2.5000,  0.0000],
        [ 3.0000, -1.0625,  0.7305,  0.0000],
        [ 2.0000,  0.7383,  1.0078,  0.0000],
        [ 0.0000,  0.2432,  1.0078,  0.0000],
        [ 2.0000,  0.2432,  0.7305,  0.0000],
        [ 0.0000,  0.0281,  0.0378,  0.0000]], dtype=torch.bfloat16)

In [6]:
for batch in tqdm(ssl_dl):
    x=batch.to(device)

100%|██████████| 7976/7976 [00:08<00:00, 995.60it/s] 


# Model

In [7]:
from smrt_foundation.model import Smrt2Vec
from smrt_foundation.loss import InfoNCE
model = Smrt2Vec().to(device)
model.train()
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"trainable parameters: {round(total_params/1e6,2)}m")

trainable parameters: 2.08m


In [8]:
loss = InfoNCE()
c_proj, targets, mask = model(batch.to(device))
loss(c_proj, targets, mask)

tensor(9.8000, device='cuda:0', grad_fn=<NllLossBackward0>)

# Pretrain

In [9]:
! nvidia-smi

Fri Feb  6 10:10:43 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.126.09             Driver Version: 580.126.09     CUDA Version: 13.1     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:1B:00.0 Off |                    0 |
| N/A   32C    P0            127W /  700W |    7773MiB /  81559MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00

In [10]:
#if torch.cuda.device_count() > 1:
#    print(f"Using {torch.cuda.device_count()} GPUs via DataParallel")
#    model = nn.DataParallel(Smrt2Vec())
#    model.to(device)

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.02)
criterion = InfoNCE(temperature=0.1).to(device)
EPOCHS = 4

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=6e-4,
    total_steps=len(ssl_dl) * EPOCHS,
    pct_start=0.05
)
    
model.train()

for epoch in range(EPOCHS):
    progress_bar = tqdm(ssl_dl, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for i, batch in enumerate(progress_bar):
        batch = batch.to(device, non_blocking=True)

        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            c_proj, targets, mask_idx = model(batch)
            loss = criterion(c_proj, targets, mask_idx)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if i % 10 == 0:
            progress_bar.set_postfix(
                loss=f"{loss.item():.4f}",
                lr=f"{scheduler.get_last_lr()[0]:.6f}"
            )

Epoch 1/4: 100%|██████████| 7976/7976 [06:31<00:00, 20.35it/s, loss=3.8523, lr=0.000537]
Epoch 2/4: 100%|██████████| 7976/7976 [06:29<00:00, 20.48it/s, loss=3.3509, lr=0.000325]
Epoch 3/4: 100%|██████████| 7976/7976 [06:29<00:00, 20.49it/s, loss=4.4448, lr=0.000097]
Epoch 4/4: 100%|██████████| 7976/7976 [06:29<00:00, 20.49it/s, loss=5.5813, lr=0.000000]


In [12]:
# torch.save(model.state_dict(), '../models/smrt2vec_8epoch.pt')

In [13]:
from smrt_foundation.dataset import LegacyMethylDataset, compute_log_normalization_stats
import polars as pl
TRAIN_PATH = '../data/01_processed/val_sets/pacbio_standard_train.parquet'
VAL_PATH =  '../data/01_processed/val_sets/pacbio_standard_test.parquet'
KINETICS_FEATURES = ['fi', 'fp', 'ri', 'rp']

df = pl.read_parquet(TRAIN_PATH).head(1_000_000)
train_means, train_stds = compute_log_normalization_stats(df, KINETICS_FEATURES)

methyl_train_ds = LegacyMethylDataset(TRAIN_PATH, train_means, train_stds, context=32, restrict_row_groups=5)
methyl_train_dl = DataLoader(methyl_train_ds,
                             # num_workers=8,
                             batch_size=256,
                             drop_last=True,
                             persistent_workers=False,
                             # prefetch_factor=5
                            )

methyl_val_ds = LegacyMethylDataset(VAL_PATH,
                                      means=train_means,
                                      stds=train_stds,
                                      context=32,
                                      restrict_row_groups=5)
methyl_val_dl = DataLoader(methyl_val_ds,
                        batch_size=256,
                        drop_last=True,
                        persistent_workers=False,
                        prefetch_factor=None)



In [14]:
df.head()

read_name,cg_pos,seq,qual,np,fi,fp,ri,rp,label
str,i64,str,list[u8],u8,list[u16],list[u16],list[u16],list[u16],i32
"""m64168_200820_000733/48169889/…",3058,"""GATGTCCTGGGGATTCGGGGGCATAACTGC…","[60, 67, … 69]",8,"[15, 29, … 35]","[7, 19, … 23]","[10, 10, … 5]","[34, 39, … 33]",0
"""m64168_200820_000733/45943110/…",8167,"""TCTCCACGTTGGCCACGCTGGTCTCGAACT…","[93, 73, … 93]",13,"[33, 18, … 26]","[20, 46, … 27]","[48, 70, … 20]","[20, 16, … 51]",0
"""m64168_200823_191315/50332760/…",1413,"""AATTTCTTGAAGAGACGAAAGTCTGTGGGT…","[93, 93, … 93]",32,"[32, 12, … 18]","[21, 9, … 13]","[16, 17, … 13]","[16, 23, … 22]",1
"""m64168_200823_191315/177537981…",4708,"""CAACCCACTGCCAAGCGCTTCCTGCCACCT…","[93, 82, … 93]",9,"[13, 19, … 19]","[23, 14, … 34]","[9, 21, … 31]","[21, 13, … 34]",1
"""m64168_200823_191315/49154585/…",5695,"""CCTCCCTACCGAAAACGGGGATCGTGTGAA…","[13, 58, … 53]",3,"[6, 35, … 10]","[12, 34, … 19]","[17, 13, … 20]","[24, 12, … 43]",1


In [15]:
import copy
import torch
from tqdm import tqdm
from smrt_foundation.probe import SingleIdxProbe

EPOCHS = 5
DEVICE = torch.device('cuda')
encoder_clone = copy.deepcopy(model.encoder)
# Fixed: used DEVICE instead of device
probe = SingleIdxProbe(encoder_clone, freeze_encoder=False).to(DEVICE)

optimizer = torch.optim.AdamW([
    {'params': probe.encoder.parameters(), 'lr': 5e-7},
    {'params': probe.head.parameters(), 'lr': 3e-5}
])

criterion = torch.nn.BCEWithLogitsLoss()
loss_history = []

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}")
    probe.train()
    running_loss = 0.0
    for i, batch in enumerate(tqdm(methyl_train_dl)):
        inputs = batch['data'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        optimizer.zero_grad()
        logits = probe(inputs)
        loss = criterion(logits, labels.unsqueeze(1).to(torch.float32))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i + 1) % 100 == 0:
            loss_history.append(running_loss / 100)
            running_loss = 0.0

    probe.eval()
    sample_count = 0
    sample_correct = 0
    # Added torch.no_grad() for validation efficiency
    with torch.no_grad():
        for batch in tqdm(methyl_val_dl):
            inputs = batch['data'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            logits = probe(inputs)
            preds = logits > 0
            correct = labels == preds.squeeze(-1)
            sample_count += correct.shape[0]
            sample_correct += correct.sum()
    
    print(f"epoch val top1_acc: {sample_correct/sample_count}")

Epoch 1


100%|██████████| 5129/5129 [00:56<00:00, 90.96it/s] 
100%|██████████| 5166/5166 [00:24<00:00, 213.06it/s]


epoch val top1_acc: 0.586393415927887
Epoch 2


100%|██████████| 5129/5129 [00:57<00:00, 89.58it/s]
100%|██████████| 5166/5166 [00:23<00:00, 217.99it/s]


epoch val top1_acc: 0.6199345588684082
Epoch 3


100%|██████████| 5129/5129 [00:57<00:00, 89.88it/s]
100%|██████████| 5166/5166 [00:23<00:00, 219.91it/s]


epoch val top1_acc: 0.6427958607673645
Epoch 4


100%|██████████| 5129/5129 [00:56<00:00, 90.31it/s]
100%|██████████| 5166/5166 [00:23<00:00, 220.33it/s]


epoch val top1_acc: 0.6594099402427673
Epoch 5


100%|██████████| 5129/5129 [00:56<00:00, 90.58it/s] 
100%|██████████| 5166/5166 [00:23<00:00, 218.74it/s]

epoch val top1_acc: 0.6715256571769714



