In [1]:
import sys
from pathlib import Path

# add parent folder to the path
module_path = str(Path.cwd().parents[0])
if module_path not in sys.path:
    sys.path.append(module_path)

___

In [41]:
# builtin libraries
import re
import logging
from typing import Union, Optional
from pathlib import Path
from collections import namedtuple, defaultdict
from functools import partial

# secondary libraries
import numpy as np
import parasail
import torch
import torch.nn as nn
from torch.utils.data import DataLoader # load batches to the network
from fast_ctc_decode import beam_search, viterbi_search
from tqdm import tqdm

# feito 
# from basecaller_tester import BasecallerTester as Tester
from feito.models import SimpleNet, Rodan
from feito.dataloaders import DatasetONT, DatasetBasecalling
from feito.api.tester import BasecallerTester
# ---- 


In [53]:
MODEL="Rodan"
PATH_CHECKPOINT="../output-rodan/training/checkpoints/Rodan-epoch29.pt"
BATCH_SIZE=8
NUM_WORKERS=4
device=torch.device("cpu")

model=eval(f"{MODEL}()")
model.to(device)
if device.type == "cpu":
    model.load_state_dict(torch.load(PATH_CHECKPOINT, map_location=torch.device('cpu')))
else: 
    model.load_state_dict(torch.load(PATH_CHECKPOINT))

model_output_len=model.output_len
model.eval()

CONFIG(vocab=['<PAD>', 'A', 'C', 'G', 'T'], activation='mish', sqex_activation='mish', dropout=0.1, sqex_reduction=32)
Activation Function is: mish
Activation Function is: mish


Rodan(
  (convlayers): Sequential(
    (conv0): ConvBlockRodan(
      (conv): Conv1d(1, 256, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
      (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): Mish()
      (sqex): SqueezeExcite(
        (avg): AdaptiveAvgPool1d(output_size=1)
        (fc1): Linear(in_features=256, out_features=32, bias=True)
        (activation): Mish()
        (fc2): Linear(in_features=32, out_features=256, bias=True)
        (sigmoid): Sigmoid()
      )
    )
    (conv1): ConvBlockRodan(
      (depthwise): Conv1d(256, 256, kernel_size=(10,), stride=(1,), padding=(5,), groups=256, bias=False)
      (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): Mish()
      (sqex): SqueezeExcite(
        (avg): AdaptiveAvgPool1d(output_size=1)
        (fc1): Linear(in_features=256, out_features=32, bias=True)
        (activation): Mish()
        (fc2): Linear(in_feat

### 1. Try with hdf5 files


In [54]:
PATH_TEST="../data/subsample_val.hdf5"

# dataset
dataset_test = DatasetONT(recfile=PATH_TEST, output_network_len=model_output_len)
dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

In [55]:
X,y, input_lens , target_lens  = next(iter(dataloader_test))

In [115]:
X

tensor([[[-0.8231, -0.8631, -0.7259,  ..., -1.0803, -1.1375, -1.1318]],

        [[-1.1146, -1.1318, -1.0918,  ..., -0.3658, -0.1658, -0.2801]],

        [[-0.4973, -0.2401, -0.3144,  ...,  0.4744,  0.4573,  0.6345]],

        ...,

        [[-0.4417, -0.7461, -0.6864,  ..., -0.3880, -0.2865, -0.4119]],

        [[-0.2865, -0.1492, -0.4119,  ..., -0.0478,  0.2686,  0.2328]],

        [[ 0.1970,  0.2089,  0.1552,  ..., -0.3820, -0.2149,  0.0716]]])

In [56]:
X.shape, y.shape , input_lens.shape, target_lens.shape


(torch.Size([8, 1, 4096]),
 torch.Size([8, 271]),
 torch.Size([8]),
 torch.Size([8]))

In [57]:
pred = model(X)

In [58]:
pred

tensor([[[-6.0028e-02, -4.3367e+00, -4.6725e+00, -3.4958e+00, -5.2015e+00],
         [-1.0876e-02, -5.4211e+00, -6.4891e+00, -5.6675e+00, -6.5578e+00],
         [-1.7277e-02, -4.7086e+00, -6.0435e+00, -6.7250e+00, -5.3955e+00],
         ...,
         [-3.3899e-02, -3.9264e+00, -5.9860e+00, -4.9651e+00, -5.4905e+00],
         [-2.0324e-02, -5.5067e+00, -5.9351e+00, -4.4076e+00, -6.7003e+00],
         [-2.4959e-02, -4.8906e+00, -5.2342e+00, -4.7601e+00, -5.7331e+00]],

        [[-6.1985e-02, -4.3117e+00, -4.7206e+00, -3.4071e+00, -5.3717e+00],
         [-1.2549e-02, -5.2570e+00, -6.3048e+00, -5.5417e+00, -6.4945e+00],
         [-1.7589e-02, -4.5786e+00, -5.9410e+00, -6.5581e+00, -5.7705e+00],
         ...,
         [-5.4407e-02, -3.2913e+00, -6.2717e+00, -4.5012e+00, -5.8914e+00],
         [-2.2678e-02, -5.8233e+00, -5.5463e+00, -4.2953e+00, -6.2500e+00],
         [-2.8894e-02, -4.5379e+00, -5.1566e+00, -4.7875e+00, -5.6021e+00]],

        [[-7.3424e-02, -4.4914e+00, -3.9847e+00, -3.4179

In [62]:
pred.min(), pred.max()

(tensor(-51.3859, grad_fn=<MinBackward1>), tensor(0., grad_fn=<MaxBackward1>))

In [92]:
torch.softmax(pred, dim=-1)

tensor([[[9.4174e-01, 1.3080e-02, 9.3486e-03, 3.0325e-02, 5.5083e-03],
         [9.8918e-01, 4.4222e-03, 1.5199e-03, 3.4564e-03, 1.4190e-03],
         [9.8287e-01, 9.0176e-03, 2.3733e-03, 1.2005e-03, 4.5370e-03],
         ...,
         [9.6667e-01, 1.9714e-02, 2.5137e-03, 6.9770e-03, 4.1258e-03],
         [9.7988e-01, 4.0596e-03, 2.6448e-03, 1.2184e-02, 1.2305e-03],
         [9.7535e-01, 7.5169e-03, 5.3311e-03, 8.5650e-03, 3.2371e-03]],

        [[9.3990e-01, 1.3411e-02, 8.9102e-03, 3.3136e-02, 4.6463e-03],
         [9.8753e-01, 5.2109e-03, 1.8276e-03, 3.9199e-03, 1.5118e-03],
         [9.8256e-01, 1.0269e-02, 2.6293e-03, 1.4186e-03, 3.1182e-03],
         ...,
         [9.4705e-01, 3.7206e-02, 1.8890e-03, 1.1096e-02, 2.7630e-03],
         [9.7758e-01, 2.9578e-03, 3.9019e-03, 1.3633e-02, 1.9305e-03],
         [9.7152e-01, 1.0696e-02, 5.7614e-03, 8.3337e-03, 3.6902e-03]],

        [[9.2921e-01, 1.1205e-02, 1.8598e-02, 3.2781e-02, 8.2091e-03],
         [9.8129e-01, 7.5825e-03, 2.9351e-03,

In [117]:
y[0]

tensor([2, 4, 3, 2, 4, 3, 2, 4, 1, 4, 1, 3, 2, 1, 3, 2, 3, 3, 2, 3, 3, 2, 3, 2,
        2, 2, 2, 4, 3, 2, 3, 2, 4, 2, 3, 4, 2, 4, 4, 4, 3, 3, 1, 2, 2, 1, 2, 3,
        1, 2, 2, 4, 1, 1, 3, 1, 2, 4, 3, 4, 3, 3, 4, 1, 2, 2, 4, 4, 1, 2, 4, 4,
        2, 4, 3, 2, 4, 4, 2, 1, 1, 2, 4, 4, 2, 1, 3, 2, 3, 3, 2, 4, 4, 2, 3, 4,
        2, 1, 1, 3, 3, 1, 3, 1, 4, 3, 2, 4, 4, 4, 4, 4, 1, 3, 2, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0])

In [118]:
torch.softmax(pred, dim=-1).argmax(dim=2)[:,0]

tensor([0, 0, 0, 0, 0, 3, 3, 0, 0, 0, 0, 0, 4, 4, 0, 4, 3, 3, 0, 0, 0, 0, 0, 2,
        3, 3, 0, 0, 0, 0, 4, 4, 0, 4, 1, 1, 0, 0, 0, 0, 4, 1, 1, 3, 3, 0, 0, 0,
        0, 2, 4, 1, 1, 3, 3, 3, 0, 0, 0, 2, 2, 3, 3, 0, 0, 0, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 0, 4, 4, 3, 3, 0, 0, 0, 0, 0, 4, 3,
        3, 3, 0, 0, 0, 0, 0, 4, 4, 4, 0, 0, 0, 0, 0, 0, 4, 4, 4, 4, 3, 0, 0, 0,
        0, 0, 0, 4, 4, 0, 0, 0, 0, 4, 4, 0, 0, 4, 4, 3, 3, 0, 0, 0, 3, 3, 0, 0,
        0, 0, 1, 2, 2, 2, 0, 0, 4, 3, 3, 0, 0, 0, 0, 3, 0, 0, 0, 2, 2, 2, 2, 2,
        2, 0, 0, 0, 0, 0, 0, 4, 4, 1, 1, 1, 0, 0, 1, 1, 3, 3, 3, 0, 0, 0, 3, 2,
        2, 0, 2, 4, 3, 3, 0, 0, 0, 4, 4, 3, 3, 3, 0, 0, 3, 3, 0, 4, 4, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 2, 2, 1, 1, 1, 0, 0, 0, 0, 2, 2, 2, 2, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 4, 4, 0, 0, 4, 3, 3, 0, 0, 0, 0, 2, 2, 0, 0, 0, 2,
        2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 0, 0, 1, 1, 3, 3, 0, 0, 0, 2, 2,
        3, 3, 3, 0, 0, 1, 1, 1, 0, 2, 0,

### 2. Try with fast5 files



In [85]:
dataset_basecalling = DatasetBasecalling(
    ["/projects5/basecalling-jorge/basecalling/data/RODAN/test/mouse-dataset/0/0a0bf68b-3b64-4fc6-ba34-d853db589f4b.fast5",
     "/projects5/basecalling-jorge/basecalling/data/RODAN/test/mouse-dataset/0/0a8787dc-a4b9-45da-b4e0-8711ec36897e.fast5"
     ], path_save_index="../output/basecalling/index.csv")
basecalling_dataloader = DataLoader(dataset_basecalling, batch_size=BATCH_SIZE, shuffle=False)

Creatind Index for reads:   0%|          | 0/2 [00:00<?, ?it/s]

Creatind Index for reads: 100%|██████████| 2/2 [00:00<00:00,  3.38it/s]


In [86]:
X = next(iter(basecalling_dataloader))
X.shape

torch.Size([8, 1, 4096])

In [87]:
predb = model(X)
predb

tensor([[[-2.3039e-03, -7.0188e+00, -8.3547e+00, -9.4399e+00, -6.8200e+00],
         [-6.3417e-05, -1.1120e+01, -1.4007e+01, -1.2508e+01, -1.0030e+01],
         [-6.0873e-02, -4.1004e+00, -4.5835e+00, -4.5066e+00, -3.8520e+00],
         ...,
         [-5.4221e-03, -6.8140e+00, -6.5006e+00, -8.5874e+00, -5.9445e+00],
         [-4.1723e-06, -1.3853e+01, -1.6901e+01, -1.5875e+01, -1.2671e+01],
         [-1.1572e-01, -3.0363e+00, -3.9706e+00, -5.1144e+00, -3.3134e+00]],

        [[-1.6385e-03, -7.5275e+00, -8.5046e+00, -9.7650e+00, -7.0832e+00],
         [-5.3643e-05, -1.1421e+01, -1.4264e+01, -1.2837e+01, -1.0140e+01],
         [-9.0112e-02, -4.1333e+00, -3.9173e+00, -4.9427e+00, -3.1440e+00],
         ...,
         [-5.5851e-03, -6.5755e+00, -6.5116e+00, -8.5503e+00, -5.9932e+00],
         [-4.6492e-06, -1.3574e+01, -1.7253e+01, -1.5772e+01, -1.2639e+01],
         [-1.3911e-01, -2.8214e+00, -4.2555e+00, -4.5428e+00, -3.0897e+00]],

        [[-1.1274e-03, -8.0260e+00, -8.6139e+00, -1.0201

In [88]:
predb.min(), predb.max()


(tensor(-29.9956, grad_fn=<MinBackward1>), tensor(0., grad_fn=<MaxBackward1>))

In [106]:
torch.softmax(predb[:,0,:], dim=-1)# .sum(axis=1)

tensor([[9.9770e-01, 8.9490e-04, 2.3528e-04, 7.9492e-05, 1.0917e-03],
        [9.9836e-01, 5.3806e-04, 2.0254e-04, 5.7429e-05, 8.3907e-04],
        [9.9887e-01, 3.2685e-04, 1.8156e-04, 3.7140e-05, 5.8124e-04],
        ...,
        [9.7050e-01, 8.3990e-03, 4.5087e-03, 2.3445e-03, 1.4250e-02],
        [9.5921e-01, 9.4525e-03, 5.5158e-03, 2.6470e-03, 2.3177e-02],
        [9.4715e-01, 1.0992e-02, 6.5677e-03, 3.1322e-03, 3.2159e-02]],
       grad_fn=<SoftmaxBackward0>)

In [103]:
torch.softmax(predb, dim=-1).argmax(dim=2)[:,3] #.sum(axis=2)
# torch.sum(
#     torch.softmax(predb, dim=-1)
#           , dim=2)

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 4, 3, 3, 3, 0, 0, 0, 3, 3, 3, 3, 0, 4,
        4, 0, 2, 3, 3, 0, 0, 0, 0, 0, 2, 2, 3, 3, 0, 0, 0, 0, 4, 0, 0, 3, 0, 0,
        0, 2, 2, 0, 2, 2, 3, 3, 0, 0, 3, 2, 2, 2, 2, 3, 0, 0, 0, 1, 1, 0, 3, 3,
        3, 3, 3, 3, 3, 3, 0, 0, 0, 3, 0, 0, 2, 2, 3, 0, 0, 0, 0, 4, 0, 4, 1, 3,
        3, 0, 0, 0, 1, 0, 0, 4, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        4, 4, 4, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 3, 3, 0, 0,
        0, 0, 4, 4, 0, 1, 1, 0, 0, 0, 4, 0, 0, 4, 3, 3, 0, 0, 2, 2, 0, 0, 0, 0,
        0, 0, 4, 4, 0, 2, 2, 1, 0, 0, 0, 0, 1, 1, 1, 3, 3, 0, 0, 0, 0, 2, 2, 2,
        0, 0, 2, 2, 3, 3, 0, 0, 0, 0, 0, 0, 0, 1, 1, 3, 3, 0, 0, 0, 0, 0, 0, 4,
        4, 0, 4, 4, 1, 1, 1, 1, 1, 1, 1, 2, 0, 0, 2, 2, 3, 3, 0, 1, 1, 1, 3, 3,
        3, 3, 3, 3, 2, 2, 2, 0, 0, 2, 2, 3, 3, 3, 0, 1, 1, 2, 2, 4, 0, 0, 4, 3,
        3, 0, 0, 0, 0, 2, 2, 1, 3, 3, 0, 0, 0, 0, 0, 0, 2, 3, 3, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 4, 4, 1, 1, 0, 0,

In [84]:
np.transpose(np.argmax(predb.detach().numpy(), -1), (1, 0))

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [114]:
item = 0
# signal = torch.softmax(predb[:,item,:], dim=-1).detach().numpy()
signal = predb[:,item,:].detach().numpy()
seq , path = beam_search( signal, "NACGT", beam_size=5, beam_cut_threshold=0.1)
seq

RuntimeError: Ran out of search space (beam_cut_threshold too high)

___

In [None]:

# types
_Path = Union[Path,str]

class BasecallerTester:
    
    split_cigar = re.compile(r"(?P<len>\d+)(?P<op>\D+)")
    
    def __init__(self, model, device, test_loader, path_fasta: Optional[_Path] = None, rna: bool = True, use_viterbi = True):
        self.model  = model.to(device) # model with pretrained weigths loaded
        self.device = device
        self.test_loader = test_loader # load signals
        self.batch_size  = test_loader.batch_size
        self.path_fasta  = path_fasta # to save basecalled raw-reads (if not None)
        self.rna = rna 
        self.alphabet    = "NACGU" if rna else "NACGT"
        self.use_viterbi = use_viterbi
        self.search_algo = viterbi_search if use_viterbi else beam_search

        # set evaluation/inference mode
        self.model.eval()

        # map integers to characters in the alphabet
        self.int2char = {i:c for i,c in enumerate(self.alphabet.replace("N",""), start=1)}

    def __call__(self, return_basecalled_signals: bool=True):
        print("Que me dice")
        # inference
        accuracies, basecalled_signals = self.accuracy_all_dataset()

        accuracy = np.array(accuracies).mean()

        if self.path_fasta:
            # create parent directory if it does not exists
            Path(self.path_fasta).parent.mkdir(exist_ok=True, parents=True)

            # save basecalled signals to a fasta file
            
        
        if return_basecalled_signals:
            return accuracy, basecalled_signals
        
        return accuracy
    
    def basecall_one_batch(self, X):
        "Return basecalled signals in the chosen alphabet"

        preds  = self.model(X) # preds shape: (len-signal, item, size-alphabet)
        basecalled_signals = list(
            self.signal_to_read(signal=preds[:,item,:].detach().numpy(), use_viterbi=self.use_viterbi, rna=self.rna) 
            for item in range(preds.shape[1])
            )

        return basecalled_signals

    def label_to_alphabet(self, label):
        "Map vector of integers to sequence in DNA or RNA alphabet"
        
        return "".join([self.int2char[i] for i in label if i > 0])
    
    def accuracy_one_batch(self, batch):

        X, y, output_len, target_len = (x.to(self.device) for x in batch)

        basecalled_signals = self.basecall_one_batch(X)
        ground_truth = np.apply_along_axis(lambda l: self.label_to_alphabet(l), 1, y.detach().numpy()) 
        accuracy_batch = [self.accuracy(ref=gt, seq=bs) for gt,bs in zip(basecalled_signals, ground_truth)]
        
        return accuracy_batch, basecalled_signals
    
    def accuracy_all_dataset(self,):
        "Returns a list with accuracies and another list with basecalled signals"
        basecalled_signals = []
        accuracies = []
        n_batches=len(self.test_loader)
    
        with tqdm(total=n_batches, leave=True, ncols=100, bar_format='{l_bar}{bar}| [{elapsed}{postfix}]') as progress_bar:

            for n_batch, batch in enumerate(self.test_loader):

                progress_bar.set_description(f"Evaluating | Batch: {n_batch+1}/{n_batches}")
                accuracy_batch, basecalled_signals_batch = self.accuracy_one_batch(batch)
                
                accuracies.extend(accuracy_batch)
                basecalled_signals.extend(basecalled_signals_batch)
                
                # progress_bar.set_postfix(train_loss='%.4f' % current_avg_loss)
                progress_bar.update(1)

        return accuracies, basecalled_signals

    def signal_to_read(self, signal, use_viterbi: bool = True, rna: bool = True):
        "Apply viterbi or beam search to a signal"
        
        if use_viterbi is True:
            seq, path = viterbi_search(signal, self.alphabet) 
        else:
            seq, path = beam_search(signal, self.alphabet, beam_size=5, beam_cut_threshold=0.1)

        return seq

    def accuracy(self, ref, seq, balanced=False, min_coverage=0.0):
        # From https://github.com/nanoporetech/bonito/blob/655feea4bca17feb77957c7f8be5077502292bcf/bonito/util.py#L354
        """
        Calculate the accuracy between `ref` and `seq`
        """
        # alignment = parasail.sw_trace_striped_32(seq, ref, 8, 4, parasail.dnafull) # this crashed, no meaningful error message
        alignment = parasail.sw_trace(seq, ref, 8, 4, parasail.dnafull)
        counts = defaultdict(int)

        q_coverage = len(alignment.traceback.query) / len(seq)
        r_coverage = len(alignment.traceback.ref) / len(ref)

        if r_coverage < min_coverage:
            return 0.0

        _, cigar = self.parasail_to_sam(alignment, seq)

        for count, op  in re.findall(self.split_cigar, cigar):
            counts[op] += int(count)

        if balanced:
            accuracy = (counts['='] - counts['I']) / (counts['='] + counts['X'] + counts['D'])
        else:
            accuracy = counts['='] / (counts['='] + counts['I'] + counts['X'] + counts['D'])
        return accuracy * 100


    def parasail_to_sam(self, result, seq):
        # From https://github.com/nanoporetech/bonito/blob/655feea4bca17feb77957c7f8be5077502292bcf/bonito/util.py#L321
        """
        Extract reference start and sam compatible cigar string.

        :param result: parasail alignment result.
        :param seq: query sequence.

        :returns: reference start coordinate, cigar string.
        """
        cigstr = result.cigar.decode.decode()
        first = re.search(self.split_cigar, cigstr)

        first_count, first_op = first.groups()
        prefix = first.group()
        rstart = result.cigar.beg_ref
        cliplen = result.cigar.beg_query

        clip = '' if cliplen == 0 else '{}S'.format(cliplen)
        if first_op == 'I':
            pre = '{}S'.format(int(first_count) + cliplen)
        elif first_op == 'D':
            pre = clip
            rstart = int(first_count)
        else:
            pre = '{}{}'.format(clip, prefix)

        mid = cigstr[len(prefix):]
        end_clip = len(seq) - result.end_query - 1
        suf = '{}S'.format(end_clip) if end_clip > 0 else ''
        new_cigstr = ''.join((pre, mid, suf))
        return rstart, new_cigstr        

In [50]:
Args=namedtuple("Args", ["path_test", "batch_size", "model", "path_checkpoint", "device"])
args = Args(
"../data/subsample_val.hdf5",
16,
"SimpleNet",
"../output/training/checkpoints/SimpleNet-epoch1.pt",
None,
)
    
PATH_TEST=args.path_test
BATCH_SIZE=args.batch_size
MODEL=args.model
DEVICE=args.device
PATH_CHECKPOINT=args.path_checkpoint

if DEVICE is None:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else: 
    device = DEVICE
print("Device" , device)

model=eval(f"{MODEL}()")
model.to(device)
if device.type == "cpu":
    model.load_state_dict(torch.load(PATH_CHECKPOINT, map_location=torch.device('cpu')))
else: 
    model.load_state_dict(torch.load(PATH_CHECKPOINT))
model_output_len = model.output_len

# dataset
dataset_test = DatasetONT(recfile=PATH_TEST, output_network_len=model_output_len)
dataloader_test = DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=True)

Device cpu


In [51]:
dataloader_test.batch_size

16

In [52]:
device.type

'cpu'

In [53]:
tester=BasecallerTester(
    model=model, 
    device=device,
    test_loader=dataloader_test,
    path_fasta="output/testing/basecalled_reads.fa"
)

In [54]:
# inference
accuracies, basecalled_signals = tester.accuracy_all_dataset()


Evaluating | Batch: 13/13: 100%|███████████████████████████████████████████████████████████| [00:02]


In [55]:
accuracies


[100.0,
 81.81818181818183,
 85.71428571428571,
 77.77777777777779,
 100.0,
 73.33333333333333,
 85.71428571428571,
 64.70588235294117,
 72.72727272727273,
 65.0,
 87.5,
 100.0,
 85.71428571428571,
 73.91304347826086,
 72.72727272727273,
 66.66666666666666,
 78.57142857142857,
 64.70588235294117,
 63.63636363636363,
 66.66666666666666,
 100.0,
 66.66666666666666,
 75.0,
 85.71428571428571,
 70.96774193548387,
 66.66666666666666,
 80.0,
 77.77777777777779,
 77.77777777777779,
 70.0,
 100.0,
 68.75,
 75.0,
 62.06896551724138,
 87.5,
 68.96551724137932,
 61.29032258064516,
 68.75,
 78.57142857142857,
 100.0,
 62.5,
 78.57142857142857,
 60.86956521739131,
 75.0,
 76.47058823529412,
 68.0,
 70.83333333333334,
 87.5,
 72.72727272727273,
 88.88888888888889,
 83.33333333333334,
 65.21739130434783,
 100.0,
 75.0,
 71.42857142857143,
 90.9090909090909,
 86.66666666666667,
 77.77777777777779,
 77.77777777777779,
 70.0,
 87.5,
 81.81818181818183,
 66.66666666666666,
 66.66666666666666,
 80.0,
 58.

In [56]:
basecalled_signals

['AGAGAGAGAAGAGAAGAGAGAGAAG',
 'GAAGAAAGAAGAGAAAAAAAAAGAAAG',
 'AGAAGAGAGAGAGAAGAGAGAGAGAGAGAGAGAGA',
 'GAAAGAGAGAAGAGAAAAAAAGAAAG',
 'GAGGGAGAAGAGAGGAGAGAGAGAGAGAGAG',
 'AGGAGGAGAGGGAAAGAGACGAGAGAGAGAGAGAGAGAAGAGAGGAGAAAAAGAG',
 'AGAGAAAGAGAGAAGAGAAAGAAAGAGAGAAGAGAGAGAGAGA',
 'AGAGACGAGAGAGGGAGAGAGGAAGGAGAGAGAG',
 'AGAGAGAAGAAAGGAAGAGAGAGAAAAG',
 'GAGAGAGGAGAGAGAGAAGAGAGAAGAAAA',
 'GAGAAGAGAAAGAAGAGAGAGAAGAGA',
 'GAGAGAGAGAGGAGAAGAGAAGAAGAGAGAGAAGAAG',
 'GAAAAAAAAAGAGAGAGAGAAAGAGAAGACGAGAGAGGG',
 'AGAAAGAGAGAGAGAGAAGAGAGAAGAGA',
 'GAGAGAAAGAGAGAGAGAAAGAGAAGAGAAGAGAAGAGAGAGAAG',
 'GAAGAGAGAGAAAAGAGAGAGAGAGAAGAGAAGAGAGAGAGAGAAGAGAA',
 'AGAGAGAAGGAGGGAGAAAGAGAGAGAGAGAAAAAAAGAAGAAG',
 'GAGAGAGAGAAAAAGAAGAGAGACGAGAGGAGGGAGAGAG',
 'GAGAGAAAGAGAGAGAAAGAAGAGAAAAAG',
 'AGAAAGAACGAGAAGAAAAA',
 'AAGAGAGAAGAAGAGAGAGAG',
 'AGAGAGAGAAAGAAGAGGAAAGAGAAAGAGGAGAGAAGAGAGA',
 'GAAGAGAGAGAGAGAGAAAGAGAGAGAGAAAGAGAA',
 'GAGAGAAGAGAAGAGAGAGACAGAAAGACGAGAAGAGAGAG',
 'AGAAAGGAGAGAAAGAGAGAAGAGAGAGAAGAGAAGAAGAG'

In [61]:
path_fasta = "../output/test/basecalled_signals.fa"
Path(path_fasta).parent.mkdir(exist_ok=True, parents=True)
with open(path_fasta, "a") as fp:
    for j,read in enumerate(basecalled_signals):
        fp.write(f">signal_{j}\n")
        fp.write(read + "\n")

In [None]:
seq, ref = ("ACGTACGTACGTACGAGCAT","ACGACTACGACTACACACAC")
result = parasail.sw_trace(seq, ref, 8, 4, parasail.dnafull)

In [12]:
rstart, new_cigstr = tester.parasail_to_sam(result, seq)

In [13]:
rstart, new_cigstr

(0, '4S3=1I2=1I5=4S')

In [8]:
batch = next(iter(dataloader_test))
model.eval()
X, y, output_len, target_len = (x.to(device) for x in batch)
preds = model(X)

In [9]:
X.shape, preds.shape, y.shape

(torch.Size([16, 1, 4096]), torch.Size([501, 16, 5]), torch.Size([16, 271]))

In [10]:
batch_signals = preds.detach().numpy()
print(batch_signals.shape)


(501, 16, 5)
