In [1]:
import gunpowder as gp
from funlib.learn.torch.models import UNet, ConvPass

import torch

import logging
import math

In [2]:
logging.basicConfig(level=logging.INFO)

# Constants
num_iterations = 1_000
snapshot_every = 500
checkpoint_every = 500

voxel_size = gp.Coordinate((4, 4, 4))
input_shape = gp.Coordinate((76, 76, 76))
input_size = input_shape * voxel_size
output_shape = gp.Coordinate((36, 36, 36))
output_size = output_shape * voxel_size
label_proportions = [
    (0, 0.9419152666666667),
    (1, 0.0516876),
    (2, 0.000764),
    (3, 0.005633133333333333),
]
batch_size = 10

In [3]:
# initialize model
unet = UNet(
    in_channels=1,
    num_fmaps=16,
    fmap_inc_factor=5,
    downsample_factors=[(2, 2, 2), (2, 2, 2)],
    activation="ReLU",
    voxel_size=voxel_size,
    num_heads=1,
    constant_upsample=True,
)
logits = ConvPass(
    in_channels=16,
    out_channels=4,
    kernel_sizes=[(1, 1, 1), (1, 1, 1)],
    activation=None,
)

semantic_model = torch.nn.Sequential(unet, logits)

semantic_model_loss = torch.nn.CrossEntropyLoss(
    weight=torch.tensor([1 / p for l, p in label_proportions])
)

semantic_model_optimizer = torch.optim.Adam(semantic_model.parameters(), lr=1e-5)

In [4]:
# Gunpowder array keys
raw = gp.ArrayKey("RAW")
predictions = gp.ArrayKey("PREDICTIONS")
prediction_gradients = gp.ArrayKey("PREDICTION_GRADIENTS")
labels = gp.ArrayKey("LABELS")

sources = tuple(
    [
        # Create a ZarrSource and RandomLocation node to get data
        for volume in ["A", "B", "C"]
    ]
)
pipeline = (
    sources
    + gp.RandomProvider()
)

request = gp.BatchRequest()
# request what you need from upstream

In [5]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

# main loop
if __name__ == "__main__":
    with gp.build(pipeline):
        for i in range(num_iterations):
            batch = pipeline.request_batch(request)

INFO:gunpowder.torch.nodes.train:Starting training from scratch
INFO:gunpowder.torch.nodes.train:Using device cuda
INFO:gunpowder.nodes.precache:starting new set of workers (20, cache size 50)...
INFO:gunpowder.nodes.generic_train:Train process: iteration=1 loss=1.320524 time=12.113995
INFO:gunpowder.nodes.snapshot:saving to snapshots/1.hdf
INFO:gunpowder.nodes.generic_train:Train process: iteration=2 loss=1.298346 time=1.801630
INFO:gunpowder.nodes.generic_train:Train process: iteration=3 loss=1.292961 time=1.799284
INFO:gunpowder.producer_pool:terminating workers...
INFO:gunpowder.producer_pool:joining workers...
INFO:gunpowder.producer_pool:done


KeyboardInterrupt: 