In [14]:
from src.dataset import MethylIterableDataset
import polars as pl
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import pyarrow.parquet as pq
import altair as alt
import itertools
import torch

In [15]:
train_parquet = 'data/processed/standard_600k_32_train.parquet'
test_parquet = 'data/processed/standard_600k_32_test.parquet'

In [16]:
subset_q = (
    pl.scan_parquet(train_parquet,
                    schema = {'read_name': pl.String,
                              'cg_pos': pl.Int64,
                              'seq': pl.String,
                              'fi': pl.List(pl.UInt16),
                              'fp': pl.List(pl.UInt16),
                              'ri': pl.List(pl.UInt16),
                              'rp': pl.List(pl.UInt16),
                              'label': pl.Int32
                              })
    .head(1_000_000)
    )


subset_df = subset_q.collect()

KINETICS_FEATURES = ['fi', 'fp', 'ri', 'rp']
def compute_log_normalization_stats(df, features, epsilon=1):
    means = {col: (df[col].explode() + epsilon).log().mean() for col in features}
    stds = {col: (df[col].explode() + epsilon).log().explode().std() for col in features}
    return means, stds

train_means, train_stds = compute_log_normalization_stats(subset_df, KINETICS_FEATURES)

In [17]:
subset_df.head()

read_name,cg_pos,seq,fi,fp,ri,rp,label
str,i64,str,list[u16],list[u16],list[u16],list[u16],i32
"""m64168_200820_000733/25101908/…",6898,"""CTCCAACAAACAAAACGGACCAAAACAAAG…","[35, 39, … 15]","[27, 12, … 20]","[37, 48, … 32]","[42, 15, … 24]",0
"""m64168_200823_191315/38798620/…",7404,"""CTCTCCCAGGTGCAACGTGGTTCTGAATCT…","[27, 38, … 36]","[15, 10, … 27]","[15, 22, … 28]","[21, 28, … 28]",1
"""m64168_200820_000733/5441406/c…",6298,"""TTTTAAAAGTGAACTCGGACACCACAGACT…","[18, 30, … 32]","[17, 30, … 21]","[18, 35, … 11]","[14, 20, … 29]",0
"""m64168_200820_000733/3081708/c…",5254,"""AAATTTAACCCTAAACGCATTTGAAACAGA…","[19, 23, … 29]","[13, 19, … 22]","[29, 37, … 37]","[21, 23, … 27]",0
"""m64168_200823_191315/168626404…",3944,"""CAAAATATTGAAAACCGCATAAATATTCAT…","[19, 48, … 29]","[10, 24, … 25]","[16, 12, … 46]","[34, 78, … 37]",1


In [21]:
print(subset_df.schema)

Schema([('read_name', String), ('cg_pos', Int64), ('seq', String), ('fi', List(UInt16)), ('fp', List(UInt16)), ('ri', List(UInt16)), ('rp', List(UInt16)), ('label', Int32)])


In [18]:
it_workers=8
batch_size=1
restrict_row_groups = 0

train_ds = MethylIterableDataset(train_parquet,
                                    means=train_means,
                                    stds=train_stds,
                                    context = 32,
                                    restrict_row_groups = 1,
                                    single_strand=False)
train_dl = DataLoader(train_ds,
                      batch_size=batch_size,
                      drop_last=True,
                      num_workers=it_workers,pin_memory=False,
                      persistent_workers=True,
                      prefetch_factor=32)


In [23]:
next(iter(train_dl))['read_name']

KeyError: 'read_name'

In [None]:
all_labels = [sample['label'].item() for sample in train_ds]
distribution = pl.Series("labels", all_labels).value_counts()
print(distribution)

shape: (2, 2)
┌────────┬────────┐
│ labels ┆ count  │
│ ---    ┆ ---    │
│ i64    ┆ u32    │
╞════════╪════════╡
│ 0      ┆ 128175 │
│ 1      ┆ 134396 │
└────────┴────────┘


: 

In [None]:
all_labels = []
for batch in train_dl:
    all_labels.append(batch['label'])
full_label_tensor = torch.cat(all_labels)
distribution = pl.Series("labels", full_label_tensor.numpy()).value_counts()
print(distribution)

In [None]:
sweep_df = pl.read_parquet('output/v0.4.4-sweep_results.parquet')
sweep_df

lr,weight_decay,train_losses,test_losses,test_accs,best_test_loss
f64,f64,list[f64],list[f64],list[f64],f64
0.01,0.01,"[0.734044, NaN]","[0.692952, 901.922286]","[0.510083, 0.489917]",0.692952
0.01,0.001,"[0.731824, 0.692888]","[0.692954, 0.692952]","[0.510083, 0.510083]",0.692952
0.01,0.0001,"[0.737333, 0.692888]","[0.692955, 0.692955]","[0.510083, 0.510083]",0.692955
0.01,0.00001,"[0.704972, 0.692887]","[0.692953, 0.692952]","[0.510083, 0.510083]",0.692952
0.003162,0.01,"[0.316287, 0.275236]","[0.285029, 0.271616]","[0.873582, 0.880099]",0.271616
…,…,…,…,…,…
0.000003,0.00001,"[0.53968, 0.46895]","[0.488932, 0.453133]","[0.760943, 0.783554]",0.453133
0.000001,0.01,"[0.603318, 0.532352]","[0.554535, 0.513639]","[0.712748, 0.743301]",0.513639
0.000001,0.001,"[0.599527, 0.535426]","[0.556, 0.5181]","[0.712634, 0.74088]",0.5181
0.000001,0.0001,"[0.598603, 0.534307]","[0.554308, 0.517027]","[0.713056, 0.740919]",0.517027


In [None]:
alt.Chart(sweep_df).mark_circle(size=200).encode(
    alt.X('lr').scale(type='log'),
    alt.Y('weight_decay').scale(type='log'),
    alt.Size('best_test_loss:Q')#.scale(scheme="turbo")
)

In [None]:
min_idx = sweep_df['best_test_loss'].arg_min()
sweep_df.row(min_idx)

(0.001,
 0.01,
 [0.3012624694368021, 0.26937795044609275],
 [0.27567068234517866, 0.2672272899667627],
 [0.8780596567648372, 0.8824523465412862],
 0.2672272899667627)

In [None]:
np.geomspace(5e-2, 5e-3, 5)

array([0.05      , 0.02811707, 0.01581139, 0.0088914 , 0.005     ])

In [None]:
np.geomspace(5e-3, 5e-4, 10)

array([0.005     , 0.00387132, 0.00299742, 0.00232079, 0.00179691,
       0.00139128, 0.00107722, 0.00083405, 0.00064577, 0.0005    ])