In [77]:
import sys
from pathlib import Path

# add parent folder to the path
module_path = str(Path.cwd().parents[0])
if module_path not in sys.path:
    sys.path.append(module_path)

___

In [78]:
from tqdm import tqdm
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader # load batches to the network

# architectures
from src.models import (
    SimpleNet,
    Rodan
)

# loss function and dataset loader
from src.dataloader import DatasetONT # custom loader (used with DataLoader)
from src.loss_functions import ctc_label_smoothing # custom CTC loss function

In [79]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

___
### Check dataset

In [80]:
output_simplenet = 501
dataset_example = DatasetONT(recfile="../data/subsample_val.hdf5", output_network_len=output_simplenet)
loader = iter(DataLoader(dataset_example, batch_size=2, shuffle=True))
x,y, input_len, target_len = next(loader)
print("Input [batch size, channels, length]:" , x.shape)
print("Output [batch size, length]:" , y.shape)
print("CTC Loss")
print("Input CTC Loss [batch size]", input_len.shape)
print("Target CTC Loss [batch size]", target_len.shape)

Input [batch size, channels, length]: torch.Size([2, 1, 4096])
Output [batch size, length]: torch.Size([2, 271])
CTC Loss
Input CTC Loss [batch size] torch.Size([2])
Target CTC Loss [batch size] torch.Size([2])


In [81]:
input_len, target_len

(tensor([501, 501]), tensor([271, 271]))

___
SimpleNet

In [82]:
simplenet = SimpleNet(n_channels = 1, n_classes = 271)
simplenet.eval()

SimpleNet(
  (conv1): Conv1d(1, 20, kernel_size=(20,), stride=(2,))
  (relu1): ReLU()
  (maxpool1): MaxPool1d(kernel_size=10, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(20, 50, kernel_size=(5,), stride=(1,))
  (relu2): ReLU()
  (maxpool2): MaxPool1d(kernel_size=10, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=50, out_features=5, bias=True)
  (relu3): ReLU()
)

In [83]:
output = simplenet(x)
output.shape

torch.Size([501, 2, 5])

___
RODAN

In [84]:
# from collections import namedtuple
# DEFAULTCONFIG = dict(
#     vocab=["<PAD>", "A", "C", "G", "T"],
#     activation_layer="mish", # options: mish, swish, relu, gelu
#     sqex_activation="mish", # options: mish, swish, relu, gelu
#     dropout=0.1,
#     sqex_reduction=32
# )

# Config=namedtuple("CONFIG",["vocab", "activation", "sqex_activation", "dropout", "sqex_reduction"])

# rodan = Rodan(config=Config(*DEFAULTCONFIG.values()))
# rodan.eval()

In [85]:
# x,y, input_len, target_len = next(loader)
# print("Input [batch size, channels, length]:" , x.shape)
# print("Input [batch size, length]:" , y.shape)

# output_rodan = rodan(x)
# print("output rodan:", output_rodan.shape)

In [86]:
len(dataloader_train), len(dataloader_train.dataset)

(13, 200)

___
## Training

In [87]:
def train(dataloader, model, loss_fn, optimizer, epoch: Optional[int]):
    """
    training function
    """
    size = len(dataloader.dataset) # number of datapoints in the dataset
    n_batches = len(dataloader)    # number of batches
    model.train() # set model in training mode

    with tqdm(total=n_batches) as pbar:
        
        for batch, (X,y, input_len, target_len) in enumerate(dataloader):
            
            # Description for progress bar
            if epoch:
                pbar.set_description(f"Epoch: {epoch} | batch: {batch}/{n_batches}")
            else:
                pbar.set_description(f"batch {batch}/{n_batches}")
            
                
            X, y, input_len, target_len = X.to(device), y.to(device), input_len.to(device), target_len.to(device)

            # Compute prediction error
            pred = model(X)
            loss = loss_fn(pred, y, input_lengths=input_len, target_lengths=target_len)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch % 100 == 0: 
                loss, current = loss.item(), (batch + 1) * len(X)
                # print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")

            # update progress bar
            pbar.update(1)


In [88]:
def test(dataloader, model, loss_fn, model_metadata=None):
    """
    test function

    Accuracy is measured here
    """
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval() # let the model know that is in evaluation model
    test_loss, correct = 0, 0
    with torch.no_grad():
        for batch, (X,y, input_len, target_len) in enumerate(dataloader):
            X, y, input_len, target_len = X.to(device), y.to(device), input_len.to(device), target_len.to(device)

            # Compute prediction error
            pred = model(X)
            test_loss += loss_fn(pred, y, input_lengths=input_len, target_lengths=target_len).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [89]:
# architecture
model = simplenet
# simplenet params
simplenet_output_len = 501
rodan_output_len = 420 

# params
epochs = 5

# dataset
dataset_train = DatasetONT(recfile="../data/subsample_train.hdf5", output_network_len=simplenet_output_len)
dataset_val   = DatasetONT(recfile="../data/subsample_val.hdf5", output_network_len=simplenet_output_len)

dataloader_train = DataLoader(dataset_val, batch_size=16, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=16, shuffle=False)

# loss function and optimizer
loss_fn = nn.CTCLoss() #ctc_label_smoothing # 
optimizer = torch.optim.SGD(model.parameters(), lr = 1e-3)

In [90]:
loss_fn = nn.CTCLoss() #ctc_label_smoothing # 

for batch, (X,y, input_len, target_len) in enumerate(dataloader_train):
    X, y, input_len, target_len = X.to(device), y.to(device), input_len.to(device), target_len.to(device)

    # Compute prediction error
    pred = model(X)
    loss = loss_fn(pred, y, input_lengths=input_len, target_lengths=target_len)

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [91]:
pred.shape, y.shape, input_len

(torch.Size([501, 8, 5]),
 torch.Size([8, 271]),
 tensor([501, 501, 501, 501, 501, 501, 501, 501]))

In [92]:
torch.from_numpy(np.array(561))

tensor(561)

In [93]:
# run training 
for t in range(epochs):
    # print(f"Epoch {t+1}\n-------------------------------")
    epoch = t+1
    train(dataloader_train, model=model, loss_fn=loss_fn, optimizer=optimizer, epoch=epoch)
    # test(dataloader_val, model=model, loss_fn=loss_fn)
print("Done!")

Epoch: 1| batch: 12/13: 100%|██████████| 13/13 [00:08<00:00,  1.54it/s]
Epoch: 2| batch: 12/13: 100%|██████████| 13/13 [00:08<00:00,  1.45it/s]
Epoch: 3| batch: 12/13: 100%|██████████| 13/13 [00:09<00:00,  1.32it/s]
Epoch: 4| batch: 12/13: 100%|██████████| 13/13 [00:08<00:00,  1.51it/s]
Epoch: 5| batch: 12/13: 100%|██████████| 13/13 [00:08<00:00,  1.62it/s]

Done!





In [23]:
# rodan(x)
x.shape

torch.Size([2, 1, 4096])

___
## basecalling 


In [None]:
import numpy as np
from fast_ctc_decode import beam_search, viterbi_search
alphabet = "NACGT"

In [None]:
posteriors = output_rodan[:,1,:].detach().numpy()
seq, path = viterbi_search(posteriors, alphabet)
seq

'T'

In [None]:

posteriors = np.random.rand(100, len(alphabet)).astype(np.float32)
posteriors

In [None]:
seq, path = viterbi_search(posteriors, alphabet)

In [None]:
seq, len(seq)


('GATCTCTATATGTGTATCACAGCAGTCATCTCATCGACGCACTCACT', 47)

In [None]:
path

[0,
 2,
 3,
 6,
 9,
 14,
 16,
 17,
 18,
 20,
 21,
 22,
 23,
 24,
 25,
 28,
 29,
 32,
 35,
 37,
 40,
 41,
 43,
 46,
 48,
 50,
 52,
 53,
 54,
 56,
 58,
 59,
 60,
 61,
 64,
 65,
 66,
 68,
 70,
 72,
 75,
 78,
 82,
 83,
 84,
 85,
 86,
 88,
 90,
 91,
 93,
 94,
 96,
 97,
 98,
 99]

In [None]:
seq, path = beam_search(posteriors, alphabet, beam_size=5, beam_cut_threshold=0.1)
seq, len(seq)

('GATCTCTATATGTGTATCACAGCAGTCATCTCATCGACGCACTCACT', 47)