In [566]:
import torch
import polars as pl
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
import altair as alt
alt.data_transformers.enable("vegafusion")
from typing import Dict, List, Any, Optional
import torch.nn.functional as F
import os
from torch import nn
from tqdm import tqdm
import numpy as np

device = torch.device(
    "mps" if torch.backends.mps.is_available() 
    else "cuda" if torch.cuda.is_available() 
    else "cpu"
)
print(device)


mps


In [567]:
train_df = pl.read_parquet('../data/processed/train_cg_32_200.parquet')
str_seq = train_df.row(0, named=True)['window_seq']
vocab = {'A':0, 'T':1, 'C':2, 'G':3}
[vocab.get(char) for char in str_seq]

[2,
 2,
 2,
 3,
 2,
 0,
 3,
 1,
 1,
 1,
 1,
 2,
 1,
 2,
 2,
 2,
 3,
 1,
 2,
 2,
 1,
 3,
 3,
 3,
 3,
 2,
 0,
 3,
 1,
 2,
 0,
 3]

In [568]:
def compute_normalization_stats(df):
    kinetic_features = ['window_fi', 'window_fp', 'window_ri', 'window_rp']
    means = {col: df[col].explode().mean() for col in kinetic_features}
    stds = {col: df[col].explode().explode().std() for col in kinetic_features}
    return means, stds
train_means, train_stds = compute_normalization_stats(train_df)
print(train_means, train_stds)

{'window_fi': 30.789639956541873, 'window_fp': 21.96058935632102, 'window_ri': 30.19122800377587, 'window_rp': 22.001033021052258} {'window_fi': 24.00485941139194, 'window_fp': 11.301933232238992, 'window_ri': 23.627035898853727, 'window_rp': 11.337388102877114}


In [569]:
class MethylDataset(Dataset):
    """
    Dataset class for metylation data stored in a parquet file
    Reads data using polars and converts samples to pytorch tensors on get_item call
    Nucleotide sequence is one-hot encoded
    """
    def __init__(self, data_path: Path, 
                 transform=None, 
                 means: Optional[Dict[str, float]] = train_means, 
                 stds: Optional[Dict[str, float]] = train_stds):
        '''
        Arguments:
        data_path: the path for the parquet file that contains either the training or test data (pos and neg)
        transform: an optional transform callable on a single sample
        '''
        self.tranform = transform
        self.means = means
        self.stds = stds

        try:
             self.data = pl.read_parquet(data_path)[:500000]
        except:
             print(f"failed to read data given path: {data_path}")
             self.data = pl.DataFrame()

        self._dataset_len = len(self.data)
        self.vocab = {'A':0, 'T':1, 'C':2, 'G':3}
        self.vocab_size = len(self.vocab)
        self.kinetic_features = ['fi', 'fp', 'ri', 'rp']

    def __len__(self):
          return self._dataset_len
    
    def __getitem__(self, idx):
         if idx >= len(self):
              raise IndexError("Index out of range")
         sample = self.data.row(idx, named=True)
         # sequence data (requires one-hot encoding)
         str_seq = sample['window_seq']
         int_seq_tensor = torch.tensor([self.vocab.get(char) for char in str_seq], dtype=torch.long)
         seq_tensor = F.one_hot(int_seq_tensor, num_classes=self.vocab_size)
         # kinetic data
         fi_tensor = (torch.tensor(sample['window_fi'], dtype=torch.long) - self.means['window_fi'])/self.stds['window_fi']
         fp_tensor = (torch.tensor(sample['window_fp'], dtype=torch.long) - self.means['window_fp'])/self.stds['window_fp']
         ri_tensor = (torch.tensor(sample['window_ri'], dtype=torch.long) - self.means['window_ri'])/self.stds['window_ri']
         rp_tensor = (torch.tensor(sample['window_rp'], dtype=torch.long) - self.means['window_rp'])/self.stds['window_rp']
         label_tensor = torch.tensor(sample['label'], dtype=torch.long) 
         
         return {
              'seq': seq_tensor,
              'fi': fi_tensor,
              'fp': fp_tensor,
              'ri': ri_tensor,
              'rp': rp_tensor,
              'label': label_tensor
              }
    

In [570]:
train_ds = MethylDataset('../data/processed/train_cg_32_10000.parquet')
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)

In [571]:
train_ds.data['label'].mean(), len(train_ds)

(0.503004, 500000)

In [587]:
len_list = np.array([len(train_ds.data.row(i, named=True)['window_fi']) for i in range(len(train_ds))])
np.argmin(len_list)

np.int64(766)

In [572]:
class MethylCNN(nn.Module):
    def __init__(self, sequence_length: int = 32, in_channels:int = 8, num_classes: int = 2):
        super().__init__()
        self.in_channels = in_channels
        self.sequence_length = sequence_length
        self.num_classes = num_classes

        # Convolution layers
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, padding=1)

        # calculate fc layer input with dummy passthrough
        fc_input_features = self._get_conv_output_size(sequence_length)

        # Linear layers
        self.fc1 = nn.Linear(in_features=fc_input_features, out_features=64)
        self.fc2 = nn.Linear(in_features=64, out_features=32)
        self.fc3 = nn.Linear(in_features=32, out_features=num_classes)

    def _extract_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)

        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)


        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool1d(x, kernel_size=2, stride=2)
        
        return x
    
    def _get_conv_output_size(self, sequence_length: int) -> int:
        """
        Calculates the flattened output size of the convolutional layers
        by performing a forward pass on random data of the right shape. 
        """
        dummy_input = torch.randn(1, self.in_channels, sequence_length)
        output = self._extract_features(dummy_input)
        return output.numel()

    def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
        # Assuming batch['seq'] is [B, L, 4]
        seq = batch['seq'].permute(0, 2, 1) # -> [B, 4, L]
        fi = batch['fi'].unsqueeze(1)      # -> [B, 1, L]
        fp = batch['fp'].unsqueeze(1)       # -> [B, 1, L]
        ri = batch['ri'].unsqueeze(1)       # -> [B, 1, L]
        rp = batch['rp'].unsqueeze(1)       # -> [B, 1, L]
        
        # the input is a dictionary, so convert to a tensor
        x = torch.cat([seq, fi, fp, ri, rp], dim=1).to(self.conv1.weight.dtype) # -> [B, 8, L]
        
        x = self._extract_features(x)

        x = torch.flatten(x, 1)

        x = self.fc1(x)
        x = F.relu(x)

        x = self.fc2(x)
        x = F.relu(x)

        logits = self.fc3(x)

        return logits

In [573]:
model = MethylCNN(sequence_length=32)
model.to(device) 

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)



epoch_losses = []
for epoch in range(5):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, batch in enumerate(tqdm(train_dl), 0):
        # get the inputs; data is a list of [inputs, labels]
        labels = batch.pop('label').to(device)
        inputs = {k: v.to(device) for k, v in batch.items()}

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
    avg_epoch_loss = running_loss/len(train_dl)
    epoch_losses.append(avg_epoch_loss)
    print(avg_epoch_loss)

print('Finished Training')

  0%|          | 58/15625 [00:00<01:42, 152.54it/s]


RuntimeError: stack expects each tensor to be equal size, but got [32] at entry 0 and [0] at entry 11

In [None]:
with torch.no_grad():
    model.eval()
    batch = next(iter(train_dl))
    labels = batch.pop('label').to(device)
    inputs: Dict[str, torch.Tensor] = {
                k: v.to(device) for k, v in batch.items()
            }

    print(model(inputs), labels)

tensor([[ 1.2136, -0.3909],
        [-0.1507,  1.0480],
        [-0.6870,  1.2956],
        [ 0.8742, -0.1396],
        [-0.4239,  1.1524],
        [ 0.2829,  0.3885],
        [ 0.0895,  0.5631],
        [ 1.0709, -0.3407],
        [ 0.4605,  0.3100],
        [ 1.3321, -0.2390],
        [ 0.8626, -0.1533],
        [-0.2730,  1.2336],
        [-0.5103,  1.2349],
        [-0.4987,  1.2338],
        [ 0.5909,  0.2851],
        [ 0.5141,  0.1391]], device='mps:0') tensor([0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1], device='mps:0')


In [None]:
loss_df = pl.DataFrame({
    'batch': np.arange(len(epoch_losses)),
    'loss' : epoch_losses
})

alt.Chart(loss_df).mark_circle(opacity=0.4).encode(
    alt.X('batch'),
    alt.Y('loss')
).properties(
    width=500,
    height=500
)

In [None]:
test_ds = MethylDataset('../data/processed/test_cg_32_200.parquet')
test_dl = DataLoader(test_ds, batch_size=4, shuffle=True)

In [None]:
print(model)

MethylCNN(
  (conv1): Conv1d(8, 16, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(16, 32, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv3): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=2, bias=True)
)


In [None]:
def evaluate_model(
    model: nn.Module,
    data_loader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Dict[str, float]:
    model.eval()
    running_loss: float = 0.0
    correct_predictions: int = 0
    total_samples: int = 0

    with torch.no_grad():
        for batch in tqdm(data_loader):
            labels: torch.Tensor = batch.pop("label").to(device)
            inputs: Dict[str, torch.Tensor] = {
                k: v.to(device) for k, v in batch.items()
            }

            logits: torch.Tensor = model(inputs)
            loss: torch.Tensor = criterion(logits, labels)

            running_loss += loss.item() * labels.size(0)

            _, predicted = torch.max(logits.data, 1)
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    epoch_loss: float = running_loss / total_samples
    epoch_acc: float = correct_predictions / total_samples
    return {"loss": epoch_loss, "accuracy": epoch_acc}

In [None]:
evaluate_model(model, test_dl, criterion, device)

100%|██████████| 1755/1755 [00:03<00:00, 475.86it/s]


{'loss': 0.5112555802679415, 'accuracy': 0.7486821484541958}

In [None]:
len(train_dl)

1755