This notebook is the place from where the network can be run and tested. Everything relating to the setup can be done here, while "specialised" code should be delegated to its own python file. Ideally the process that is run through here will then later be adapted to a 'main' execution file in Python that can be run from the command line.

In [None]:
# Import the (probably) necessary imports.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision import transforms, utils

from skimage import io, transform

import os

# Probable project code structure
from project_code.utils import preprocessing
from project_code.data.zebrafish_data_module import *
from project_code.networks.rnn import *

In [None]:
# Setup tensorboard for easy debugging.
import tensorboard

# This might be a little different for Pytorch lightning.
# For one, the logs are stored in lightning_logs.
# For two, I don't know if we should still remove them in between.

%load_ext tensorboard
%tensorboard --logdir lightning_logs

# If you run this notebook locally, you can also access Tensorbaord at 127.0.0.1:6006 now.

# Clean up old logs.
if os.path.isdir('./lightning_logs/'):
  import shutil
  shutil.rmtree('lightning_logs/')

from torch.utils.tensorboard import SummaryWriter

# default 'log_dir' is "lightning_logs"
writer = SummaryWriter('lightning_logs')

In [None]:
# Run through the whole process using Pytorch Lightning.

# Initialise the model
model = MutationNet()

# The model needs to use double (instead of float)
model = model.double()

# Initialise the data.
data_module = ZebrafishDataModule(batch_size=1)

# Train the model.
trainer = pl.Trainer(max_epochs=25)
trainer.fit(model, data_module)

# Test the model
#trainer.test(datamodule=data_module)