In [95]:

train_parquet = '../data/processed/train_cg_32_10000.parquet'
test_parquet = '../data/processed/test_cg_32_10000.parquet'


In [96]:
import torch
import polars as pl
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import altair as alt
alt.data_transformers.enable("vegafusion")
from typing import Dict, List, Any, Optional
from operator import itemgetter
import torch.nn.functional as F
import os
from torch import nn
from tqdm import tqdm
import numpy as np

pl.Config(fmt_str_lengths=50)

device = torch.device(
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(device)


mps


In [97]:
train_df = pl.read_parquet(train_parquet,  schema = {'read_name': pl.String,
                                                     'cg_pos': pl.Int64,
                                                     'seq': pl.String,
                                                     'fi': pl.List(pl.Int64),
                                                     'fp': pl.List(pl.Int64),
                                                     'ri': pl.List(pl.Int64),
                                                     'rp': pl.List(pl.Int64),
                                                     'label': pl.Int32
                                                     }
                           )
train_df.head(10)

read_name,cg_pos,seq,fi,fp,ri,rp,label
str,i64,str,list[i64],list[i64],list[i64],list[i64],i32
"""m64168_200823_191315/1837517/ccs""",10282,"""AGATCTATTACACAACGTGGTGACCATAGCTA""","[52, 29, … 21]","[19, 10, … 14]","[62, 14, … 12]","[29, 25, … 13]",1
"""m64168_200820_000733/263719/ccs""",5955,"""GAAGGGGCTGATGCCCGGCCTCAGAGGTTAAG""","[19, 60, … 39]","[12, 49, … 24]","[31, 52, … 28]","[12, 17, … 54]",0
"""m64168_200823_191315/263868/ccs""",2321,"""GACGGGGCAGCTGGCCGGGCGGGGGGGCTGAC""","[23, 59, … 69]","[12, 25, … 9]","[35, 48, … 68]","[18, 11, … 28]",1
"""m64168_200823_191315/1442515/ccs""",7697,"""TGGAATGCAATGGAACGGAATGGAGTGGGATG""","[16, 11, … 21]","[16, 33, … 33]","[22, 18, … 12]","[37, 12, … 11]",1
"""m64168_200823_191315/1638483/ccs""",7803,"""GATGTACTCCACTTTCGAGCCTGATTCAGAAA""","[21, 21, … 22]","[44, 21, … 13]","[17, 17, … 21]","[11, 26, … 41]",1
"""m64168_200820_000733/742/ccs""",7290,"""GCCTGGGCGACAGAGCGAGACTCCATCTCAAA""","[8, 11, … 24]","[11, 19, … 23]","[16, 28, … 60]","[40, 36, … 33]",0
"""m64168_200820_000733/327687/ccs""",7671,"""TGGCCTTAAGTGATCCGCCCACCTTGGCCTCC""","[24, 33, … 141]","[15, 17, … 41]","[29, 23, … 27]","[23, 22, … 15]",0
"""m64168_200820_000733/132160/ccs""",4461,"""GAAAACTGAGTCCCCCGTGAGGATCTTGTTTT""","[16, 11, … 17]","[15, 16, … 17]","[16, 23, … 14]","[23, 9, … 38]",0
"""m64168_200823_191315/1444273/ccs""",6638,"""GGAATAATTCCTTTCCGTCCTCTCGGGAACAG""","[6, 12, … 10]","[27, 10, … 9]","[17, 22, … 17]","[27, 10, … 13]",1
"""m64168_200823_191315/264947/ccs""",5944,"""CTTAAGAGAAACAAACGGCCTCAGGAAGGGCA""","[29, 37, … 18]","[19, 9, … 10]","[11, 44, … 27]","[21, 67, … 20]",1


# Background

Methylation is a modification of DNA that controls gene expression. There are many different types of methylation, but one of the most prominent is methylation at CpG sites see [this wikipedia article](https://en.wikipedia.org/wiki/CpG_site). When a CpG site is methylated, it is of course chemically changed, meaning that the biophysical properties of that section of the DNA molecule are slightly different. Therefore, if we had data which contained a signal of the properties of the molecule, we could detect methylation.

The data that I'm using in this notebook was sequenced using PacBio Sequell technology. Here's PacBio's [advertisement](https://youtu.be/_lD8JyAbwEo?si=zILsY74u6tJyXsmJ)... it's not the most detailed, and fails to mention that the data you get from the system also includes the temporal information from the polymerase that does the sequencing. These include how long the light flash lasted (from the nucleotide incorporation), and the time until the next light flash. This is included for both the forward and reverse strand of the molecule that was sequenced.

The polymerase spans around 8 nucleotides, and since methylation adds some resistance, there is a difference in these signals for the methylated an unmethylated CpG sites. The issue of course, is that the kinetics are not just affected by methylation, but also by regular nucleotide context. So if we want to predict methlyation based on kinetics, the model needs to learn how to deconflate the kinetics signal associated with methylation from the affect that different nucleotides have on the polymerase behaviour.

PacBio actually already figured this out with [primrose](https://github.com/mattoslmp/primrose). But I thought it would be a nice exercise before setting of into uncharted territory of genetics models to 1. reimplement primrose with a CNN as PacBio did 2. Pretrain with contrastive learning and try to beat that with a classification downstream task. Yes, I should get to the science, but I also know that deep learning is hard even on easy problems, so I think I'll spend a bit more time on problems I know are solvable. Furthermore, Primrose is not public.

Regarding the dataset that's present here: I filtered 32 base windows centered on CpG sites. They come from two BAM files, one fully methylated and the other fully unmethylated. Below is a summary of the features:

    read_name: The unique identifier for the read. This corresponds to one zero-mode-waveguide's measurements in the video.

    cg_pos: The position in the read, indexed from 0, where the CpG site occurs

    window_seq: The 32 bp window of nucleotides. This is a consensus as a result of between 5 and 30 passes over the same section of DNA by the polymerase.

    window_fi: The interpulse duration at each position in the window. This is how long it took the polymerase to reach the next base. This is for the forward strand

    window_rp: The pulse width at each position in the window. This is how long the light flash lasted.

    window_ri: same as window_fi, but the reverse strand

    window_rp: same as window_fp, but the reverse strand

Note: Units for the kinetics features (window_fi, window_fp, etc) are in "frames" so we can think of it as an arbitrary time unit.

So the intended structure of the model is take a  (C,H,W) (1,8,32) tensor and make a binary prediction about the methylation status.



# EDA

The chart below shows that at least on average, we should be able to tell the difference between methylated and unmethylted CpG sites. The chart averages IPD at each index across all of the training samples independantly for the forward and reverse strands (indices 15 and 16 are the CpG site).

In [98]:
pos_means = (
    pl.read_parquet(train_parquet)
    .filter(pl.col('label')==1)
    .select(pl.col("fi").alias("fwd"), pl.col("ri").alias("rev"))
    .unpivot(on=["fwd", "rev"], variable_name="strand", value_name="ipd_list")
    .with_columns(index=pl.int_ranges(start=0, end=pl.col("ipd_list").list.len()))
    .explode("index", "ipd_list")
    .rename({"ipd_list": "ipd"})
    .group_by("index", "strand").agg(pl.col("ipd").mean())
)

neg_means = (
    pl.read_parquet(train_parquet)
    .filter(pl.col('label')==0)
    .select(pl.col("fi").alias("fwd"), pl.col("ri").alias("rev"))
    .unpivot(on=["fwd", "rev"], variable_name="strand", value_name="ipd_list")
    .with_columns(index=pl.int_ranges(start=0, end=pl.col("ipd_list").list.len()))
    .explode("index", "ipd_list")
    .rename({"ipd_list": "ipd"})
    .group_by("index", "strand").agg(pl.col("ipd").mean())
)

means = pos_means.join(
    neg_means, on=['index', 'strand'], suffix='_neg'
    ).with_columns((pl.col('ipd')-pl.col('ipd_neg')).alias('residual'))

residual_chart = alt.Chart(means).mark_line().encode(
    alt.X("index:Q", title="Position", axis=alt.Axis(tickCount=16)),
    alt.Y("residual:Q", title="IPD Mean Difference", scale=alt.Scale(domain=(-3, 15), clamp=True)),
    alt.Color("strand:N", title="Strand"),
    ).properties(
    title = "Mean IPD Residual (Meth-Unmeth) Across CG Context",
    width=800,
    height=600
    ).configure_axis(
    labelFontSize=12,
    titleFontSize=14,
    grid=True
    ).configure_title(
    fontSize=16,
    anchor='middle'
    ).configure_legend(
    titleFontSize=12,
    labelFontSize=11
    )

neg_chart = alt.Chart(neg_means).mark_line().encode(
    alt.X('index:Q',scale=alt.Scale(domain=(0, 31))),
    alt.Y('ipd:Q', title = 'mean IPD' ,scale=alt.Scale(domain=(24, 44), clamp=True)),
    alt.Color('strand')
    ).properties(
    title = "Mean Unmethylated IPD Across CG Context",
    width=800,
    height=600
    )

pos_chart = alt.Chart(pos_means).mark_line().encode(
    alt.X('index:Q',scale=alt.Scale(domain=(0, 31))),
    alt.Y('ipd:Q', title = 'mean IPD' ,scale=alt.Scale(domain=(24, 44), clamp=True)),
    alt.Color('strand')
    ).properties(
    title = "Mean Methylated IPD Across CG Context",
    width=800,
    height=600
    )

alt.hconcat(neg_chart,pos_chart)




# Normalization
Precompute on the training set.

In [99]:
def compute_normalization_stats(df):
    kinetic_features = ['fi', 'fp', 'ri', 'rp']
    means = {col: df[col].explode().mean() for col in kinetic_features}
    stds = {col: df[col].explode().explode().std() for col in kinetic_features}
    return means, stds
train_means, train_stds = compute_normalization_stats(train_df)
print(train_means, train_stds)

{'fi': 30.594821621789382, 'fp': 21.945176340699103, 'ri': 30.13117185101172, 'rp': 22.007529335023822} {'fi': 24.25535836619395, 'fp': 11.27049294825807, 'ri': 23.845802361655878, 'rp': 11.310865492241938}


# Dataset Definition

In [100]:
class MethylDataset(Dataset):
    """
    Dataset class for metylation data stored in a parquet file
    Reads data using polars and converts samples to pytorch tensors on get_item call
    Nucleotide sequence is one-hot encoded
    """
    def __init__(self, data_path: Path,
                 transform=None,
                 means: Optional[Dict[str, float]] = train_means,
                 stds: Optional[Dict[str, float]] = train_stds,
                 context: int = 32):
        '''
        Arguments:
        data_path: the path for the parquet file that contains either the training or test data (pos and neg)
        transform: an optional transform callable on a single sample
        '''
        self.tranform = transform
        self.means = means
        self.stds = stds
        self.context = 32
        self.kinetics_features = ['fi', 'fp', 'ri', 'rp']
        # initialize the dataframe, converting kinetic columns in to pl.array
        try:
             self.df = pl.read_parquet(data_path).with_columns([
                 pl.col("fi").list.to_array(self.context),
                 pl.col("fp").list.to_array(self.context),
                 pl.col("ri").list.to_array(self.context),
                 pl.col("rp").list.to_array(self.context),
                 ])
        # intialize with empty df if loading fails
        except:
             print(f"failed to read data given path: {data_path}")
             self.df = pl.DataFrame()
        self._dataset_len = len(self.df)
        self.vocab = {'A':0, 'T':1, 'C':2, 'G':3}
        self.vocab_size = len(self.vocab)
        self.kinetic_features = ['fi', 'fp', 'ri', 'rp']
        # define the sequence tensor
        self.seq_tensor = self._prep_seq_tensor()
        # define the kinetics tensor
        self.kinetics_tensor = self._prep_kinetics_tensor()
        # define the label tensor
        self.label_tensor = self._prep_label_tensor()

    def _prep_seq_tensor(self):
      seq_ints = (
          self.df['seq']
          .str.split("")
          .list.eval(
              pl.element().replace_strict(self.vocab)
          )
          .to_numpy()
      )
      return torch.tensor(np.stack(seq_ints), dtype=torch.long)

    def _prep_kinetics_tensor(self):
      kinetics_array = np.stack([self.df[col].to_numpy() for col in self.kinetics_features], axis=1)
      return torch.tensor(kinetics_array, dtype=torch.long)

    def _prep_label_tensor(self):
      label_array = self.df['label'].to_numpy()
      return torch.tensor(label_array, dtype=torch.long)

    def __len__(self):
          return self._dataset_len

    def __getitem__(self, idx):
         if idx >= len(self):
              raise IndexError("Index out of range")
         seq, kinetics, label = self.seq_tensor[idx], self.kinetics_tensor[idx], self.label_tensor[idx]
         # sequence data (requires one-hot encoding)
         seq_tensor_one_hot = F.one_hot(seq, num_classes=self.vocab_size).T
         return {
              'seq': seq_tensor_one_hot,
              'kinetics': kinetics,
              'label': label
              }



In [101]:
train_ds = MethylDataset(train_parquet)
train_dl = DataLoader(train_ds, batch_size=4096, shuffle=True)

test_ds = MethylDataset(test_parquet)
test_dl = DataLoader(test_ds, batch_size=4096, shuffle=True)

In [102]:
batch_example = next(iter(train_dl))
torch.cat((batch_example['seq'], batch_example['kinetics']), axis=1)

tensor([[[  0,   1,   0,  ...,   0,   1,   1],
         [  0,   0,   0,  ...,   0,   0,   0],
         [  1,   0,   0,  ...,   0,   0,   0],
         ...,
         [ 33,  24,  32,  ...,  18,  22,  19],
         [ 76,  19,  23,  ...,  57,  26,  35],
         [ 25,  32,  95,  ...,  20,  14,  32]],

        [[  1,   0,   0,  ...,   0,   1,   0],
         [  0,   0,   1,  ...,   0,   0,   1],
         [  0,   0,   0,  ...,   0,   0,   0],
         ...,
         [ 18,  15,  23,  ...,  22,  35,  12],
         [ 25,  17,  23,  ...,  58,  18,   9],
         [ 15,  19,  31,  ...,  18,  22,  11]],

        [[  1,   0,   0,  ...,   0,   0,   1],
         [  0,   1,   1,  ...,   0,   0,   0],
         [  0,   0,   0,  ...,   1,   1,   0],
         ...,
         [ 22,  15,  24,  ...,  24,  21,  12],
         [ 32,  24,  22,  ...,  21,  19,  27],
         [ 21,  48,  30,  ...,  22,  14,  14]],

        ...,

        [[  0,   0,   0,  ...,   0,   0,   1],
         [  0,   1,   1,  ...,   0,   1,   0]

### Sanity Check 1
Are the classes evenly distributed? -> Yes they are.

In [103]:
train_ds.df['label'].mean(), len(train_ds)

(0.5023785352982572, 1361342)

In [104]:
len_list = np.array([len(train_ds.df.row(i, named=True)['fi']) for i in range(len(train_ds))])
np.min(len_list),np.argmin(len_list)

(np.int64(32), np.int64(0))

# Model Definition

In [105]:
class MethylCNN(nn.Module):
    def __init__(self, sequence_length: int = 32, in_channels:int = 8, num_classes: int = 2, dropout_p: float = 0.3):
        super().__init__()
        self.in_channels = in_channels
        self.sequence_length = sequence_length
        self.num_classes = num_classes

        # Convolution layers
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5, padding=1)
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)

        # calculate fc layer input with dummy passthrough
        fc_input_features = self._get_conv_output_size(sequence_length)

        # Linear layers
        self.fc1 = nn.Linear(in_features=fc_input_features, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=32)
        self.fc3 = nn.Linear(in_features=32, out_features=num_classes)

        # dropout
        self.dropout = nn.Dropout(p=dropout_p)

    def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)


        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv4(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        return x

    def _get_conv_output_size(self, sequence_length: int) -> int:
        """
        Calculates the flattened output size of the convolutional layers
        by performing a forward pass on random data of the right shape.
        """
        dummy_input = torch.randn(1, self.in_channels, sequence_length)
        output = self._extract_features(dummy_input)
        return output.numel()

    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        seq = batch['seq']#.permute(0, 2, 1)
        kinetics = batch['kinetics']#.unsqueeze(1)

        # the input is a dictionary, so convert to a tensor
        x = torch.cat([seq, kinetics], dim=1).to(self.conv1.weight.dtype) # -> [B, 8, L]

        x = self._extract_features(x)

        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)

        x = self.fc2(x)
        x = F.relu(x)
        x = self.dropout(x)

        logits = self.fc3(x)

        return logits

# Eval Loop Function

In [106]:
def evaluate_model(
    model: nn.Module,
    data_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Dict[str, float]:
    model.eval()
    running_loss: float = 0.0
    correct_predictions: int = 0
    total_samples: int = 0

    with torch.no_grad():
        for batch in tqdm(data_loader):
            labels: torch.Tensor = batch.pop("label").to(device)
            inputs: Dict[str, torch.Tensor] = {
                k: v.to(device) for k, v in batch.items()
            }

            logits: torch.Tensor = model(inputs)
            loss: torch.Tensor = criterion(logits, labels)

            running_loss += loss.item() * labels.size(0)

            _, predicted = torch.max(logits.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    epoch_loss: float = running_loss / total_samples
    epoch_acc: float = correct_predictions / total_samples
    return {"loss": epoch_loss, "accuracy": epoch_acc}

# Training Loop

In [107]:
model = MethylCNN(sequence_length=32)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001)



epoch_train_losses = []
epoch_test_losses = []
epoch_test_acc = []
for epoch in range(20):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, batch in enumerate(tqdm(train_dl), 0):
        # get the inputs; data is a list of [inputs, labels]
        labels = batch.pop('label').to(device)
        inputs = {k: v.to(device) for k, v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    avg_epoch_loss = running_loss/len(train_dl)
    epoch_train_losses.append(avg_epoch_loss)

    eval_dict = evaluate_model(model, test_dl, criterion, device)
    test_loss = eval_dict['loss']
    test_acc = eval_dict['accuracy']
    epoch_test_losses.append(test_loss)
    epoch_test_acc.append(test_acc)

    print(f' avg epoch train loss: {round(avg_epoch_loss, 4)}\n test set loss: {round(test_loss,4)}\n test set accuracy: {round(test_acc,4)}')

print('Finished Training')

100%|██████████| 333/333 [00:19<00:00, 16.91it/s]
100%|██████████| 84/84 [00:03<00:00, 21.44it/s]


 avg epoch train loss: 0.6369
 test set loss: 0.5908
 test set accuracy: 0.6809


100%|██████████| 333/333 [00:19<00:00, 17.40it/s]
100%|██████████| 84/84 [00:03<00:00, 22.07it/s]


 avg epoch train loss: 0.5706
 test set loss: 0.5562
 test set accuracy: 0.7114


100%|██████████| 333/333 [00:19<00:00, 17.49it/s]
100%|██████████| 84/84 [00:03<00:00, 22.20it/s]


 avg epoch train loss: 0.5492
 test set loss: 0.543
 test set accuracy: 0.7217


 67%|██████▋   | 224/333 [00:13<00:06, 17.18it/s]


KeyboardInterrupt: 

In [None]:
with torch.no_grad():
    model.eval()
    batch = next(iter(train_dl))
    labels = batch.pop('label').to(device)
    inputs: Dict[str, torch.Tensor] = {
                k: v.to(device) for k, v in batch.items()
            }

    print(model(inputs), labels)

tensor([[-0.0212, -0.3862],
        [-0.8658,  0.4607],
        [ 0.7770, -1.2230],
        [ 0.1373, -0.5195],
        [-0.3249, -0.1833],
        [-2.2178,  1.5798],
        [-0.8361,  0.4366],
        [ 0.4960, -0.9493],
        [ 0.1698, -0.5474],
        [-1.6839,  1.2186],
        [-0.2005, -0.1638],
        [ 0.7057, -1.2251],
        [-0.6920,  0.3115],
        [ 0.5345, -0.9589],
        [ 1.2524, -1.7707],
        [-0.3460, -0.0945],
        [-0.2464, -0.2449],
        [ 0.7045, -1.2033],
        [ 0.4360, -0.8862],
        [-1.6306,  1.0301],
        [-0.4275, -0.0155],
        [-0.1117, -0.2579],
        [-0.7127,  0.2365],
        [-1.0861,  0.7080],
        [-0.3432, -0.0681],
        [-1.0777,  0.6654],
        [-0.5358, -0.0751],
        [-1.1178,  0.7123],
        [ 0.2318, -0.6866],
        [-0.8599,  0.4081],
        [-1.5715,  1.1297],
        [ 0.4727, -0.8858]], device='mps:0') tensor([0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
       

## Plot the epoch losses

In [None]:
loss_df = pl.DataFrame({
    'batch': np.arange(len(epoch_train_losses)),
    'train_loss' : epoch_train_losses,
    'test_loss' : epoch_test_losses
})

loss_df = loss_df.unpivot(index='batch', value_name='loss')
loss_df.head()

alt.Chart(loss_df).mark_line(opacity=0.8).encode(
    alt.X('batch:Q'),
    alt.Y('loss:Q'),
    alt.Color('variable')
).properties(
    width=500,
    height=500
)

In [None]:
print(model)

MethylCNN(
  (conv1): Conv1d(8, 16, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(16, 32, kernel_size=(5,), stride=(1,), padding=(1,))
  (conv3): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv4): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=2, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)


# Evaluation

In [None]:
evaluate_model(model, test_dl, criterion, device)

100%|██████████| 220/220 [00:00<00:00, 445.37it/s]


{'loss': 0.5772895166160272, 'accuracy': 0.7075081920501496}

In [None]:
evaluate_model(model, train_dl, criterion, device)

100%|██████████| 878/878 [00:01<00:00, 459.80it/s]


{'loss': 0.4392545590189055, 'accuracy': 0.7947494033412887}

# Notes

### test 1:
0.815 test set accuracy
### test 2:
move from adam -> adamW, change conv1 to size 5
0.808 test set accuracy
0.807 train set accuracy (why is this lower than the train set accuracy??)
### test 3:
add another convolutional layer (64-> 128)
test set
train set