In [1]:
import gunpowder as gp
from funlib.persistence import open_ds #may need to change to ome later

In [2]:
# declare arrays to use in the pipeline
raw = gp.ArrayKey('RAW')
labels = gp.ArrayKey('LABELS')
mask = gp.ArrayKey('MASK')
gt_affs = gp.ArrayKey('GT_AFFS')
gt_affs_mask = gp.ArrayKey('GT_AFFS_MASK')
prediction = gp.ArrayKey('PREDICT')
aff_scale= gp.ArrayKey("SCALE")

In [3]:
raw_array = open_ds("/mnt/efs/aimbl_2025/student_data/S-EK/EK_transfer/GT_movie1/crop_1.zarr/raw")
labels_array = open_ds("/mnt/efs/aimbl_2025/student_data/S-EK/EK_transfer/GT_movie1/crop_1.zarr/labels")
mask_array = open_ds("/mnt/efs/aimbl_2025/student_data/S-EK/EK_transfer/GT_movie1/crop_1.zarr/mask")

In [4]:
print(raw_array.voxel_size)
print(labels_array.voxel_size)
print(mask_array.voxel_size)

(1000, 170, 170)
(1000, 170, 170)
(1000, 170, 170)


In [5]:
# create "pipeline" consisting only of a data source
source_raw = gp.ArraySource( key= raw, array= raw_array, interpolatable= True)
source_labels = gp.ArraySource( key= labels, array= labels_array, interpolatable= False)
source_mask = gp.ArraySource( key= mask, array= mask_array, interpolatable= False)

In [6]:
random_location = gp.RandomLocation()

In [7]:
snapshot = gp.Snapshot( every = 100,
    dataset_names={raw:"raw", labels:"labels", mask:"mask", gt_affs: "affs", gt_affs_mask: "affS_mask",prediction:"prediction", aff_scale:"scale"},
    dataset_dtypes={gt_affs: 'float32'}
)

In [8]:
normalization = gp.Normalize(array= raw) #looks at maximal possible value of your data type and divides by that 

In [9]:
deform = gp.DeformAugment(
    gp.Coordinate((5100, 5100)),
    (340,340),
    spatial_dims=2,
    #graph_raster_voxel_size=raw_array.voxel_size[1:]
)

In [10]:
intensity_augment = gp.IntensityAugment(array=raw, scale_min=0.9, scale_max=1.1, shift_min=-0.1, shift_max=0.1, z_section_wise=False, clip=True)
noise = gp.NoiseAugment(array=raw, mode='Gaussian', clip=True)

In [11]:
neighborhood = [
    (1, 0 ,0),
    (0, 1, 0),
    (0, 0, 1),
    (2, 0, 0),
    (0, 5, 0),
    (0, 0, 5)
]
add_affs = gp.AddAffinities(
    affinity_neighborhood=neighborhood,
    labels=labels,
    affinities=gt_affs,
    unlabelled=mask,
    affinities_mask=gt_affs_mask
)

In [12]:
from boundary_issues.model import UNet
import torch
import numpy as np

model = torch.nn.Sequential(UNet(
    in_channels=1,
    num_fmaps=64,
    fmap_inc_factor=3,
    downsample_factors=[
        [1, 2, 2],  
        [1, 2, 2]
    ],
    kernel_size_down=[[(1, 3, 3), (1, 3, 3)],[(1, 3, 3), (1, 3, 3)],[(3, 3, 3), (3, 3, 3)]],
    kernel_size_up=[[(3, 3, 3), (3, 3, 3),(3, 3, 3)],[(3, 3, 3), (3, 3, 3), (3, 3, 3)]],
    padding=("same", "valid", "valid"),
    voxel_size=(1000, 170, 170),
    fov=(1, 1, 1),  
    num_fmaps_out=None,
    constant_upsample=True
), torch.nn.Conv3d(in_channels = 64, out_channels= 6, kernel_size=(1,1,1)),torch.nn.Sigmoid())

In [13]:
import sys
sys.path.append("src")   # add src/ to Python path
from boundary_issues.loss import WeightedLoss

loss_fn = WeightedLoss()

In [14]:
balanced_labels=gp.BalanceLabels(gt_affs,scales=aff_scale, mask= gt_affs_mask)

In [15]:


train = gp.torch.Train(
    model = model,
    loss = loss_fn,
    optimizer = torch.optim.Adam(model.parameters()),
    inputs = {0:raw},
    loss_inputs = {0: prediction, 1: gt_affs, 2: gt_affs_mask},
    outputs = {0: prediction},
    save_every = 200,
    log_dir = "training_logs"
)

In [None]:
pipeline = (
    source_raw,
    source_labels,
    source_mask
) + gp.MergeProvider() 

pipeline += random_location 
pipeline += normalization 
pipeline += gp.SimpleAugment(transpose_only=[1,2]) 
pipeline += deform 
pipeline += intensity_augment
pipeline += noise
pipeline += add_affs
pipeline += balanced_labels 
pipeline += gp.Unsqueeze([raw],0)
pipeline += gp.Stack(1) 
pipeline += train 
pipeline += gp.Squeeze([raw,labels,mask, gt_affs,gt_affs_mask,prediction,aff_scale],0)
pipeline += snapshot


In [17]:
print (pipeline)

(ArraySource, ArraySource, ArraySource) -> MergeProvider -> RandomLocation -> Normalize -> SimpleAugment -> DeformAugment -> IntensityAugment -> NoiseAugment -> AddAffinities -> BalanceLabels -> Unsqueeze -> Stack -> Train -> Squeeze -> Snapshot


In [18]:
# formulate a request for "raw"

input_size = gp.Coordinate((3, 256, 256)) * raw_array.voxel_size
output_size = gp.Coordinate((3, 210, 210)) * raw_array.voxel_size

request = gp.BatchRequest()

request.add(raw, input_size)
request.add(labels, output_size)
request.add(mask, output_size)
request.add(gt_affs, output_size)
request.add(gt_affs_mask, output_size)
request.add(prediction, output_size)
request.add(aff_scale,output_size)

# request[raw] = gp.Roi((5000, 119000, 119000), (10000, 85000, 85000)) #always in world units
# request[labels] = gp.Roi((5000, 119000, 119000), (10000, 85000, 85000)) #always in world units
# request[mask] = gp.Roi((5000, 119000, 119000), (10000, 85000, 85000)) #always in world units
# request[gt_affs] = gp.Roi((5000, 119000, 119000), (10000, 85000, 85000)) #always in world units
# request[gt_affs_mask] = gp.Roi((5000, 119000, 119000), (10000, 85000, 85000)) #always in world units

In [19]:
# build the pipeline...
with gp.build(pipeline):
  for i in range(602):
    # ...and request a batch
    print(".", end="")
    batch = pipeline.request_batch(request)
    if i%50==0:
      print(f"Iteration {i}: {batch.loss}")
    

.Iteration 0: 0.0
..................................................Iteration 50: 0.0
..................................................Iteration 100: 0.18731670081615448
..................................................Iteration 150: 0.2824770510196686
..................................................Iteration 200: 0.28001946210861206
..................................................Iteration 250: 0.18596245348453522
..................................................Iteration 300: 0.20567458868026733
..................................................Iteration 350: 0.19245150685310364
..................................................Iteration 400: 0.20457245409488678
..................................................Iteration 450: 0.6000000238418579
..................................................Iteration 500: 0.47669631242752075
..................................................Iteration 550: 0.0
..................................................Iteration 600: 0.239317208528518

In [20]:
print(model)

Sequential(
  (0): UNet(
    (l_conv): ModuleList(
      (0): ConvPass(
        (conv_pass): Sequential(
          (0): Conv3d(1, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1))
          (1): ReLU()
          (2): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1))
          (3): ReLU()
        )
      )
      (1): ConvPass(
        (conv_pass): Sequential(
          (0): Conv3d(64, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1))
          (1): ReLU()
          (2): Conv3d(192, 192, kernel_size=(1, 3, 3), stride=(1, 1, 1))
          (3): ReLU()
        )
      )
      (2): ConvPass(
        (conv_pass): Sequential(
          (0): Conv3d(192, 576, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0))
          (1): ReLU()
          (2): Conv3d(576, 576, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 0, 0))
          (3): ReLU()
        )
      )
    )
    (l_down): ModuleList(
      (0-1): 2 x Downsample(
        (down): MaxPool3d(kernel_size=[1, 2, 2], stride=[1, 2, 2], pad