In [1]:
import os 
import sys

sys.path.append('/hpc/compgen/users/mpages/babe/src')
sys.path.append('/hpc/compgen/users/mpages/babe/models')

from gridmodel import GridAnalysisModel
from normalization import normalize_signal_from_read_data, med_mad
from read import read_fast5, list_reads_ids
import uuid


import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Data

In [2]:
class BaseFast5Dataset(Dataset):
    """Base dataset class that iterates over fast5 files for basecalling
    """

    def __init__(self, 
        data_dir = None, 
        fast5_list = None, 
        recursive = True, 
        buffer_size = 100,
        window_size = 2000,
        window_overlap = 400,
        trim_signal = True,
        ):
        """
        Args:
            data_dir (str): dir where the fast5 file
            fast5_list (str): file with a list of files to be processed
            recursive (bool): if the data_dir should be searched recursively
            buffer_size (int): number of fast5 files to read 

        data_dir and fast5_list are esclusive
        """
        
        super(BaseFast5Dataset, self).__init__()
    
        self.data_dir = data_dir
        self.recursive = recursive
        self.buffer_size = buffer_size
        self.window_size = window_size
        self.window_overlap = window_overlap
        self.trim_signal = trim_signal

        if fast5_list is None:
            self.data_files = self.find_all_fast5_files()
        else:
            self.data_files = self.read_fast5_list(fast5_list)
    
    def __len__(self):
        return len(self.data_files)
    
    def __getitem__(self, idx):
        return self.process_reads(self.data_files[idx])

        
    def find_all_fast5_files(self):
        """Find all fast5 files in a dir recursively
        """
        # find all the files that we have to process
        files_list = list()
        for path in Path(self.data_dir).rglob('*.fast5'):
            files_list.append(str(path))
        files_list = self.buffer_list(files_list, self.buffer_size)
        return files_list

    def read_fast5_list(self, fast5_list):
        """Read a text file with the reads to be processed
        """

        files_list = list()
        with open(fast5_list, 'r') as f:
            for line in f:
                files_list.append(line.strip('\n'))
        files_list = self.buffer_list(files_list, self.buffer_size)
        return files_list

    def buffer_list(self, files_list, buffer_size):
        buffered_list = list()
        for i in range(0, len(files_list), buffer_size):
            buffered_list.append(files_list[i:i+buffer_size])
        return buffered_list

    def trim(self, signal, window_size=40, threshold_factor=2.4, min_elements=3):
        """

        from: https://github.com/nanoporetech/bonito/blob/master/bonito/fast5.py
        """

        min_trim = 10
        signal = signal[min_trim:]

        med, mad = med_mad(signal[-(window_size*100):])

        threshold = med + mad * threshold_factor
        num_windows = len(signal) // window_size

        seen_peak = False

        for pos in range(num_windows):
            start = pos * window_size
            end = start + window_size
            window = signal[start:end]
            if len(window[window > threshold]) > min_elements or seen_peak:
                seen_peak = True
                if window[-1] > threshold:
                    continue
                return min(end + min_trim, len(signal)), len(signal)

        return min_trim, len(signal)

    def chunk(self, signal, chunksize, overlap):
        """
        Convert a read into overlapping chunks before calling

        The first N datapoints will be cut out so that the window ends perfectly
        with the number of datapoints of the read.
        """
        if isinstance(signal, np.ndarray):
            signal = torch.from_numpy(signal)

        T = signal.shape[0]
        if chunksize == 0:
            chunks = signal[None, :]
        elif T < chunksize:
            chunks = torch.nn.functional.pad(signal, (chunksize - T, 0))[None, :]
        else:
            stub = (T - overlap) % (chunksize - overlap)
            chunks = signal[stub:].unfold(0, chunksize, chunksize - overlap)
        
        return chunks.unsqueeze(1)
    
    def normalize(self, read_data):
        return normalize_signal_from_read_data(read_data)

    def process_reads(self, read_list):
        """
        Args:
            read_list (list): list of files to be processed

        Returns:
            two arrays, the first one with the normalzized chunked data,
            the second one with the read ids of each chunk.
        """
        chunks_list = list()
        id_list = list()
        l_list = list()

        for read_file in read_list:
            reads_data = read_fast5(read_file)

            for read_id in reads_data.keys():
                read_data = reads_data[read_id]
                norm_signal = self.normalize(read_data)

                if self.trim_signal:
                    trim, _ = self.trim(norm_signal[:8000])
                    norm_signal = norm_signal[trim:]

                chunks = self.chunk(norm_signal, self.window_size, self.window_overlap)
                num_chunks = chunks.shape[0]
                
                uuid_fields = uuid.UUID(read_id).fields
                id_arr = np.zeros((num_chunks, 6), dtype = np.int)
                for i, uf in enumerate(uuid_fields):
                    id_arr[:, i] = uf
                
                id_list.append(id_arr)
                l_list.append(np.full((num_chunks,), len(norm_signal)))
                chunks_list.append(chunks)
        
        out = {
            'x': torch.vstack(chunks_list).squeeze(1), 
            'id': np.vstack(id_list),
            'len': np.concatenate(l_list)
        }
        return out

            


In [3]:
files_list = '/hpc/compgen/users/mpages/babe/doc/splits/human_task_test_reads.txt'

In [4]:
DS = BaseFast5Dataset(fast5_list= files_list, buffer_size = 10)

# Model

In [5]:
model_dir = '/hpc/compgen/projects/nanoxog/babe/analysis/mpages/models/grid_analysis'

In [6]:
task = 'human'
model_name = 'bonito_bonitorev_ctc_True_2000'

config = model_name.split('_')

cnn_type = config[0]
encoder_type = config[1]
decoder_type = config[2]
use_connector = config[3]
window_size = config[4]

In [7]:
# load model
log = df = pd.read_csv(os.path.join(model_dir, task, model_name, 'train.log'))
log = log[log['checkpoint'] == 'yes']
best_step = log['step'].iloc[np.argmax(log['metric.accuracy.val'])]
checkpoint_file = os.path.join(model_dir, task, model_name, 'checkpoints', 'checkpoint_' + str(best_step) + '.pt')

use_amp = True
scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
model = GridAnalysisModel(
    cnn_type = cnn_type, 
    encoder_type = encoder_type, 
    decoder_type = decoder_type,
    use_connector = use_connector,
    device = device,
    dataloader_train = None, 
    dataloader_validation = None, 
    scaler = scaler,
    use_amp = use_amp
)

state_dict = torch.load(checkpoint_file)
model.load_state_dict(state_dict['model_state'])
model = model.to(device)

In [8]:
class BaseBasecaller:
    """A base Basecaller class that is used to basecall complete reads
    """
    
    def __init__(self, dataset, model, batch_size, n_cores = 1):
        """
        Args:
            model (nn.Module): a model that has the following methods:
                predict, decode
            chunk_size (int): length of the chunks that a read will be divided into
            overlap (int): amount of overlap between consecutive chunks
            batch_size (int): batch size to forward through the network
        """
        
        self.dataset = DataLoader(dataset, batch_size=1, shuffle=False, num_workers = n_cores)
        self.model = model
        self.batch_size = batch_size

    def basecall(self, verbose = True):

        # iterate over the data
        for batch in tqdm(self.dataset, disable = not verbose):
            
            x = batch['x']
            l = x.shape[0]
            ss = torch.arange(0, l, self.batch_size)
            nn = ss + self.batch_size

            p_list = list()
            for s, n in zip(ss, nn):
                p = self.model.predict_step(x[s:n, :])
                p_list.append(p)

            return p_list
    
    
    def stich(self, chunks, method, *args, **kwargs):
        """
        Stitch chunks together with a given overlap
        
        Args:
            chunks (tensor): predictions with shape [samples, length, classes]
        """

        if method == 'stride':
            return self.stich_by_stride(chunks, *args, **kwargs)
        else:
            raise NotImplementedError()
    
    def stitch_by_stride(self, chunks, chunksize, overlap, length, stride, reverse=False):
        """
        Stitch chunks together with a given overlap
        
        This works by calculating what the overlap should be between two outputed
        chunks from the network based on the stride and overlap of the inital chunks.
        The overlap section is divided in half and the outer parts of the overlap
        are discarded and the chunks are concatenated. There is no alignment.
        
        Chunk1: AAAAAAAAAAAAAABBBBBCCCCC
        Chunk2:               DDDDDEEEEEFFFFFFFFFFFFFF
        Result: AAAAAAAAAAAAAABBBBBEEEEEFFFFFFFFFFFFFF
        
        Args:
            chunks (tensor): predictions with shape [samples, length, classes]
            chunk_size (int): initial size of the chunks
            overlap (int): initial overlap of the chunks
            length (int): original length of the signal
            stride (int): stride of the model
            reverse (bool): if the chunks are in reverse order
            
        Copied from https://github.com/nanoporetech/bonito
        """
        if chunks.shape[0] == 1: return chunks.squeeze(0)

        semi_overlap = overlap // 2
        start, end = semi_overlap // stride, (chunksize - semi_overlap) // stride
        stub = (length - overlap) % (chunksize - overlap)
        first_chunk_end = (stub + semi_overlap) // stride if (stub > 0) else end

        if reverse:
            chunks = list(chunks)
            return torch.cat([
                chunks[-1][:-start], *(x[-end:-start] for x in reversed(chunks[1:-1])), chunks[0][-first_chunk_end:]
            ])
        else:
            return torch.cat([
                chunks[0, :first_chunk_end], *chunks[1:-1, start:end], chunks[-1, start:]
            ])

In [9]:
BB = BaseBasecaller(DS, model, 64)

In [10]:
for batch in tqdm(BB.dataset, disable = not True):
    
    x = batch['x'].squeeze(0)
    l = x.shape[0]
    ss = torch.arange(0, l, BB.batch_size)
    nn = ss + BB.batch_size

    p_list = list()
    for s, n in zip(ss, nn):
        p = BB.model.predict_step({'x':x[s:n, :]})
        p_list.append(p)
        
    p = torch.hstack(p_list)

    ids = batch['id'][0]
    ids_arr = np.zeros((ids.shape[0], ), dtype = 'U32')
    for i in range(ids.shape[0]):
        ids_arr[i] = str(uuid.UUID(fields=ids[i].tolist()))

    read_stacks = dict()
    read_lens = dict()
    for id in np.unique(ids_arr):
        w = np.where(ids_arr == id)[0]
        read_stacks[id] = p[:, w, :].permute(1, 0, 2).exp()
        read_lens[id] = batch['len'][0, w[0]].item()
        
    break

  0%|                                                                                                                                                              | 0/2500 [00:23<?, ?it/s]


In [11]:
kid = '4afc223a-939c-4e72-afe7-45801f55'

In [12]:
read_stacks[kid].shape

torch.Size([17, 400, 5])

In [13]:
stiched_p = BB.stitch_by_stride(read_stacks[kid], chunksize = 2000, overlap = 400, length = read_lens[kid], stride = 5, reverse=False)

In [14]:
from fast_ctc_decode import viterbi_search

seq, path = viterbi_search(stiched_p.cpu().numpy(), 'NACGT', qstring = True, qscale = 1.0, qbias = 0.0, collapse_repeats = True)

In [17]:
seq = model.decode(stiched_p.unsqueeze(1).cpu().numpy(), greedy = True, qstring = True, collapse_repeats = True, return_path = True)

In [20]:
len(seq[0][1])

2852

In [21]:
seq[0][0][:2852]

'TAATGGTCCCTTGGATAAATATCCTCAGATTGCTTTTGGGAAAGTTTCCCTAACCACTGCTGTTTACCTGTTTGCTACACTGCAAATCCATTAATTTTTAGTTTTGAAAATGCAGTTTTCTTGAAACCCTGATGCTTCAATCATTCAGCGATCTAGTTTAGACAAGGCCTTCCAAATGGGGGAGGGGCAAGCTGCGGCGGCTTCGGGCTGGATTTGGGGGCAGGAGGGTTAAAGTACCTGAGGGATAAAGCAGTGGTGACATTAATGATACTTTTGGATAAAAAAGAAAAGCCAAGAACAAATACGTAGGGGTTCCCACTCCCTTTTCTTTGATTCCATCAGGGTAGAAACACAGCTCTGTGTGCCTTTGAGATCTTTCACTGGTGCTTCTTTACCCTTTGCTCATCAAGGGCTGACCATGTTTTGTGGGTTTTTGTTCATTTAGATAATCAATCTAAGATTATCTCAGGGCAATGGGATACCCAAGCAATGGTGGTATATGTGCCTATAACTTGATACTGGTGGGTGGGGGTCTGGTAACCCTAACTTGTCCTCTAGCAAGACCTCTGATGGGTTTCATGCTCCCAGAAGGTTTAGTCTCTAGCTTTTCTAGATTGCTTCCAATTCAAACCAACCCCCGTATCTTCTTGCTTTCCAACAGGGCACTCCTAAGCACCCTCTAGAGACTCCTCAATAGAGAAGATTAAAGAATGAGAAATTCAAAAGTTGAAAAAAACTCAAAAGGAATTAGGATGCAGAAGAAGAAAGAAAAAATATTTACTACTAAGCTTTGAATTCTCAATGTCAGCCCTGAAAAGTTATTGCCAGCAAAAAAAAAAGCCCCTTTCTAAGAGACATAAATTTTTGCAAACCCCTATTGCATTTTGATGGGATTGAGAGGCAGACTGTGGGACCAAAGAAGTATTTACTAATCAGTCCCAAATGGGGGGCAGTTCCAACTTTCCCTTGGGGTATGTTTTTCTCTCAATTCTCCTCAAT

In [24]:
seq[0][0][2852:]

"&'&*&)&)'*+-%(%$(,+.*1./.0(-124&((62*1655:6&54++3'/40130.'&%(%(+%%.+0,04)(0+(/1-*00(5033%-32''(+,,++(141/,%2402.41'0576)11&674&,355..'371*3('$(./&&'$(%('++(-0.*'*1+%(&+%%&'**5-4.&)(%-(*)&-*/,750*-+.%.&'(&&)(),37.46,0250,(*)31'0-12&%&&%'4+77)'')+()04).(+-74-2,-28,745*)+)367:+5:8-()('&$.*),44520%)1.*153'&$+((+2+-&,%%%$1,,&)251-)53+()682:.;)260)/0*-(.(((+)'&%'%'%&*,*-0.*''%/'87871(&0.*0)66)61+&&&&%'&$&%%$.0*')-,/('%'&-86765&(4,)(&/-(%%*.($&&$'+43114.%%&''/-(%('&1-(.%%,'(,(1*%%,)%/./,(%)0/0).'-+./0)$%'$(%$)'1.4)*/'(,&)'(''%&').-,,5*.01+2/(&'+-(*&(((+,)/)1-444.-))-'(***.76(4(+*6/-.5.0+086,65521+646(%%'.,22%23.)&.,*,)'(%*(,+%14(1.2(,)%$&$+).'.,+$/*/161/-5.2)-%&0-22.$',-,.,221'3),+'*.1*33(3-&%&$%%1'.0*,36$4:=(()*+51693%/0''32(.4.,,-/*%&-%1()%&&**(16/*64*-/-(.*)*)&*+..//*0..2(&+(33(+'+)656+8799.*474)4776*.*+'2(/)*2=9146076'**&'(''&'+.+-351'3720)2/(.*'.*+$*,',*)$$&.0)%-5$/+.10,*887'444025)/4-(/'115+44+,,(')1),++(%%,+%&%&&&)+''&(+)/3/.%9741(-..+%,(&..,$2&222./1')(&&)).+3,.00+'('&&%+)/1-,)*-0,31