## CNNMLPModel

Here I test if we do better by inserting a hidden layer in the CNN model.
The answer is no.

In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F

from netam import framework, models
from netam.framework import calculate_loss
from netam.common import pick_device

from netam.common import PositionalEncoding

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
shmoof_data_path = "/Users/matsen/data/shmoof_edges_11-Jan-2023_NoNode0_iqtree_K80+R_masked.csv"
train_df, val_df = framework.load_shmoof_dataframes(shmoof_data_path, val_nickname="59") #, sample_count=5000)


In [3]:
kmer_length = 3
max_length = 410

train_dataset = framework.SHMoofDataset(train_df, kmer_length=kmer_length, max_length=max_length)
val_dataset = framework.SHMoofDataset(val_df, kmer_length=kmer_length, max_length=max_length)

device = pick_device()
train_dataset.to(device)
val_dataset.to(device)

print(f"we have {len(train_dataset)} training examples and {len(val_dataset)} validation examples")

Using Metal Performance Shaders
we have 44330 training examples and 4686 validation examples


In [4]:
class CNNMLPModel(nn.Module):
    def __init__(self, dataset, embedding_dim, num_filters, kernel_size, hidden_dim, dropout_rate=0.1):
        super(CNNMLPModel, self).__init__()
        self.kmer_count = len(dataset.kmer_to_index)
        self.kmer_embedding = nn.Embedding(self.kmer_count, embedding_dim)
        self.conv = nn.Conv1d(in_channels=embedding_dim, out_channels=num_filters, kernel_size=kernel_size, padding='same')
        self.dropout = nn.Dropout(dropout_rate)
        self.hidden_linear = nn.Linear(in_features=num_filters, out_features=hidden_dim)
        self.output_linear = nn.Linear(in_features=hidden_dim, out_features=1)

    def forward(self, encoded_parents, masks):
        kmer_embeds = self.kmer_embedding(encoded_parents)
        kmer_embeds = kmer_embeds.permute(0, 2, 1)  # Transpose for Conv1D
        conv_out = F.relu(self.conv(kmer_embeds))
        conv_out = self.dropout(conv_out)
        conv_out = conv_out.permute(0, 2, 1)  # Transpose back for Linear layer
        hidden_out = F.relu(self.hidden_linear(conv_out))
        log_rates = self.output_linear(hidden_out).squeeze(-1)
        rates = torch.exp(log_rates * masks)

        return rates



model = CNNMLPModel(train_dataset, embedding_dim=10, num_filters=9, kernel_size=11, hidden_dim=5, dropout_rate=0.1)
model.to(device)


CNNMLPModel(
  (kmer_embedding): Embedding(65, 10)
  (conv): Conv1d(10, 9, kernel_size=(11,), stride=(1,), padding=same)
  (dropout): Dropout(p=0.1, inplace=False)
  (hidden_linear): Linear(in_features=9, out_features=5, bias=True)
  (output_linear): Linear(in_features=5, out_features=1, bias=True)
)

In [5]:
burrito = framework.Burrito(train_dataset, val_dataset, model, batch_size=1024, learning_rate=0.1, min_learning_rate=1e-4, l2_regularization_coeff=1e-6)
print("starting training...")
losses = burrito.train(epochs=100)
losses

starting training...


Epoch:  37%|███▋      | 37/100 [05:44<09:46,  9.31s/it, loss_diff=-1.351e-06, lr=3.2e-5] 


Unnamed: 0,train_loss,val_loss
0,0.065654,0.061373
1,0.063354,0.058455
2,0.061467,0.058362
3,0.061244,0.058237
4,0.061101,0.058254
5,0.06108,0.05829
6,0.061052,0.058312
7,0.061026,0.058234
8,0.061034,0.058239
9,0.06084,0.05816


For comparison, here's a simple CNN.

In [6]:
model = models.CNNModel(train_dataset, embedding_dim=10, num_filters=9, kernel_size=11, dropout_rate=0.1)
model.to(device)

burrito = framework.Burrito(train_dataset, val_dataset, model, batch_size=1024, learning_rate=0.1, min_learning_rate=1e-4, l2_regularization_coeff=1e-6)
print("starting training...")
losses = burrito.train(epochs=100)
losses

starting training...


Epoch:  34%|███▍      | 34/100 [04:57<09:38,  8.76s/it, loss_diff=7.136e-07, lr=3.2e-5]  


Unnamed: 0,train_loss,val_loss
0,0.065707,0.061421
1,0.062443,0.058267
2,0.061315,0.058295
3,0.061204,0.058136
4,0.061114,0.058215
5,0.061115,0.058206
6,0.061119,0.058082
7,0.061072,0.058211
8,0.061078,0.058248
9,0.061053,0.058174
