# Background

Methylation is a modification of DNA that controls gene expression. There are many different types of methylation, but one of the most prominent is methylation at CpG sites (see [this wikipedia article](https://en.wikipedia.org/wiki/CpG_site)). When a CpG site is methylated, it is chemically changed, meaning that the biophysical properties of that section of the DNA molecule are slightly different than "canonical structure." Therefore, if we had data which contained a signal of the properties of the molecule, we could detect methylation.

The data that I'm using in this notebook was sequenced using PacBio Sequell technology. Here's PacBio's [advertisement](https://youtu.be/_lD8JyAbwEo?si=zILsY74u6tJyXsmJ)... it's not the most detailed, and fails to mention that the data you get from the system also includes the temporal information from the polymerase that does the sequencing. These include how long the light flash lasted (from the nucleotide incorporation), and the time until the next light flash. This is included for both the forward and reverse strand of the molecule that was sequenced.

The polymerase spans around 8 nucleotides, and since methylation adds some resistance, there is a difference in these signals for the methylated an unmethylated CpG sites. The issue of course, is that the kinetics are not just affected by methylation, but also by regular nucleotide context. So if we want to predict methlyation based on kinetics, the model needs to learn how to deconflate the kinetics signal associated with methylation from the affect that different nucleotides have on the polymerase behaviour.

PacBio actually already figured this out with [primrose](https://github.com/mattoslmp/primrose). But I thought it would be a nice exercise before setting off into uncharted territory of genetics models to

1. reimplement primrose with a CNN as PacBio did

2. Pretrain with contrastive learning and try to beat that with a classification downstream task.

Yes, I should get to the science, but I also know that deep learning is hard even on easy problems, so I think I'll spend a bit more time on problems I know are solvable: this model did not learn at all until I added a final linear layer for classification, which I did on a hunch. Another reason that reimplementing Primrose is interesting is that the source code for Primrose is not public.

## Dataset
Regarding the dataset that's present here: I filtered 32 base windows centered on CpG sites. They come from two [BAM](https://en.wikipedia.org/wiki/BAM_(file_format)) files, one artificially fully methylated and the other artificially fully unmethylated. I don't have enough background in geneitcs/chemistry to know whether this imbues the data with some artifacts that make a model non-generalizeable to real data, but it's implied that PacBio believes that it's generalizeable (PrimRose and its successors run on-device to make methylation calls).

I really don't like interacting with BAM files, so I made a script to take out the CpG sites and place them into parquet columnar files. Below is a summary of the features:

  ``read_name``: The unique identifier for the read. This corresponds to one zero-mode-waveguide's measurements in the video.

  ``cg_pos``: The position in the read, indexed from 0, where the CpG site occurs

  ``seq``: The 32 bp window of nucleotides. This is a consensus as a result of between 5 and 30 passes over the same section of DNA by the polymerase. PacBio is not specific about how it chooses which base to call for each position based on the 5-30 passes. But someo of them are called by the google DeepConsensus algorithm, which has a little more background.

  ``fi``: The interpulse duration at each position in the window. This is how long it took the polymerase to reach the next base. This is for the forward strand. A 'kinetics' feature.

  ``fp``: The pulse width at each position in the window. This is how long the light flash lasted. A 'kinetics' feature.

  ``ri``: same as window_fi, but the reverse strand. A 'kinetics' feature.

  ``rp``: same as window_fp, but the reverse strand. A 'kinetics' feature.

Note: Units for the kinetics features (window_fi, window_fp, etc) are in "frames" so we can think of it as an arbitrary time unit.

# Implementation plan
## Now
### Architecture
The model should take a  (Channels, Features, Context)=(1,8,32) tensor and make a binary prediction about the methylation status. 4 of the features are one-hot encoded nucleotides, and the other 4 are the kinetic features as listed above.  At first shis should just be a convolutional model, with a linear classification head. To start with, I used 200 reads of data (about 8k CpG training samples), and then moved to 10k reads of data (1.3m CpG training samples). A read has c. 12k bases each, but CpG sites are under-represented in comparison to what we'd expect randomly since they have a higher mutation rate (due to their methylation).
### Results
The model achieved 83 percent accuracy on the 10k reads dataset test partition. It seems to plateau there, achieving similar results for the last 8 epochs, and I suspect that before adding more data, which I have, I could make some model-architecture changes that would improve this. PacBio's primrose achieve's 85 percent accuracy, but they don't specify what part of the data was test/train. A group in china was able to get to 0.90 test set accuracy. So for a first attempt, I'm not too dissapointed (assuming I haven't polluted my data in some way), but it would be nice to see if we can improve past 0.85 and then past 0.90.

## Future
### Contrastive pretraining
Train a JEPA with my groups abundant pacbio data. Focus on learning the associations between nucletides and kinetics. Take the encoder, add a classification head, and use it to classify methylation.
### Transformers
My intution is that a transformer encoder would be much better for the nucleotide section than a CNN. So for future implementations it would be nice to use this for the section of the model that processes the nucleotides.
### JAX
Yep. I still want to pursue jax. Maybe I'll outgrow it, but after writing this first implementation in pytorch, I still enjoy jax more

# Ideas
## Test the independent efficacy of kinetics and nucleotides as features
Premise: What if one of the nucleotides/kinetics data streams is dominating the prediction? The model could be learning just that some nucleotide contexts occur for methylation, and some don't.

Test: Make a nucleotide only and a kinetics only classifier to sanity test whether one is contributing all the predictive information.

Result:


## Treat nucleotides and kinetics seperately
Make a new kinetics+nucleotides model that has seperate towers for the two datatypes. Guessing that architectural choices probably don't generalize between the two. A pretrained transformer would probably be the best for the nucleotides, but for now I'd like to try with just the convolutional type since that was effective for primrose. The next step might be a transformer based JEPA model for pretraining that I could use for downstream methylation detection.

## How much context do we need?
I chose 32 since that seemed like it woud definitely be large enough (based on the eda plots below). But smaller would make training faster, and carry less risk of running into multiple CG sites per sample. Could be nice to test different context sizes and compare performance.

In [1]:
# Mount google drive
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


# Questions
1. Do we need to normalize categorical (nucleotide) data? Right now I normalized the kinetics features, but not the one-hot encodings. To me, this was the intuitive obvious solution, but a prof in my deparment keeps bugging me with "how do you know that the nucleotides are not contributing all the signal information?? how are you choosing to normalize those??" And I keep thinking... I don't want to normalize one-hot encodings. I mean, I could use label smoothing, but that doesn't seem like the same thing.

2. How can we change the architecture so that the model can learn associations between nucleotide contexts and their kinetic signatures? Based on the fact that the kinetics-only model performs almost as well as the kinetics+nucleotides model, I think that the nucleotides are being under-utilized. Simply looking at the kinetics might allow the model to predict methylation to limited accuracy, but as we can see from the EDA plots, CG has as particular kinetic signature distinct from the average surroundings even when it's unmethylated.

3. How much risk is there in the fact that sometimes we get two CG instances in the same sample? My hunch is that it would make our test set prediction optimistic, since there is a greater methylation signal (in reality I don't think adjacent CG sites have P(both methylated)=1). On the otherhand, when we move to contrastive pretraining, I suspect that a larger context window will be really beneficial, so I think this is something worth thinking about.

4. There's a text box down by training explaining that I needed to switch to a larger batch size to get good epoch training times. But the question is, how do we balance large batch sizes with generalizability? Is this dependent on the optimizer? Or do we just test things and see what works? I would guess the last option is most common, but there must be limits since I was previously told to "always use a batch size less than 128."

5. Should we expect that a pretrained contrastive learning model will outperform this, when it is used in downstream finetuning? For me it makes sense that it would help, but I'd be curious to know what your intuition is.

In [2]:
!cp /content/gdrive/MyDrive/methylation/train_cg_32_10000.parquet /content/
!cp /content/gdrive/MyDrive/methylation/test_cg_32_10000.parquet /content/



train_parquet = '/content/train_cg_32_10000.parquet'
test_parquet = '/content/test_cg_32_10000.parquet'


In [3]:
!pip install polars "vegafusion[embed]>=1.5.0" vl-convert-python



In [4]:
import torch
import polars as pl
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import altair as alt
alt.data_transformers.enable("vegafusion")
from typing import Dict, List, Any, Optional
from operator import itemgetter
import torch.nn.functional as F
import os
from torch import nn
from tqdm import tqdm
import numpy as np

pl.Config(fmt_str_lengths=50)

device = torch.device(
    "mps" if torch.backends.mps.is_available() # this actually works pretty well locally
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(device)


cuda


# EDA

## Glimpse of the dataset

In [5]:
train_df = pl.read_parquet(train_parquet,  schema = {'read_name': pl.String,
                                                     'cg_pos': pl.Int64,
                                                     'seq': pl.String,
                                                     'fi': pl.List(pl.Int64),
                                                     'fp': pl.List(pl.Int64),
                                                     'ri': pl.List(pl.Int64),
                                                     'rp': pl.List(pl.Int64),
                                                     'label': pl.Int32
                                                     }
                           )
train_df.head(10)

read_name,cg_pos,seq,fi,fp,ri,rp,label
str,i64,str,list[i64],list[i64],list[i64],list[i64],i32
"""m64168_200823_191315/1837517/ccs""",10282,"""AGATCTATTACACAACGTGGTGACCATAGCTA""","[52, 29, … 21]","[19, 10, … 14]","[62, 14, … 12]","[29, 25, … 13]",1
"""m64168_200820_000733/263719/ccs""",5955,"""GAAGGGGCTGATGCCCGGCCTCAGAGGTTAAG""","[19, 60, … 39]","[12, 49, … 24]","[31, 52, … 28]","[12, 17, … 54]",0
"""m64168_200823_191315/263868/ccs""",2321,"""GACGGGGCAGCTGGCCGGGCGGGGGGGCTGAC""","[23, 59, … 69]","[12, 25, … 9]","[35, 48, … 68]","[18, 11, … 28]",1
"""m64168_200823_191315/1442515/ccs""",7697,"""TGGAATGCAATGGAACGGAATGGAGTGGGATG""","[16, 11, … 21]","[16, 33, … 33]","[22, 18, … 12]","[37, 12, … 11]",1
"""m64168_200823_191315/1638483/ccs""",7803,"""GATGTACTCCACTTTCGAGCCTGATTCAGAAA""","[21, 21, … 22]","[44, 21, … 13]","[17, 17, … 21]","[11, 26, … 41]",1
"""m64168_200820_000733/742/ccs""",7290,"""GCCTGGGCGACAGAGCGAGACTCCATCTCAAA""","[8, 11, … 24]","[11, 19, … 23]","[16, 28, … 60]","[40, 36, … 33]",0
"""m64168_200820_000733/327687/ccs""",7671,"""TGGCCTTAAGTGATCCGCCCACCTTGGCCTCC""","[24, 33, … 141]","[15, 17, … 41]","[29, 23, … 27]","[23, 22, … 15]",0
"""m64168_200820_000733/132160/ccs""",4461,"""GAAAACTGAGTCCCCCGTGAGGATCTTGTTTT""","[16, 11, … 17]","[15, 16, … 17]","[16, 23, … 14]","[23, 9, … 38]",0
"""m64168_200823_191315/1444273/ccs""",6638,"""GGAATAATTCCTTTCCGTCCTCTCGGGAACAG""","[6, 12, … 10]","[27, 10, … 9]","[17, 22, … 17]","[27, 10, … 13]",1
"""m64168_200823_191315/264947/ccs""",5944,"""CTTAAGAGAAACAAACGGCCTCAGGAAGGGCA""","[29, 37, … 18]","[19, 9, … 10]","[11, 44, … 27]","[21, 67, … 20]",1


## Mean IPD at Context Indices
The chart below shows that at least on average, we should be able to tell the difference between methylated and unmethylted CpG sites. The chart averages IPD at each index across all of the training samples independantly for the forward and reverse strands (indices 15 and 16 are the CpG site).

In [6]:
# df of means for methlated data at each index
pos_means = (
    pl.read_parquet(train_parquet)
    .filter(pl.col('label')==1)
    .select(pl.col("fi").alias("fwd"), pl.col("ri").alias("rev"))
    .unpivot(on=["fwd", "rev"], variable_name="strand", value_name="ipd_list")
    .with_columns(index=pl.int_ranges(start=0, end=pl.col("ipd_list").list.len()))
    .explode("index", "ipd_list")
    .rename({"ipd_list": "ipd"})
    .group_by("index", "strand").agg(pl.col("ipd").mean())
)
# df of means for unmethlated data at each index
neg_means = (
    pl.read_parquet(train_parquet)
    .filter(pl.col('label')==0)
    .select(pl.col("fi").alias("fwd"), pl.col("ri").alias("rev"))
    .unpivot(on=["fwd", "rev"], variable_name="strand", value_name="ipd_list")
    .with_columns(index=pl.int_ranges(start=0, end=pl.col("ipd_list").list.len()))
    .explode("index", "ipd_list")
    .rename({"ipd_list": "ipd"})
    .group_by("index", "strand").agg(pl.col("ipd").mean())
)

# join on index/strand
means = pos_means.join(
    neg_means, on=['index', 'strand'], suffix='_neg'
    ).with_columns((pl.col('ipd')-pl.col('ipd_neg')).alias('residual'))

# make a chart of the difference between pos/neg mean at each index
residual_chart = alt.Chart(means).mark_line().encode(
    alt.X("index:Q", title="Position", axis=alt.Axis(tickCount=16)),
    alt.Y("residual:Q", title="IPD Mean Difference", scale=alt.Scale(domain=(-3, 15), clamp=True)),
    alt.Color("strand:N", title="Strand"),
    ).properties(
    title = "Mean IPD Residual (Meth-Unmeth) Across CG Context",
    width=800,
    height=600
    ).configure_axis(
    labelFontSize=12,
    titleFontSize=14,
    grid=True
    ).configure_title(
    fontSize=16,
    anchor='middle'
    ).configure_legend(
    titleFontSize=12,
    labelFontSize=11
    )
# chart of unmethylated mean at each index
neg_chart = alt.Chart(neg_means).mark_line().encode(
    alt.X('index:Q',scale=alt.Scale(domain=(0, 31))),
    alt.Y('ipd:Q', title = 'mean IPD' ,scale=alt.Scale(domain=(24, 44), clamp=True)),
    alt.Color('strand')
    ).properties(
    title = "Mean Unmethylated IPD Across CG Context",
    width=800,
    height=600
    )
# chart of methylated mean at each index
pos_chart = alt.Chart(pos_means).mark_line().encode(
    alt.X('index:Q',scale=alt.Scale(domain=(0, 31))),
    alt.Y('ipd:Q', title = 'mean IPD' ,scale=alt.Scale(domain=(24, 44), clamp=True)),
    alt.Color('strand')
    ).properties(
    title = "Mean Methylated IPD Across CG Context",
    width=800,
    height=600
    )

# display the charts
alt.hconcat(neg_chart,pos_chart)




## How many CPG sites per sample?
Every sample has a CG at the center, but this does not guarantee that there are no CG's in the rest of the sample.

In [7]:
train_df_counts= train_df.with_columns(
    pl.col("seq").str.count_matches("CG").alias("cg_count")
)
print(train_df_counts.head())
alt.Chart(train_df_counts).mark_bar().encode(
    alt.X('cg_count:O'),
    alt.Y('count():Q')
).properties(
    width=700,
    height=500,
    title=f'Distribution of CG Instances in Training Context Windows ({len(train_df)} samples)'
)

shape: (5, 9)
┌─────────────┬────────┬─────────────┬────────────┬───┬────────────┬────────────┬───────┬──────────┐
│ read_name   ┆ cg_pos ┆ seq         ┆ fi         ┆ … ┆ ri         ┆ rp         ┆ label ┆ cg_count │
│ ---         ┆ ---    ┆ ---         ┆ ---        ┆   ┆ ---        ┆ ---        ┆ ---   ┆ ---      │
│ str         ┆ i64    ┆ str         ┆ list[i64]  ┆   ┆ list[i64]  ┆ list[i64]  ┆ i32   ┆ u32      │
╞═════════════╪════════╪═════════════╪════════════╪═══╪════════════╪════════════╪═══════╪══════════╡
│ m64168_2008 ┆ 10282  ┆ AGATCTATTAC ┆ [52, 29, … ┆ … ┆ [62, 14, … ┆ [29, 25, … ┆ 1     ┆ 1        │
│ 23_191315/1 ┆        ┆ ACAACGTGGTG ┆ 21]        ┆   ┆ 12]        ┆ 13]        ┆       ┆          │
│ 837517/ccs  ┆        ┆ ACCATAGCTA  ┆            ┆   ┆            ┆            ┆       ┆          │
│ m64168_2008 ┆ 5955   ┆ GAAGGGGCTGA ┆ [19, 60, … ┆ … ┆ [31, 52, … ┆ [12, 17, … ┆ 0     ┆ 1        │
│ 20_000733/2 ┆        ┆ TGCCCGGCCTC ┆ 39]        ┆   ┆ 28]        ┆ 54]     

# Normalization
Precompute on the training set.

In [8]:
def compute_normalization_stats(df):
    kinetic_features = ['fi', 'fp', 'ri', 'rp']
    means = {col: df[col].explode().mean() for col in kinetic_features}
    stds = {col: df[col].explode().explode().std() for col in kinetic_features}
    return means, stds
train_means, train_stds = compute_normalization_stats(train_df)
print(train_means, train_stds)

{'fi': 30.594821621789382, 'fp': 21.945176340699103, 'ri': 30.13117185101172, 'rp': 22.007529335023822} {'fi': 24.25535836619395, 'fp': 11.27049294825807, 'ri': 23.845802361655878, 'rp': 11.310865492241938}


# Dataset Definition
## Notes
1. this does not include a transform implementation.
2. turning the entire dataset into tensor inside __init__ resultined in a 2x speedup in training (before the tensor initiation was in getitem)

In [9]:
class MethylDataset(Dataset):
    """
    Dataset class for metylation data stored in a parquet file
    Reads data using polars and converts samples to pytorch tensors before get_item call
    Nucleotide sequence is one-hot encoded
    """
    def __init__(self, data_path: Path,
                 means: Optional[Dict[str, float]] = train_means,
                 stds: Optional[Dict[str, float]] = train_stds,
                 context: int = 32):
        '''
        Arguments:
        data_path: the path for the parquet file that contains either the training or test data (pos and neg)
        transform: an optional transform callable on a single sample
        '''
        self.means = means
        self.stds = stds
        self.context = 32
        self.kinetics_features = ['fi', 'fp', 'ri', 'rp']

        # initialize the dataframe, converting kinetic columns in to pl.array
        try:
             self.df = pl.read_parquet(data_path).with_columns([
                 pl.col("fi").list.to_array(self.context),
                 pl.col("fp").list.to_array(self.context),
                 pl.col("ri").list.to_array(self.context),
                 pl.col("rp").list.to_array(self.context),
                 ])
        # intialize with empty df if loading fails
        except:
             print(f"failed to read data given path: {data_path}")
             self.df = pl.DataFrame()

        self._dataset_len = len(self.df)
        self.vocab = {'A':0, 'T':1, 'C':2, 'G':3}
        self.vocab_size = len(self.vocab)
        self.kinetic_features = ['fi', 'fp', 'ri', 'rp']
        # define the sequence tensor
        self.seq_tensor = self._prep_seq_tensor()
        # define the kinetics tensor
        self.kinetics_tensor = self._prep_kinetics_tensor()
        # define the label tensor
        self.label_tensor = self._prep_label_tensor()

    def _prep_seq_tensor(self):
      seq_ints = (
          self.df['seq']
          .str.split("")
          .list.eval(
              pl.element().replace_strict(self.vocab)
          )
          .to_numpy()
      )
      return torch.tensor(np.stack(seq_ints), dtype=torch.long)

    def _prep_kinetics_tensor(self):
      kinetics_array = np.stack([self.df[col].to_numpy() for col in self.kinetics_features], axis=1)
      return torch.tensor(kinetics_array, dtype=torch.long)

    def _prep_label_tensor(self):
      label_array = self.df['label'].to_numpy()
      return torch.tensor(label_array, dtype=torch.long)

    def __len__(self):
          return self._dataset_len

    def __getitem__(self, idx):
         if idx >= len(self):
              raise IndexError("Index out of range")
         seq, kinetics, label = self.seq_tensor[idx], self.kinetics_tensor[idx], self.label_tensor[idx]
         # sequence data (requires one-hot encoding)
         seq_tensor_one_hot = F.one_hot(seq, num_classes=self.vocab_size).T
         return {
              'seq': seq_tensor_one_hot,
              'kinetics': kinetics,
              'label': label
              }



## Test out the dataset/dataloader

In [10]:
batch_size = 512

train_ds = MethylDataset(train_parquet)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

test_ds = MethylDataset(test_parquet)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=True, num_workers=4)

In [11]:
batch_example = next(iter(train_dl))
torch.cat((batch_example['seq'], batch_example['kinetics']), axis=1)

tensor([[[  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   0,   1,   0],
         [  0,   0,   0,  ...,   1,   0,   0],
         ...,
         [ 43,  11,  14,  ...,  17,  10,  30],
         [ 17,  27,  23,  ..., 116,  50,  64],
         [ 59,  12,  18,  ...,  56,  40,  21]],

        [[  0,   0,   0,  ...,   1,   0,   1],
         [  0,   0,   0,  ...,   0,   0,   0],
         [  0,   0,   1,  ...,   0,   0,   0],
         ...,
         [ 25,  22,  42,  ...,  18,  23,  21],
         [ 69,  21,  21,  ...,  12,  13,  17],
         [ 11,  45,  14,  ...,  28,  23,  28]],

        [[  0,   1,   1,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   1,   0,   0],
         [  0,   0,   0,  ...,   0,   0,   1],
         ...,
         [ 15,  21,  17,  ...,  21,  20,  24],
         [ 31,  24,  14,  ...,  13,  31,  25],
         [ 40,  19,  13,  ...,  20,  24,  47]],

        ...,

        [[  1,   1,   1,  ...,   0,   1,   1],
         [  0,   0,   0,  ...,   0,   0,   0]

### Label distribution check

Question: Are the classes evenly distributed?

Test: Mean of label column should be 0.5

Result: mean is very close to 0.5 -> classes are evenly distributed

In [12]:
train_ds.df['label'].mean(), len(train_ds)

(0.5023785352982572, 1361342)

In [13]:
len_list = np.array([len(train_ds.df.row(i, named=True)['fi']) for i in range(len(train_ds))])
np.min(len_list),np.argmin(len_list)

(np.int64(32), np.int64(0))

# Model Definitions

## Full model (nucleotides + kinetics)

In [14]:
class MethylCNN(nn.Module):
    def __init__(self, sequence_length: int = 32, in_channels:int = 8, num_classes: int = 2, dropout_p: float = 0.5):
        super().__init__()
        self.in_channels = in_channels
        self.sequence_length = sequence_length
        self.num_classes = num_classes

        # Convolution layers
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5, padding=1)
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)

        # calculate fc layer input with dummy passthrough
        self.fc_input_features = self._get_conv_output_size(sequence_length)

        # Linear layers
        self.fc1 = nn.Linear(in_features=self.fc_input_features, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=32)
        self.fc3 = nn.Linear(in_features=32, out_features=num_classes)

        # dropout
        self.dropout = nn.Dropout(p=dropout_p)

    def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)


        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv4(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        return x

    def _get_conv_output_size(self, sequence_length: int) -> int:
        """
        Calculates the flattened output size of the convolutional layers
        by performing a forward pass on random data of the right shape.
        """
        dummy_input = torch.randn(1, self.in_channels, sequence_length)
        output = self._extract_features(dummy_input)
        return output.numel()

    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        seq = batch['seq']#.permute(0, 2, 1)
        kinetics = batch['kinetics']#.unsqueeze(1)

        # the input is a dictionary, so convert to a tensor
        x = torch.cat([seq, kinetics], dim=1).to(self.conv1.weight.dtype) # -> [B, 8, L]

        x = self._extract_features(x)

        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.fc2(x)
        x = F.relu(x)
        x = self.dropout(x)

        logits = self.fc3(x)

        return logits

## Resticted model (nucleotides only)

In [15]:
class MethylCNN_Nucleotides(nn.Module):
    def __init__(self, sequence_length: int = 32, in_channels:int = 4, num_classes: int = 2, dropout_p: float = 0.5):
        super().__init__()
        self.in_channels = in_channels
        self.sequence_length = sequence_length
        self.num_classes = num_classes

        # Convolution layers
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=8, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=8, out_channels=16, kernel_size=5, padding=1)
        self.conv3 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.conv4 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1)

        # calculate fc layer input with dummy passthrough
        self.fc_input_features = self._get_conv_output_size(sequence_length)

        # Linear layers
        self.fc1 = nn.Linear(in_features=self.fc_input_features, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=32)
        self.fc3 = nn.Linear(in_features=32, out_features=num_classes)

        # dropout
        self.dropout = nn.Dropout(p=dropout_p)

    def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)


        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv4(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        return x

    def _get_conv_output_size(self, sequence_length: int) -> int:
        """
        Calculates the flattened output size of the convolutional layers
        by performing a forward pass on random data of the right shape.
        """
        dummy_input = torch.randn(1, self.in_channels, sequence_length)
        output = self._extract_features(dummy_input)
        return output.numel()

    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        seq = batch['seq']#.permute(0, 2, 1)
        kinetics = batch['kinetics']#.unsqueeze(1)

        # the input is a dictionary, so convert to a tensor
        x = seq.to(self.conv1.weight.dtype)

        x = self._extract_features(x)

        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.fc2(x)
        x = F.relu(x)
        x = self.dropout(x)

        logits = self.fc3(x)

        return logits

## Restricted model (kinetics only)

In [16]:
class MethylCNN_Kinetics(nn.Module):
    def __init__(self, sequence_length: int = 32, in_channels:int = 4, num_classes: int = 2, dropout_p: float = 0.5):
        super().__init__()
        self.in_channels = in_channels
        self.sequence_length = sequence_length
        self.num_classes = num_classes

        # Convolution layers
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=8, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=8, out_channels=16, kernel_size=5, padding=1)
        self.conv3 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.conv4 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1)

        # calculate fc layer input with dummy passthrough
        self.fc_input_features = self._get_conv_output_size(sequence_length)

        # Linear layers
        self.fc1 = nn.Linear(in_features=self.fc_input_features, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=32)
        self.fc3 = nn.Linear(in_features=32, out_features=num_classes)

        # dropout
        self.dropout = nn.Dropout(p=dropout_p)

    def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)


        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv4(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        return x

    def _get_conv_output_size(self, sequence_length: int) -> int:
        """
        Calculates the flattened output size of the convolutional layers
        by performing a forward pass on random data of the right shape.
        """
        dummy_input = torch.randn(1, self.in_channels, sequence_length)
        output = self._extract_features(dummy_input)
        return output.numel()

    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        seq = batch['seq']#.permute(0, 2, 1)
        kinetics = batch['kinetics']#.unsqueeze(1)

        # the input is a dictionary, so convert to a tensor
        x = kinetics.to(self.conv1.weight.dtype)

        x = self._extract_features(x)

        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.fc2(x)
        x = F.relu(x)
        x = self.dropout(x)

        logits = self.fc3(x)

        return logits

# Eval Loop Function

In [17]:
def evaluate_model(
    model: nn.Module,
    data_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Dict[str, float]:
    model.eval()
    running_loss: float = 0.0
    correct_predictions: int = 0
    total_samples: int = 0

    with torch.no_grad():
        for batch in tqdm(data_loader):
            labels: torch.Tensor = batch.pop("label").to(device)
            inputs: Dict[str, torch.Tensor] = {
                k: v.to(device) for k, v in batch.items()
            }

            logits: torch.Tensor = model(inputs)
            loss: torch.Tensor = criterion(logits, labels)

            running_loss += loss.item() * labels.size(0)

            _, predicted = torch.max(logits.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    epoch_loss: float = running_loss / total_samples
    epoch_acc: float = correct_predictions / total_samples
    return {"loss": epoch_loss, "accuracy": epoch_acc}

# Training Loop
## Things I'm thinking about
### Batch sizes and training time
At first training was really slow. It took around 4 minutes for one epoch at batch size 32, with a training dataset length of 1.3m examples. Gains were also slow, taking around 10 epochs to reach test set accuracy of 0.8. Not only that, I discovered that using an A100 GPU (big fancy gpu) was not any faster than a CPU! After thinking about it, I realized that my data is very small per sample, and so that gpu is completely unsaturated at this small of a batch size. That said, I was told that using a batch size above 64 leads to worse generalizability, since we miss out on a lot of the stochasticity in gradients that is inherent at smaller batch sizes. Despite that, I switched to larger and larger batch sizes, all the way to size 1024, which lowered the epoch training time to around 15 seconds, keeping everything else constant. The result is that we get test set accuracy of 0.8 after the same number of epochs, but in under 10 minutes. So it seems like generlizability was not damaged too bad, but I'm still wondering: What can we do to regains some of that stochasticity/generalizability that was lost? Some ideas:

1. Increase the learning rate. I scaled the batch size by approx 100x, so maybe start with 10x learning rate.
2. More dropout? I'm already using dropout on the linear layers with p=0.5
3. Warmup with the learning rate, and then decay. Seems like the test set acc plateus around 0.834. Perhaps that could be solved by using a lower training rate at late

In [18]:
def train_model(
    model: nn.Module,
    data_loader: DataLoader,
    epochs,
    criterion: nn.Module,
    optimizer,
    device: torch.device,
    ) -> Dict[str, float]:

    epoch_train_losses = []
    epoch_test_losses = []
    epoch_test_acc = []
    for epoch in range(epochs):
        running_loss = 0.0
        for i, batch in enumerate(tqdm(data_loader), 0):
            # remove the label from batch
            labels = batch.pop('label').to(device)
            # dictionary of features, with features on device
            inputs = {k: v.to(device) for k, v in batch.items()}
            # zero grads
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            # store training loss
            running_loss += loss.item()
        # calculate avg training epoch loss
        avg_epoch_loss = running_loss/len(train_dl)
        # add to running list
        epoch_train_losses.append(avg_epoch_loss)
        # get test set evaluation stats
        eval_dict = evaluate_model(model, test_dl, criterion, device)
        test_loss = eval_dict['loss']
        test_acc = eval_dict['accuracy']
        epoch_test_losses.append(test_loss)
        epoch_test_acc.append(test_acc)
        # print stats after each epoch
        print(f' avg epoch train loss: {round(avg_epoch_loss, 4)}\n test set loss: {round(test_loss,4)}\n test set accuracy: {round(test_acc,4)}')

    print(f'Completed training for {epochs} epochs')
    return {'train_losses': epoch_train_losses, 'test_losses': epoch_test_losses, 'test_acc': epoch_test_acc}



## Loss Plot function

In [19]:
def make_loss_plot(loss_stats: dict):
  loss_df = pl.DataFrame({
      'epoch': np.arange(len(loss_stats['train_losses'])),
      'train_loss' : loss_stats['train_losses'],
      'test_loss' : loss_stats['test_losses'],
      'test_acc': loss_stats['test_acc']
    })

  loss_df_long = loss_df.drop(pl.col('test_acc')).unpivot(index='epoch', value_name='loss')

  loss_chart = alt.Chart(loss_df_long).mark_line().encode(
    alt.X('epoch:O'),
    alt.Y('loss:Q', scale=alt.Scale(domain=(0.3, 0.57))),
    alt.Color('variable')
  ).properties(
    width=700,
    height=500,
    title = 'Train and Test Loss Per Epoch'
  )
  acc_chart = alt.Chart(loss_df).mark_line().encode(
    alt.X('epoch:O'),
    alt.Y('test_acc:Q', scale=alt.Scale(domain=(0.4, 1.0)))
  ).properties(
    width=700,
    height=500,
    title = 'Test Set Prediction Accuracy Per Epoch'
  )
  return acc_chart | loss_chart

## Nucleotide only model training
As hoped, the train/test loss stays at 0.69, indicating that even after 30 epochs, the model is still just guessing. This is good because we shouldn't be able to tell if a CG site is methylated unless we've polluted our data in some way making it not generalizable wild data.

In [20]:
train_ds = MethylDataset(train_parquet)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

test_ds = MethylDataset(test_parquet)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=True, num_workers=4)

model_nucleotides = MethylCNN_Nucleotides(sequence_length=32)
model_nucleotides.to(device)

criterion_nucleotides = nn.CrossEntropyLoss()
optimizer_nucleotides = torch.optim.Adam(model_nucleotides.parameters(), lr=0.002)

training_stats_nucleotides = train_model(model_nucleotides, train_dl, epochs = 30, criterion=criterion_nucleotides, optimizer = optimizer_nucleotides, device=device)

100%|██████████| 2659/2659 [00:17<00:00, 154.78it/s]
100%|██████████| 665/665 [00:03<00:00, 173.35it/s]


 avg epoch train loss: 0.6931
 test set loss: 0.6929
 test set accuracy: 0.5084


100%|██████████| 2659/2659 [00:16<00:00, 160.34it/s]
100%|██████████| 665/665 [00:04<00:00, 157.39it/s]


 avg epoch train loss: 0.6929
 test set loss: 0.6928
 test set accuracy: 0.5116


100%|██████████| 2659/2659 [00:16<00:00, 162.12it/s]
100%|██████████| 665/665 [00:04<00:00, 166.24it/s]


 avg epoch train loss: 0.6928
 test set loss: 0.6927
 test set accuracy: 0.5122


100%|██████████| 2659/2659 [00:16<00:00, 159.01it/s]
100%|██████████| 665/665 [00:03<00:00, 176.70it/s]


 avg epoch train loss: 0.6927
 test set loss: 0.6927
 test set accuracy: 0.5102


100%|██████████| 2659/2659 [00:16<00:00, 165.02it/s]
100%|██████████| 665/665 [00:04<00:00, 165.76it/s]


 avg epoch train loss: 0.6927
 test set loss: 0.6927
 test set accuracy: 0.5113


100%|██████████| 2659/2659 [00:16<00:00, 163.07it/s]
100%|██████████| 665/665 [00:03<00:00, 179.02it/s]


 avg epoch train loss: 0.6927
 test set loss: 0.6927
 test set accuracy: 0.5096


100%|██████████| 2659/2659 [00:16<00:00, 163.23it/s]
100%|██████████| 665/665 [00:04<00:00, 161.03it/s]


 avg epoch train loss: 0.6926
 test set loss: 0.6926
 test set accuracy: 0.5109


100%|██████████| 2659/2659 [00:16<00:00, 161.45it/s]
100%|██████████| 665/665 [00:03<00:00, 170.16it/s]


 avg epoch train loss: 0.6926
 test set loss: 0.6926
 test set accuracy: 0.5123


100%|██████████| 2659/2659 [00:16<00:00, 161.24it/s]
100%|██████████| 665/665 [00:03<00:00, 168.29it/s]


 avg epoch train loss: 0.6926
 test set loss: 0.6928
 test set accuracy: 0.5114


100%|██████████| 2659/2659 [00:16<00:00, 163.67it/s]
100%|██████████| 665/665 [00:04<00:00, 158.51it/s]


 avg epoch train loss: 0.6925
 test set loss: 0.6926
 test set accuracy: 0.5114


100%|██████████| 2659/2659 [00:16<00:00, 161.58it/s]
100%|██████████| 665/665 [00:04<00:00, 159.60it/s]


 avg epoch train loss: 0.6925
 test set loss: 0.6926
 test set accuracy: 0.5112


100%|██████████| 2659/2659 [00:16<00:00, 164.94it/s]
100%|██████████| 665/665 [00:03<00:00, 177.81it/s]


 avg epoch train loss: 0.6925
 test set loss: 0.6926
 test set accuracy: 0.5113


100%|██████████| 2659/2659 [00:16<00:00, 164.05it/s]
100%|██████████| 665/665 [00:04<00:00, 160.51it/s]


 avg epoch train loss: 0.6925
 test set loss: 0.6926
 test set accuracy: 0.5111


100%|██████████| 2659/2659 [00:16<00:00, 159.83it/s]
100%|██████████| 665/665 [00:03<00:00, 172.05it/s]


 avg epoch train loss: 0.6924
 test set loss: 0.6927
 test set accuracy: 0.5114


100%|██████████| 2659/2659 [00:16<00:00, 160.99it/s]
100%|██████████| 665/665 [00:03<00:00, 174.33it/s]


 avg epoch train loss: 0.6924
 test set loss: 0.6926
 test set accuracy: 0.511


100%|██████████| 2659/2659 [00:16<00:00, 164.89it/s]
100%|██████████| 665/665 [00:04<00:00, 158.80it/s]


 avg epoch train loss: 0.6924
 test set loss: 0.6925
 test set accuracy: 0.5119


100%|██████████| 2659/2659 [00:16<00:00, 160.98it/s]
100%|██████████| 665/665 [00:03<00:00, 175.24it/s]


 avg epoch train loss: 0.6923
 test set loss: 0.6927
 test set accuracy: 0.5114


100%|██████████| 2659/2659 [00:16<00:00, 162.70it/s]
100%|██████████| 665/665 [00:03<00:00, 174.42it/s]


 avg epoch train loss: 0.6923
 test set loss: 0.6925
 test set accuracy: 0.5113


100%|██████████| 2659/2659 [00:16<00:00, 164.01it/s]
100%|██████████| 665/665 [00:04<00:00, 155.95it/s]


 avg epoch train loss: 0.6922
 test set loss: 0.6924
 test set accuracy: 0.5109


100%|██████████| 2659/2659 [00:16<00:00, 160.43it/s]
100%|██████████| 665/665 [00:04<00:00, 166.17it/s]


 avg epoch train loss: 0.6921
 test set loss: 0.6924
 test set accuracy: 0.512


100%|██████████| 2659/2659 [00:17<00:00, 153.07it/s]
100%|██████████| 665/665 [00:03<00:00, 175.82it/s]


 avg epoch train loss: 0.6921
 test set loss: 0.6925
 test set accuracy: 0.5115


100%|██████████| 2659/2659 [00:17<00:00, 151.09it/s]
100%|██████████| 665/665 [00:04<00:00, 163.27it/s]


 avg epoch train loss: 0.692
 test set loss: 0.6924
 test set accuracy: 0.5118


100%|██████████| 2659/2659 [00:16<00:00, 162.37it/s]
100%|██████████| 665/665 [00:04<00:00, 147.70it/s]


 avg epoch train loss: 0.6919
 test set loss: 0.6923
 test set accuracy: 0.5121


100%|██████████| 2659/2659 [00:16<00:00, 160.36it/s]
100%|██████████| 665/665 [00:03<00:00, 171.01it/s]


 avg epoch train loss: 0.6919
 test set loss: 0.6923
 test set accuracy: 0.511


100%|██████████| 2659/2659 [00:16<00:00, 156.65it/s]
100%|██████████| 665/665 [00:03<00:00, 172.44it/s]


 avg epoch train loss: 0.6919
 test set loss: 0.6923
 test set accuracy: 0.5117


100%|██████████| 2659/2659 [00:16<00:00, 162.43it/s]
100%|██████████| 665/665 [00:04<00:00, 159.99it/s]


 avg epoch train loss: 0.6918
 test set loss: 0.6924
 test set accuracy: 0.5113


100%|██████████| 2659/2659 [00:16<00:00, 164.12it/s]
100%|██████████| 665/665 [00:03<00:00, 175.19it/s]


 avg epoch train loss: 0.6918
 test set loss: 0.6923
 test set accuracy: 0.512


100%|██████████| 2659/2659 [00:16<00:00, 161.17it/s]
100%|██████████| 665/665 [00:04<00:00, 164.95it/s]


 avg epoch train loss: 0.6918
 test set loss: 0.6922
 test set accuracy: 0.5118


100%|██████████| 2659/2659 [00:16<00:00, 163.57it/s]
100%|██████████| 665/665 [00:03<00:00, 167.72it/s]


 avg epoch train loss: 0.6917
 test set loss: 0.6923
 test set accuracy: 0.5122


100%|██████████| 2659/2659 [00:17<00:00, 155.78it/s]
100%|██████████| 665/665 [00:03<00:00, 169.54it/s]

 avg epoch train loss: 0.6917
 test set loss: 0.6923
 test set accuracy: 0.5106
Completed training for 30 epochs





In [21]:
make_loss_plot(training_stats_nucleotides)

## Kinetics only model training

In [22]:
train_ds = MethylDataset(train_parquet)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

test_ds = MethylDataset(test_parquet)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=True, num_workers=4)

model_kinetics = MethylCNN_Kinetics(sequence_length=32)
model_kinetics.to(device)

criterion_kinetics = nn.CrossEntropyLoss()
optimizer_kinetics = torch.optim.Adam(model_kinetics.parameters(), lr=0.002)

training_stats_kinetics = train_model(model_kinetics, train_dl, epochs = 30, criterion=criterion_kinetics, optimizer = optimizer_kinetics, device=device)

100%|██████████| 2659/2659 [00:17<00:00, 155.83it/s]
100%|██████████| 665/665 [00:03<00:00, 168.22it/s]


 avg epoch train loss: 0.5465
 test set loss: 0.5184
 test set accuracy: 0.7442


100%|██████████| 2659/2659 [00:16<00:00, 163.74it/s]
100%|██████████| 665/665 [00:03<00:00, 174.28it/s]


 avg epoch train loss: 0.5085
 test set loss: 0.506
 test set accuracy: 0.7506


100%|██████████| 2659/2659 [00:16<00:00, 158.97it/s]
100%|██████████| 665/665 [00:03<00:00, 170.74it/s]


 avg epoch train loss: 0.5026
 test set loss: 0.4983
 test set accuracy: 0.7546


100%|██████████| 2659/2659 [00:16<00:00, 162.36it/s]
100%|██████████| 665/665 [00:04<00:00, 164.64it/s]


 avg epoch train loss: 0.4986
 test set loss: 0.4964
 test set accuracy: 0.7563


100%|██████████| 2659/2659 [00:16<00:00, 158.82it/s]
100%|██████████| 665/665 [00:03<00:00, 171.34it/s]


 avg epoch train loss: 0.496
 test set loss: 0.4999
 test set accuracy: 0.7537


100%|██████████| 2659/2659 [00:16<00:00, 162.42it/s]
100%|██████████| 665/665 [00:03<00:00, 167.59it/s]


 avg epoch train loss: 0.4949
 test set loss: 0.4977
 test set accuracy: 0.7545


100%|██████████| 2659/2659 [00:16<00:00, 162.62it/s]
100%|██████████| 665/665 [00:04<00:00, 164.64it/s]


 avg epoch train loss: 0.4934
 test set loss: 0.493
 test set accuracy: 0.7595


100%|██████████| 2659/2659 [00:16<00:00, 158.47it/s]
100%|██████████| 665/665 [00:03<00:00, 170.13it/s]


 avg epoch train loss: 0.4924
 test set loss: 0.4906
 test set accuracy: 0.7605


100%|██████████| 2659/2659 [00:17<00:00, 153.61it/s]
100%|██████████| 665/665 [00:03<00:00, 174.35it/s]


 avg epoch train loss: 0.4918
 test set loss: 0.4897
 test set accuracy: 0.7612


100%|██████████| 2659/2659 [00:16<00:00, 158.73it/s]
100%|██████████| 665/665 [00:03<00:00, 168.07it/s]


 avg epoch train loss: 0.4915
 test set loss: 0.497
 test set accuracy: 0.7566


100%|██████████| 2659/2659 [00:16<00:00, 161.56it/s]
100%|██████████| 665/665 [00:03<00:00, 170.30it/s]


 avg epoch train loss: 0.4905
 test set loss: 0.491
 test set accuracy: 0.7603


100%|██████████| 2659/2659 [00:16<00:00, 161.94it/s]
100%|██████████| 665/665 [00:03<00:00, 173.43it/s]


 avg epoch train loss: 0.4901
 test set loss: 0.4978
 test set accuracy: 0.7569


100%|██████████| 2659/2659 [00:16<00:00, 159.98it/s]
100%|██████████| 665/665 [00:03<00:00, 166.27it/s]


 avg epoch train loss: 0.4899
 test set loss: 0.4896
 test set accuracy: 0.7607


100%|██████████| 2659/2659 [00:16<00:00, 159.87it/s]
100%|██████████| 665/665 [00:04<00:00, 155.24it/s]


 avg epoch train loss: 0.4893
 test set loss: 0.4886
 test set accuracy: 0.762


100%|██████████| 2659/2659 [00:16<00:00, 157.66it/s]
100%|██████████| 665/665 [00:04<00:00, 162.28it/s]


 avg epoch train loss: 0.4891
 test set loss: 0.4913
 test set accuracy: 0.7601


100%|██████████| 2659/2659 [00:17<00:00, 153.24it/s]
100%|██████████| 665/665 [00:03<00:00, 168.90it/s]


 avg epoch train loss: 0.4888
 test set loss: 0.4891
 test set accuracy: 0.7618


100%|██████████| 2659/2659 [00:16<00:00, 156.83it/s]
100%|██████████| 665/665 [00:04<00:00, 152.96it/s]


 avg epoch train loss: 0.4887
 test set loss: 0.4873
 test set accuracy: 0.7628


100%|██████████| 2659/2659 [00:17<00:00, 152.82it/s]
100%|██████████| 665/665 [00:03<00:00, 169.96it/s]


 avg epoch train loss: 0.4883
 test set loss: 0.4902
 test set accuracy: 0.7608


100%|██████████| 2659/2659 [00:17<00:00, 155.64it/s]
100%|██████████| 665/665 [00:03<00:00, 173.58it/s]


 avg epoch train loss: 0.4881
 test set loss: 0.4872
 test set accuracy: 0.7625


100%|██████████| 2659/2659 [00:16<00:00, 159.80it/s]
100%|██████████| 665/665 [00:04<00:00, 162.32it/s]


 avg epoch train loss: 0.4878
 test set loss: 0.493
 test set accuracy: 0.758


100%|██████████| 2659/2659 [00:16<00:00, 159.94it/s]
100%|██████████| 665/665 [00:03<00:00, 175.97it/s]


 avg epoch train loss: 0.4876
 test set loss: 0.489
 test set accuracy: 0.7623


100%|██████████| 2659/2659 [00:16<00:00, 160.41it/s]
100%|██████████| 665/665 [00:03<00:00, 172.56it/s]


 avg epoch train loss: 0.4877
 test set loss: 0.4884
 test set accuracy: 0.7613


100%|██████████| 2659/2659 [00:16<00:00, 157.45it/s]
100%|██████████| 665/665 [00:04<00:00, 156.16it/s]


 avg epoch train loss: 0.4874
 test set loss: 0.4867
 test set accuracy: 0.7628


100%|██████████| 2659/2659 [00:16<00:00, 158.46it/s]
100%|██████████| 665/665 [00:03<00:00, 169.52it/s]


 avg epoch train loss: 0.487
 test set loss: 0.4867
 test set accuracy: 0.7631


100%|██████████| 2659/2659 [00:17<00:00, 156.36it/s]
100%|██████████| 665/665 [00:03<00:00, 171.66it/s]


 avg epoch train loss: 0.4867
 test set loss: 0.4871
 test set accuracy: 0.761


100%|██████████| 2659/2659 [00:16<00:00, 160.16it/s]
100%|██████████| 665/665 [00:04<00:00, 161.16it/s]


 avg epoch train loss: 0.4864
 test set loss: 0.4856
 test set accuracy: 0.764


100%|██████████| 2659/2659 [00:16<00:00, 160.82it/s]
100%|██████████| 665/665 [00:04<00:00, 160.23it/s]


 avg epoch train loss: 0.4861
 test set loss: 0.4856
 test set accuracy: 0.764


100%|██████████| 2659/2659 [00:16<00:00, 158.63it/s]
100%|██████████| 665/665 [00:03<00:00, 168.77it/s]


 avg epoch train loss: 0.486
 test set loss: 0.4868
 test set accuracy: 0.7626


100%|██████████| 2659/2659 [00:16<00:00, 159.73it/s]
100%|██████████| 665/665 [00:04<00:00, 162.62it/s]


 avg epoch train loss: 0.4859
 test set loss: 0.4866
 test set accuracy: 0.7641


100%|██████████| 2659/2659 [00:16<00:00, 159.68it/s]
100%|██████████| 665/665 [00:04<00:00, 162.50it/s]

 avg epoch train loss: 0.4858
 test set loss: 0.4892
 test set accuracy: 0.7619
Completed training for 30 epochs





In [23]:
make_loss_plot(training_stats_kinetics)

## Combined data model Training

In [24]:
train_ds = MethylDataset(train_parquet)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)

test_ds = MethylDataset(test_parquet)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=True, num_workers=4)

model_full = MethylCNN(sequence_length=32)
model_full.to(device)

criterion_full = nn.CrossEntropyLoss()
optimizer_full = torch.optim.Adam(model_full.parameters(), lr=0.001)

training_stats_full = train_model(model_full, train_dl, epochs = 30, criterion=criterion_full, optimizer = optimizer_full, device=device)

100%|██████████| 2659/2659 [00:17<00:00, 155.02it/s]
100%|██████████| 665/665 [00:03<00:00, 171.46it/s]


 avg epoch train loss: 0.5249
 test set loss: 0.4743
 test set accuracy: 0.7723


100%|██████████| 2659/2659 [00:16<00:00, 160.14it/s]
100%|██████████| 665/665 [00:04<00:00, 166.16it/s]


 avg epoch train loss: 0.4637
 test set loss: 0.452
 test set accuracy: 0.7874


100%|██████████| 2659/2659 [00:16<00:00, 158.37it/s]
100%|██████████| 665/665 [00:04<00:00, 165.66it/s]


 avg epoch train loss: 0.4444
 test set loss: 0.4502
 test set accuracy: 0.7876


100%|██████████| 2659/2659 [00:16<00:00, 161.50it/s]
100%|██████████| 665/665 [00:03<00:00, 170.16it/s]


 avg epoch train loss: 0.4309
 test set loss: 0.4248
 test set accuracy: 0.8041


100%|██████████| 2659/2659 [00:16<00:00, 159.93it/s]
100%|██████████| 665/665 [00:04<00:00, 161.37it/s]


 avg epoch train loss: 0.4211
 test set loss: 0.4218
 test set accuracy: 0.8045


100%|██████████| 2659/2659 [00:16<00:00, 160.85it/s]
100%|██████████| 665/665 [00:04<00:00, 163.41it/s]


 avg epoch train loss: 0.4134
 test set loss: 0.4086
 test set accuracy: 0.8123


100%|██████████| 2659/2659 [00:16<00:00, 159.61it/s]
100%|██████████| 665/665 [00:03<00:00, 170.72it/s]


 avg epoch train loss: 0.4071
 test set loss: 0.4015
 test set accuracy: 0.8167


100%|██████████| 2659/2659 [00:16<00:00, 157.21it/s]
100%|██████████| 665/665 [00:03<00:00, 174.39it/s]


 avg epoch train loss: 0.4012
 test set loss: 0.3953
 test set accuracy: 0.8192


100%|██████████| 2659/2659 [00:16<00:00, 158.10it/s]
100%|██████████| 665/665 [00:04<00:00, 158.03it/s]


 avg epoch train loss: 0.3958
 test set loss: 0.3961
 test set accuracy: 0.819


100%|██████████| 2659/2659 [00:16<00:00, 160.22it/s]
100%|██████████| 665/665 [00:03<00:00, 175.97it/s]


 avg epoch train loss: 0.3911
 test set loss: 0.3924
 test set accuracy: 0.8213


100%|██████████| 2659/2659 [00:16<00:00, 156.74it/s]
100%|██████████| 665/665 [00:03<00:00, 172.32it/s]


 avg epoch train loss: 0.3876
 test set loss: 0.3854
 test set accuracy: 0.8245


100%|██████████| 2659/2659 [00:16<00:00, 161.69it/s]
100%|██████████| 665/665 [00:04<00:00, 163.09it/s]


 avg epoch train loss: 0.3843
 test set loss: 0.3886
 test set accuracy: 0.8223


100%|██████████| 2659/2659 [00:16<00:00, 158.41it/s]
100%|██████████| 665/665 [00:04<00:00, 165.05it/s]


 avg epoch train loss: 0.3821
 test set loss: 0.3831
 test set accuracy: 0.8275


100%|██████████| 2659/2659 [00:16<00:00, 157.48it/s]
100%|██████████| 665/665 [00:03<00:00, 172.47it/s]


 avg epoch train loss: 0.3804
 test set loss: 0.3858
 test set accuracy: 0.8239


100%|██████████| 2659/2659 [00:16<00:00, 157.26it/s]
100%|██████████| 665/665 [00:04<00:00, 157.36it/s]


 avg epoch train loss: 0.3781
 test set loss: 0.3794
 test set accuracy: 0.829


100%|██████████| 2659/2659 [00:16<00:00, 158.48it/s]
100%|██████████| 665/665 [00:03<00:00, 168.00it/s]


 avg epoch train loss: 0.3768
 test set loss: 0.3799
 test set accuracy: 0.8275


100%|██████████| 2659/2659 [00:16<00:00, 158.60it/s]
100%|██████████| 665/665 [00:04<00:00, 163.14it/s]


 avg epoch train loss: 0.3753
 test set loss: 0.3771
 test set accuracy: 0.8298


100%|██████████| 2659/2659 [00:16<00:00, 160.88it/s]
100%|██████████| 665/665 [00:04<00:00, 162.05it/s]


 avg epoch train loss: 0.3737
 test set loss: 0.379
 test set accuracy: 0.8285


100%|██████████| 2659/2659 [00:16<00:00, 158.40it/s]
100%|██████████| 665/665 [00:03<00:00, 169.53it/s]


 avg epoch train loss: 0.3727
 test set loss: 0.3759
 test set accuracy: 0.83


100%|██████████| 2659/2659 [00:16<00:00, 158.00it/s]
100%|██████████| 665/665 [00:03<00:00, 171.81it/s]


 avg epoch train loss: 0.3709
 test set loss: 0.375
 test set accuracy: 0.8298


100%|██████████| 2659/2659 [00:16<00:00, 160.77it/s]
100%|██████████| 665/665 [00:03<00:00, 169.55it/s]


 avg epoch train loss: 0.3701
 test set loss: 0.3728
 test set accuracy: 0.8312


100%|██████████| 2659/2659 [00:16<00:00, 160.98it/s]
100%|██████████| 665/665 [00:03<00:00, 175.30it/s]


 avg epoch train loss: 0.3693
 test set loss: 0.3735
 test set accuracy: 0.8311


100%|██████████| 2659/2659 [00:17<00:00, 156.11it/s]
100%|██████████| 665/665 [00:03<00:00, 170.23it/s]


 avg epoch train loss: 0.3681
 test set loss: 0.3762
 test set accuracy: 0.8306


100%|██████████| 2659/2659 [00:16<00:00, 157.08it/s]
100%|██████████| 665/665 [00:04<00:00, 161.65it/s]


 avg epoch train loss: 0.3673
 test set loss: 0.3734
 test set accuracy: 0.8317


100%|██████████| 2659/2659 [00:16<00:00, 158.01it/s]
100%|██████████| 665/665 [00:04<00:00, 154.37it/s]


 avg epoch train loss: 0.3662
 test set loss: 0.3716
 test set accuracy: 0.8309


100%|██████████| 2659/2659 [00:16<00:00, 162.48it/s]
100%|██████████| 665/665 [00:04<00:00, 161.89it/s]


 avg epoch train loss: 0.3657
 test set loss: 0.3727
 test set accuracy: 0.8313


100%|██████████| 2659/2659 [00:17<00:00, 155.52it/s]
100%|██████████| 665/665 [00:04<00:00, 164.11it/s]


 avg epoch train loss: 0.3652
 test set loss: 0.3906
 test set accuracy: 0.8205


100%|██████████| 2659/2659 [00:17<00:00, 156.00it/s]
100%|██████████| 665/665 [00:04<00:00, 160.30it/s]


 avg epoch train loss: 0.3646
 test set loss: 0.372
 test set accuracy: 0.8315


100%|██████████| 2659/2659 [00:17<00:00, 156.15it/s]
100%|██████████| 665/665 [00:04<00:00, 163.15it/s]


 avg epoch train loss: 0.3638
 test set loss: 0.3686
 test set accuracy: 0.8334


100%|██████████| 2659/2659 [00:16<00:00, 158.62it/s]
100%|██████████| 665/665 [00:03<00:00, 170.83it/s]

 avg epoch train loss: 0.3631
 test set loss: 0.3677
 test set accuracy: 0.8339
Completed training for 30 epochs





In [25]:
make_loss_plot(training_stats_full)

In [26]:
# with torch.no_grad():
#     model_full.eval()
#     batch = next(iter(train_dl))
#     labels = batch.pop('label').to(device)
#     inputs: Dict[str, torch.Tensor] = {
#                 k: v.to(device) for k, v in batch.items()
#             }

#     print(model_full(inputs), labels)

In [27]:
# torch.save(model, '/content/gdrive/MyDrive/methylation/models/methyl_cnn_v0.pt')

In [28]:
# torch.save({
#             'epoch': epoch,
#             'model_state_dict': model.state_dict(),
#             'optimizer_state_dict': optimizer.state_dict(),
#             'loss': loss
#             }, '/content/gdrive/MyDrive/methylation/models/methyl_cnn_v0_full.pt')

# Notes

### test 1:
0.815 test set accuracy
### test 2:
move from adam -> adam2, change conv1 to size 5
0.808 test set accuracy
0.807 train set accuracy (why is this lower than the train set accuracy??)
### test 3:
add another convolutional layer (64-> 128)
test set
train set

In [29]:
model_load_test = model = MethylCNN(sequence_length=32)