In [None]:
import datetime
import os
import pandas as pd
import h5py
import numpy as np
from multiprocessing import Pool
from multiprocessing.pool import ThreadPool
from functools import partial, reduce
from collections import deque
from IPython.core.debugger import set_trace
from tensorflow.keras.utils import Sequence
import matplotlib.pyplot as plt

labelBaseMap = {
    0: "A",
    1: "C",
    2: "G",
    3: "T"
}

possible_filenames = ["/mnt/sdb/taiyaki_mapped/mapped_umi16to9.hdf5",
                      "/mnt/nvme/taiyaki_aligned/mapped_umi16to9.hdf5",
                      "/Users/felix/MsC/DNA/mapped_umi16to9.hdf5"]

for filename in possible_filenames:
    if os.path.isfile(filename):
        the_filename = filename
        print(f"Using {filename}")
        break
else:
    the_filename = ""
    print("Error, no filename valid!")

RNN_LEN = 300

In [None]:
class PrepData(Sequence):
    
    def __init__(self, filename, train_validate_split=0.8, min_labels=5):
        self.filename = filename
        self.train_validate_split=train_validate_split
        self.min_labels=min_labels
        self.pos = 0
        self.test_gen_data = ([],[])
        self.last_train_gen_data = ({},{})
        self.max_label_len = 50
        with h5py.File(filename, 'r') as h5file:
            self.readIDs = list(h5file['Reads'].keys())
        self.raw = []
            
    def get_len(self):
        return len(self.readIDs)
    
    def get_max_label_len(self):
        return self.max_label_len
        
    def normalise(self, dac):
        dmin = min(dac)
        dmax = max(dac)
        return [(d-dmin)/(dmax-dmin) for d in dac]
    
    def processRead(self, readID):
        train_X = []
        train_y = []
        test_X  = []
        test_y  = []
        with h5py.File(self.filename, 'r') as h5file:
            DAC = list(self.normalise(h5file['Reads'][readID]['Dacs'][()]))
            RTS = deque(list(h5file['Reads'][readID]['Ref_to_signal'][()]))
            REF = deque(h5file['Reads'][readID]['Reference'][()])
            self.raw = DAC
            
        train_validate_split = round(len(REF)*(1-self.train_validate_split))
        curdacs  = deque( [[x] for x in DAC[RTS[0]:RTS[0]+RNN_LEN-5]], RNN_LEN )
        curdacts = RTS[0]+RNN_LEN-5
        labels  = deque([])
        labelts = deque([])

        while RTS[0] < curdacts:
            labels.append(REF.popleft())
            labelts.append(RTS.popleft())


        while curdacts+5 < RTS[-1]-RNN_LEN:
            curdacs.extend([[x] for x in DAC[curdacts:curdacts+5]])
            curdacts += 5
            
            while RTS[0] < curdacts:
                labels.append(REF.popleft())
                labelts.append(RTS.popleft())
                
            while len(labelts) > 0 and labelts[0] < curdacts - RNN_LEN:
                labels.popleft()
                labelts.popleft()

            if len(labels) > self.min_labels:
                if len(RTS) > train_validate_split:
                    train_X.append(list(curdacs))
                    train_y.append(list(labels))
                else:
                    test_X.append(list(curdacs))
                    test_y.append(list(labels))

        return train_X, train_y, test_X, test_y
    
    
    def train_gen(self, full=True):
        while self.pos < len(self.readIDs):
            print(f"Processing {self.pos}")
            train_X, train_y, test_X, test_y = self.processRead(self.readIDs[self.pos])
            self.pos += 1
            
            train_X = np.array(train_X) if full else np.array(train_X[:100])
            train_y = np.array(train_y) if full else np.array(train_y[:100])
            test_X  = np.array(test_X) if full else np.array(test_X[:100])
            test_y  = np.array(test_y) if full else np.array(test_y[:100])
            self.test_gen_data = (test_X, test_y)
            
            train_X_lens = np.array([[95] for x in train_X], dtype="float32")
            train_y_lens = np.array([[len(x)] for x in train_y], dtype="float32")
#             maxlen = max([len(r) for r in train_y])
            train_y_padded = np.array([r + [5]*(self.get_max_label_len()-len(r)) for r in train_y], dtype='float32')
            X = {'the_input': train_X,
                      'the_labels': train_y_padded,
                      'input_length': train_X_lens,
                      'label_length': train_y_lens,
                      'unpadded_labels' : train_y
                      }
            y = {'ctc': np.zeros([len(train_X)])}
            self.last_train_gen_data = (X, y)
            yield (X, y)
        
    def test_gen(self):
        while True:
            tgd, self.test_gen_data = self.test_gen_data, ([],[])
            yield tgd
            
            
    def __len__(self):
        return len(self.readIDs)

    def __getitem__(self, idx):
        return next(self.train_gen())
    
prepData = PrepData(filename)

In [None]:
next(prepData.train_gen())
a = np.array(prepData.raw)
print(a.shape)
print(a.mean())
mn = a.mean()
a = a - mn
print(a.mean())

In [None]:
from scipy import signal
f, t, Sxx = signal.spectrogram(a[500:1000], nperseg=10, noverlap=8)
print(Sxx.shape)

In [None]:
sxxmax = max(map(max, Sxx))
sxxmin = min(map(min, Sxx))
print(sxxmax, sxxmin)

In [None]:
import matplotlib
import matplotlib.colors as colors
cmap = matplotlib.cm.gray.reversed()
plt.figure(figsize=(20,4))
plt.pcolormesh(t, f, Sxx, cmap=cmap, norm=colors.LogNorm(vmin=sxxmin, vmax=sxxmax))
plt.ylabel('Frequency [Hz]')
plt.xlabel('Time [sec]')
plt.show()