### A simple training demo

In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from netam.framework import (
    SHMoofDataset,
    RSSHMBurrito,
)
from epam.torch_common import pick_device

from shmex.shm_data import load_shmoof_dataframes, dataset_dict
from shmex import shm_zoo

Using Metal Performance Shaders


We're just going to use shmoof training data, the same data used to train the context NT model.

In [2]:
site_count = 500
train_df, val_df = load_shmoof_dataframes(dataset_dict["shmoof"], val_nickname="small")
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
subsampled_train_df = train_df.iloc[::int(len(train_df)/5000)].copy().reset_index(drop=True)

In [3]:
device = pick_device()
train_data_5mer = SHMoofDataset(train_df, kmer_length=5, site_count=site_count)
val_data_5mer = SHMoofDataset(val_df, kmer_length=5, site_count=site_count)

train_data_3mer = SHMoofDataset(train_df, kmer_length=3, site_count=site_count)
val_data_3mer = SHMoofDataset(val_df, kmer_length=3, site_count=site_count)

for data in [train_data_5mer, val_data_5mer, train_data_3mer, val_data_3mer]:
    data.to(device)

Using Metal Performance Shaders


In [4]:
fivemer_model = shm_zoo.create_model("fivemer")
fivemer_model.to(device)
fivemer_burrito = RSSHMBurrito(train_data_5mer, val_data_5mer, fivemer_model, name="fivemer")
fivemer_loss_history = fivemer_burrito.train(epochs=100)
fivemer_burrito.save_crepe("trained_models/fivemer")
fivemer_crepe = fivemer_burrito.to_crepe()

Epoch: 100%|██████████| 100/100 [18:22<00:00, 11.03s/it, loss_diff=9.313e-07, lr=0.00313, val_loss=0.06544]


In [5]:
cnn_model = shm_zoo.create_model("cnn_joi_sml")
cnn_model.to(device)
cnn_burrito = RSSHMBurrito(train_data_3mer, val_data_3mer, cnn_model, name="cnn_ind_lrg")
cnn_burrito.joint_train()
cnn_burrito.save_crepe("trained_models/cnn_ind_lrg")
cnn_crepe = cnn_burrito.to_crepe()

Finding optimal branch lengths: 100%|██████████| 46391/46391 [03:23<00:00, 228.41it/s]
Finding optimal branch lengths: 100%|██████████| 2625/2625 [00:10<00:00, 240.01it/s]
Epoch: 100%|██████████| 100/100 [23:44<00:00, 14.24s/it, loss_diff=-1.386e-06, lr=0.00313, val_loss=0.06532]
Finding optimal branch lengths: 100%|██████████| 46391/46391 [03:10<00:00, 243.91it/s]
Finding optimal branch lengths: 100%|██████████| 2625/2625 [00:09<00:00, 282.21it/s]
Epoch: 100%|██████████| 100/100 [22:20<00:00, 13.41s/it, loss_diff=-3.546e-06, lr=0.00156, val_loss=0.06539]
