### A simple training demo

In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from netam.framework import (
    SHMoofDataset,
    RSSHMBurrito,
)
from epam.torch_common import pick_device

from shmex.shm_data import load_shmoof_dataframes, dataset_dict
from shmex import shm_zoo

Using Metal Performance Shaders


We're just going to use shmoof training data, the same data used to train the context NT model.

In [2]:
site_count = 500
train_df, val_df = load_shmoof_dataframes(dataset_dict["shmoof"], val_nickname="small")
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
subsampled_train_df = train_df.iloc[::int(len(train_df)/5000)].copy().reset_index(drop=True)

In [3]:
device = pick_device()
train_data_5mer = SHMoofDataset(train_df, kmer_length=5, site_count=site_count)
val_data_5mer = SHMoofDataset(val_df, kmer_length=5, site_count=site_count)

train_data_3mer = SHMoofDataset(train_df, kmer_length=3, site_count=site_count)
val_data_3mer = SHMoofDataset(val_df, kmer_length=3, site_count=site_count)

for data in [train_data_5mer, val_data_5mer, train_data_3mer, val_data_3mer]:
    data.to(device)

Using Metal Performance Shaders


In [8]:
model_name = "cnn_joi_sml"
training_method = "full"
burrito_name = f"{model_name}_{training_method}_100"
cnn_model = shm_zoo.create_model(model_name)
cnn_model.to(device)
cnn_burrito = RSSHMBurrito(train_data_3mer, val_data_3mer, cnn_model, name=burrito_name)
cnn_burrito.joint_train(training_method=training_method, epochs=100, cycle_count=5)
cnn_burrito.save_crepe(f"trained_models/{burrito_name}")
cnn_crepe = cnn_burrito.to_crepe()

Finding optimal branch lengths: 100%|██████████| 46391/46391 [04:14<00:00, 182.41it/s]
Finding optimal branch lengths: 100%|██████████| 2625/2625 [00:15<00:00, 164.07it/s]
Epoch: 100%|██████████| 100/100 [26:52<00:00, 16.13s/it, loss_diff=-5.521e-06, lr=0.000781, val_loss=0.06538]
Finding optimal branch lengths: 100%|██████████| 46391/46391 [16:40<00:00, 46.35it/s]
Finding optimal branch lengths: 100%|██████████| 2625/2625 [00:57<00:00, 45.57it/s]
Epoch:  78%|███████▊  | 78/100 [20:28<05:46, 15.76s/it, loss_diff=1.013e-06, lr=6.91e-5, val_loss=0.06531]  
Finding optimal branch lengths: 100%|██████████| 46391/46391 [07:17<00:00, 106.02it/s]
Finding optimal branch lengths: 100%|██████████| 2625/2625 [00:25<00:00, 103.33it/s]
Epoch:  58%|█████▊    | 58/100 [15:36<11:17, 16.14s/it, loss_diff=7.972e-07, lr=8.21e-5, val_loss=0.06529]  
Finding optimal branch lengths: 100%|██████████| 46391/46391 [00:57<00:00, 805.59it/s]
Finding optimal branch lengths: 100%|██████████| 2625/2625 [00:03<00:00