In [None]:
import time
import pandas as pd
import h5py
import numpy as np
from multiprocessing import Pool
from functools import partial, reduce


import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, MaxPooling1D, Flatten, Conv1D, CuDNNLSTM, Softmax
from tensorflow.nn import ctc_loss
from tensorflow.keras.callbacks import TensorBoard
import numpy as np

# labelBaseMap = {
#     0: "A",
#     1: "C",
#     2: "G",
#     3: "T"
# }

filename = "/mnt/nvme/taiyaki_aligned/mapped_umi16to9.hdf5"

RNN_LEN = 200
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.5)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))

In [None]:
with h5py.File(filename, 'r') as h5file:
    readIDs = list(h5file['Reads'].keys())
    print(f"{len(readIDs)} reads, keys: {list(h5file['Reads'][readIDs[0]].keys())}")

In [None]:
def processRead(readID, filename):
    data = []
    with h5py.File(filename, 'r') as h5file:
        DAC = list(h5file['Reads'][readID]['Dacs'][()])
        RTS = list(h5file['Reads'][readID]['Ref_to_signal'][()])
        REF = list(h5file['Reads'][readID]['Reference'][()])
    for rtsidx in range(len(RTS)-1):
        # Add to dataset in increments of 5 until too close to the next rtsidx
        # Or not enough Dacs left
        i = RTS[rtsidx]

        #make the labels iteratively
        labels = []
        l = rtsidx
        while RTS[l] < i + RNN_LEN and l < len(REF):
            labels.append(REF[l])
            l += 1

        while i < (RTS[rtsidx+1] - 5) and (i + RNN_LEN) < len(DAC):
            # check if we should include another label
            while RTS[l] <= i + RNN_LEN and l < len(REF):
                labels.append(REF[l])
                l += 1
            data.append([
                    DAC[i:(i+RNN_LEN)],
                    labels
            ])
            i += 5
    return data

# pp = partial(processRead, filename=filename)
# pp(readIDs[0])

In [None]:
%%time
pool = Pool(16)
results_prim = pool.map(partial(processRead, filename=filename), readIDs[:16])
pool.close()
pool.join()

In [None]:
results = []
for res in results_prim:
    results.extend(res)

In [None]:
def normalise_dacs(dac):
    dmin = min(dac)
    dmax = max(dac)
    return [[(d-dmin)/(dmax-dmin)] for d in dac]

# to test without CTC
def ohe(v):
    tr = np.array([0,0,0,0])
    tr[v] = 1
    return tr


In [None]:
X = np.array([normalise_dacs(r[0]) for r in results])
y = np.array([r[1] for r in results])
simple_y = np.array([ohe(yy[-1]) for yy in y])

In [None]:
print(X[0][0])
print(y[0])
print(simple_y[0])
X.shape

# HERE COME DAT ML

In [None]:
model = Sequential()
model.add(CuDNNLSTM(32,return_sequences=True))
model.add(Conv1D(32, 3,
          padding="valid",
          activation="relu", 
          input_shape=X[0].shape))
model.add(Conv1D(32, 10,
          padding="valid",
          activation="relu"))
model.add(Flatten())
model.add(Dense(32, activation="relu"))
model.add(Dense(4, activation="softmax"))
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=['accuracy'])

model.fit(x=X, y=simple_y, batch_size=10, epochs=2, validation_split=0.1)

In [None]:
p = model.predict(X[:10])
p.shape