In [1]:
import gunpowder as gp
from funlib.learn.torch.models import UNet

import torch

import logging
import math

In [2]:
logging.basicConfig(level=logging.INFO)

# Constants
num_iterations = 1_000
snapshot_every = 500
checkpoint_every = 500

voxel_size = gp.Coordinate((4, 4, 4))
input_shape = gp.Coordinate((76, 76, 76))
input_size = input_shape * voxel_size
output_shape = gp.Coordinate((36, 36, 36))
output_size = output_shape * voxel_size
label_proportions = [
    (0, 0.9419152666666667),
    (1, 0.0516876),
    (2, 0.000764),
    (3, 0.005633133333333333),
]
batch_size = 10

In [3]:
# initialize model
unet = UNet(
    in_channels=1,
    num_fmaps=16,
    fmap_inc_factor=5,
    downsample_factors=[(2, 2, 2), (2, 2, 2)],
    activation="ReLU",
    voxel_size=voxel_size,
    num_heads=1,
    constant_upsample=True,
)
logits = torch.nn.Conv3d(16, 4, (1,) * 3)

# build your model from unet and logits using torch.nn.Sequential

semantic_model_loss = torch.nn.CrossEntropyLoss(
    weight=torch.tensor([1 / p for l, p in label_proportions])
)

semantic_model_optimizer = torch.optim.Adam(semantic_model.parameters(), lr=1e-5)

NameError: name 'semantic_model' is not defined

In [4]:
# Gunpowder array keys
raw = gp.ArrayKey("RAW")
predictions = gp.ArrayKey("PREDICTIONS")
prediction_gradients = gp.ArrayKey("PREDICTION_GRADIENTS")
labels = gp.ArrayKey("LABELS")

sources = tuple(
    [
        gp.ZarrSource(
            # Add the filename
            datasets={
                raw: f"{volume}/raw",
                labels: f"{volume}/intracellular_semantic",
            },
            array_specs={
                raw: gp.ArraySpec(interpolatable=True),
                labels: gp.ArraySpec(interpolatable=False),
            }
        )
        # You need to select a random location here!
        # Since requests come up from the bottom of the pipeline,
        # at this point it is clear to gunpowder which volume is
        # being requested, and thus gunpowder knows how to pick
        # a valid location.
        for volume in ["A", "B", "C"]
    ]
)
pipeline = (
    sources
    + gp.RandomProvider()
    + gp.Normalize(raw)
    + gp.ElasticAugment(
        control_point_spacing=(10, 10, 10),
        jitter_sigma=(0.1, 0.1, 0.1),
        rotation_interval=(0, math.pi / 2)
    )
    # add your simple augment here
    # add your intensity augment here
    + gp.PreCache()
    # shapes:
    # raw: [d, h, w]
    # labels: [d, h, w]
    + gp.Stack(batch_size)
    # raw: [s, d, h, w] where s is the batch size
    # labels: [s, d, h, w]
    
    # Add an Unsqueeze node here!
    
    # raw: [s, 1, d, h, w]
    # labels: [s, 1, d, h, w]
    + gp.torch.Train(
        model=semantic_model,
        loss=semantic_model_loss,
        optimizer=semantic_model_optimizer,
        inputs={"input": raw},
        outputs={0: predictions},
        # Define the loss inputs here!
        gradients={0: prediction_gradients},
        save_every=checkpoint_every,
    )
    + gp.Squeeze([raw], axis=1)
    + gp.Snapshot(
        output_filename="{iteration}.hdf",
        dataset_names={
            raw: "volumes/raw",
            labels: "volumes/labels",
            predictions: "volumes/predictions",
            prediction_gradients: "volumes/prediction_gradients",
        },
        every=snapshot_every,
    )
)

request = gp.BatchRequest()
request.add(raw, input_size)
# add the labels to the request here!
request.add(predictions, output_size)
request.add(prediction_gradients, output_size)

In [5]:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True

# main loop
if __name__ == "__main__":
    with gp.build(pipeline):
        for i in range(num_iterations):
            batch = pipeline.request_batch(request)

INFO:gunpowder.torch.nodes.train:Starting training from scratch
INFO:gunpowder.torch.nodes.train:Using device cuda
INFO:gunpowder.nodes.precache:starting new set of workers (20, cache size 50)...
INFO:gunpowder.nodes.generic_train:Train process: iteration=1 loss=1.320524 time=12.113995
INFO:gunpowder.nodes.snapshot:saving to snapshots/1.hdf
INFO:gunpowder.nodes.generic_train:Train process: iteration=2 loss=1.298346 time=1.801630
INFO:gunpowder.nodes.generic_train:Train process: iteration=3 loss=1.292961 time=1.799284
INFO:gunpowder.producer_pool:terminating workers...
INFO:gunpowder.producer_pool:joining workers...
INFO:gunpowder.producer_pool:done


KeyboardInterrupt: 