In [1]:
from src.dataset import MethylIterableDataset, compute_log_normalization_stats
from src.evaluate import evaluate
from src.train import train
from src.model import MethylCNN, FeatureSet
import polars as pl
from torch.utils.data import DataLoader
from torch import nn

import torch
import numpy as np
from tqdm import tqdm

In [2]:
device = torch.device(
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(device)

mps


In [3]:
train_parquet = 'data/processed/standard_600k_32_train.parquet'
test_parquet = 'data/processed/standard_600k_32_test.parquet'


In [4]:
subset_q = (
    pl.scan_parquet(train_parquet,
                    schema = {'read_name': pl.String,
                              'cg_pos': pl.Int64,
                              'seq': pl.String,
                              'fi': pl.List(pl.UInt16),
                              'fp': pl.List(pl.UInt16),
                              'ri': pl.List(pl.UInt16),
                              'rp': pl.List(pl.UInt16),
                              'label': pl.Int32
                              })
    .head(1_000_000)
    )

subset_df = subset_q.collect()

KINETICS_FEATURES = ['fi', 'fp', 'ri', 'rp']

train_means, train_stds = compute_log_normalization_stats(subset_df, KINETICS_FEATURES)
print(train_means, train_stds)

{'fi': 3.2361628985125575, 'fp': 3.0442874824956117, 'ri': 3.2246924771219243, 'rp': 3.0464773632625053} {'fi': 0.6563387717253599, 'fp': 0.44864215501330723, 'ri': 0.650989218017293, 'rp': 0.44927033472968564}


In [5]:
subset_df.head()

read_name,cg_pos,seq,fi,fp,ri,rp,label
str,i64,str,list[u16],list[u16],list[u16],list[u16],i32
"""m64168_200820_000733/25101908/…",6898,"""CTCCAACAAACAAAACGGACCAAAACAAAG…","[35, 39, … 15]","[27, 12, … 20]","[37, 48, … 32]","[42, 15, … 24]",0
"""m64168_200823_191315/38798620/…",7404,"""CTCTCCCAGGTGCAACGTGGTTCTGAATCT…","[27, 38, … 36]","[15, 10, … 27]","[15, 22, … 28]","[21, 28, … 28]",1
"""m64168_200820_000733/5441406/c…",6298,"""TTTTAAAAGTGAACTCGGACACCACAGACT…","[18, 30, … 32]","[17, 30, … 21]","[18, 35, … 11]","[14, 20, … 29]",0
"""m64168_200820_000733/3081708/c…",5254,"""AAATTTAACCCTAAACGCATTTGAAACAGA…","[19, 23, … 29]","[13, 19, … 22]","[29, 37, … 37]","[21, 23, … 27]",0
"""m64168_200823_191315/168626404…",3944,"""CAAAATATTGAAAACCGCATAAATATTCAT…","[19, 48, … 29]","[10, 24, … 25]","[16, 12, … 46]","[34, 78, … 37]",1


In [8]:
workers=8
batch_size=2**13

# train
train_ds = MethylIterableDataset(train_parquet,
                                    means=train_means,
                                    stds=train_stds,
                                    context = 32,
                                    restrict_row_groups=2)
train_dl = DataLoader(train_ds,
                         batch_size=batch_size,
                         num_workers=workers,
                         pin_memory=True,
                         persistent_workers=True,
                         prefetch_factor=64)

# test
test_ds = MethylIterableDataset(test_parquet,
                                   means=train_means,
                                   stds=train_stds,
                                   context = 32,
                                   restrict_row_groups=2)
test_dl = DataLoader(test_ds,
                        batch_size=batch_size,
                        num_workers=workers,
                        pin_memory=True,
                        persistent_workers=True,
                        prefetch_factor=64)

In [9]:
len(train_ds)

525142

In [11]:
for batch in tqdm(train_dl):
  batch_i = batch
  break

  0%|          | 0/65 [00:00<?, ?it/s]


In [12]:
model_all = MethylCNN(features = FeatureSet.ALL)
model_all.to(device)

model_all_total_params = sum(p.numel() for p in model_all.parameters() if p.requires_grad)

print(f'Total trainable parameters: {model_all_total_params}')

print(model_all)

criterion_all = nn.CrossEntropyLoss()
optimizer_all = torch.optim.Adam(model_all.parameters(), lr=0.002)

training_stats_all = train(model_all,
                                 train_dl,
                                 test_dl,
                                 epochs = 1,
                                 criterion=criterion_all,
                                 optimizer = optimizer_all,
                                 device=device)

Total trainable parameters: 512114
MethylCNN(
  (extractor): Sequential(
    (0): ResBlock(
      (bn1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv1d(8, 16, kernel_size=(7,), stride=(1,), padding=(3,))
      (bn2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(16, 16, kernel_size=(7,), stride=(1,), padding=(3,))
      (relu): ReLU(inplace=True)
      (residual): Sequential(
        (0): Conv1d(8, 16, kernel_size=(1,), stride=(1,), bias=False)
      )
    )
    (1): ResBlock(
      (bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
      (bn2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(16, 16, kernel_size=(5,), stride=(1,), padding=(2,))
      (relu): ReLU(inplace=True)
      (residual): Sequential()
    )
  

66it [00:15,  4.17it/s]                        
66it [00:06, 10.54it/s]                        

 avg epoch train loss: 0.6667
         test set loss: 0.4929
 test set accuracy: 0.7615
Completed training for 1 epochs



