In [725]:
import polars as pl

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from operator import itemgetter
import numpy as np

from pathlib import Path

import altair as alt
alt.data_transformers.enable("vegafusion")

from typing import Dict, Optional



device = torch.device(
    "mps" if torch.backends.mps.is_available() 
    else "cuda" if torch.cuda.is_available() 
    else "cpu"
)

print(f'using {device}')


using mps


In [727]:
train_df = pl.read_parquet('../data/processed/train_cg_32_200.parquet')
str_seq = train_df.row(0, named=True)['seq']
vocab = {'A':0, 'T':1, 'C':2, 'G':3}
# [vocab.get(char) for char in str_seq]

In [729]:
def compute_normalization_stats(df):
    kinetic_features = ['fi', 'fp', 'ri', 'rp']
    means = {col: df[col].explode().mean() for col in kinetic_features}
    stds = {col: df[col].explode().explode().std() for col in kinetic_features}
    return means, stds
train_means, train_stds = compute_normalization_stats(train_df)
print(train_means, train_stds)

{'fi': 30.789639956541873, 'fp': 21.96058935632102, 'ri': 30.19122800377587, 'rp': 22.001033021052258} {'fi': 24.00485941139194, 'fp': 11.301933232238992, 'ri': 23.627035898853727, 'rp': 11.337388102877114}


In [None]:
class MethylDataset(Dataset):
    """
    Dataset class for metylation data stored in a parquet file
    Reads data using polars and converts samples to pytorch tensors on get_item call
    Nucleotide sequence is one-hot encoded
    """
    def __init__(self, data_path: Path, 
                 transform=None, 
                 means: Optional[Dict[str, float]] = train_means, 
                 stds: Optional[Dict[str, float]] = train_stds):
        '''
        Arguments:
        data_path: the path for the parquet file that contains either the training or test data (pos and neg)
        transform: an optional transform callable on a single sample
        '''
        self.tranform = transform
        self.means = means
        self.stds = stds

        try:
             self.data = pl.read_parquet(data_path)#[:500000] # can shorten the dataset here
        except:
             print(f"failed to read data given path: {data_path}")
             self.data = pl.DataFrame()

        self._dataset_len = len(self.data)
        self.vocab = {'A':0, 'T':1, 'C':2, 'G':3}
        self.vocab_size = len(self.vocab)
        self.kinetic_features = ['fi', 'fp', 'ri', 'rp']

    def __len__(self):
          return self._dataset_len
    
    def __getitem__(self, idx):
         if idx >= len(self):
              raise IndexError("Index out of range")
         sample = self.data.row(idx, named=True)
         # sequence data (requires one-hot encoding)
         str_seq = sample['window_seq']
         int_seq_tensor = torch.tensor([self.vocab.get(char) for char in str_seq], dtype=torch.long)
         seq_tensor = F.one_hot(int_seq_tensor, num_classes=self.vocab_size)
         # kinetic data
         fi_tensor = (torch.tensor(sample['window_fi'], dtype=torch.long) - self.means['window_fi'])/self.stds['window_fi']
         fp_tensor = (torch.tensor(sample['window_fp'], dtype=torch.long) - self.means['window_fp'])/self.stds['window_fp']
         ri_tensor = (torch.tensor(sample['window_ri'], dtype=torch.long) - self.means['window_ri'])/self.stds['window_ri']
         rp_tensor = (torch.tensor(sample['window_rp'], dtype=torch.long) - self.means['window_rp'])/self.stds['window_rp']
         label_tensor = torch.tensor(sample['label'], dtype=torch.long) 
         
         return {
              'seq': seq_tensor,
              'fi': fi_tensor,
              'fp': fp_tensor,
              'ri': ri_tensor,
              'rp': rp_tensor,
              'label': label_tensor
              }
    

In [None]:
train_ds = MethylDataset('../data/processed/train_cg_32_200.parquet')
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)

In [None]:
# sanity check that we have an even distribution of positives/negatives
train_ds.data['label'].mean(), len(train_ds)

(0.5005877533573184, 28073)

In [None]:
len_list = np.array([len(train_ds.data.row(i, named=True)['window_fi']) for i in range(len(train_ds))])
np.min(len_list),np.argmin(len_list)

(np.int64(32), np.int64(0))

In [None]:
class MethylCNN(nn.Module):
    def __init__(self, sequence_length: int = 32, in_channels:int = 8, num_classes: int = 2):
        super().__init__()
        self.in_channels = in_channels
        self.sequence_length = sequence_length
        self.num_classes = num_classes

        # Convolution layers
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=5, padding=1)
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv4 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=3, padding=1)

        # calculate fc layer input with dummy passthrough
        fc_input_features = self._get_conv_output_size(sequence_length)

        # Linear layers
        self.fc1 = nn.Linear(in_features=fc_input_features, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=32)
        self.fc3 = nn.Linear(in_features=32, out_features=num_classes)

    def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)


        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv4(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)
        
        return x
    
    def _get_conv_output_size(self, sequence_length: int) -> int:
        """
        Calculates the flattened output size of the convolutional layers
        by performing a forward pass on random data of the right shape. 
        """
        dummy_input = torch.randn(1, self.in_channels, sequence_length)
        output = self._extract_features(dummy_input)
        return output.numel()

    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        # Assuming batch['seq'] is [B, L, 4]
        seq = batch['seq'].permute(0, 2, 1) # -> [B, 4, L]
        fi = batch['fi'].unsqueeze(1)      # -> [B, 1, L]
        fp = batch['fp'].unsqueeze(1)       # -> [B, 1, L]
        ri = batch['ri'].unsqueeze(1)       # -> [B, 1, L]
        rp = batch['rp'].unsqueeze(1)       # -> [B, 1, L]
        
        # the input is a dictionary, so convert to a tensor
        x = torch.cat([seq, fi, fp, ri, rp], dim=1).to(self.conv1.weight.dtype) # -> [B, 8, L]
        
        x = self._extract_features(x)

        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)

        x = self.fc2(x)
        x = F.relu(x)

        logits = self.fc3(x)

        return logits

In [None]:
model = MethylCNN(sequence_length=32)
model.to(device) 

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)



epoch_losses = []
for epoch in range(20):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, batch in enumerate(tqdm(train_dl), 0):
        # get the inputs; data is a list of [inputs, labels]
        labels = batch.pop('label').to(device)
        inputs = {k: v.to(device) for k, v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    avg_epoch_loss = running_loss/len(train_dl)
    epoch_losses.append(avg_epoch_loss)
    print(avg_epoch_loss)

print('Finished Training')

100%|██████████| 878/878 [00:05<00:00, 165.41it/s]


0.6670178258473346


100%|██████████| 878/878 [00:05<00:00, 160.09it/s]


0.5740320823390826


100%|██████████| 878/878 [00:05<00:00, 162.55it/s]


0.538193117648974


100%|██████████| 878/878 [00:05<00:00, 157.55it/s]


0.51678927809611


100%|██████████| 878/878 [00:05<00:00, 162.80it/s]


0.5028713926978426


100%|██████████| 878/878 [00:05<00:00, 159.41it/s]


0.4909298901270082


100%|██████████| 878/878 [00:06<00:00, 135.56it/s]


0.4839951020913678


100%|██████████| 878/878 [00:06<00:00, 130.21it/s]


0.4754205398719783


100%|██████████| 878/878 [00:05<00:00, 148.71it/s]


0.47028748000333415


100%|██████████| 878/878 [00:05<00:00, 150.43it/s]


0.46363196163894377


100%|██████████| 878/878 [00:05<00:00, 165.91it/s]


0.45956149539466873


100%|██████████| 878/878 [00:05<00:00, 162.53it/s]


0.4544522482774524


100%|██████████| 878/878 [00:05<00:00, 158.04it/s]


0.44936236543983965


100%|██████████| 878/878 [00:05<00:00, 156.62it/s]


0.44493991314753856


100%|██████████| 878/878 [00:05<00:00, 158.13it/s]


0.4416767768137547


100%|██████████| 878/878 [00:05<00:00, 165.24it/s]


0.43695019696391524


100%|██████████| 878/878 [00:05<00:00, 166.69it/s]


0.4339111344146294


100%|██████████| 878/878 [00:05<00:00, 157.57it/s]


0.42813630558227894


100%|██████████| 878/878 [00:05<00:00, 162.79it/s]


0.4237457142637911


100%|██████████| 878/878 [00:05<00:00, 166.19it/s]

0.4212837063607702
Finished Training





In [None]:
with torch.no_grad():
    model.eval()
    batch = next(iter(train_dl))
    labels = batch.pop('label').to(device)
    inputs: Dict[str, torch.Tensor] = {
                k: v.to(device) for k, v in batch.items()
            }

    print(model(inputs), labels)

tensor([[ 1.3210e+00, -1.2048e+00],
        [-1.4418e+00,  1.8942e+00],
        [ 1.3418e+00, -1.2189e+00],
        [ 6.5940e-01, -4.0190e-01],
        [ 2.8530e-02,  3.0001e-01],
        [ 1.1157e+00, -9.8696e-01],
        [-4.7015e-01,  8.3435e-01],
        [-2.7035e-01,  5.7051e-01],
        [ 1.5051e+00, -1.4167e+00],
        [-1.3569e+00,  1.7928e+00],
        [ 7.2928e-01, -5.5585e-01],
        [-5.3890e-01,  9.2664e-01],
        [-2.1676e-01,  5.5794e-01],
        [ 1.6882e+00, -1.5067e+00],
        [ 2.1765e-01,  4.4489e-02],
        [-2.5903e-01,  6.1229e-01],
        [ 4.9300e-01, -2.4271e-01],
        [ 4.3383e-01, -1.8364e-01],
        [ 1.5126e+00, -1.4137e+00],
        [-3.6736e-01,  7.2130e-01],
        [ 1.8543e-01,  6.9353e-02],
        [ 8.6622e-01, -6.8850e-01],
        [ 9.4464e-01, -7.2723e-01],
        [ 1.6170e-01,  1.0657e-01],
        [ 1.5844e-03,  3.0410e-01],
        [-1.7206e-01,  4.7717e-01],
        [ 1.2194e+00, -1.1332e+00],
        [-1.2380e+00,  1.678

In [None]:
loss_df = pl.DataFrame({
    'batch': np.arange(len(epoch_losses)),
    'loss' : epoch_losses
})

alt.Chart(loss_df).mark_line(opacity=0.4).encode(
    alt.X('batch'),
    alt.Y('loss')
).properties(
    width=500,
    height=500
)

In [None]:
test_ds = MethylDataset('../data/processed/test_cg_32_200.parquet')
test_dl = DataLoader(test_ds, batch_size=4, shuffle=True)

In [None]:
print(model)

MethylCNN(
  (conv1): Conv1d(8, 16, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(16, 32, kernel_size=(5,), stride=(1,), padding=(1,))
  (conv3): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv4): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=2, bias=True)
)


In [None]:
def evaluate_model(
    model: nn.Module,
    data_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Dict[str, float]:
    model.eval()
    running_loss: float = 0.0
    correct_predictions: int = 0
    total_samples: int = 0

    with torch.no_grad():
        for batch in tqdm(data_loader):
            labels: torch.Tensor = batch.pop("label").to(device)
            inputs: Dict[str, torch.Tensor] = {
                k: v.to(device) for k, v in batch.items()
            }

            logits: torch.Tensor = model(inputs)
            loss: torch.Tensor = criterion(logits, labels)

            running_loss += loss.item() * labels.size(0)

            _, predicted = torch.max(logits.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    epoch_loss: float = running_loss / total_samples
    epoch_acc: float = correct_predictions / total_samples
    return {"loss": epoch_loss, "accuracy": epoch_acc}

### test 1: 
0.815 test set accuracy
### test 2: 
move from adam -> adam2, change conv1 to size 5
0.808 test set accuracy 
0.807 train set accuracy (why is this lower than the train set accuracy??)
### test 3: 
add another convolutional layer (64-> 128)
test set: 0.79
train set: 0.79
### test 34 
added 5 more epochs of training
test set: 0.818
train set: 0.825
So this was 10 epochs with batch size 32, on the 10k reads datafile

In [None]:
evaluate_model(model, test_dl, criterion, device)

  0%|          | 0/1755 [00:00<?, ?it/s]

100%|██████████| 1755/1755 [00:03<00:00, 454.81it/s]


{'loss': 0.4925670400442332, 'accuracy': 0.7700527140618322}

In [None]:
evaluate_model(model, train_dl, criterion, device)

100%|██████████| 878/878 [00:03<00:00, 249.27it/s]


{'loss': 0.4169128398978833, 'accuracy': 0.8102803405407331}

In [None]:
all([True, False])

False