# Preliminaries

## Imports

In [1]:
import multiprocessing
import torch

# This forces the launcher to use 'spawn' instead of 'fork'
# Must be run before any other torch imports
try:
    multiprocessing.set_start_method('spawn', force=True)
    print("Start method set to 'spawn'. This should fix the CUDA error.")
except RuntimeError:
    print("Could not set start method (already set).")

Start method set to 'spawn'. This should fix the CUDA error.


In [2]:
# local import
import sys
import os
import torch

module_path = os.path.abspath("/dcai/users/chache/smrt-foundation")

if module_path not in sys.path:
    sys.path.append(module_path)

from smrt_foundation.dataset import ShardedMemmapDataset
from smrt_foundation.model import Smrt2Vec
from smrt_foundation.loss import InfoNCE


In [3]:
# packages
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader

from accelerate import Accelerator, notebook_launcher
from accelerate.utils import set_seed

  from .autonotebook import tqdm as notebook_tqdm


## Load local module

## Optional copy step

Copying to /tmp increases the speed, but the storage is limited to 1TB, so if we want to do full runs later, we'll need to find a different solution, or add some sort of loop that copies new data every 800GB or something

In [4]:

! du -h ../data/01_processed/ssl_sets/ob007_test.memmap/
! cd ../data/01_processed/ssl_sets/ob007_test.memmap/ && time find  -type f -name '*.npy' | xargs -P16 -IX cp -v X $TMPDIR/ 
! df -h ${TMPDIR:-/tmp}

16G	../data/01_processed/ssl_sets/ob007_test.memmap/
'./shard_00028.npy' -> '/tmp/shard_00028.npy'
'./shard_00021.npy' -> '/tmp/shard_00021.npy'
'./shard_00026.npy' -> '/tmp/shard_00026.npy'
'./shard_00013.npy' -> '/tmp/shard_00013.npy'
'./shard_00014.npy' -> '/tmp/shard_00014.npy'
'./shard_00005.npy' -> '/tmp/shard_00005.npy'
'./shard_00002.npy' -> '/tmp/shard_00002.npy'
'./shard_00030.npy' -> '/tmp/shard_00030.npy'
'./shard_00015.npy' -> '/tmp/shard_00015.npy'
'./shard_00012.npy' -> '/tmp/shard_00012.npy'
'./shard_00027.npy' -> '/tmp/shard_00027.npy'
'./shard_00020.npy' -> '/tmp/shard_00020.npy'
'./shard_00029.npy' -> '/tmp/shard_00029.npy'
'./shard_00031.npy' -> '/tmp/shard_00031.npy'
'./shard_00003.npy' -> '/tmp/shard_00003.npy'
'./shard_00004.npy' -> '/tmp/shard_00004.npy'
'./shard_00007.npy' -> '/tmp/shard_00007.npy'
'./shard_00000.npy' -> '/tmp/shard_00000.npy'
'./shard_00009.npy' -> '/tmp/shard_00009.npy'
'./shard_00018.npy' -> '/tmp/shard_00018.npy'
'./shard_00011.npy' -> '/tm

In [5]:
print(f"Is CUDA initialized? {torch.cuda.is_initialized()}")

Is CUDA initialized? False


# Training function

In [6]:
# ! nvidia-smi

In [7]:

def training_loop(mixed_precision="bfloat16", seed: int = 42, batch_size: int = 64, epochs: int = 1):
    set_seed(seed)
    accelerator = Accelerator(mixed_precision=mixed_precision)
    
    ds = ShardedMemmapDataset('/tmp')
    dl = DataLoader(
        ds,  
        batch_size=batch_size, 
        num_workers=4, 
        pin_memory=True,
        prefetch_factor=4,
        shuffle=True
    )

    model = Smrt2Vec(d_model=128, n_layers=4, n_head=4, max_len=4096)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.02)
    criterion = InfoNCE(temperature=0.1)

    model, optimizer, dl = accelerator.prepare(model, optimizer, dl)

    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=6e-4,
        total_steps=len(dl) * epochs, 
        pct_start=0.05
    )
    
    scheduler = accelerator.prepare(scheduler)

    for epoch in range(epochs):
        model.train()

        if accelerator.is_main_process:
            progress_bar = tqdm(dl, desc=f"Epoch {epoch+1}/{epochs}", disable=not accelerator.is_local_main_process)
        else:
            progress_bar = dl
            
        for batch in progress_bar:
            c_proj, targets, mask_idx = model(batch)
            loss = criterion(c_proj, targets, mask_idx)
            
            optimizer.zero_grad()
            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()

            if accelerator.is_main_process and hasattr(progress_bar, 'set_postfix'):
                 progress_bar.set_postfix(loss=f"{loss.item():.4f}")

args = ("bf16", 42, 64, 1)
notebook_launcher(training_loop, args, num_processes=8)

Launching training on 8 CUDAs.


E0206 10:06:37.199000 579580 torch/distributed/elastic/multiprocessing/api.py:827] failed (exitcode: 1) local_rank: 0 (pid: 580483) of fn: training_loop (start_method: fork)
E0206 10:06:37.199000 579580 torch/distributed/elastic/multiprocessing/api.py:827] Traceback (most recent call last):
E0206 10:06:37.199000 579580 torch/distributed/elastic/multiprocessing/api.py:827]   File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/api.py", line 782, in _poll
E0206 10:06:37.199000 579580 torch/distributed/elastic/multiprocessing/api.py:827]     self._pc.join(-1)
E0206 10:06:37.199000 579580 torch/distributed/elastic/multiprocessing/api.py:827]   File "/usr/local/lib/python3.12/dist-packages/torch/multiprocessing/spawn.py", line 216, in join
E0206 10:06:37.199000 579580 torch/distributed/elastic/multiprocessing/api.py:827]     raise ProcessRaisedException(msg, error_index, failed_process.pid)
E0206 10:06:37.199000 579580 torch/distributed/elastic/multiproces

ChildFailedError: 
============================================================
training_loop FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2026-02-06_10:06:37
  host      : dgx170.cm.cluster
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 580483)
  error_file: /tmp/torchelastic_c_vp0srd/none_tbvzj2z1/attempt_0/0/error.json
  traceback : Traceback (most recent call last):
    File "/usr/local/lib/python3.12/dist-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
      return f(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^
    File "/tmp/ipykernel_579580/2407183425.py", line 3, in training_loop
      accelerator = Accelerator(mixed_precision=mixed_precision)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    File "/usr/local/lib/python3.12/dist-packages/accelerate/accelerator.py", line 461, in __init__
      self.state = AcceleratorState(
                   ^^^^^^^^^^^^^^^^^
    File "/usr/local/lib/python3.12/dist-packages/accelerate/state.py", line 912, in __init__
      PartialState(cpu, **kwargs)
    File "/usr/local/lib/python3.12/dist-packages/accelerate/state.py", line 301, in __init__
      self.set_device()
    File "/usr/local/lib/python3.12/dist-packages/accelerate/state.py", line 838, in set_device
      device_module.set_device(self.device)
    File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 584, in set_device
      torch._C._cuda_setDevice(device)
    File "/usr/local/lib/python3.12/dist-packages/torch/cuda/__init__.py", line 412, in _lazy_init
      raise RuntimeError(
  RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
  
============================================================

In [None]:
args = ("bf16", 42, 64, 1)
notebook_launcher(training_loop, args, num_processes=8)

In [None]:

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.02)
criterion = InfoNCE(temperature=0.1).to(device)
EPOCHS = 4

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=6e-4,
    total_steps=len(ssl_dl) * EPOCHS,
    pct_start=0.05
)
model, optimizer, ssl_dl, scheduler = accelerator.prepare(
        model, optimizer, ssl_dl, scheduler
)

In [None]:
model.train()

for epoch in range(EPOCHS):
    progress_bar = tqdm(ssl_dl, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for i, batch in enumerate(progress_bar):
        batch = batch.to(device, non_blocking=True)

        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            c_proj, targets, mask_idx = model(batch)
            loss = criterion(c_proj, targets, mask_idx)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if i % 10 == 0:
            progress_bar.set_postfix(
                loss=f"{loss.item():.4f}",
                lr=f"{scheduler.get_last_lr()[0]:.6f}"
            )

In [None]:
# torch.save(model.state_dict(), '../models/smrt2vec_8epoch.pt')

In [None]:
from smrt_foundation.dataset import LegacyMethylDataset, compute_log_normalization_stats
import polars as pl
TRAIN_PATH = '../data/01_processed/val_sets/pacbio_standard_train.parquet'
VAL_PATH =  '../data/01_processed/val_sets/pacbio_standard_test.parquet'
KINETICS_FEATURES = ['fi', 'fp', 'ri', 'rp']

df = pl.read_parquet(TRAIN_PATH).head(1_000_000)
train_means, train_stds = compute_log_normalization_stats(df, KINETICS_FEATURES)

methyl_train_ds = LegacyMethylDataset(TRAIN_PATH, train_means, train_stds, context=32, restrict_row_groups=5)
methyl_train_dl = DataLoader(methyl_train_ds,
                             # num_workers=8,
                             batch_size=256,
                             drop_last=True,
                             persistent_workers=False,
                             # prefetch_factor=5
                            )

methyl_val_ds = LegacyMethylDataset(VAL_PATH,
                                      means=train_means,
                                      stds=train_stds,
                                      context=32,
                                      restrict_row_groups=5)
methyl_val_dl = DataLoader(methyl_val_ds,
                        batch_size=256,
                        drop_last=True,
                        persistent_workers=False,
                        prefetch_factor=None)



In [None]:
df.head()

In [None]:
import copy
import torch
from tqdm import tqdm
from smrt_foundation.probe import SingleIdxProbe

EPOCHS = 5
DEVICE = torch.device('cuda')
encoder_clone = copy.deepcopy(model.encoder)
# Fixed: used DEVICE instead of device
probe = SingleIdxProbe(encoder_clone, freeze_encoder=False).to(DEVICE)

optimizer = torch.optim.AdamW([
    {'params': probe.encoder.parameters(), 'lr': 5e-7},
    {'params': probe.head.parameters(), 'lr': 3e-5}
])

criterion = torch.nn.BCEWithLogitsLoss()
loss_history = []

for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}")
    probe.train()
    running_loss = 0.0
    for i, batch in enumerate(tqdm(methyl_train_dl)):
        inputs = batch['data'].to(DEVICE)
        labels = batch['label'].to(DEVICE)

        optimizer.zero_grad()
        logits = probe(inputs)
        loss = criterion(logits, labels.unsqueeze(1).to(torch.float32))
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if (i + 1) % 100 == 0:
            loss_history.append(running_loss / 100)
            running_loss = 0.0

    probe.eval()
    sample_count = 0
    sample_correct = 0
    # Added torch.no_grad() for validation efficiency
    with torch.no_grad():
        for batch in tqdm(methyl_val_dl):
            inputs = batch['data'].to(DEVICE)
            labels = batch['label'].to(DEVICE)

            logits = probe(inputs)
            preds = logits > 0
            correct = labels == preds.squeeze(-1)
            sample_count += correct.shape[0]
            sample_correct += correct.sum()
    
    print(f"epoch val top1_acc: {sample_correct/sample_count}")