In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 150

from netam.common import print_parameter_count, pick_device

from netam import framework, models

from shmex.shm_data import train_val_dfs_of_nickname
from shmex.shm_zoo import default_burrito_params
from shmex.local import localify

figures_dir = localify('FIGURES_DIR')

device = pick_device()

Using Metal Performance Shaders
Using Metal Performance Shaders


In [2]:
kmer_length = 3
site_count = 500
weight_decay = 1e-6

model_params = {
    "kmer_length": kmer_length,
    "kernel_size": 11,
    "embedding_dim": 6, # Slightly modified from `lrg` so that the embedding dimension is even (required for the positional encoding)
    "filter_count": 19,
    "dropout_prob": 0.3,
}

train_df, val_df = train_val_dfs_of_nickname('shmoof_small')

train_dataset = framework.SHMoofDataset(train_df, kmer_length=kmer_length, site_count=site_count)
val_dataset = framework.SHMoofDataset(val_df, kmer_length=kmer_length, site_count=site_count)

train_dataset.to(device)
val_dataset.to(device)

print(f"we have {len(train_dataset)} training examples and {len(val_dataset)} validation examples")

Interpreting shmoof_small as a shmoof dataset
we have 46391 training examples and 2625 validation examples


In [3]:
model = models.CNNPEModel(**model_params)
model.to(device)

burrito = framework.SHMBurrito(train_dataset, val_dataset, model, **default_burrito_params)
losses = burrito.train(epochs=100)
losses.tail()

Epoch: 100%|██████████| 100/100 [13:13<00:00,  7.94s/it, loss_diff=1.36e-06, lr=6.25e-5, val_loss=0.05566]  


Unnamed: 0,train_loss,val_loss
95,0.062211,0.055658
96,0.062208,0.055659
97,0.06221,0.055658
98,0.06221,0.055657
99,0.062226,0.055659


For comparison, here's a simple CNN.

In [4]:
model = models.CNNModel(**model_params)
model.to(device)

burrito = framework.SHMBurrito(train_dataset, val_dataset, model, **default_burrito_params)
losses = burrito.train(epochs=100)
losses.tail()

Epoch: 100%|██████████| 100/100 [13:31<00:00,  8.11s/it, loss_diff=-6.184e-07, lr=0.000125, val_loss=0.05553]


Unnamed: 0,train_loss,val_loss
95,0.061249,0.055531
96,0.061259,0.055534
97,0.061256,0.055533
98,0.061233,0.055529
99,0.06124,0.055529
