In [1]:
# Import modules
from __future__ import print_function
import numpy as np
import tensorflow as tf

import matplotlib.pyplot as plt
tf.enable_eager_execution()

# Notebook auto reloads code.
%load_ext autoreload
%autoreload 2

# NeuroTorch Tutorial

**NeuroTorch** is a framework for reconstructing neuronal morphology from
optical microscopy images. It interfaces PyTorch with different
automated neuron tracing algorithms for fast, accurate, scalable
neuronal reconstructions. It uses deep learning to generate an initial
segmentation of neurons in optical microscopy images. This
segmentation is then traced using various automated neuron tracing
algorithms to convert the segmentation into an SWC file—the most
common neuronal morphology file format. NeuroTorch is designed with
scalability in mind and can handle teravoxel-sized images.

This IPython notebook will outline a brief tutorial for using NeuroTorch
to train and predict on image volume datasets.

## Creating image datasets

One of NeuroTorch’s key features is its dynamic approach to volumetric datasets, which allows it to handle teravoxel-sized images without worrying about memory concerns and efficiency. Everything is loaded just-in-time based on when it is needed or expected to be needed. To load an image dataset, we need
to specify the voxel coordinates of each image file as shown in files `inputs_spec.json` and `labels_spec.json`.

### `inputs_spec.json`

```json
[
    {
	"filename" : "inputs.tif",
	"bounding_box" : [[0, 0, 0], [1024, 512, 50]]
    },
    {
	"filename" : "inputs.tif",
	"bounding_box" : [[0, 0, 50], [1024, 512, 100]]
    }
]

```

### `labels_spec.json`

```json
[
    {
	"filename" : "labels.tif",
	"bounding_box" : [[0, 0, 0], [1024, 512, 50]]
    },
    {
	"filename" : "labels.tif",
	"bounding_box" : [[0, 0, 50], [1024, 512, 100]]
    }
]
```

## Loading image datasets

Now that the image datasets for the inputs and labels have been specified,
these datasets can be loaded with NeuroTorch.

In [2]:
from neurotorch.datasets.specification import JsonSpec
import os

IMAGE_PATH = '/data/gornet/annotated-neurons'

json_spec = JsonSpec() # Initialize the JSON specification

# Create a dataset containing the inputs
inputs = json_spec.open(os.path.join(IMAGE_PATH,
                                     "input_spec.json")) 

# Create a dataset containing the labels
labels = json_spec.open(os.path.join(IMAGE_PATH,
                                     "label_spec.json"))


FileNotFoundError: [Errno 2] No such file or directory: '/data/gornet/annotated-neurons/input_spec.json'

In [4]:
from neurotorch.datasets.specification import JsonSpec
import os

IMAGE_PATH = './tests/images/'

json_spec = JsonSpec() # Initialize the JSON specification

# Create a dataset containing the inputs
inputs = json_spec.open(os.path.join(IMAGE_PATH,
                                     "inputs_spec.json")) 

# Create a dataset containing the labels
labels = json_spec.open(os.path.join(IMAGE_PATH,
                                     "labels_spec.json"))


In [5]:
from neurotorch.datasets.dataset import AlignedVolume

volume = AlignedVolume([inputs, labels])

## Training with the image datasets

To train a neural network using these image datasets, load the 
neural network architecture and initialize a `Trainer`. To save
training checkpoints, add a `CheckpointWriter` to the `Trainer` object.
Lastly, call the `Trainer` object to run training.

In [6]:
from neurotorch.core.trainer import Trainer
from neurotorch.training.logging import LossWriter, ImageWriter
from neurotorch.training.checkpoint import CheckpointWriter
from neurotorch.nets.RSLSTMUnet import RSUNet

net = RSUNet([4, 16], [128], (16, 64, 64)) # Initialize the U-Net architecture

# Setup the trainer
trainer = Trainer(net, volume, max_epochs=10,
                  gpu_device=1)

# Setup the trainer the add a checkpoint every 500 epochs
trainer = CheckpointWriter(trainer, checkpoint_dir='BMEN4000',
                           checkpoint_period=2000)
trainer = LossWriter(trainer, ".", "BMEN4000")
trainer = ImageWriter(trainer, ".", "BMEN4000")

trainer.run_training()


Opening /home/jamesgornet/Downloads/NeuroTorch/tests/images/labels.tif
Opening /home/jamesgornet/Downloads/NeuroTorch/tests/images/labels.tif
Opening /home/jamesgornet/Downloads/NeuroTorch/tests/images/inputs.tif
Opening /home/jamesgornet/Downloads/NeuroTorch/tests/images/inputs.tif
Iteration: 1
Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Use tf.cast instead.
Iteration: 2
Iteration: 3
Iteration: 4
Iteration: 5
Iteration: 6
Iteration: 7
Iteration: 8
Iteration: 9
Iteration: 10
Iteration: 10 Epoch 1/10  Loss: 1.9436 Accuracy: 6.51
Iteration: 11
Iteration: 12
Iteration: 13
Iteration: 14
Iteration: 15
Iteration: 16
Iteration: 17
Iteration: 18
Iteration: 19
Iteration: 20
Iteration: 20 Epoch 1/10  Loss: 0.7016 Accuracy: 5.08
Iteration: 21
Iteration: 22
Iteration: 23
Iteration: 24
Iteration: 25
Iteration: 26
Iteration: 27
Iteration: 28
Iteration: 29
Iteration: 30
Iteration: 30 Epoch 1/10  Loss: 0.6669 Accuracy: 0.41
Iteration: 31
Iteration

Iteration: 366
Iteration: 367
Iteration: 368
Iteration: 369
Iteration: 370
Iteration: 370 Epoch 1/10  Loss: 0.3655 Accuracy: 0.00
Iteration: 371
Iteration: 372
Iteration: 373
Iteration: 374
Iteration: 375
Iteration: 376
Iteration: 377
Iteration: 378
Iteration: 379
Iteration: 380
Iteration: 380 Epoch 1/10  Loss: 0.3592 Accuracy: 0.00
Iteration: 381
Iteration: 382
Iteration: 383
Iteration: 384
Iteration: 385
Iteration: 386
Iteration: 387
Iteration: 388
Iteration: 389
Iteration: 390
Iteration: 390 Epoch 1/10  Loss: 0.3532 Accuracy: 0.00
Iteration: 391
Iteration: 392
Iteration: 393
Iteration: 394
Iteration: 395
Iteration: 396
Iteration: 397
Iteration: 398
Iteration: 399
Iteration: 400
Iteration: 400 Epoch 1/10  Loss: 0.3475 Accuracy: 0.00
Iteration: 401
Iteration: 402
Iteration: 403
Iteration: 404
Iteration: 405
Iteration: 406
Iteration: 407
Iteration: 408
Iteration: 409
Iteration: 410
Iteration: 410 Epoch 1/10  Loss: 0.3420 Accuracy: 0.00
Iteration: 411
Iteration: 412
Iteration: 413
Itera

Iteration: 766
Iteration: 767
Iteration: 768
Iteration: 769
Iteration: 770
Iteration: 770 Epoch 1/10  Loss: 0.2610 Accuracy: 0.00
Iteration: 771
Iteration: 772
Iteration: 773
Iteration: 774
Iteration: 775
Iteration: 776
Iteration: 777
Iteration: 778
Iteration: 779
Iteration: 780
Iteration: 780 Epoch 1/10  Loss: 0.2606 Accuracy: 0.00
Iteration: 781
Iteration: 782
Iteration: 783
Iteration: 784
Iteration: 785
Iteration: 786
Iteration: 787
Iteration: 788
Iteration: 789
Iteration: 790
Iteration: 790 Epoch 1/10  Loss: 0.2602 Accuracy: 0.00
Iteration: 791
Iteration: 792
Iteration: 793
Iteration: 794
Iteration: 795
Iteration: 796
Iteration: 797
Iteration: 798
Iteration: 799
Iteration: 800
Iteration: 800 Epoch 1/10  Loss: 0.2599 Accuracy: 0.00
Iteration: 801
Iteration: 802
Iteration: 803
Iteration: 804
Iteration: 805
Iteration: 806
Iteration: 807
Iteration: 808
Iteration: 809
Iteration: 810
Iteration: 810 Epoch 1/10  Loss: 0.2596 Accuracy: 0.00
Iteration: 811
Iteration: 812
Iteration: 813
Itera

Iteration: 1158
Iteration: 1159
Iteration: 1160
Iteration: 1160 Epoch 1/10  Loss: 0.2545 Accuracy: 0.00
Iteration: 1161
Iteration: 1162
Iteration: 1163
Iteration: 1164
Iteration: 1165
Iteration: 1166
Iteration: 1167
Iteration: 1168
Iteration: 1169
Iteration: 1170
Iteration: 1170 Epoch 1/10  Loss: 0.2544 Accuracy: 0.00
Iteration: 1171
Iteration: 1172
Iteration: 1173
Iteration: 1174
Iteration: 1175
Iteration: 1176
Iteration: 1177
Iteration: 1178
Iteration: 1179
Iteration: 1180
Iteration: 1180 Epoch 1/10  Loss: 0.2543 Accuracy: 0.00
Iteration: 1181
Iteration: 1182
Iteration: 1183
Iteration: 1184
Iteration: 1185
Iteration: 1186
Iteration: 1187
Iteration: 1188
Iteration: 1189
Iteration: 1190
Iteration: 1190 Epoch 1/10  Loss: 0.2542 Accuracy: 0.00
Iteration: 1191
Iteration: 1192
Iteration: 1193
Iteration: 1194
Iteration: 1195
Iteration: 1196
Iteration: 1197
Iteration: 1198
Iteration: 1199
Iteration: 1200
Iteration: 1200 Epoch 1/10  Loss: 0.2541 Accuracy: 0.00
Iteration: 1201
Iteration: 1202


Iteration: 1538
Iteration: 1539
Iteration: 1540
Iteration: 1540 Epoch 1/10  Loss: 0.2518 Accuracy: 0.00
Iteration: 1541
Iteration: 1542
Iteration: 1543
Iteration: 1544
Iteration: 1545
Iteration: 1546
Iteration: 1547
Iteration: 1548
Iteration: 1549
Iteration: 1550
Iteration: 1550 Epoch 1/10  Loss: 0.2518 Accuracy: 0.00
Iteration: 1551
Iteration: 1552
Iteration: 1553
Iteration: 1554
Iteration: 1555
Iteration: 1556
Iteration: 1557
Iteration: 1558
Iteration: 1559
Iteration: 1560
Iteration: 1560 Epoch 1/10  Loss: 0.2517 Accuracy: 0.00
Iteration: 1561
Iteration: 1562
Iteration: 1563
Iteration: 1564
Iteration: 1565
Iteration: 1566
Iteration: 1567
Iteration: 1568
Iteration: 1569
Iteration: 1570
Iteration: 1570 Epoch 1/10  Loss: 0.2517 Accuracy: 0.00
Iteration: 1571
Iteration: 1572
Iteration: 1573
Iteration: 1574
Iteration: 1575
Iteration: 1576
Iteration: 1577
Iteration: 1578
Iteration: 1579
Iteration: 1580
Iteration: 1580 Epoch 1/10  Loss: 0.2516 Accuracy: 0.00
Iteration: 1581
Iteration: 1582


Iteration: 1920 Epoch 1/10  Loss: 0.2502 Accuracy: 0.00
Iteration: 1921
Iteration: 1922
Iteration: 1923
Iteration: 1924
Iteration: 1925
Iteration: 1926
Iteration: 1927
Iteration: 1928
Iteration: 1929
Iteration: 1930
Iteration: 1930 Epoch 1/10  Loss: 0.2501 Accuracy: 0.00
Iteration: 1931
Iteration: 1932
Iteration: 1933
Iteration: 1934
Iteration: 1935
Iteration: 1936
Iteration: 1937
Iteration: 1938
Iteration: 1939
Iteration: 1940
Iteration: 1940 Epoch 1/10  Loss: 0.2500 Accuracy: 0.00
Iteration: 1941
Iteration: 1942
Iteration: 1943
Iteration: 1944
Iteration: 1945
Iteration: 1946
Iteration: 1947
Iteration: 1948
Iteration: 1949
Iteration: 1950
Iteration: 1950 Epoch 1/10  Loss: 0.2499 Accuracy: 0.00
Iteration: 1951
Iteration: 1952
Iteration: 1953
Iteration: 1954
Iteration: 1955
Iteration: 1956
Iteration: 1957
Iteration: 1958
Iteration: 1959
Iteration: 1960
Iteration: 1960 Epoch 1/10  Loss: 0.2499 Accuracy: 0.00
Iteration: 1961
Iteration: 1962
Iteration: 1963
Iteration: 1964
Iteration: 1965


Iteration: 2300 Epoch 2/10  Loss: 0.2488 Accuracy: 0.00
Iteration: 2301
Iteration: 2302
Iteration: 2303
Iteration: 2304
Iteration: 2305
Iteration: 2306
Iteration: 2307
Iteration: 2308
Iteration: 2309
Iteration: 2310
Iteration: 2310 Epoch 2/10  Loss: 0.2487 Accuracy: 0.00
Iteration: 2311
Iteration: 2312
Iteration: 2313
Iteration: 2314
Iteration: 2315
Iteration: 2316
Iteration: 2317
Iteration: 2318
Iteration: 2319
Iteration: 2320
Iteration: 2320 Epoch 2/10  Loss: 0.2487 Accuracy: 0.00
Iteration: 2321
Iteration: 2322
Iteration: 2323
Iteration: 2324
Iteration: 2325
Iteration: 2326
Iteration: 2327
Iteration: 2328
Iteration: 2329
Iteration: 2330
Iteration: 2330 Epoch 2/10  Loss: 0.2486 Accuracy: 0.00
Iteration: 2331
Iteration: 2332
Iteration: 2333
Iteration: 2334
Iteration: 2335
Iteration: 2336
Iteration: 2337
Iteration: 2338
Iteration: 2339
Iteration: 2340
Iteration: 2340 Epoch 2/10  Loss: 0.2486 Accuracy: 0.00
Iteration: 2341
Iteration: 2342
Iteration: 2343
Iteration: 2344
Iteration: 2345


Iteration: 2680 Epoch 2/10  Loss: 0.2477 Accuracy: 0.00
Iteration: 2681
Iteration: 2682
Iteration: 2683
Iteration: 2684
Iteration: 2685
Iteration: 2686
Iteration: 2687
Iteration: 2688
Iteration: 2689
Iteration: 2690
Iteration: 2690 Epoch 2/10  Loss: 0.2477 Accuracy: 0.00
Iteration: 2691
Iteration: 2692
Iteration: 2693
Iteration: 2694
Iteration: 2695
Iteration: 2696
Iteration: 2697
Iteration: 2698
Iteration: 2699
Iteration: 2700
Iteration: 2700 Epoch 2/10  Loss: 0.2476 Accuracy: 0.00
Iteration: 2701
Iteration: 2702
Iteration: 2703
Iteration: 2704
Iteration: 2705
Iteration: 2706
Iteration: 2707
Iteration: 2708
Iteration: 2709
Iteration: 2710
Iteration: 2710 Epoch 2/10  Loss: 0.2476 Accuracy: 0.00
Iteration: 2711
Iteration: 2712
Iteration: 2713
Iteration: 2714
Iteration: 2715
Iteration: 2716
Iteration: 2717
Iteration: 2718
Iteration: 2719
Iteration: 2720
Iteration: 2720 Epoch 2/10  Loss: 0.2476 Accuracy: 0.00
Iteration: 2721
Iteration: 2722
Iteration: 2723
Iteration: 2724
Iteration: 2725


Iteration: 3060 Epoch 2/10  Loss: 0.2473 Accuracy: 0.00
Iteration: 3061
Iteration: 3062
Iteration: 3063
Iteration: 3064
Iteration: 3065
Iteration: 3066
Iteration: 3067
Iteration: 3068
Iteration: 3069
Iteration: 3070
Iteration: 3070 Epoch 2/10  Loss: 0.2472 Accuracy: 0.00
Iteration: 3071
Iteration: 3072
Iteration: 3073
Iteration: 3074
Iteration: 3075
Iteration: 3076
Iteration: 3077
Iteration: 3078
Iteration: 3079
Iteration: 3080
Iteration: 3080 Epoch 2/10  Loss: 0.2470 Accuracy: 0.00
Iteration: 3081
Iteration: 3082
Iteration: 3083
Iteration: 3084
Iteration: 3085
Iteration: 3086
Iteration: 3087
Iteration: 3088
Iteration: 3089
Iteration: 3090
Iteration: 3090 Epoch 2/10  Loss: 0.2470 Accuracy: 0.00
Iteration: 3091
Iteration: 3092
Iteration: 3093
Iteration: 3094
Iteration: 3095
Iteration: 3096
Iteration: 3097
Iteration: 3098
Iteration: 3099
Iteration: 3100
Iteration: 3100 Epoch 2/10  Loss: 0.2470 Accuracy: 0.00
Iteration: 3101
Iteration: 3102
Iteration: 3103
Iteration: 3104
Iteration: 3105


KeyboardInterrupt: 

## Predicting using NeuroTorch

Once training has completed, we can use the training checkpoints
to predict on image datasets. We first have to 
load the neural network architecture and image volume.
We then have to initialize a `Predictor` object and an output volume.
Once these have been specified, we can begin prediction.

In [23]:
from neurotorch.nets.RSLSTMUnet import RSUNet
from neurotorch.core.predictor import Predictor
from neurotorch.datasets.filetypes import TiffVolume
from neurotorch.datasets.dataset import Array
from neurotorch.datasets.datatypes import (BoundingBox, Vector)
import numpy as np
import tifffile as tif
import os

IMAGE_PATH = 'tests/images/'

net = RSUNet([4, 16], [128], (8, 32, 32)) # Initialize the U-Net architecture

checkpoint = './checkpoints/best.ckpt-39' # Specify the checkpoint path

with TiffVolume(os.path.join(IMAGE_PATH,
                             "inputs.tif"),
                BoundingBox(Vector(0, 0, 0),
                            Vector(1024, 512, 50)),
                iteration_size=BoundingBox(Vector(0, 0, 0),
                                          Vector(32, 32, 8)),
                stride=Vector(4, 16, 16)) as inputs:
    predictor = Predictor(net, checkpoint, gpu_device=0)

    output_volume = Array(np.zeros(inputs.getBoundingBox()
                                   .getNumpyDim(), dtype=np.float32))

    predictor.run(inputs, output_volume, batch_size=16)

    tif.imsave("test_prediction.tif",
               output_volume.getArray().astype(np.float32))


Opening tests/images/inputs.tif


ValueError: could not broadcast input array from shape (8,32,32) into shape (2,32,32)

In [None]:
with TiffVolume(os.path.join(IMAGE_PATH,
                             "inputs.tif"),
                BoundingBox(Vector(0, 0, 0),
                            Vector(1024, 512, 50))) as inputs:
    print(inputs.getArray().getArray().dtype)
    print(type(inputs[1]))
    plt.imshow(inputs.get(
    BoundingBox(Vector(0, 0, 0),
                Vector(256, 256, 40))).getArray()[25],
           cmap='gray'
          )

## Displaying the prediction

Predictions are output in logits form. To map this to a
probability distribution, we need to apply a sigmoid function
to the prediction. We can then evaluate the prediction and 
ground-truth.

In [None]:
# Apply sigmoid function
probability_map = 1/(1+np.exp(-output_volume.getArray()))

# Plot prediction and ground-truth
plt.subplot(2, 1, 1)
plt.title('Prediction')
plt.imshow(np.amax(output_volume.getArray(), axis=0))
plt.axis('off')

plt.subplot(2, 1, 2)
plt.title('Ground-Truth')
plt.imshow(labels.get(
    BoundingBox(Vector(0, 0, 0),
                Vector(1024, 512, 50))).getArray()[25],
           cmap='gray'
          )
plt.axis('off')

plt.show()