# Imports & Other Setup

In [9]:
# !conda activate n2v
%load_ext autoreload

import numpy as np
from matplotlib import pyplot as plt
import sys
import torch
import random
import zarr
from PIL import Image
from skimage import data
from skimage import filters
from skimage import metrics

from funlib.learn.torch.models import UNet, ConvPass
import gunpowder as gp
import logging
logging.basicConfig(level=logging.INFO)

# from this repo
import loser
from boilerPlate import BoilerPlate
# from segway.tasks.make_zarr_from_tiff import task_make_zarr_from_tiff_volume as tif2zarr

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def imshow(raw, ground_truth=None, prediction=None):
  rows = 1
  if ground_truth is not None:
    rows += 1
  if prediction is not None:
    rows += 1
  cols = raw.shape[0] if len(raw.shape) > 3 else 1
  fig, axes = plt.subplots(rows, cols, figsize=(10, 4), sharex=True, sharey=True, squeeze=False)
  if len(raw.shape) == 3:
    axes[0][0].imshow(raw.transpose(1, 2, 0))
  else:
    for i, im in enumerate(raw):
      axes[0][i].imshow(im.transpose(1, 2, 0))
  row = 1
  if ground_truth is not None:
    if len(ground_truth.shape) == 3:
      axes[row][0].imshow(ground_truth[0])
    else:
      for i, gt in enumerate(ground_truth):
        axes[row][i].imshow(gt[0])
    row += 1
  if prediction is not None:
    if len(prediction.shape) == 3:
      axes[row][0].imshow(prediction[0])
    else:
      for i, gt in enumerate(prediction):
        axes[row][i].imshow(gt[0])
  plt.show()

# Set Parameters (including data source, training variables, destination, etc.)

## **Paths for training, predictions and results**

**`train_source:`:** This is the path to your folders containing the Training_source (noisy images). To find the path of the folder containing your datasets, go to your Files on the left of the notebook, navigate to the folder containing your files and copy the path by right-clicking on the folder, **Copy path** and pasting it into the right box below.

**`model_name`:** Use only my_model -style, not my-model (Use "_" not "-"). Do not use spaces in the name. Do not re-use the name of an existing model (saved in the same folder), otherwise it will be overwritten.

**`model_path`**: Enter the path where your model will be saved once trained (for instance your result folder).


## **Training parameters**

**`num_epochs`:** Input how many epochs (rounds) the network will be trained. Preliminary results can already be observed after a few (10-30) epochs, but a full training should run for 100-200 epochs. Evaluate the performance after training (see 5.). **Default value: 30**

**`side_length`:** Noise2Void divides the image into patches for training. Input the size of the patches (length of a side). The value should be smaller than the dimensions of the image and divisible by 8. **Default value: 100**

**If you get an Out of memory (OOM) error during the training, manually decrease the patch_size values until the OOM error disappear.**

**`batch_size:`** This parameter defines the number of patches seen in each training step. Noise2Void requires a large batch size for stable training. Reduce this parameter if your GPU runs out of memory. **Default value: 1**

**`num_steps`:** Define the number of training steps by epoch. By default this parameter is calculated so that each image / patch is seen at least once per epoch. **Default value: Number of patch / batch_size**

**`perc_validation`:**  Input the percentage of your training dataset you want to use to validate the network during the training. **Default value: 10** 

**`init_learn_rate`:** Input the initial value to be used as learning rate. **Default value: 0.0004**

**`perc_hotPixels:`** Percent of output pixels to designate as targets and *heat* for training **Default value: 0.198**


In [3]:
train_source = '/n/groups/htem/ESRF_id16a/tomo_ML/ReducedAnglesXray/CARE/mCTX/450p_stacks/mCTX_17keV_30nm_512c_first256.zarr' #EXPECTS TIFF VOLUME
#TODO: MAKE ABLE TO HANDLE TIFF VOLUME OR STACK

data_name = 'mCTX_17keV_30nm_512c_first256'
data_path = '/n/groups/htem/ESRF_id16a/tomo_ML/ReducedAnglesXray/CARE/mCTX/450p_stacks/'
data_format = 'zarr'

model_name = 'noise2gun_mCTX30nm_450p'
model_path = ''
voxel_size=[30, 30, 30] # set for each dataset (may be able to get from zarr)

side_length = 12 # in voxels for prediction (i.e. network output) - actual used ROI for network input will be bigger for valid padding
unet_depth = 4 # number of layers in unet
downsample_factor = 2
conv_padding = 'valid'
num_fmaps = 12
fmap_inc_factor=5
perc_hotPixels = 0.198
constant_upsample=True

num_epochs = 100
batch_size = 1
num_steps = 100
perc_validation = 10
init_learn_rate = 0.0004

### Make sure data source is a **zarr** 

# Build Gunpowder Pipeline for Training

### Elements are:

- Data Source
- *(optional) Normalization*
- Random Patch Grab
- Pixel Heating (select and mutate *hotPixels*, i.e. training targets, and keep masks)
- Simple Augmentation (rotations/reflections)
- Stacking
- Training


In [23]:
# declare arrays to use in the pipeline
raw = gp.ArrayKey('RAW') # raw data
hot = gp.ArrayKey('HOT') # data with random pixels heated
mask = gp.ArrayKey('MASK') # data with random pixels heated
prediction = gp.ArrayKey('PREDICTION') # prediction of denoised data

source = gp.ZarrSource(    # add the data source
    train_source,  # the zarr container
    {raw: 'volumes/train'},  # which dataset to associate to the array key
    {raw: gp.ArraySpec(interpolatable=True)}  # meta-information
)

# add normalization
# normalize = gp.Normalize(raw)

# add a RandomLocation node to the pipeline to randomly select a sample
random_location = gp.RandomLocation()

# add transpositions/reflections
simple_augment = gp.SimpleAugment()

# stack for batches
stack = gp.Stack(batch_size)

# add pixel heater
boilerPlate = BoilerPlate(raw, mask, hot, plate_size=side_length, perc_hotPixels=perc_hotPixels, ndims=3)

# prepare tensors for UNet
unsqueeze_1 = gp.Unsqueeze([hot])
unsqueeze_2 = gp.Unsqueeze([hot])

# define our network model for training
unet = UNet(
  in_channels=1,
  num_fmaps=num_fmaps,
  fmap_inc_factor=fmap_inc_factor,
  downsample_factors=[(downsample_factor,)*3,] * (unet_depth - 1),
  padding=conv_padding,
  constant_upsample=constant_upsample,
  voxel_size=voxel_size # set for each dataset
  )

model = torch.nn.Sequential(
  unet,
  ConvPass(num_fmaps, 1, [(1, 1, 1)], activation=None),
  torch.nn.Sigmoid())

# pick loss function
loss = loser.MaskedMSELoss()

# pick optimizer
optimizer = torch.optim.Adam(model.parameters())

# create a train node using our model, loss, and optimizer
train = gp.torch.Train(
  model,
  loss,
  optimizer,
  inputs = {
    'input': hot
  },
  loss_inputs = {
    'src': prediction,
    'mask': mask,
    'target': raw
  },
  outputs = {
    0: prediction
  },
  log_dir='./tensorboard/'
  )

# figure out proper ROI padding for context
conv_passes = 2 # set by default in unet
kernel_size = 3 # set by default in unet
context_side_length = 2 * np.sum([(conv_passes * (kernel_size - 1)) * (2 ** scale) for scale in np.arange(unet_depth - 1)]) + (conv_passes * (kernel_size - 1)) * (2 ** (unet_depth - 1)) + side_length

# create request
request = gp.BatchRequest()
request[raw] = gp.Roi(tuple(0*np.array(voxel_size)), tuple(context_side_length*np.array(voxel_size)))
request[mask] = gp.Roi(tuple(0*np.array(voxel_size)), tuple(context_side_length*np.array(voxel_size)))
request[hot] = gp.Roi(tuple(0*np.array(voxel_size)), tuple(context_side_length*np.array(voxel_size)))
request[prediction] = gp.Roi(tuple(0*np.array(voxel_size)), tuple(side_length*np.array(voxel_size)))

# assemble pipeline
pipeline = (source +
            #normalize + 
            random_location +
            simple_augment + 
            boilerPlate +
            unsqueeze_1 + 
            #unsqueeze_2 +
            stack + 
            train)



In [17]:
unet = UNet(
  in_channels=1,
  num_fmaps=10,
  fmap_inc_factor=5,
  downsample_factors=[(2, 2, 2), (2, 2, 2), (2, 2, 2)],
  padding='valid',
  constant_upsample=True,
  voxel_size=[30, 30, 30] # set for each dataset
  )

level = 0
downsample_factors=[(downsample_factor,)*3,] * (unet_depth - 1)
num_levels = len(downsample_factors) + 1
kernel_size_down = [[(3, 3, 3), (3, 3, 3)]]*num_levels

test = ConvPass(
                1
                if level == 0
                else num_fmaps*fmap_inc_factor**(level - 1),
                num_fmaps*fmap_inc_factor**level,
                kernel_size_down[level],
                activation='ReLU',
                padding='valid')

unet.forward(torch.tensor(np.random.rand(12, 1, 100,100,100)))
#test.conv_pass(torch.tensor(np.random.rand(1, 1, 100,100,100)))

RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #3 'mat1' in call to _th_addmm_

# Train

In [22]:
%autoreload
from boilerPlate import BoilerPlate
import loser

In [24]:
with gp.build(pipeline):
  batch = pipeline.request_batch(request)

imshow(batch[raw].data, batch[hot].data, batch[prediction].data)

INFO:gunpowder.torch.nodes.train:Starting training from scratch
INFO:gunpowder.torch.nodes.train:Using device cuda


PipelineRequestError: Exception in pipeline:
ZarrSource[/n/groups/htem/ESRF_id16a/tomo_ML/ReducedAnglesXray/CARE/mCTX/450p_stacks/mCTX_17keV_30nm_512c_first256.zarr] -> RandomLocation -> SimpleAugment -> BoilerPlate -> Unsqueeze -> Stack -> Train
while trying to process request

	RAW: ROI: [0:3000, 0:3000, 0:3000] (3000, 3000, 3000), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False
	MASK: ROI: [0:3000, 0:3000, 0:3000] (3000, 3000, 3000), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False
	HOT: ROI: [0:3000, 0:3000, 0:3000] (3000, 3000, 3000), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False
	PREDICTION: ROI: [0:360, 0:360, 0:360] (360, 360, 360), voxel size: None, interpolatable: None, non-spatial: False, dtype: None, placeholder: False


In [21]:
x = np.zeros((3,3,3))
for i, y in enumerate(x):
    for j, z in enumerate(y):
        if i==j:
            x[i,j,:] = i+j
y

array([[0., 0., 0.],
       [0., 0., 0.],
       [4., 4., 4.]])

# Examine Results

# Build Prediction Pipeline