In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F

from netam import framework, models
from netam.common import pick_device

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
shmoof_data_path = "/Users/matsen/data/shmoof_edges_11-Jan-2023_NoNode0_iqtree_K80+R_masked.csv"
train_df, val_df = framework.load_shmoof_dataframes(shmoof_data_path, val_nickname="59") #, sample_count=5000)


In [3]:
kmer_length = 3
max_length = 410

train_dataset = framework.SHMoofDataset(train_df, kmer_length=kmer_length, max_length=max_length)
val_dataset = framework.SHMoofDataset(val_df, kmer_length=kmer_length, max_length=max_length)

device = pick_device()
train_dataset.to(device)
val_dataset.to(device)

print(f"we have {len(train_dataset)} training examples and {len(val_dataset)} validation examples")

Using Metal Performance Shaders
we have 44330 training examples and 4686 validation examples


In [4]:
model = models.CNNPEModel(train_dataset, embedding_dim=10, num_filters=9, kernel_size=11, dropout_rate=0.1)
model.to(device)

CNNPEModel(
  (kmer_embedding): Embedding(65, 10)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (conv): Conv1d(10, 9, kernel_size=(11,), stride=(1,), padding=same)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear): Linear(in_features=9, out_features=1, bias=True)
)

In [5]:
burrito = framework.Burrito(train_dataset, val_dataset, model, batch_size=1024, learning_rate=0.1, min_learning_rate=1e-4, l2_regularization_coeff=1e-6)
print("starting training...")
losses = burrito.train(epochs=100)
losses

starting training...


Epoch:  28%|██▊       | 28/100 [02:38<06:46,  5.65s/it, loss_diff=-2.097e-09, lr=3.2e-5] 


Unnamed: 0,train_loss,val_loss
0,0.067697,0.063182
1,0.067706,0.061345
2,0.065516,0.061148
3,0.065414,0.061138
4,0.065418,0.06125
5,0.065562,0.061381
6,0.065622,0.061158
7,0.065432,0.061158
8,0.065434,0.061157
9,0.065423,0.061157


For comparison, here's a simple CNN.

In [6]:
model = models.CNNModel(train_dataset, embedding_dim=10, num_filters=9, kernel_size=11, dropout_rate=0.1)
model.to(device)

burrito = framework.Burrito(train_dataset, val_dataset, model, batch_size=1024, learning_rate=0.1, min_learning_rate=1e-4, l2_regularization_coeff=1e-6)
print("starting training...")
losses = burrito.train(epochs=100)
losses

starting training...


Epoch:  39%|███▉      | 39/100 [04:52<07:37,  7.50s/it, loss_diff=1.038e-06, lr=3.2e-5]  


Unnamed: 0,train_loss,val_loss
0,0.065794,0.06138
1,0.063026,0.058338
2,0.061412,0.058284
3,0.061225,0.058224
4,0.061164,0.058191
5,0.061087,0.058243
6,0.061083,0.058225
7,0.061105,0.05822
8,0.061054,0.058221
9,0.061059,0.05826
