In [None]:
from itertools import zip_longest
import warnings

from keras.models import Sequential
from keras.layers import Dense, Dropout, Lambda
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn.utils

%matplotlib inline

In [None]:
warnings.simplefilter('ignore', UserWarning)

In [None]:
def build_model(input_dim, mean=None, variance=None):
    model = Sequential()
    #model.add(Dense(1, activation='sigmoid', input_dim=batch_feature_count))
    #if mean is not None and variance is not None:
    #    model.add(Lambda(lambda x: (x - mean) / variance, input_shape=(input_dim,)))
    model.add(Dense(64, activation='relu', input_dim=input_dim))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])
    return model

In [None]:
def train_and_validate_model(model, epochs, steps_per_epoch, training_generator, training_args, validation_generator, validation_args):
    """

    """

    epoch_training_generator = training_generator(**training_args)
    epoch_validation_generator = validation_generator(**validation_args)

    history = model.fit_generator(
        generator=epoch_training_generator,
        validation_data=epoch_validation_generator,
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        validation_steps=steps_per_epoch
    )
    print(history.history)


In [None]:
def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return zip_longest(*args, fillvalue=fillvalue)

In [None]:
def load_kmer_batches_combined_h5(bacteria_dset, virus_dset, bacteria_subsample, virus_subsample, half_batch_size):
    """
    Return batches of input and labels from a combined H5 file.

    The returned data is shuffled. This is very important. If the
    batches are returned with the first half bacteria data and
    the second half virus data the models train 'almost perfectly'
    and evaluate 'perfectly'.

    Arguments:
        training_testing_fp: file path of combined training and testing data
        bacteria_dset:       H5 dataset bacteria training data
        virus_dset:          H5 dataset virus training data
        half_batch_size:     each batch will have half_batch_size bacteria samples and half_batch_size virus samples

    Yield:
        (batch, labels) tuple of (half_batch_size*2, features) and (half_batch_size*2, 1) labels
    """
    print('reading bacteria dataset "{}" with shape {}'.format(bacteria_dset.name, bacteria_dset.shape))
    print('reading virus dataset "{}" with shape {}'.format(virus_dset.name, virus_dset.shape))

    batch_size = half_batch_size * 2
    batch_count = min(len(bacteria_subsample) // half_batch_size, len(virus_subsample) // half_batch_size)
    print('{} batches of {} samples will be yielded in each epoch'.format(batch_count, batch_size))
    
    # bacteria label is 0
    # virus label is 1
    labels = np.vstack((np.zeros((half_batch_size, 1)), np.ones((half_batch_size, 1))))

    # this is a never ending generator
    epoch = 0
    while True:
        epoch += 1
        bacteria_sample_groups = grouper(bacteria_subsample, n=half_batch_size)
        virus_sample_groups = grouper(virus_subsample, n=half_batch_size)

        # note that zip will terminate when it has depleted the shortest of the input iterators
        # this is the behavior we want since it happens that some virus testing sets are shorter
        # than their associated bacteria testing sets
        for bacteria_group, virus_group in zip(bacteria_sample_groups, virus_sample_groups):
            # H5 wants a list index to be in ascending order
            batch = np.vstack((bacteria_dset[sorted(bacteria_group), :], virus_dset[sorted(virus_group), :]))
            yield sklearn.utils.shuffle(batch, labels)
        print('generator epoch {} has ended'.format(epoch))


In [None]:
# go!
train_test_fp = '../data/training_testing.h5'
batch_size = 50

with h5py.File(train_test_fp) as train_test_file:
    mean_dset = train_test_file['/clean-bact-vir/training1/extract/kmers/kmer_file1/mean']
    variance_dset = train_test_file['/clean-bact-vir/training1/extract/kmers/kmer_file1/variance']

    mean = np.zeros(mean_dset.shape)
    variance = np.zeros(variance_dset.shape)

    mean_dset.read_direct(mean)
    variance_dset.read_direct(variance)
    variance[variance == 0.0] = 1.0
    
    model = build_model(input_dim=mean_dset.shape[1], mean=mean, variance=variance)

    bacteria_dset = train_test_file['/clean-bact/training1/extract/kmers/kmer_file1']
    virus_dset = train_test_file['/clean-vir/training1/extract/kmers/kmer_file1']

    training_sample_count = (bacteria_dset.shape[0] // 2) + (virus_dset.shape[0] // 2)
    batches_per_epoch =  training_sample_count // batch_size
    print('{} training samples'.format(training_sample_count))
    print('batch size is {}'.format(batch_size))
    print('{} batches in training data'.format(batches_per_epoch))
    
    epochs = 20
    steps_per_epoch = 100
    print('{} epochs = {} training samples'.format(epochs, epochs * steps_per_epoch * batch_size))
    
    history = model.fit_generator(
        generator=load_kmer_batches_combined_h5(
            bacteria_dset=bacteria_dset,
            bacteria_subsample=np.random.permutation(bacteria_dset.shape[0] // 2),
            virus_dset=virus_dset,
            virus_subsample=np.random.permutation(virus_dset.shape[0] // 2),
            half_batch_size=batch_size // 2
        ),
        # there is no advantage to permuting the validation samples
        # and there may be a speed advantage to reading them in order
        validation_data=load_kmer_batches_combined_h5(
            bacteria_dset=bacteria_dset,
            bacteria_subsample=np.arange(bacteria_dset.shape[0] // 2) + (bacteria_dset.shape[0] // 2),
            virus_dset=virus_dset,
            virus_subsample=np.arange(virus_dset.shape[0] // 2) + (virus_dset.shape[0] // 2),
            half_batch_size=batch_size // 2
        ),
        epochs=epochs,
        steps_per_epoch=steps_per_epoch,
        validation_steps=steps_per_epoch,
        workers=2
    )
    #print(history.history)

In [None]:
training_performance_df = pd.DataFrame(data=history.history, index=range(1, epochs + 1))
training_performance_df.index.name = 'epoch'

In [None]:
training_performance_df.head()

In [None]:
plt.figure()
plt.plot(training_performance_df.index, training_performance_df.loss, training_performance_df.val_loss)
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['training', 'validation'])

In [None]:
plt.figure()
plt.plot(training_performance_df.index, training_performance_df.acc, training_performance_df.val_acc)
plt.title('Training and Validation Accuracy')
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.legend(['training', 'validation'])