In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F

from netam import framework, models
from netam.common import pick_device

In [2]:
shmoof_data_path = "/Users/matsen/data/shmoof_edges_11-Jan-2023_NoNode0_iqtree_K80+R_masked.csv"
train_df, val_df = framework.load_shmoof_dataframes(shmoof_data_path, val_nickname="59") #, sample_count=5000)

In [3]:
kmer_length = 3
site_count = 410

train_dataset = framework.SHMoofDataset(train_df, kmer_length=kmer_length, site_count=site_count)
val_dataset = framework.SHMoofDataset(val_df, kmer_length=kmer_length, site_count=site_count)

device = pick_device()
train_dataset.to(device)
val_dataset.to(device)

print(f"we have {len(train_dataset)} training examples and {len(val_dataset)} validation examples")

Using Metal Performance Shaders
we have 44330 training examples and 4686 validation examples


In [4]:
model = models.CNNPEModel(kmer_length=kmer_length, embedding_dim=10, filter_count=9, kernel_size=11, dropout_rate=0.1)
model.to(device)

CNNPEModel(
  (kmer_embedding): Embedding(65, 10)
  (conv): Conv1d(10, 9, kernel_size=(11,), stride=(1,), padding=same)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear): Linear(in_features=9, out_features=1, bias=True)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [5]:
burrito = framework.Burrito(train_dataset, val_dataset, model, batch_size=1024, learning_rate=0.1, min_learning_rate=1e-4, l2_regularization_coeff=1e-6)
print("starting training...")
losses = burrito.train(epochs=100)
losses.tail()

starting training...


Epoch:  75%|███████▌  | 75/100 [09:51<03:17,  7.88s/it, loss_diff=-4.884e-09, lr=3.2e-5] 


Unnamed: 0,train_loss,val_loss
71,0.065463,0.061172
72,0.06545,0.061172
73,0.065455,0.061172
74,0.065458,0.061172
75,0.065459,0.061172


For comparison, here's a simple CNN.

In [6]:
model = models.CNNModel(kmer_length=kmer_length, embedding_dim=10, filter_count=9, kernel_size=11, dropout_rate=0.1)
model.to(device)

burrito = framework.Burrito(train_dataset, val_dataset, model, batch_size=1024, learning_rate=0.1, min_learning_rate=1e-4, l2_regularization_coeff=1e-6)
print("starting training...")
losses = burrito.train(epochs=100)
losses.tail()

starting training...


Epoch:  38%|███▊      | 38/100 [06:34<10:43, 10.38s/it, loss_diff=1.563e-06, lr=3.2e-5]  


Unnamed: 0,train_loss,val_loss
34,0.06068,0.057939
35,0.060667,0.057939
36,0.060669,0.05794
37,0.060678,0.057939
38,0.060672,0.05794
