# Program 4
Adapted from previous iteration by Parth.

## Base Setup

This section contains the basic environment set up for this notebook, including imports, constants, and any variable that needs to be easily accessed for changing.

In [None]:
#Import modules
import torch
import torchvision
import cv2
import datetime
import os
import re
import shutil
import gc
import platform
import time

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn as nn

from torchvision import models, transforms
from torchvision.datasets import VisionDataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image
from ast import literal_eval
from ipylab import JupyterFrontEnd
from IPython.display import clear_output

This is a set of constants used mainly for workspace setup.

`CHECKPOINT`: The base name of the checkpoint file(s) being used to make predictions on the sample image(s). Current version of the program expects "epoch_" or some other similar 6 characters at the end when not using multiple epochs for predictions.\
`OUTPUT_DIR`: The absolute path from the current directory to the directory to be used for output files made by this notebook. <b>Note:</b> The directory structure may already exist, but it does not need to. A later function will make it if it does not exist.\
`SAMPLE_DIR`: The relative or absolute path to the directory containing the sample image(s) having predictions made on them.\
`SAMPLE_IMAGE_DICT`: A dictionary containing keys, which are the complete names of the sample image files, and their corresponding values, which are the complete names of the annotation files associated with those sample images. Currently, used mainly for testing purposes.\
`COLOR_PALETTE`: A list containing colors to be used when visually plotting the prediction heatmaps. Each class has its own color.\
`NOTEBOOK_NAME`: The exact name of this notebook, including the file extension. Needed later for programmatic html conversion and copying of the notebook.\
`SAVED_FILES`: Not technically a constant, but should not be altered by user. Used to keep track of any non-checkpoint file that gets saved to later move to output.\
`RUNTIMES`: Not technivally a constant, but should not be altered by use. Used to track the prediction and total time it takes for each image being processed.\
`APP`: JupyterFrontEnd instance that is used to save the notebook programmatically later.

In [None]:
#Define constants
CHECKPOINT = "ResNet34_Small_Follicles_Epoch_"
OUTPUT_DIR = "Output/Small Follicles/ResNet34/Run Final 1/"
SAMPLE_DIR = "../../Data/Original/"
SAMPLE_IMAGE_DICT = {
    "14736_UN_050a.ome.tif": "14736_UN_050a.annotations.txt",
    "16418_UN_140b.ome.tif": "16418_UN_140b.annotations.txt",
    "19006_UN_020a.ome.tif": "19006_UN_020a.annotations.txt",
    "21930_LT_060a.ome.tif": "21930_LT_060a_annotationsTable.txt",
    "21930_LT_120b.ome.tif": "21930_LT_120b_annotationsTable.txt",
    "25058_LT_005a.ome.tif": "25058_LT_005a.annotations.txt",
    "25065_LT_010a.ome.tif": "25065_LT_010a.annotations.txt",
    "25081_LT_010a.ome.tif": "25081_LT_010a.annotations.txt",
    "27570_UN_110a.ome.tif": "27570_UN_110a.annotations.txt",
    "30381_RT_070b.ome.tif": "30381_RT_070b.ome.annotationsTable.txt",
    "30381_RT_140b.ome.tif": "30381_RT_140b.ome.annotationsTable.txt",
    "30381_RT_200c.ome.tif": "30381_RT_200c.ome.annotationsTable.txt",
    "32002_RT_050a.ome.tif": "32002_RT_050a.ome.annotationsTable.txt",
    "32002_RT_110b.ome.tif": "32002_RT_110b.ome.annotationsTable.txt",
    "32002_RT_160c.ome.tif": "32002_RT_160c.ome.annotationsTable.txt",
    "33564_RT_060a.ome.tif": "33564_RT_060a.ome.annotationsTable.txt",
    "33564_RT_120b.ome.tif": "33564_RT_120b.ome.annotationsTable.txt",
    "33564_RT_180b.ome.tif": "33564_RT_180b.ome.annotationsTable.txt",
    "DP28_25081_Section3_10X_ome_copy.tif": "DP28_25081_Section3_10X_ome_copy.annotations.txt",
    "32002_LT_180a.ome.tif": "32002_LT_180a.ome.annotationsTable.txt",
    "KY_PS_LB40SDwk16601_7_a.ome.tif": "LB40_SDwk16601_7a_annotationsTable.txt"
}
COLOR_PALETTE = ['white', 'red', 'gold', 'blue', 'green', 'darkviolet', 'dimgray']
NOTEBOOK_NAME = "Program 4 - Final v2.ipynb" #Make sure this is identical to the name of THIS notebook
SAVED_FILES = [] #Leave as an empty list
RUNTIMES = {} #Leave as an empty dictionary
APP = JupyterFrontEnd() #Needed to save the notebook programmatically later, do not change.

These are variables and flags for functions that get used later.

`transform`: The set of torchvision transforms to be applied to the images before they are input into the model before prediction. <b>Note:</b> This variable should be identical to the one in Program 3 that was used to train the model, otherwise the predictions will be very incorrect.\
`batch_size`: The amount of images per DataLoader batch. Heavily affects VRAM usage. Speed testing has determined 128 to be optimal for my hardware.\
`num_workers`: The amount of workers for the DataLoader to use to parallelize training. Affects system RAM usage. Speed testing has determined 8 to be optimal for my hardware. <b>Note:</b> Windows is incapable of parallelizing Jupyter Notebooks like this; therefore, this variable will be set to 0 if on Windows.\
`freeze_model`: Flag to determine whether or not to freeze all layers of the model except the final layer. Testing has shown better prediction performance with this set to False. <b>Note:</b> This should be set to the same value as was used in Program 3 to train the model.\
`num_classes`: The number of output classes to be added to the new final layer of the model. <b>Note:</b> This should be the same as the number of output classes used in Program 3 when training the model.\
`do_entire_image`: Flag to determine whether to make predictions over an entire image or a small section. <b>Warning:</b> If this is False, any image that will be predicted on needs to have both a row and collumn slice. The program will throw a reminder error if this is not done.\
`row_slices`: Dictionary whose keys are sample image file names, like `SAMPLE_IMAGE_DICT`. The values are lists containing ranges that denote the slices of rows to be predicted on for a particular slice on the image in the key. Only used if `do_entire_image` is False. <b>Note:</b> Matches rows to collumns based on location in the list.\
`col_slices`: Dictionary whose keys are sample image file names, like `SAMPLE_IMAGE_DICT`. The values are lists containing ranges that denote the slices of collumns to be predicted on for a particular slice on the image in the key. Only used if `do_entire_image` is False. <b>Note:</b> Matches collumns to rows based on location in the list.\
`do_certain_images`: Flag to determine whether to do every image contained in `SAMPLE_IMAGE_DICT`. If True, only images in `images_to_do` will have predictions made. Otherwise, all sample images are done.\
`images_to_do`: A list containing the full sample image file names as strings of every sample image you want to be predicted on. <b>Warning:</b> If `do_entire_image` is False and `do_certain_images` is True, any image in this list must have row and collumn slices denoted in `row_slices` and `col_slices` respectively.\
`window_size`: The size of one side of the square window to use for predictions. <b>Note:</b> Default is 200. Should be the same number used in Program 1 to generate the training images.\
`window_radius`: The window radius to use when making predictions. Half of the `window_size`. <b>Note:</b> If the next variable, `image_scaling_testing` is `True`, this value will be multiplied by the value in `multiplier` as well.\
`image_scaling_testing`: Flag to determine whether to scale the image when making predictions. Testing has found that scaling an image down speeds up prediction times at the cost of accuracy due to the data loss involved with lowering resolution. <b>Deprecated:</b> Pixel skip does better in both time and accuracy metrics. This functionality may or may not even work.\
`multiplier`: The multiplier to use when scaling an image. >1 for up-scaling. >0 and <1 for down-scaling.\
`enable_pixel_skip`: Flag to determine whether to use pixel-skipping when making predictions. Testing has found it significantly speeds up prediction times with minimal cost to accuracy as long as `pixel_skip` is small.\
`pixel_skip`: The number of pixels to skip both horizontally and vertically. 5 is what we currently use. <b>Warning:</b> Large numbers will lose prediction accuracy as entire smaller follicles could be missed and prediction blobs become blockier.\
`save_figs`: Flag to determine whether to save any extra figures created in this notebook for analysis purposes.\
`save_dataframes`: Flag to determine whether to save the prediction dataframes that are created in this program. <b>Note:</b> Requires `save_figs` to be True for this to have an effect.\
`use_amp`: Flag to determine whether or not to use PyTorch's Automatic Mixed Precision. Can significantly boost performance with minimal cost to calculation accuracy.\
`multiple_epochs`: Flag to determine if predictions are going to be made using checkpoints from every epoch generated by Program 3 when its `save_each_epoch` flag is True.\
`num_epochs`: The number of epochs that Program 3 was run with if the `save_each_epoch` flag was True. Only used when `multiple_epochs` is True.\
`checkpoint_name`: The unique addition to the end of `CHECKPOINT` to know the checkpoint file to use. Currently, the program expects "epoch_" or another 6 characters to be at the end of `CHECKPOINT` to be replaced by this variable. Only used when `multiple_epochs` is False.

In [None]:
#Define variables
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean = [0.5, 0.5, 0.5], std = [0.5, 0.5, 0.5])
])

batch_size = 128
num_workers = 0 if platform.system() == "Windows" else 8

freeze_model = False
num_classes = 7

do_entire_image = True
row_slices = {
    "14736_UN_050a.ome.tif": [range(1000, 2000)],
    "27570_UN_110a.ome.tif": [range(500, 1500)],
    "32002_LT_180a.ome.tif": [range(1500, 2500), range(3300, 4300)],
    "KY_PS_LB40SDwk16601_7_a.ome.tif": []
}
col_slices = {
    "14736_UN_050a.ome.tif": [range(1500, 2500)],
    "27570_UN_110a.ome.tif": [range(2000, 3000)],
    "32002_LT_180a.ome.tif": [range(5500, 6500), range(8000, 9000)],
    "KY_PS_LB40SDwk16601_7_a.ome.tif": []
}

do_certain_images = True
images_to_do = ["14736_UN_050a.ome.tif", "DP28_25081_Section3_10X_ome_copy.tif"]

window_size = 200
window_radius = int(window_size / 2)

image_scaling_testing = False
multiplier = 0.5
if image_scaling_testing:
    window_radius = int(window_radius * multiplier)

enable_pixel_skip = True
pixel_skip = 5

save_figs = True
save_dataframes = True

use_amp = True

multiple_epochs = True
num_epochs = 25
checkpoint_name = 20

### Class and Function Definitions

This section contains all the Classes and Functions used by this notebook.

`read_image`: Loads an image file and makes any modifications, if necessary.

<b>Parameters:</b>\
&emsp;`file`: The image file to be loaded in.\
&emsp;`image_slice`: A dictionary containing the row and collumn slices to modify the image with. <b>Default:</b> None

Loads in the image using `cv2` and converts it to the correct color layout.\
If `image_scaling_testing` is True, resizes the image based on `multiplier`.\
If `image_slice` parameter was passed a value, create a slice of the loaded image.

<b>Returns:</b>\
&emsp;The loaded image with any potential modifications made.

In [None]:
def read_image(file, image_slice = None):
    '''Read in the given image, file, then modify as requested. If image_slice is passed a dictionary containing row and col range objects,
    it will be used to create a slice of the image being read in.'''
    image = cv2.cvtColor(cv2.imread(SAMPLE_DIR + file), cv2.COLOR_BGR2RGB)

    if image_scaling_testing:
        image = cv2.resize(image, dsize = None, fx = multiplier, fy = multiplier, interpolation = cv2.INTER_CUBIC)

    if image_slice is not None:
        image = image[image_slice['row'][0]:image_slice['row'][-1] + 1, image_slice['col'][0]:image_slice['col'][-1] + 1, :]

    return image

`read_probability_csv`: Reads in a class probability csv file. Mainly used for running analyses later without having to remake the predictions. <b>Deprecated:</b> After switching from csv files to parquet files, this is no longer needed as parquet files are binary and thus don't lose the data types when being saved and loaded.

<b>Parameters:</b>\
&emsp;`file`: The probabilities csv file to be loaded.

Reads in the csv file. Then, processes each item in the dataframe to remake the original list that was converted to a string when saving.

<b>Returns:</b>\
&emsp;The newly loaded dataframe containing lists of probabilities at each pixel.

In [None]:
def read_probability_csv(file):
    '''Read in the class probabilities csv, if necessary. Everything will be strings, so convert to proper data type.
    Returns the opened DataFrame.'''
    #Read in csv file
    df = pd.read_csv(file)

    #Convert string values to lists
    for row in df.index:
        for col in df.columns:
            df.loc[row, col] = literal_eval(re.sub("\s+", ',', df.loc[row, col]))

    return df

`setup_model`: Prepares the image classification model being used.

<b>Parameters:</b>\
&emsp;`model`: The image classification model being setup.

If `freeze_model` flag is True, loops through the existing layers of the model and freezes them.\
Then, replaces the fully connected layer of the model with a Linear layer that has a number of output features equal to `num_classes`.\
Finally, sends the model to the Torch device `device`.

<b>Returns:</b>\
&emsp;The model set to eval mode.

In [None]:
def setup_model(model):
    '''Applies desired modifications to given model. Returns model in eval mode.'''
    #Freeze model, if desired
    if freeze_model:
        for param in model.parameters():
            param.requires_grad = False

    #Replace final layer
    model.fc = nn.Linear(model.fc.in_features, num_classes)

    #Send model to torch device
    model.to(device)

    return model.eval()

`check_coord_bounds`: Checks if the coordinates contained in a dataframe are within certain bounds.

<b>Parameters:</b>\
&emsp;`df`: The dataframe containing coordinates.\
&emsp;`x`: A range object of the x bound.\
&emsp;`y`: A range object of the y bound.\
&emsp;`row_num`: The row number of the coordinates in the dataframe that are currently being looked at.

<b>Returns:</b>\
&emsp;A boolean of whether the coordinates in the desired row of the dataframe are within the x, y bounds.

In [None]:
def check_coord_bounds(df, x, y, row_num):
    '''Checks whether coordinates in a given row of a given dataframe are within the given x, y bounds.'''
    return x[0] <= df['Centroid Y px'][row_num] <= x[-1] and y[0] <= df['Centroid X px'][row_num] <= y[-1]

`get_coords`: Gets the modified coordinates contained in a dataframe as a list.

<b>Parameters:</b>\
&emsp;`df`: The dataframe containing coordinates.\
&emsp;`x`: A range object of the x bound.\
&emsp;`y`: A range object of the y bound.\
&emsp;`row_num`: The row number of the coordinates in the dataframe that are currently being looked at.

<b>Returns:</b>\
&emsp;A list containing the coordinates in a row of the dataframe that have been modified to fit in the x and y range bounds where the beginning of those bounds are the new 0, 0.

In [None]:
def get_coords(df, x, y, row_num):
    '''Returns a list of coords from a given row of a given dataframe with the given x, y bounds subtracted to fit on a plot.'''
    return list(df[['Centroid X px', 'Centroid Y px']].iloc[row_num].to_numpy() - np.array([y[0], x[0]]))

`ImageSet`: Class that extends `torchvision.datasets.VisionDataset`

This class extends the base VisionDataset class in torchvision to create a dataset that can store the image slices of the full image to have predictions made.

<b>Methods:</b>\
&emsp;`__init__`: Initializes the dataset. Override.

&emsp;<b>Parameters:</b>\
&emsp;&emsp;`image_file`: The file path for the full image that will be having predictions made on it.\
&emsp;&emsp;`transforms`: The torchvision transforms to apply to the subimage before it is sent to the model.\
&emsp;&emsp;`centers`: A list of tuples containing the valid center points for subimages.\
&emsp;&emsp;`radius`: The window radius used to cut out a subimage.

&emsp;Calls `read_image` to load in the full image to `self.full_image`.\
&emsp;Stores the center points in `self.images`.\
&emsp;Stores the radius in `self.radius`.\
&emsp;Finally, calls `super().__init__()` passing the `transforms` parameter to finish setting up the dataset.

&emsp;------------------------------------------------------------

&emsp;`__getitem__`: Gets the item in the dataset at an index. Override.

&emsp;<b>Parameters:</b>\
&emsp;&emsp;`index`: An index value to be used to get an item in the dataset.

&emsp;Gets the row and collumn center point located at `index` in `self.images`.
&emsp;Creates a slice of `self.full_image` using the center point and the `self.radius` value to create the subimage.

&emsp;<b>Returns:</b>\
&emsp;&emsp;A tuple containing the center point of the subimage at `index` as well as the transformed subimage which has been cut out of the full image using `self.radius`.

&emsp;------------------------------------------------------------

&emsp;`__len__`: Gets the length of the dataset. Override.

&emsp;<b>Returns:</b>\
&emsp;&emsp;The length of `self.images`.

In [None]:
class ImageSet(VisionDataset):
    '''Image dataset class extending torchvision.datasets.VisionDataset'''
    def __init__(self, image_file, transforms, centers, radius):
        '''__init__ override to handle some custom variables. Calls __init__ of parent class to finish setup.'''
        self.full_image = read_image(image_file)
        self.images = centers
        self.radius = radius

        super().__init__(transforms = transforms)

    def __getitem__(self, index):
        '''Takes an index and gets the values of the dataset at that index. Returns the x and y values unchanged and transformed image.'''
        row, col = self.images[index]
        image = self.full_image[row - self.radius:row + self.radius, col - self.radius:col + self.radius, :]
        return (row, col, self.transforms(Image.fromarray(image)))

    def __len__(self):
        '''Gets the length of the dataset.'''
        return len(self.images)

`make_predictions`: Makes predictions on a given slice of an image.

<b>Parameters:</b>\
&emsp;`row`: Range object for slice of rows to be predicted on.\
&emsp;`col`: Range object for slice of collumns to be predicted on.\
&emsp;`image_file`: The file path to the image to be predicted on.

Sets up empty numpy arrays to store predictions and probability data.\
If `enable_pixel_skip` is True, sets up variables needed to handle pixel skipping.\
Loops through ranges to create a list containing all the subimage centers that will need to have predictions made.\
Creates dataset and dataloader from list of centers.\
Makes predictions on the subimages, and stores the prediction and probability values into the numpy arrays created earlier. If `enable_pixel_skip` is True, the data is stored in a `pixel_skip` by `pixel_skip` box that is centered on the center value that was just predicted on. Otherwise, the data is stored in its respective pixel.\
Converts numpy arrays to pandas dataframes and saves dataframes if `save_figs` and `save_dataframes` are both True.

<b>Returns:</b>\
&emsp;Dataframe of the probabilities for each class at each pixel in the prediction range, dataframe of the predicted class value at each pixel in the prediction range, and the time when predictions were finished.

In [None]:
def make_predictions(row, col, image_file):
    '''Iterates over and makes predictions on given image.'''
    global SAVED_FILES
    
    x_cols = len(row)
    y_rows = len(col)
    class_probability = np.empty((x_cols, y_rows), dtype = object)
    prediction = np.zeros((x_cols, y_rows))
    radius = window_radius

    #Setup for pixel skipping
    if enable_pixel_skip:
        padding = pixel_skip // 2
        
        row_first, row_last = row[0] + padding, row[-1]
        col_first, col_last = col[0] + padding, col[-1]
        
        row = range(row_first, row_last + 1, pixel_skip)
        col = range(col_first, col_last + 1, pixel_skip)

        row_final = row[-1]
        col_final = col[-1]

        padding_final = [row_last - row_final, col_last - col_final]

    #Create list of valid image centers
    imgs = []
    
    for i in row:
        for j in col:
            imgs.append((i, j))

    #Create image dataset and dataloader
    imgs = ImageSet(image_file = image_file, transforms = transform, centers = imgs, radius = radius)
    imgs_data_loader = DataLoader(imgs, batch_size = batch_size, shuffle = False, num_workers = num_workers, pin_memory = True)

    #Make predictions
    for row_inds, col_inds, images in tqdm(imgs_data_loader):
        with torch.set_grad_enabled(False):
            with torch.autocast(device_type = "cuda" if torch.cuda.is_available() else "cpu", dtype = torch.float16, enabled = use_amp):
                outputs = model(images.to(device))

        for row_ind, col_ind, output in zip(row_inds.numpy(), col_inds.numpy(), outputs):
            if enable_pixel_skip: #Skipping pixels requires more logic, still much faster overall
                output_probability = (torch.exp(output) / torch.sum(torch.exp(output)).reshape(-1, 1))[0].data.cpu().numpy()
                output_prediction = output.argmax().item()

                if row_ind == row_final and col_ind == col_final: #If at both edges
                    for x in range(row_ind - row[0] - padding, row_ind - row[0] + padding_final[0] + 1):
                        for y in range(col_ind - col[0] - padding, col_ind - col[0] + padding_final[1] + 1):
                            class_probability[x, y] = output_probability
                            prediction[x, y] = output_prediction
                elif row_ind == row_final and col_ind != col_final: #If at bottom edge but not right edge
                    for x in range(row_ind - row[0] - padding, row_ind - row[0] + padding_final[0] + 1):
                        for y in range(col_ind - col[0] - padding, col_ind - col[0] + padding + 1):
                            class_probability[x, y] = output_probability
                            prediction[x, y] = output_prediction
                elif row_ind != row_final and col_ind == col_final: #If at right edge but not bottom edge
                    for x in range(row_ind - row[0] - padding, row_ind - row[0] + padding + 1):
                        for y in range(col_ind - col[0] - padding, col_ind - col[0] + padding_final[1] + 1):
                            class_probability[x, y] = output_probability
                            prediction[x, y] = output_prediction
                else: #If not at either edge
                    for x in range(row_ind - row[0] - padding, row_ind - row[0] + padding + 1):
                        for y in range(col_ind - col[0] - padding, col_ind - col[0] + padding + 1):
                            class_probability[x, y] = output_probability
                            prediction[x, y] = output_prediction
            else: #Not skipping pixels
                class_probability[row_ind - row[0], col_ind - col[0]] = (torch.exp(output) / torch.sum(torch.exp(output)).reshape(-1, 1))[0].data.cpu().numpy()
                prediction[row_ind - row[0], col_ind - col[0]] = output.argmax().item()
                
    #Memory management
    del imgs, imgs_data_loader
    
    #Get time for finished predictions
    pred_time = datetime.datetime.now()
    
    #Convert to dataframes
    class_probability = pd.DataFrame(class_probability)
    prediction = pd.DataFrame(prediction)
    
    #Save dataframes
    if save_figs and save_dataframes:
        files = ["Class_Probabilities.parquet", "Predictions.parquet"]

        #Keep track of files
        for file in files:
            SAVED_FILES.append(file)
        
        class_probability.to_parquet(files[0], index = False)
        prediction.to_parquet(files[1], index = False)
        
    return class_probability, prediction, pred_time

`make_output_dir`: Create output directory structure to store files created by this program.

<b>Parameters:</b>\
&emsp;`slice_num`: The slice number that is currently being predicted on for a particular image.\
&emsp;`image_file`: The file name of the image that is currently being predicted on.\
&emsp;`epoch`: The epoch number of the checkpoint that is currently being used to make predictions. <b>Default:</b> None.

Makes the base output directory if it does not already exist.\
If `epoch` is passed a value, creates a directory within the base output directory for that epoch number. Otherwise, moves to the next step.\
If `do_entire_image` is True, creates a directory within the epoch directory that is just the image name. Otherwise, the new directory contains the image name and the slice number that is currently being predicted on.\
Creates a time-stamped directory within the previously created directory and saves it to a global constant `TIME_STAMP_OUTPUT_DIR`.

In [None]:
def make_output_dir(slice_num, image_file, epoch = None):
    '''Check if the directory specified by OUTPUT_DIR exists. Create directory if it does not exist.
    Also create time-stamped directory within.'''
    time = datetime.datetime.now()

    #Create base output directory if it does not exist
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    #Create epoch directory if passed a value
    if epoch is not None:
        epoch_dir = OUTPUT_DIR + "Epoch " + str(epoch) + "/"

        if not os.path.exists(epoch_dir):
            os.mkdir(epoch_dir)
    else:
        epoch_dir = OUTPUT_DIR

    #Create output directory for image/slice if it does not exist
    if do_entire_image:
        image_dir = epoch_dir + image_file.split(".")[0] + "/"
    else:
        image_dir = epoch_dir + image_file.split(".")[0] + "_slice_" + str(slice_num + 1) + "/"
        
    if not os.path.exists(image_dir):
        os.mkdir(image_dir)

    #Define global scope constant
    global TIME_STAMP_OUTPUT_DIR
    TIME_STAMP_OUTPUT_DIR = image_dir + time.strftime("%Y-%m-%d_%H-%M-%S")

    #Make time-stamped output directory
    os.mkdir(TIME_STAMP_OUTPUT_DIR)

`create_overlay`: Creates and adds an overlay showing where the model made predictions to the original image that was predicted on.

<b>Parameters:</b>\
&emsp;`predictions`: The dataframe containing prediction values for each pixel in the slice of the image that was predicted on.\
&emsp;`image_file`: The file name of the image that was predicted on.\
&emsp;`image_slice`: The slice of the image that was predicted on.

Defines a function to convert the prediction values to a binary overlay where white means a non-negative prediction was made and black means a negative prediction was made.\
Uses pandas dataframe map method to apply the previously created function to every datapoint in the dataframe.\
Merge the overlay with the original image to make a new image that shows where on the original image non-negative predictions were made.\
If `save_figs` is True, save the newly created image.

<b>Returns:</b>\
&emsp;The newly created overlay image.

In [None]:
def create_overlay(predictions, image_file, image_slice):
    '''Converts predictions into an overlay to be superimposed on the original image slice.'''
    def overlay_conversion(value):
        '''Return 255 if a non-negative prediction was made for that pixel. Return 0 if a negative prediction was made.'''
        if value != 0.0:
            return 255
        else:
            return 0

    #Map overlay_conversion to the prediction dataframe that was passed
    overlay = predictions.map(overlay_conversion)

    #Memory management
    del predictions
    
    #Merge the overlay into the original image
    overlay = cv2.cvtColor(np.asarray(overlay, np.uint8), cv2.COLOR_GRAY2RGB)
    overlay = cv2.addWeighted(read_image(image_file, image_slice = image_slice), 0.7, overlay, 0.3, 0.0)
    
    #Save overlay image
    if save_figs:
        file = "Overlay.png"

        #Keep track of file
        SAVED_FILES.append(file)

        cv2.imwrite(file, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
    
    return overlay

`create_probability_data`: Converts the probability dataframe to just contain the maximum probability at each pixel, regardless of which class was predicted.

<b>Parameters:</b>\
&emsp;`probabilities`: The dataframe containing each class probability across every pixel.

Uses the pandas dataframe map method to apply the max function to every list in the dataframe.\
If `save_figs` is True, saves the new max probabilities dataframe.

<b>Returns:</b>\
&emsp;The newly created dataframe that contains only the max probability at each pixel.

In [None]:
def create_probability_data(probabilities):
    '''Takes probabilities DataFrame and creates a list of the max data points for each row, col value.
    Returns the probability data for graphing.'''
    global SAVED_FILES
    
    #Map max to the entire dataframe
    max_probabilities = probabilities.map(max)

    #Memory management
    del probabilities

    #Save max dataframe
    if save_figs:
        file = "Max_Probabilities.parquet"

        #Keep track of file
        SAVED_FILES.append(file)

        max_probabilities.to_parquet(file, index = False)
    
    return max_probabilities

`plot_annotations`: Plots the given annotation coordinates on a given plot.

<b>Parameters:</b>\
&emsp;`annotations`: Dictionary containing the coordinates of human annotated follicles with what class they are as the key.\
&emsp;`ax`: The matplotlib axis object to plot the coordinates on.\
&emsp;`marker_size`: A number to determine how large the 'x' markers will be on the plot.

Loops through each key, value pair in the `annotations` dictionary and plots the coordinate pairs on the given axis. Each scatter plot is given a label based on the class of follicle it belongs to.

In [None]:
def plot_annotations(annotations, ax, marker_size):
    '''Plot the given annotations as a scatterplot on the axis, ax, of an already generated plot.'''
    for follicle_type, value in annotations.items():
        if len(value[0]) != 0: #If there are coordinates for the follicle type
            x_coords = []
            y_coords = []
            
            #Separate coords into x and y values for plotting
            for point in value[0]:
                x_coords.append(point[0])
                y_coords.append(point[1])
                
            #Plot the x and y coords as a scatter plot
            ax.scatter(x_coords, y_coords, s = marker_size, c = value[1], marker = 'x', label = follicle_type)

`plot_heatmap`: Plots the predictions as a Seaborn heatmap.

<b>Parameters:</b>\
&emsp;`heatmap`: The predictions dataframe to be plotted as a heatmap.

Uses the Seaborn heatmap function to plot the passed predictions dataframe, `heatmap`, as a heatmap.\
If `save_figs` is True, saves the heatmap to a file.

In [None]:
def plot_heatmap(heatmap):
    '''Creates and saves a heatmap from predictions DataFrame.'''
    global SAVED_FILES

    heatmap_shape = heatmap.shape
    
    #Plot heatmap
    _, ax = plt.subplots(figsize = (10, 8))
    
    sns.heatmap(heatmap, cmap = COLOR_PALETTE, vmin = 0, vmax = 6)
    
    ax.set_xticks(range(0, heatmap_shape[1] + 1, 1000))
    ax.set_yticks(range(0, heatmap_shape[0] + 1, 1000))
    ax.set_xticklabels(range(0, heatmap_shape[1] + 1, 1000), rotation = 90)
    ax.set_yticklabels(range(0, heatmap_shape[0] + 1, 1000))
    ax.set_title("Heatmap of Predictions")

    #Memory management
    del heatmap

    #Save heatmap
    if save_figs:
        file = "Heatmap.png"

        #Keep track of file
        SAVED_FILES.append(file)

        plt.savefig(file)

    plt.show()

`plot_analysis`: Plots a four panel analysis of the orginal image and the predictions made on it.

<b>Parameters:</b>\
&emsp;`image_file`: The file name of the image being predicted on.\
&emsp;`image_slice`: The slice of the image being predicted on.\
&emsp;`heatmap`: The prediction dataframe to be plotted as a heatmap.\
&emsp;`overlay`: The overlay image showing where on the original image predictions were made.\
&emsp;`pred_strength`: The dataframe containing maximum probabilities at each pixel, regardless of class.\
&emsp;`annotations`: Dictionary containing the coordinates of each human annotation for the image being predicted on, separated by follicle type.

Loads in the original image using `read_image`. Then, plots the original image on the top-left panel.\
Uses the Seaborn heatmap function to plot the class predictions in `heatmap` as a heatmap on the top-right panel.\
Uses the Seaborn heatmap function to plot the maximum probabilities in `pred_strength` as a heatmap on the bottom-left panel.\
Plots the image overlay on the bottom-right panel. Additionally, uses `plot_annotations` to plot the human annotations in `annotations` on top of the overlay.

In [None]:
def plot_analysis(image_file, image_slice, heatmap, overlay, pred_strength, annotations):
    '''Create a plot consisting of the original image, prediction heatmap, heatmap overlay with annotations, and prediction strengths.
    Save plot to a file.'''
    global SAVED_FILES

    img = read_image(image_file, image_slice = image_slice)
    img_shape = img.shape
    
    fig, axes = plt.subplots(2, 4, figsize = (12, 12), gridspec_kw = {'width_ratios': [100, 5, 100, 5], 'height_ratios': [100, 100]})

    #Plot original image
    axes[0, 0].imshow(img)

    axes[0, 0].set_xticks(range(0, img_shape[1] + 1, 1000))
    axes[0, 0].set_yticks(range(0, img_shape[0] + 1, 1000))
    axes[0, 0].set_xticklabels(range(0, img_shape[1] + 1, 1000), rotation = 90)
    axes[0, 0].set_yticklabels(range(0, img_shape[0] + 1, 1000))
    axes[0, 0].set_title("Original Image")

    #Memory management
    del img

    gc.collect()

    #Plot heatmap
    sns.heatmap(heatmap, ax = axes[0, 2], cmap = COLOR_PALETTE, vmin = 0, vmax = 6, cbar_ax = axes[0, 3])
    
    axes[0, 2].set_xticks(range(0, img_shape[1] + 1, 1000))
    axes[0, 2].set_yticks(range(0, img_shape[0] + 1, 1000))
    axes[0, 2].set_xticklabels(range(0, img_shape[1] + 1, 1000), rotation = 90)
    axes[0, 2].set_yticklabels(range(0, img_shape[0] + 1, 1000))
    axes[0, 2].set_title("Heatmap")

    axes[0, 3].yaxis.set_ticks_position("left")

    #Memory management
    del heatmap

    gc.collect()

    #Plot prediction strengths
    sns.heatmap(pred_strength, ax = axes[1, 0], vmin = 0.0, vmax = 1.0, cbar_ax = axes[1, 1])

    axes[1, 0].set_xticks(range(0, img_shape[1] + 1, 1000))
    axes[1, 0].set_yticks(range(0, img_shape[0] + 1, 1000))
    axes[1, 0].set_xticklabels(range(0, img_shape[1] + 1, 1000), rotation = 90)
    axes[1, 0].set_yticklabels(range(0, img_shape[0] + 1, 1000))
    axes[1, 0].set_title("Prediction Strength")

    axes[1, 1].yaxis.set_ticks_position("left")

    #Memory management
    del pred_strength

    gc.collect()

    #Plot heatmap overlay and annotations
    axes[1, 2].imshow(overlay)
    plot_annotations(annotations, axes[1, 2], int((1 / overlay.shape[0]) * 100000))

    axes[1, 2].set_xticks(range(0, img_shape[1] + 1, 1000))
    axes[1, 2].set_yticks(range(0, img_shape[0] + 1, 1000))
    axes[1, 2].set_xticklabels(range(0, img_shape[1] + 1, 1000), rotation = 90)
    axes[1, 2].set_yticklabels(range(0, img_shape[0] + 1, 1000))
    axes[1, 2].set_title("Heatmap Overlay with Annotations")

    #Memory management
    del overlay

    gc.collect()

    #Final touches
    fig.delaxes(axes[0, 1])
    fig.delaxes(axes[1, 3])
    fig.suptitle("Prediction Analysis")

    #Save plot
    if save_figs:
        file = "Analysis_Plot.png"

        #Keep track of file
        SAVED_FILES.append(file)

        plt.savefig(file)

    plt.show()

### Model Setup

This section contains the setup for the image classification model.

Set the device to use for PyTorch. Then print out the device to make sure it is using the correct one.\
&emsp;"cuda" = GPU, "cpu" = CPU

In [None]:
#Initialize torch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: {}".format(device))

Call `setup_model` and pass it the pretrained ResNet34 model to setup.

In [None]:
#Model setup
model = setup_model(models.resnet34(weights = "ResNet34_Weights.DEFAULT"))

### Sample Image Setup

This section sets up the sample image(s) to have predictions made on them.

Creates a dictionary of dictionaries that contains the images being predicted on and every slice being done on each image.

<b>Dictionary Structure:</b>
```python
sample_imgs = {
    "image_file": {
        slice_number: {
            'row': range object,
            'col': range object
        }
    }
}
```

In [None]:
sample_imgs = {}

#Loop through all sample images
for image_file in SAMPLE_IMAGE_DICT.keys():
    #If only doing some images, skip any images that are not in images_to_do
    if do_certain_images and image_file not in images_to_do:
        continue

    #If doing the entire image, overwrite the row and collumn slices with a new range object that is the whole image with a 100 pixel buffer for subimages
    if do_entire_image:
        image = cv2.imread(SAMPLE_DIR + image_file)

        row_slices[image_file] = [range(window_radius, image.shape[0] - window_radius)]
        col_slices[image_file] = [range(window_radius, image.shape[1] - window_radius)]

        #Memory management
        del image

    #Check if there is a row and collumn slice for an image if not doing entire image
    if image_file not in row_slices.keys() or image_file not in col_slices.keys():
        raise KeyError("Image \'" + image_file + "\' has not been given a slice and do_entire_image flag is set to False.")

    #Create a nested dictionary tied to the image file
    sample_imgs[image_file] = {}
    
    #Loop through all row and col slices for the image
    for i, (row_slice, col_slice) in enumerate(zip(row_slices[image_file], col_slices[image_file])):
        if image_scaling_testing: #Rescale slices if doing image scaling
            row_slice = range(int(row_slice[0] * multiplier), int((row_slice[-1] + 1) * multiplier))
            col_slice = range(int(col_slice[0] * multiplier), int((col_slice[-1] + 1) * multiplier))

        #Add row and collumn slices to another nested dictionary
        sample_imgs[image_file][i] = {'row': row_slice, 'col': col_slice}

Plots all base images being predicted on. Uses `sample_imgs` dictionary made previously.

In [None]:
#Determine how many base images there are
count = len(sample_imgs.keys())

#Plot base images
_, axes = plt.subplots(1, count, figsize = (count * 5, 5))

if count < 2:
    axes = [axes]

for i, (image_file) in enumerate(sample_imgs):
    image = read_image(image_file)

    axes[i].imshow(image)
    axes[i].set_title(image_file)

    #Memory management
    del image

plt.show()

Plots all sliced images being predicted on. Uses `sample_imgs` dictionary made previously.

In [None]:
#Determine how many image slices there are
count = 0

for values in sample_imgs.values():
    for value in values.values():
        count += 1

#Plot the image slices
_, axes = plt.subplots(1, count, figsize = (count * 5, 5))

if count < 2:
    axes = [axes]

axis_num = 0

for image_file, image_slices in sample_imgs.items():
    image = read_image(image_file)

    for slice_num, image_slice in image_slices.items():
        row, col = image_slice['row'], image_slice['col']

        image_sliced = image[row[0]:row[-1] + 1, col[0]:col[-1] + 1, :]

        axes[axis_num].imshow(image_sliced)
        axes[axis_num].set_title(image_file + " Slice " + str(slice_num + 1))

        axis_num += 1

    #Memory management
    del image, image_sliced

plt.show()

#### Create Annotations Overlay for Sample Image Slice

Finds and stores every coordinate for a human annotation on each image being predicted on that exists within each slice of the base image being done.

<b>Dictionary Structure:</b>
```python
image_annot_coords = {
    (slice_num, "image_file"): {
        "Primordial": [[list of valid coordinates], COLOR_PALETTE[1]],
        "Transitional Primordial": [[list of valid coordinates], COLOR_PALETTE[2]],
        "Primary": [[list of valid coordinates], COLOR_PALETTE[3]],
        "Transitional Primary": [[list of valid coordinates], COLOR_PALETTE[4]],
        "Secondary": [[list of valid coordinates], COLOR_PALETTE[5]],
        "Multilayer": [[list of valid coordinates], COLOR_PALETTE[6]]
    }
}
```

In [None]:
image_annot_coords = {}

#Loop through every image and slice being done
for image_file, image_slices in sample_imgs.items():
    #Load in base image's annotations
    annotations = pd.read_csv(SAMPLE_DIR + SAMPLE_IMAGE_DICT[image_file], sep = '\t')

    #Convert microns to pixels
    annotations[["Centroid X px", "Centroid Y px"]] = annotations[['Centroid X µm', 'Centroid Y µm']] / 0.69

    if image_file == "DP28_25081_Section3_10X_ome_copy.tif": #One of the sample images has a different scale
        annotations[["Centroid X px", "Centroid Y px"]] = annotations[['Centroid X µm', 'Centroid Y µm']] / 0.1725

    if image_scaling_testing: #If scaling images
        annotations[['Centroid X px', 'Centroid Y px']] = annotations[['Centroid X px', 'Centroid Y px']] * multiplier
    
    #Remove any collumn in the dataframe that is not needed and remove na values
    annotations = annotations[['Name', 'Centroid X px', 'Centroid Y px']]
    annotations.dropna(axis = 0, inplace = True, ignore_index = True)

    #Loop through each image slice
    for slice_num, image_slice in image_slices.items():
        row_slice, col_slice = image_slice['row'], image_slice['col']
    
        annot_coords = {
            "Primordial": [[], COLOR_PALETTE[1]],
            "Transitional Primordial": [[], COLOR_PALETTE[2]],
            "Primary": [[], COLOR_PALETTE[3]],
            "Transitional Primary": [[], COLOR_PALETTE[4]],
            "Secondary": [[], COLOR_PALETTE[5]],
            "Multilayer": [[], COLOR_PALETTE[6]]
        }
    
        #Loop through each annotation coordinate and check if they are within the bounds of the image slice
        for i in range(len(annotations)):
            annot_class = annotations['Name'][i]
                
            match annot_class:
                case 'Primordial':
                    if check_coord_bounds(annotations, row_slice, col_slice, i):
                        annot_coords[annot_class][0].append(get_coords(annotations, row_slice, col_slice, i))
                case 'Transitional Primordial':
                    if check_coord_bounds(annotations, row_slice, col_slice, i):
                        annot_coords[annot_class][0].append(get_coords(annotations, row_slice, col_slice, i))
                case 'Primary':
                    if check_coord_bounds(annotations, row_slice, col_slice, i):
                        annot_coords[annot_class][0].append(get_coords(annotations, row_slice, col_slice, i))
                case 'Transitional Primary':
                    if check_coord_bounds(annotations, row_slice, col_slice, i):
                        annot_coords[annot_class][0].append(get_coords(annotations, row_slice, col_slice, i))
                case 'Secondary':
                    if check_coord_bounds(annotations, row_slice, col_slice, i):
                        annot_coords[annot_class][0].append(get_coords(annotations, row_slice, col_slice, i))
                case 'Multilayer':
                    if check_coord_bounds(annotations, row_slice, col_slice, i):
                        annot_coords[annot_class][0].append(get_coords(annotations, row_slice, col_slice, i))

        image_annot_coords[(slice_num, image_file)] = annot_coords

Plots the annotation coordinates as a scatter plot over each image slice, respectively. Uses both `sample_imgs` and `image_annot_coords`.

In [None]:
#Determine how many images slices are being done
count = 0

for values in sample_imgs.values():
    for value in values.values():
        count += 1

#Plot the image slices
_, axes = plt.subplots(1, count, figsize = (count * 5, 5))

if count < 2:
    axes = [axes]

axis_num = 0

for image_file, image_slices in sample_imgs.items():
    image = read_image(image_file)

    for slice_num, image_slice in image_slices.items():
        row, col = image_slice['row'], image_slice['col']

        image_sliced = image[row[0]:row[-1] + 1, col[0]:col[-1] + 1, :]

        axes[axis_num].imshow(image_sliced)
        axes[axis_num].set_title(image_file + " Slice " + str(slice_num + 1))
        #Additionally, scatter plot the annotation coordinates over each image slice
        plot_annotations(image_annot_coords[(slice_num, image_file)], axes[axis_num], int((1 / image_sliced.shape[0]) * 100000))

        axis_num += 1

    #Memory management
    del image, image_sliced

plt.show()

## Main Code

This section contains the main code of the program that makes predictions on the images and creates analysis plots to examine those predictions.

If the `multiple_epochs` flag is set to False, load the checkpoint that is being used to make predictions. Then, loop through every image/slice that is having predictions made. Make predictions on the image/slice. Then, make an overlay out of the predictions and make a dataframe of maximum probabilities out of the class probabilities for each prediction. Plot the predictions as a heatmap on its own, then plot the full, four panel analysis plot using the original image/slice, the predictions, the maximum probabilities, and the human annotations. Finally, make an output directory for this image/slice and move any related files that were created into the output directory. Once the image/slice is finished, move on to the next one. If all images/slices have been finished, "Done" will be printed to the cell output.

If the `multiple_epochs` flag is set to True, loop through every epoch, loading the checkpoint for that epoch at the beginning of the iteration. Then, loop through every image/slice that is having predictions made. Make predictions on the image/slice. Then, make an overlay out of the predictions and make a dataframe of maximum probabilities out of the class probabilities for each prediction. Plot the predictions as a heatmap on its own, then plot the full, four panel analysis plot using the original image/slice, the predictions, the maximum probabilities, and the human annotations. Finally, make an output directory for this image/slice and move any related files that were created into the output directory. Once the image/slice is finished, move on to the next one. If all images/slices have been finished, move on to the next epoch. Once all epochs are finished, "Done" will be printed to the cell output.

In [None]:
if not multiple_epochs: #Only one epoch
    #Load checkpoint
    model.load_state_dict(torch.load(CHECKPOINT + str(checkpoint_name) + ".pt", map_location = device, weights_only = True))
    
    #Loop through all images and slices being done
    for image_file, image_slices in sample_imgs.items():
        for slice_num, image_slice in image_slices.items():
            #Memory management
            gc.collect()
            
            #Trach start time
            start_time = datetime.datetime.now()
            print("---------- Slice: {} of Image: {} ----------".format(slice_num + 1, image_file))
    
            #Make predictions on the image slice
            probabilities, predictions, pred_time = make_predictions(image_slice['row'], image_slice['col'], image_file)
            
            #Create an overlay showing where the model made predictions
            overlay_img = create_overlay(predictions, image_file, image_slice)
            
            #Create a dataframe that shows how strong the model's predictions were at each pixel, regardless of predicted class
            prediction_strengths = create_probability_data(probabilities)
    
            #Memory management
            del probabilities
            
            #Plot heatmap and full analysis plots
            plot_heatmap(predictions)
            plot_analysis(image_file, image_slice, predictions, overlay_img, prediction_strengths, image_annot_coords[(slice_num, image_file)])
    
            #Memory management
            del predictions, overlay_img, prediction_strengths
        
            print("  ---------- Saving and Moving Files ----------")
                
            #Create output directory and move files
            make_output_dir(slice_num, image_file)
            print("    Output Dir: {}".format(TIME_STAMP_OUTPUT_DIR))
        
            shutil.copy2(CHECKPOINT + str(checkpoint_name) + ".pt", TIME_STAMP_OUTPUT_DIR)
        
            for file in SAVED_FILES:
                shutil.move(file, TIME_STAMP_OUTPUT_DIR)
        
            SAVED_FILES.clear()
        
            APP.commands.execute("docmanager:save")
        
            shutil.copy2(NOTEBOOK_NAME, TIME_STAMP_OUTPUT_DIR)
        
            #Prevents html errors in notebook
            clear_output(wait = True)
        
            #Track end time
            end_time = datetime.datetime.now()
            #Save prediction time and total time taken to a dictionary
            RUNTIMES[(slice_num, image_file)] = {"prediction": str(pred_time - start_time).split(".")[0], "total": str(end_time - start_time).split(".")[0]}
    
    print("Done")
elif multiple_epochs: #Doing each saved epoch from program 3 when save_each_epoch is True
    #Loop through each epoch
    for epoch in range(0, num_epochs + 1):
        #Load checkpoint
        model.load_state_dict(torch.load(CHECKPOINT + str(epoch) + ".pt", map_location = device, weights_only = True))
        
        #Loop through all images and slices being done
        for image_file, image_slices in sample_imgs.items():
            for slice_num, image_slice in image_slices.items():
                #Memory management
                gc.collect()

                print("---------- Epoch: {} of {} ----------".format(epoch, num_epochs))
                
                #Trach start time
                start_time = datetime.datetime.now()
                print("  ---------- Slice: {} of Image: {} ----------".format(slice_num + 1, image_file))
        
                #Make predictions on the image slice
                probabilities, predictions, pred_time = make_predictions(image_slice['row'], image_slice['col'], image_file)
                
                #Create an overlay showing where the model made predictions
                overlay_img = create_overlay(predictions, image_file, image_slice)
                
                #Create a dataframe that shows how strong the model's predictions were at each pixel, regardless of predicted class
                prediction_strengths = create_probability_data(probabilities)
        
                #Memory management
                del probabilities
                
                #Plot heatmap and full analysis plots
                plot_heatmap(predictions)
                plot_analysis(image_file, image_slice, predictions, overlay_img, prediction_strengths, image_annot_coords[(slice_num, image_file)])
        
                #Memory management
                del predictions, overlay_img, prediction_strengths
            
                print("    ---------- Saving and Moving Files ----------")
                    
                #Create output directory and move files
                make_output_dir(slice_num, image_file, epoch)
                print("      Output Dir: {}".format(TIME_STAMP_OUTPUT_DIR))
            
                shutil.copy2(CHECKPOINT + str(epoch) + ".pt", TIME_STAMP_OUTPUT_DIR)
            
                for file in SAVED_FILES:
                    shutil.move(file, TIME_STAMP_OUTPUT_DIR)
            
                SAVED_FILES.clear()
            
                APP.commands.execute("docmanager:save")
            
                shutil.copy2(NOTEBOOK_NAME, TIME_STAMP_OUTPUT_DIR)
            
                #Prevents html errors in notebook
                clear_output(wait = True)
            
                #Track end time
                end_time = datetime.datetime.now()
                #Save prediction time and total time taken to a dictionary
                RUNTIMES[("Epoch " + str(epoch), slice_num, image_file)] = {"prediction": str(pred_time - start_time).split(".")[0], "total": str(end_time - start_time).split(".")[0]}
        
    print("Done")

Prints out the `RUNTIMES` dictionary, showing how long each image/slice took. Each image/slice has a prediction time, which is how long it took to only make predictions on the image, and a total time, which is how long the entire iteration of the loop for that image/slice took including prediction time.

In [None]:
#Show the runtimes for each image/slice
for key, value in RUNTIMES.items():
    print("{}: {}".format(key, value))

### Convert Notebook to HTML and Move to Output Directory / Clean Up Working Directory

Programmatically save the notebook, convert it to html and move the html file to the base output directory.

In [None]:
#For some reason, this is needed for the save command just below it to function
time.sleep(1)

#Save notebook
APP.commands.execute("docmanager:save")

#Convert notebook to html
!jupyter nbconvert --to html "$NOTEBOOK_NAME"

#Move html and copy notebook to output directory
shutil.move(NOTEBOOK_NAME[:-6] + ".html", OUTPUT_DIR + NOTEBOOK_NAME[:-6] + "_" + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".html")

Final cleanup of current working directory.

In [None]:
#Remove remaining checkpoint(s)
if not multiple_epochs:
    os.remove(CHECKPOINT + str(checkpoint_name) + ".pt")
elif multiple_epochs:
    for epoch in range(0, num_epochs + 1):
        os.remove(CHECKPOINT + str(epoch) + ".pt")