# Imports

In [3]:
# !conda activate n2v

import numpy as np
from matplotlib import pyplot as plt
import sys
import torch
import random
import zarr
from skimage import data
from skimage import filters
from skimage import metrics

from funlib.learn.torch.models import UNet, ConvPass
import gunpowder as gp

# from this repo
import loser

# Set Parameters (including data source, training variables, destination, etc.)

## **Paths for training, predictions and results**

**`train_source:`:** This is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.

**`model_name`:** Use only my_model -style, not my-model (Use "_" not "-"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.

**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).


## **Training parameters**

**`num_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 5.). **Default value: 30**

**`side_length`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 100**

**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size values until the OOM error disappear.**

**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 1**

**`num_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**

**`perc_validation`:**  Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** 

**`init_learn_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**

**`perc_hotPixels:`** Percent of output pixels to designate as targets and *heat* for training **Default value: 0.198**


In [1]:
train_source = ''
model_name = 'noise2gun_mCTX30nm_450p'
model_path = ''

num_epochs = 100
side_length = 100
batch_size = 1
num_steps = 100
perc_validation = 10
init_learn_rate = 0.0004
perc_hotPixels = 0.198

### Make sure data source is a **zarr** 

# Build Gunpowder Pipeline for Training

### Elements are:

- Data Source
- *(optional) Normalization*
- Random Patch Grab
- Pixel Heating (select and mutate *hotPixels*, i.e. training targets, and keep masks)
- Simple Augmentation (rotations/reflections)
- Stacking
- Training


In [None]:
# declare arrays to use in the pipeline
raw = gp.ArrayKey('RAW') # raw data
hot = gp.ArrayKey('HOT') # data with random pixels heated
mask = gp.ArrayKey('MASK') # data with random pixels heated
prediction = gp.ArrayKey('PREDICTION') # prediction of denoised data

source = gp.ZarrSource(    # add the data source
    train_source,  # the zarr container
    {raw: 'raw'},  # which dataset to associate to the array key
    {raw: gp.ArraySpec(interpolatable=True)}  # meta-information
)

# add normalization
# normalize = gp.Normalize(raw)

# add a RandomLocation node to the pipeline to randomly select a sample
random_location = gp.RandomLocation()

# add transpositions/reflections
simple_augment = gp.SimpleAugment()

# add pixel heater
boilerPlate = $$$$<----TODO:NEED TO GENERATE/MAKE SOURCE

# stack for batches
stack = gp.Stack(batch_size)

# define our network model for training
num_fmaps = 12

unet = UNet(
  in_channels=1,
  num_fmaps=num_fmaps,
  fmap_inc_factor=5,
  downsample_factors=[[2, 2, 2], [2, 2, 2]],
  padding='valid',
  constant_upsample=True,
  voxel_size=[30, 30, 30] # set for each dataset
  )

model = torch.nn.Sequential(
  unet,
  ConvPass(num_fmaps, 1, [(1, 1, 1)], activation=None),
  torch.nn.Sigmoid())

# pick loss function
loss = loser.MaskedMSELoss()

# pick optimizer
optimizer = torch.optim.Adam(model.parameters())

# create a train node using our model, loss, and optimizer
train = gp.torch.Train(
  model,
  loss,
  optimizer,
  inputs = {
    'input': hot <----TODO: NEEDS TO BE HOT
  },
  loss_inputs = {
    'input': prediction,
    'mask': mask,<----TODO:NEED TO GENERATE/MAKE SOURCE
    'target': raw <----TODO: NEED TO CROP APPROPRIATELY
  },
  outputs = {
    0: prediction
  })

# create request
TODO: DETERMINE CORRECT REQUEST ROI SIZES
request[raw] = gp.Roi((0, 0, 0), (side_length, side_length, side_length))
request[prediction] = gp.Roi((0, 0, 0), (side_length, side_length, side_length))

# assemble pipeline
pipeline = (source +
            random_location +
            simple_augment + 
            stack + 
            train)

# Train

# Examine Results

# Build Prediction Pipeline