In [1]:
from src.dataset import MethylIterableDataset, compute_log_normalization_stats
from src.evaluate import evaluate
from src.train import train
from src.model import MethylCNN, FeatureSet
import polars as pl
from torch.utils.data import DataLoader
from torch import nn

import torch
import numpy as np
from tqdm import tqdm

In [2]:
device = torch.device(
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(device)

mps


In [3]:
train_parquet = 'data/processed/standard_600k_32_train.parquet'
test_parquet = 'data/processed/standard_600k_32_test.parquet'


In [4]:
subset_q = (
    pl.scan_parquet(train_parquet,
                    schema = {'read_name': pl.String,
                              'cg_pos': pl.Int64,
                              'seq': pl.String,
                              'fi': pl.List(pl.UInt16),
                              'fp': pl.List(pl.UInt16),
                              'ri': pl.List(pl.UInt16),
                              'rp': pl.List(pl.UInt16),
                              'label': pl.Int32
                              })
    .head(1_000_000)
    )

subset_df = subset_q.collect()

KINETICS_FEATURES = ['fi', 'fp', 'ri', 'rp']

train_means, train_stds = compute_log_normalization_stats(subset_df, KINETICS_FEATURES)
print(train_means, train_stds)

{'fi': 3.2361628985125575, 'fp': 3.0442874824956117, 'ri': 3.2246924771219243, 'rp': 3.0464773632625053} {'fi': 0.6563387717253599, 'fp': 0.44864215501330723, 'ri': 0.650989218017293, 'rp': 0.44927033472968564}


In [5]:
workers=8
batch_size=2**13

# train
train_ds = MethylIterableDataset(train_parquet,
                                    means=train_means,
                                    stds=train_stds,
                                    context = 32)
train_dl = DataLoader(train_ds,
                         batch_size=batch_size,
                         num_workers=workers,
                         pin_memory=True,
                         persistent_workers=True,
                         prefetch_factor=64)

# test
test_ds = MethylIterableDataset(test_parquet,
                                   means=train_means,
                                   stds=train_stds,
                                   context = 32)
test_dl = DataLoader(test_ds,
                        batch_size=batch_size,
                        num_workers=workers,
                        pin_memory=True,
                        persistent_workers=True,
                        prefetch_factor=64)

In [6]:
for batch in tqdm(train_dl):
  batch_i = batch
  break

  0%|          | 0/10097 [00:03<?, ?it/s]


In [7]:
model_all = MethylCNN(features = FeatureSet.ALL)
model_all.to(device)

model_all_total_params = sum(p.numel() for p in model_all.parameters() if p.requires_grad)

print(f'Total trainable parameters: {model_all_total_params}')

print(model_all)

criterion_all = nn.CrossEntropyLoss()
optimizer_all = torch.optim.Adam(model_all.parameters(), lr=0.002)

training_stats_all = train(model_all,
                                 train_dl,
                                 test_dl,
                                 epochs = 4,
                                 criterion=criterion_all,
                                 optimizer = optimizer_all,
                                 device=device)

Total trainable parameters: 512114
MethylCNN(
  (extractor): Sequential(
    (0): ResBlock(
      (bn1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv1d(8, 16, kernel_size=(7,), stride=(1,), padding=(3,))
      (bn2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(16, 16, kernel_size=(7,), stride=(1,), padding=(3,))
      (relu): ReLU(inplace=True)
      (residual): Sequential(
        (0): Conv1d(8, 16, kernel_size=(1,), stride=(1,), bias=False)
      )
    )
    (1): ResBlock(
      (bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
      (bn2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
      (relu): ReLU(inplace=True)
      (residual): Sequential()
    )
  

  8%|â–Š         | 806/10097 [03:01<34:47,  4.45it/s] 


KeyboardInterrupt: 