In [1]:
import sys
from pathlib import Path

# add parent folder to the path
module_path = str(Path.cwd().parents[0])
if module_path not in sys.path:
    sys.path.append(module_path)

___

In [2]:
import numpy as np
import torch
import torch.nn as nn
from src.models.base import SimpleNet
from src.models.rodan import Rodan
from src.dataloader import DatasetONT
from torch.utils.data import DataLoader

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
dataset_train = DatasetONT(recfile="../data/subsample_train.hdf5")
dataset_val   = DatasetONT(recfile="../data/subsample_val.hdf5")

In [5]:
loader = iter(DataLoader(dataset_val, batch_size=2, shuffle=True))

In [6]:
x,y = next(loader)
print("Input [batch size, channels, length]:" , x.shape)
print("Input [batch size, length]:" , y.shape)

Input [batch size, channels, length]: torch.Size([2, 1, 4096])
Input [batch size, length]: torch.Size([2, 271])


___
SimpleNet

In [7]:
simplenet = SimpleNet(n_channels = 1, n_classes = 271)
simplenet.eval()

SimpleNet(
  (conv1): Conv1d(1, 20, kernel_size=(5,), stride=(1,))
  (relu1): ReLU()
  (maxpool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(20, 50, kernel_size=(5,), stride=(1,))
  (relu2): ReLU()
  (maxpool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten1): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=51050, out_features=500, bias=True)
  (relu3): ReLU()
  (fc2): Linear(in_features=500, out_features=271, bias=True)
)

In [8]:
output = simplenet(x)
output.shape

torch.Size([2, 271])

___
RODAN

In [9]:
from collections import namedtuple
DEFAULTCONFIG = dict(
    vocab=["<PAD>", "A", "C", "G", "T"],
    activation_layer="mish", # options: mish, swish, relu, gelu
    sqex_activation="mish", # options: mish, swish, relu, gelu
    dropout=0.1,
    sqex_reduction=32
)

Config=namedtuple("CONFIG",["vocab", "activation", "sqex_activation", "dropout", "sqex_reduction"])


In [10]:
rodan = Rodan(config=Config(*DEFAULTCONFIG.values()))

In [11]:
Config(*DEFAULTCONFIG.values())

CONFIG(vocab=['<PAD>', 'A', 'C', 'G', 'T'], activation='mish', sqex_activation='mish', dropout=0.1, sqex_reduction=32)

In [12]:
rodan.eval()

Rodan(
  (convlayers): Sequential(
    (conv0): ConvBlockRodan(
      (conv): Conv1d(1, 256, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): Mish()
      (sqex): SqueezeExcite(
        (avg): AdaptiveAvgPool1d(output_size=1)
        (fc1): Linear(in_features=256, out_features=32, bias=True)
        (activation): Mish()
        (fc2): Linear(in_features=32, out_features=256, bias=True)
        (sigmoid): Sigmoid()
      )
    )
    (conv1): ConvBlockRodan(
      (depthwise): Conv1d(256, 256, kernel_size=(10,), stride=(1,), padding=(5,), groups=256, bias=False)
      (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): Mish()
      (sqex): SqueezeExcite(
        (avg): AdaptiveAvgPool1d(output_size=1)
        (fc1): Linear(in_features=256, out_features=32, bias=True)
        (activation): Mish()
        (fc2): Linear(in_feat

In [15]:
x,y = next(loader)
print("Input [batch size, channels, length]:" , x.shape)
print("Input [batch size, length]:" , y.shape)

output_rodan = rodan(x)
print("output rodan:", output_rodan.shape)

Input [batch size, channels, length]: torch.Size([2, 1, 4096])
Input [batch size, length]: torch.Size([2, 271])
output rodan: torch.Size([420, 2, 5])


In [16]:
output_rodan[:,1,:]

tensor([[-1.5983, -1.6233, -1.6262, -1.6126, -1.5873],
        [-1.5983, -1.6233, -1.6262, -1.6126, -1.5873],
        [-1.5983, -1.6233, -1.6262, -1.6126, -1.5873],
        ...,
        [-1.5983, -1.6233, -1.6262, -1.6126, -1.5873],
        [-1.5983, -1.6233, -1.6262, -1.6126, -1.5873],
        [-1.5983, -1.6233, -1.6262, -1.6126, -1.5873]],
       grad_fn=<SliceBackward0>)

___
## basecalling 


In [17]:
import numpy as np
from fast_ctc_decode import beam_search, viterbi_search
alphabet = "NACGT"
posteriors = np.random.rand(100, len(alphabet)).astype(np.float32)
posteriors

array([[5.37786186e-01, 3.06539237e-01, 2.38115355e-01, 8.75939310e-01,
        8.38041663e-01],
       [7.73312628e-01, 6.22576654e-01, 2.04668492e-01, 4.43479717e-01,
        9.15474519e-02],
       [3.43880862e-01, 7.86546409e-01, 1.32471815e-01, 8.20510030e-01,
        3.81418914e-02],
       [1.73744440e-01, 7.93899536e-01, 4.90396172e-01, 1.80202708e-01,
        4.00521338e-01],
       [2.82824725e-01, 7.08019912e-01, 4.88876671e-01, 2.18415156e-01,
        3.83160233e-01],
       [1.30367368e-01, 9.97157633e-01, 2.63530105e-01, 2.52526820e-01,
        7.39629745e-01],
       [6.09270334e-01, 4.79452312e-01, 8.21151674e-01, 9.36129928e-01,
        9.30963993e-01],
       [9.81746078e-01, 3.63575488e-01, 8.97481799e-01, 1.72297060e-01,
        4.54032719e-01],
       [9.59204018e-01, 4.28185672e-01, 3.48673701e-01, 6.61315203e-01,
        5.36692321e-01],
       [6.28796220e-01, 2.18091756e-02, 7.18856275e-01, 9.70183536e-02,
        5.35664201e-01],
       [3.61510813e-01, 1.6897

In [19]:
seq, path = viterbi_search(posteriors, alphabet)

In [24]:
seq, len(seq)


('GATCTCTATATGTGTATCACAGCAGTCATCTCATCGACGCACTCACT', 47)

In [21]:
path

[0,
 2,
 3,
 6,
 9,
 14,
 16,
 17,
 18,
 20,
 21,
 22,
 23,
 24,
 25,
 28,
 29,
 32,
 35,
 37,
 40,
 41,
 43,
 46,
 48,
 50,
 52,
 53,
 54,
 56,
 58,
 59,
 60,
 61,
 64,
 65,
 66,
 68,
 70,
 72,
 75,
 78,
 82,
 83,
 84,
 85,
 86,
 88,
 90,
 91,
 93,
 94,
 96,
 97,
 98,
 99]

In [25]:
seq, path = beam_search(posteriors, alphabet, beam_size=5, beam_cut_threshold=0.1)
seq, len(seq)

('GATCTCTATATGTGTATCACAGCAGTCATCTCATCGACGCACTCACT', 47)