This notebook is the place from where the network can be run and tested. Everything relating to the setup can be done here, while "specialised" code should be delegated to its own python file. Ideally the process that is run through here will then later be adapted to a 'main' execution file in Python that can be run from the command line.

In [None]:
# Import the (probably) necessary imports.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import transforms, utils

from skimage import io, transform

import os

# Probable project code structure
from project_code.utils import preprocessing
from project_code.data.zebrafish_data_module import *
from project_code.networks.rnn import *

In [None]:
# Setup tensorboard for easy debugging.
import tensorboard

# This might be a little different for Pytorch lightning.
# For one, the logs are stored in lightning_logs.
# For two, I don't know if we should still remove them in between.

%load_ext tensorboard
%tensorboard --logdir lightning_logs

# If you run this notebook locally, you can also access Tensorbaord at 127.0.0.1:6006 now.

# Clean up old logs.
if os.path.isdir('./lightning_logs/'):
  import shutil
  shutil.rmtree('lightning_logs/')

from torch.utils.tensorboard import SummaryWriter

# default 'log_dir' is "lightning_logs"
writer = SummaryWriter('lightning_logs')

In [None]:
# We want to define a simple DataModule and two Datasets
# the Datasets will contain generated sequences of white
# noise with different means. The idea is that the model
# should be able to differentiate between them and update
# their belief system. If we can see this, then we know
# that the model works.

import torch
from torch.utils.data import Dataset

import pytorch_lightning as pl

class TestDataset(Dataset):
    
    def __init__(self, 
                 mean1=1,
                 mean2=5,
                 num_samples=1000,
                 sequence_length=1000):
        self.mean1 = mean1
        self.mean2 = mean2
        self.num_samples = num_samples
        self.sequence_length = sequence_length
        
        # Generate the sequences and put them in an array.
        self.sequences = []
        self.sequences_target = []
        output_shape = (self.sequence_length, 1)
        for i in range(self.num_samples):
            # Make sure that half is of one mean and half of the other.
            if i % 2 == 0:
                sequence = np.random.normal(mean1, 1, (self.sequence_length, 1))
                target = np.ones(output_shape)
            else:
                sequence = np.random.normal(mean2, 1, (self.sequence_length, 1))
                target = np.zeros(output_shape)
            
            # Use a low-pass filter on the noise such that
            # the confidence values also don't change as
            # quickly (and we can hopefully see better
            # integration over time).
            sequence = np.convolve(np.squeeze(sequence),
                                   np.ones(10).T/10,
                                   mode='same')
            sequence = np.expand_dims(sequence, axis=1)

            self.sequences.append(sequence)
            self.sequences_target.append(target)

            if i < 2:
                plt.figure()
                plt.plot(sequence)
                plt.show()
                
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return (self.sequences[idx], self.sequences_target[idx])

class TestDataModule(pl.LightningDataModule):
   
    def __init__(self, 
                 mean1=1,
                 mean2=5, 
                 batch_size=32, 
                 num_samples=1000, 
                 sequence_length=1000):
        super().__init__()
        self.batch_size = batch_size
        self.mean1 = mean1
        self.mean2 = mean2
        self.num_samples = num_samples
        self.sequence_length = sequence_length
    
    def setup(self, stage = None):
        # Make assignments here (val/train/test split).
        # Called on every process in Distributed Data Processing.
        
        dataset = TestDataset(self.mean1, 
                              self.mean2, 
                              self.num_samples, 
                              self.sequence_length)
        
        # Like in the RestingDataset we go for a 80, 10, 10 split.
        num_train = round(self.num_samples * 0.8)
        num_val = round(self.num_samples * 0.1)
        num_test = round(self.num_samples * 0.1)
        
        # We could be missing samples due to rounding.
        # In that case we add it to the test set.
        num_test = num_test + self.num_samples \
                   - (num_train + num_val + num_test)
        
        self.train, self.val, self.test = \
                torch.utils.data.random_split(dataset, \
                [num_train, num_val, num_test])
        
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

In [None]:
# Run through the whole process using Pytorch Lightning.

# Initialise the model
model = MutationNet()

# The model needs to use double (instead of float)
model = model.double()

# Initialise the data.
data_module = ZebrafishDataModule(batch_size=1, sampling_frequency=100)
#data_module = TestDataModule(mean1=1, mean2=1.50, num_samples=10000, sequence_length=100)

# Train the model.
trainer = pl.Trainer(max_epochs=25)
#trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(model, data_module)

# Test the model
#trainer.test(datamodule=data_module)

In [None]:
# Reintialise trainer such that the logs will be stored in a new version.
trainer = pl.Trainer()
trainer.test(model, datamodule=TestDataModule(mean1=1, mean2=1.35))