In [5]:
from src.dataset import MethylIterableDataset
import polars as pl
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import pyarrow.parquet as pq

In [6]:
train_parquet = 'data/processed/standard_600k_32_train.parquet'
test_parquet = 'data/processed/standard_600k_32_test.parquet'


In [7]:
pq.read_metadata(test_parquet).row_group(1).num_rows

264720

In [8]:
subset_q = (
    pl.scan_parquet(train_parquet,
                    schema = {'read_name': pl.String,
                              'cg_pos': pl.Int64,
                              'seq': pl.String,
                              'fi': pl.List(pl.UInt16),
                              'fp': pl.List(pl.UInt16),
                              'ri': pl.List(pl.UInt16),
                              'rp': pl.List(pl.UInt16),
                              'label': pl.Int32
                              })
    .head(1_000_000)
    )


subset_df = subset_q.collect()

KINETICS_FEATURES = ['fi', 'fp', 'ri', 'rp']
def compute_log_normalization_stats(df, features, epsilon=1):
    means = {col: (df[col].explode() + epsilon).log().mean() for col in features}
    stds = {col: (df[col].explode() + epsilon).log().explode().std() for col in features}
    return means, stds

train_means, train_stds = compute_log_normalization_stats(subset_df, KINETICS_FEATURES)
print(train_means, train_stds)

{'fi': 3.2361628985125575, 'fp': 3.0442874824956117, 'ri': 3.2246924771219243, 'rp': 3.0464773632625053} {'fi': 0.6563387717253599, 'fp': 0.44864215501330723, 'ri': 0.650989218017293, 'rp': 0.44927033472968564}


In [9]:
vocab = {'A':0, 'T':1, 'C':2, 'G':3}
np.stack(subset_df['seq'].str.split("").list.eval(pl.element().replace_strict(vocab)).to_numpy()).shape

(1000000, 32)

In [9]:
it_workers=8
batch_size=2**13

#train
train_ds  = MethylIterableDataset(train_parquet,
                                    means=train_means,
                                    stds=train_stds,
                                    context = 32)
train_dl = DataLoader(train_ds,
                         batch_size=batch_size,
                         num_workers=it_workers,
                         pin_memory=True,
                         persistent_workers=True,
                         prefetch_factor=64)

In [None]:
for batch in tqdm(train_dl):
  batch_i = batch

10103it [03:05, 54.54it/s]                            


In [10]:
len(train_ds)

82710056