In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from netam import framework, models
from netam.framework import calculate_loss
from netam.common import pick_device, PositionalEncoding

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
shmoof_data_path = "/Users/matsen/data/shmoof_edges_11-Jan-2023_NoNode0_iqtree_K80+R_masked.csv"
train_df, val_df = framework.load_shmoof_dataframes(shmoof_data_path, val_nickname="59") #, sample_count=5000)


In [3]:
kmer_length = 3
max_length = 410

train_dataset = framework.SHMoofDataset(train_df, kmer_length=kmer_length, max_length=max_length)
val_dataset = framework.SHMoofDataset(val_df, kmer_length=kmer_length, max_length=max_length)

device = pick_device()
train_dataset.to(device)
val_dataset.to(device)

print(f"we have {len(train_dataset)} training examples and {len(val_dataset)} validation examples")

Using Metal Performance Shaders
we have 44330 training examples and 4686 validation examples


In [4]:
class CNNXfModel(nn.Module):
    def __init__(self, dataset, embedding_dim, num_filters, kernel_size, nhead, dim_feedforward, num_transformer_layers, dropout_prob=0.1):
        super(CNNXfModel, self).__init__()
        self.kmer_count = len(dataset.kmer_to_index)
        self.kmer_embedding = nn.Embedding(self.kmer_count, embedding_dim)
        self.pos_encoder = PositionalEncoding(embedding_dim, dropout=dropout_prob) 
        self.conv = nn.Conv1d(in_channels=embedding_dim, out_channels=num_filters, kernel_size=kernel_size, padding='same')
        self.transformer_encoder_layer = TransformerEncoderLayer(
            d_model=num_filters,  # This should match the number of filters in the last conv layer
            nhead=nhead,
            dim_feedforward=dim_feedforward
        )
        self.transformer_encoder = TransformerEncoder(self.transformer_encoder_layer, num_layers=num_transformer_layers)
        self.dropout = nn.Dropout(dropout_prob)
        self.linear = nn.Linear(in_features=num_filters, out_features=1)

    def forward(self, encoded_parents, masks):
        kmer_embeds = self.kmer_embedding(encoded_parents)
        kmer_embeds = self.pos_encoder(kmer_embeds)
        kmer_embeds = kmer_embeds.permute(0, 2, 1)
        conv_out = F.relu(self.conv(kmer_embeds))
        conv_out = self.dropout(conv_out)
        conv_out = conv_out.permute(0, 2, 1)
        transformer_out = self.transformer_encoder(conv_out)
        log_rates = self.linear(transformer_out).squeeze(-1)
        rates = torch.exp(log_rates * masks)

        return rates

model = CNNXfModel(train_dataset, embedding_dim=10, num_filters=10, kernel_size=11, nhead=2, dim_feedforward=64, num_transformer_layers=2)

model.to(device)



CNNXfModel(
  (kmer_embedding): Embedding(65, 10)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (conv): Conv1d(10, 10, kernel_size=(11,), stride=(1,), padding=same)
  (transformer_encoder_layer): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=10, out_features=10, bias=True)
    )
    (linear1): Linear(in_features=10, out_features=64, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=64, out_features=10, bias=True)
    (norm1): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((10,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuant

In [5]:
burrito = framework.SHMBurrito(train_dataset, val_dataset, model, batch_size=1024, learning_rate=0.1, min_learning_rate=1e-4, l2_regularization_coeff=1e-6)
print("starting training...")
losses = burrito.train(epochs=100)

starting training...


Epoch:  26%|██▌       | 26/100 [22:29<1:04:01, 51.91s/it, loss_diff=2.097e-09, lr=3.2e-5]  


In [6]:
losses

Unnamed: 0,train_loss,val_loss
0,0.067418,0.062881
1,0.065834,0.061159
2,0.065424,0.061157
3,0.065423,0.061157
4,0.065422,0.061157
5,0.065422,0.061157
6,0.065423,0.061157
7,0.065423,0.061157
8,0.065421,0.061157
9,0.065425,0.061157
