# Dacapo

DaCapo is a framework that allows for easy configuration and execution of established machine learning techniques on arbitrarily large volumes of multi-dimensional images.

DaCapo has 4 major configurable components:
1. **dacapo.datasplits.DataSplit**

2. **dacapo.tasks.Task**

3. **dacapo.architectures.Architecture**

4. **dacapo.trainers.Trainer**

These are then combined in a single **dacapo.experiments.Run** that includes your starting point (whether you want to start training from scratch or continue off of a previously trained model) and stopping criterion (the number of iterations you want to train).

## Environment setup
If you have not already done so, you will need to install DaCapo. You can do this by first creating a new environment and then installing DaCapo using pip.

```bash
conda create -n dacapo python=3.10
conda activate dacapo
```

Then, you can install DaCapo using pip, via GitHub:

```bash
pip install git+https://github.com/janelia-cellmap/dacapo.git
```

Or you can clone the repository and install it locally:

```bash
git clone https://github.com/janelia-cellmap/dacapo.git
cd dacapo
pip install -e .
```

Be sure to select this environment in your Jupyter notebook or JupyterLab.

## Config Store
To define where the data goes, create a dacapo.yaml configuration file either in `~/.config/dacapo/dacapo.yaml` or in `./dacapo.yaml`. Here is a template:

```yaml
mongodbhost: mongodb://dbuser:dbpass@dburl:dbport/
mongodbname: dacapo
runs_base_dir: /path/to/my/data/storage
```
The runs_base_dir defines where your on-disk data will be stored. The mongodbhost and mongodbname define the mongodb host and database that will store your cloud data. If you want to store everything on disk, replace mongodbhost and mongodbname with a single type `files` and everything will be saved to disk:

```yaml
type: files
runs_base_dir: /path/to/my/data/storage
```


In [None]:
from dacapo.store.create_store import create_config_store

config_store = create_config_store()

In [None]:
from examples.random_source_pipeline import random_source_pipeline
import gunpowder as gp
import neuroglancer
import numpy as np
from IPython.display import IFrame
from scipy.ndimage import gaussian_filter

pipeline, request = random_source_pipeline(input_shape=(120, 120, 120))


def batch_generator():
    with gp.build(pipeline):
        while True:
            yield pipeline.request_batch(request)


batch_gen = batch_generator()
batch = next(batch_gen)
raw_array = batch.arrays[gp.ArrayKey("RAW")]
labels_array = batch.arrays[gp.ArrayKey("LABELS")]


labels_data = labels_array.data
raw_data = raw_array.data

neuroglancer.set_server_bind_address("0.0.0.0")
viewer = neuroglancer.Viewer()
with viewer.txn() as state:
    state.showSlices = False
    state.layers["segs"] = neuroglancer.SegmentationLayer(
        # segments=[str(i) for i in np.unique(data[data > 0])], # this line will cause all objects to be selected and thus all meshes to be generated...will be slow if lots of high res meshes
        source=neuroglancer.LocalVolume(
            data=labels_data,
            dimensions=neuroglancer.CoordinateSpace(
                names=["z", "y", "x"],
                units=["nm", "nm", "nm"],
                scales=labels_array.spec.voxel_size,
            ),
            # voxel_offset=ds.roi.begin / ds.voxel_size,
        ),
        segments=np.unique(labels_data[labels_data > 0]),
    )

    state.layers["raw"] = neuroglancer.ImageLayer(
        source=neuroglancer.LocalVolume(
            data=raw_data,
            dimensions=neuroglancer.CoordinateSpace(
                names=["z", "y", "x"],
                units=["nm", "nm", "nm"],
                scales=raw_array.spec.voxel_size,
            ),
        ),
    )

IFrame(src=viewer, width=1500, height=600)

## Datasplit
Where can you find your data? What format is it in? Does it need to be normalized? What data do you want to use for validation?

In [None]:
from dacapo.experiments.datasplits.datasets.arrays import (
    BinarizeArrayConfig,
    CropArrayConfig,
    ConcatArrayConfig,
    IntensitiesArrayConfig,
    MissingAnnotationsMaskConfig,
    ResampledArrayConfig,
    ZarrArrayConfig,
)
from dacapo.experiments.datasplits import TrainValidateDataSplitConfig
from dacapo.experiments.datasplits.datasets import RawGTDatasetConfig
from pathlib import PosixPath
from funlib.geometry import Roi

datasplit_config = TrainValidateDataSplitConfig(
    name="example_synthetic_datasplit_config",
    train_configs=[
        RawGTDatasetConfig(
            name="example_raw_data",
            weight=1,
            raw_config=IntensitiesArrayConfig(
                name="jrc_mus-liver_s1_raw",
                source_array_config=ZarrArrayConfig(
                    name="jrc_mus-liver_raw_uint8",
                    file_name=PosixPath(
                        "/nrs/cellmap/data/jrc_mus-liver/jrc_mus-liver.n5"
                    ),
                    dataset="volumes/raw/s1",
                    snap_to_grid=(16, 16, 16),
                    axes=None,
                ),
                min=0.0,
                max=255.0,
            ),
            gt_config=BinarizeArrayConfig(
                name="jrc_mus-liver_124_mito_proxisome_many_8nm_gt",
                source_array_config=ResampledArrayConfig(
                    name="jrc_mus-liver_124_gt_resampled_8nm",
                    source_array_config=ZarrArrayConfig(
                        name="jrc_mus-liver_124_gt",
                        file_name=PosixPath(
                            "/nrs/cellmap/zouinkhim/data/tmp_data_v3/jrc_mus-liver/jrc_mus-liver.n5"
                        ),
                        dataset="volumes/groundtruth/crop124/labels//all",
                        snap_to_grid=(16, 16, 16),
                        axes=None,
                    ),
                    upsample=(0, 0, 0),
                    downsample=(2, 2, 2),
                    interp_order=False,
                ),
                groupings=[("mito", [3, 4, 5]), ("peroxisome", [47, 48])],
                background=0,
            ),
            mask_config=MissingAnnotationsMaskConfig(
                name="jrc_mus-liver_124_mito_proxisome_many_8nm_mask",
                source_array_config=ResampledArrayConfig(
                    name="jrc_mus-liver_124_gt_resampled_8nm",
                    source_array_config=ZarrArrayConfig(
                        name="jrc_mus-liver_124_gt",
                        file_name=PosixPath(
                            "/nrs/cellmap/zouinkhim/data/tmp_data_v3/jrc_mus-liver/jrc_mus-liver.n5"
                        ),
                        dataset="volumes/groundtruth/crop124/labels//all",
                        snap_to_grid=(16, 16, 16),
                        axes=None,
                    ),
                    upsample=(0, 0, 0),
                    downsample=(2, 2, 2),
                    interp_order=False,
                ),
                groupings=[("mito", [3, 4, 5]), ("peroxisome", [47, 48])],
            ),
            sample_points=None,
        ),
        RawGTDatasetConfig(
            name="jrc_mus-liver_125_mito_proxisome_many_8nm",
            weight=1,
            raw_config=IntensitiesArrayConfig(
                name="jrc_mus-liver_s1_raw",
                source_array_config=ZarrArrayConfig(
                    name="jrc_mus-liver_raw_uint8",
                    file_name=PosixPath(
                        "/nrs/cellmap/data/jrc_mus-liver/jrc_mus-liver.n5"
                    ),
                    dataset="volumes/raw/s1",
                    snap_to_grid=(16, 16, 16),
                    axes=None,
                ),
                min=0.0,
                max=255.0,
            ),
            gt_config=BinarizeArrayConfig(
                name="jrc_mus-liver_125_mito_proxisome_many_8nm_gt",
                source_array_config=ResampledArrayConfig(
                    name="jrc_mus-liver_125_gt_resampled_8nm",
                    source_array_config=ZarrArrayConfig(
                        name="jrc_mus-liver_125_gt",
                        file_name=PosixPath(
                            "/nrs/cellmap/zouinkhim/data/tmp_data_v3/jrc_mus-liver/jrc_mus-liver.n5"
                        ),
                        dataset="volumes/groundtruth/crop125/labels//all",
                        snap_to_grid=(16, 16, 16),
                        axes=None,
                    ),
                    upsample=(0, 0, 0),
                    downsample=(2, 2, 2),
                    interp_order=False,
                ),
                groupings=[("mito", [3, 4, 5]), ("peroxisome", [47, 48])],
                background=0,
            ),
            mask_config=MissingAnnotationsMaskConfig(
                name="jrc_mus-liver_125_mito_proxisome_many_8nm_mask",
                source_array_config=ResampledArrayConfig(
                    name="jrc_mus-liver_125_gt_resampled_8nm",
                    source_array_config=ZarrArrayConfig(
                        name="jrc_mus-liver_125_gt",
                        file_name=PosixPath(
                            "/nrs/cellmap/zouinkhim/data/tmp_data_v3/jrc_mus-liver/jrc_mus-liver.n5"
                        ),
                        dataset="volumes/groundtruth/crop125/labels//all",
                        snap_to_grid=(16, 16, 16),
                        axes=None,
                    ),
                    upsample=(0, 0, 0),
                    downsample=(2, 2, 2),
                    interp_order=False,
                ),
                groupings=[("mito", [3, 4, 5]), ("peroxisome", [47, 48])],
            ),
            sample_points=None,
        ),
    ],
    validate_configs=[
        RawGTDatasetConfig(
            name="jrc_mus-liver_145_mito_proxisome_many_8nm",
            weight=1,
            raw_config=IntensitiesArrayConfig(
                name="jrc_mus-liver_s1_raw",
                source_array_config=ZarrArrayConfig(
                    name="jrc_mus-liver_raw_uint8",
                    file_name=PosixPath(
                        "/nrs/cellmap/data/jrc_mus-liver/jrc_mus-liver.n5"
                    ),
                    dataset="volumes/raw/s1",
                    snap_to_grid=(16, 16, 16),
                    axes=None,
                ),
                min=0.0,
                max=255.0,
            ),
            gt_config=BinarizeArrayConfig(
                name="jrc_mus-liver_145_mito_proxisome_many_8nm_gt",
                source_array_config=ResampledArrayConfig(
                    name="jrc_mus-liver_145_gt_resampled_8nm",
                    source_array_config=ZarrArrayConfig(
                        name="jrc_mus-liver_145_gt",
                        file_name=PosixPath(
                            "/nrs/cellmap/zouinkhim/data/tmp_data_v3/jrc_mus-liver/jrc_mus-liver.n5"
                        ),
                        dataset="volumes/groundtruth/crop145/labels//all",
                        snap_to_grid=(16, 16, 16),
                        axes=None,
                    ),
                    upsample=(0, 0, 0),
                    downsample=(2, 2, 2),
                    interp_order=False,
                ),
                groupings=[("mito", [3, 4, 5]), ("peroxisome", [47, 48])],
                background=0,
            ),
            mask_config=MissingAnnotationsMaskConfig(
                name="jrc_mus-liver_145_mito_proxisome_many_8nm_mask",
                source_array_config=ResampledArrayConfig(
                    name="jrc_mus-liver_145_gt_resampled_8nm",
                    source_array_config=ZarrArrayConfig(
                        name="jrc_mus-liver_145_gt",
                        file_name=PosixPath(
                            "/nrs/cellmap/zouinkhim/data/tmp_data_v3/jrc_mus-liver/jrc_mus-liver.n5"
                        ),
                        dataset="volumes/groundtruth/crop145/labels//all",
                        snap_to_grid=(16, 16, 16),
                        axes=None,
                    ),
                    upsample=(0, 0, 0),
                    downsample=(2, 2, 2),
                    interp_order=False,
                ),
                groupings=[("mito", [3, 4, 5]), ("peroxisome", [47, 48])],
            ),
            sample_points=None,
        ),
        RawGTDatasetConfig(
            name="jrc_mus-liver-zon-1_386_mito_proxisome_many_8nm",
            weight=1,
            raw_config=IntensitiesArrayConfig(
                name="jrc_mus-liver-zon-1_s1_raw",
                source_array_config=ZarrArrayConfig(
                    name="jrc_mus-liver-zon-1_raw_uint8",
                    file_name=PosixPath(
                        "/nrs/cellmap/data/jrc_mus-liver-zon-1/jrc_mus-liver-zon-1.n5"
                    ),
                    dataset="em/fibsem-uint8/s1",
                    snap_to_grid=(16, 16, 16),
                    axes=None,
                ),
                min=0.0,
                max=255.0,
            ),
            gt_config=CropArrayConfig(
                name="jrc_mus-liver-zon-1_386_8nm_gt_cropped",
                source_array_config=ConcatArrayConfig(
                    name="jrc_mus-liver-zon-1_386_8nm_gt",
                    channels=["mito", "peroxisome"],
                    source_array_configs={
                        "peroxisome": BinarizeArrayConfig(
                            name="jrc_mus-liver-zon-1_386_peroxisome_8nm_binarized",
                            source_array_config=ResampledArrayConfig(
                                name="jrc_mus-liver-zon-1_386_peroxisome_resampled_8nm",
                                source_array_config=ZarrArrayConfig(
                                    name="jrc_mus-liver-zon-1_386_peroxisome",
                                    file_name=PosixPath(
                                        "/nrs/cellmap/zouinkhim/data/tmp_data_v3/jrc_mus-liver-zon-1/jrc_mus-liver-zon-1.n5"
                                    ),
                                    dataset="volumes/groundtruth/crop386/labels//peroxisome",
                                    snap_to_grid=(16, 16, 16),
                                    axes=None,
                                ),
                                upsample=(4, 4, 4),
                                downsample=(0, 0, 0),
                                interp_order=False,
                            ),
                            groupings=[("peroxisome", [])],
                            background=0,
                        )
                    },
                    default_config=None,
                ),
                roi=Roi([145600, 59200, 147200], [3200, 3200, 6400]),
            ),
            mask_config=CropArrayConfig(
                name="jrc_mus-liver-zon-1_386_8nm_mask_cropped",
                source_array_config=CropArrayConfig(
                    name="jrc_mus-liver-zon-1_386_8nm_gt_cropped",
                    source_array_config=ConcatArrayConfig(
                        name="jrc_mus-liver-zon-1_386_8nm_gt",
                        channels=["mito", "peroxisome"],
                        source_array_configs={
                            "peroxisome": BinarizeArrayConfig(
                                name="jrc_mus-liver-zon-1_386_peroxisome_8nm_binarized",
                                source_array_config=ResampledArrayConfig(
                                    name="jrc_mus-liver-zon-1_386_peroxisome_resampled_8nm",
                                    source_array_config=ZarrArrayConfig(
                                        name="jrc_mus-liver-zon-1_386_peroxisome",
                                        file_name=PosixPath(
                                            "/nrs/cellmap/zouinkhim/data/tmp_data_v3/jrc_mus-liver-zon-1/jrc_mus-liver-zon-1.n5"
                                        ),
                                        dataset="volumes/groundtruth/crop386/labels//peroxisome",
                                        snap_to_grid=(16, 16, 16),
                                        axes=None,
                                    ),
                                    upsample=(4, 4, 4),
                                    downsample=(0, 0, 0),
                                    interp_order=False,
                                ),
                                groupings=[("peroxisome", [])],
                                background=0,
                            )
                        },
                        default_config=None,
                    ),
                    roi=Roi([145600, 59200, 147200], [3200, 3200, 6400]),
                ),
                roi=Roi([145600, 59200, 147200], [3200, 3200, 6400]),
            ),
            sample_points=None,
        ),
    ],
)

config_store.store_datasplit_config(datasplit_config)

## Task
What do you want to learn? An instance segmentation? If so, how? Affinities,
Distance Transform, Foreground/Background, etc. Each of these tasks are commonly learned
and evaluated with specific loss functions and evaluation metrics. Some tasks may
also require specific non-linearities or output formats from your model.

In [None]:
from dacapo.experiments.tasks import DistanceTaskConfig

task_config = DistanceTaskConfig(
    name="example_distances_8nm_mito_proxisome_many",
    channels=["mito", "peroxisome"],
    clip_distance=80.0,
    tol_distance=80.0,
    scale_factor=160.0,
    mask_distances=True,
    clipmin=0.05,
    clipmax=0.95,
)
config_store.store_task_config(task_config)

## Architecture

The setup of the network you will train. Biomedical image to image translation often utilizes a UNet, but even after choosing a UNet you still need to provide some additional parameters. How much do you want to downsample? How many convolutional layers do you want?

In [None]:
from dacapo.experiments.architectures import CNNectomeUNetConfig

architecture_config = CNNectomeUNetConfig(
    name="example_attention-upsample-unet",
    input_shape=(216, 216, 216),
    fmaps_out=72,
    fmaps_in=1,
    num_fmaps=12,
    fmap_inc_factor=6,
    downsample_factors=[(2, 2, 2), (3, 3, 3), (3, 3, 3)],
    kernel_size_down=None,
    kernel_size_up=None,
    eval_shape_increase=(72, 72, 72),
    upsample_factors=[(2, 2, 2)],
    constant_upsample=True,
    padding="valid",
    use_attention=True,
)
config_store.store_architecture_config(architecture_config)

## Trainer

How do you want to train? This config defines the training loop and how the other three components work together. What sort of augmentations to apply during training, what learning rate and optimizer to use, what batch size to train with.

In [None]:
from dacapo.experiments.trainers import GunpowderTrainerConfig
from dacapo.experiments.trainers.gp_augments import (
    ElasticAugmentConfig,
    GammaAugmentConfig,
    IntensityAugmentConfig,
    IntensityScaleShiftAugmentConfig,
)

trainer_config = GunpowderTrainerConfig(
    name="default",
    batch_size=2,
    learning_rate=0.0001,
    num_data_fetchers=20,
    augments=[
        ElasticAugmentConfig(
            control_point_spacing=[100, 100, 100],
            control_point_displacement_sigma=[10.0, 10.0, 10.0],
            rotation_interval=(0.0, 1.5707963267948966),
            subsample=8,
            uniform_3d_rotation=True,
        ),
        IntensityAugmentConfig(scale=(0.25, 1.75), shift=(-0.5, 0.35), clip=True),
        GammaAugmentConfig(gamma_range=(0.5, 2.0)),
        IntensityScaleShiftAugmentConfig(scale=2.0, shift=-1.0),
    ],
    snapshot_interval=10000,
    min_masked=0.05,
    clip_raw=True,
)
config_store.store_trainer_config(trainer_config)

## Run
Now that we have our components configured, we just need to combine them into a run and start training. We can have multiple repetitions of a single set of configs in order to increase our chances of finding an optimum.

In [None]:
from dacapo.experiments.starts import StartConfig
from dacapo.experiments import RunConfig
from dacapo.experiments.run import Run

start_config = None

# Uncomment to start from a pretrained model
# start_config = StartConfig(
#     "setup04",
#     "best",
# )

iterations = 200
validation_interval = 5
repetitions = 3
for i in range(repetitions):
    run_config = RunConfig(
        name=("_").join(
            [
                "example",
                "scratch" if start_config is None else "finetuned",
                datasplit_config.name,
                task_config.name,
                architecture_config.name,
                trainer_config.name,
            ]
        )
        + f"__{i}",
        datasplit_config=datasplit_config,
        task_config=task_config,
        architecture_config=architecture_config,
        trainer_config=trainer_config,
        num_iterations=iterations,
        validation_interval=validation_interval,
        repetition=i,
        start_config=start_config,
    )

    print(run_config.name)
    config_store.store_run_config(run_config)

## Train

To train one of the runs, you can either do it by first creating a **Run** directly from the run config

In [None]:
from dacapo.train import train_run

run = Run(config_store.retrieve_run_config(run_config.name))
train_run(run)

If you want to start your run on some compute cluster, you might want to use the command line interface: dacapo train -r {run_config.name}. This makes it particularly convenient to run on compute nodes where you can specify specific compute requirements.

In [None]:
from scipy.ndimage import (
    distance_transform_edt,
    binary_dilation,
    generate_binary_structure,
)
import numpy as np
from skimage.measure import label as relabel

labels = np.zeros((512, 512, 512), dtype=np.uint8)
generate_binary_structure(3, connectivity=2)

random_point_centers = np.random.randint(1, 255, (250, 3))

labels[random_point_centers] = 1
generate_binary_structure(3, connectivity=2)
labels = binary_dilation(labels)

relabeled = relabel(labels, connectivity=2).astype(labels.dtype)  # type: ignore
relabeled = relabel(relabeled, connectivity=2).astype(labels.dtype)  # type: ignore

In [None]:
generate_binary_structure(3, connectivity=2)
binary_dilation(arr)

In [None]:
import neuroglancer
from funlib.persistence import open_ds
import numpy as np
from IPython.display import IFrame

voxel_size = (8, 8, 8)
neuroglancer.set_server_bind_address("0.0.0.0")
viewer = neuroglancer.Viewer()
raw = open_ds("./tmp/validation.zarr", "RAW")
labels = open_ds("./tmp/validation.zarr", "LABELS")
labels_data = labels.to_ndarray()

with viewer.txn() as state:
    state.showSlices = False
    state.layers["segs"] = neuroglancer.SegmentationLayer(
        # segments=[str(i) for i in np.unique(data[data > 0])], # this line will cause all objects to be selected and thus all meshes to be generated...will be slow if lots of high res meshes
        source=neuroglancer.LocalVolume(
            data=labels_data,
            dimensions=neuroglancer.CoordinateSpace(
                names=["z", "y", "x"],
                units=["nm", "nm", "nm"],
                scales=labels.voxel_size,
            ),
            # voxel_offset=ds.roi.begin / ds.voxel_size,
        ),
        segments=np.unique(labels_data[labels_data > 0]),
    )

    state.layers["raw"] = neuroglancer.ImageLayer(
        source=neuroglancer.LocalVolume(
            data=raw.data,
            dimensions=neuroglancer.CoordinateSpace(
                names=["z", "y", "x"],
                units=["nm", "nm", "nm"],
                scales=raw.voxel_size,
            ),
        ),
    )
IFrame(src=viewer, width=1800, height=900)

# View run

## neuroglancer run viewer class

In [1]:
from funlib.persistence import open_ds
from threading import Thread
import neuroglancer
from neuroglancer.viewer_state import ViewerState
import os
from dacapo.experiments.run import Run
from dacapo.store.create_store import create_config_store, create_array_store
from IPython.display import IFrame
import time
import copy
import json

config_store = create_config_store()
run_name = "example_scratch_example_synthetic_datasplit_config_example_synthetic_distance_task_config_example_synthetic_unet_example_synthetic_trainer_config__0"
run = Run(config_store.retrieve_run_config(run_name))


class NeuroglancerRunViewer:
    def __init__(self, run):
        self.run: Run = run
        self.most_recent_iteration = 0
        self.prediction = None

    def updated_neuroglancer_layer(self, layer_name, ds):
        source = neuroglancer.LocalVolume(
            data=ds.data,
            dimensions=neuroglancer.CoordinateSpace(
                names=["c", "z", "y", "x"],
                units=["", "nm", "nm", "nm"],
                scales=[1] + list(ds.voxel_size),
            ),
            voxel_offset=[0] + list(ds.roi.offset),
        )
        new_state = copy.deepcopy(self.viewer.state)
        if len(new_state.layers) == 1:
            new_state.layers[layer_name] = neuroglancer.ImageLayer(source=source)
        else:
            # replace name everywhere to preserve state, like what is selected
            new_state_str = json.dumps(new_state.to_json())
            new_state_str = new_state_str.replace(new_state.layers[-1].name, layer_name)
            new_state = ViewerState(json.loads(new_state_str))
            new_state.layers[layer_name].source = source

        self.viewer.set_state(new_state)
        print(self.viewer.state)

    def deprecated_start_neuroglancer(self):
        neuroglancer.set_server_bind_address("0.0.0.0")
        self.viewer = neuroglancer.Viewer()

    def start_neuroglancer(self):
        neuroglancer.set_server_bind_address("0.0.0.0")
        self.viewer = neuroglancer.Viewer()
        # raw = open_ds("./tmp/validation.zarr", "RAW")
        # labels = open_ds("./tmp/validation.zarr", "LABELS")
        # labels_data = labels.to_ndarray()
        with self.viewer.txn() as state:
            state.showSlices = False

            state.layers["raw"] = neuroglancer.ImageLayer(
                source=neuroglancer.LocalVolume(
                    data=self.raw.data,
                    dimensions=neuroglancer.CoordinateSpace(
                        names=["z", "y", "x"],
                        units=["nm", "nm", "nm"],
                        scales=self.raw.voxel_size,
                    ),
                    voxel_offset=self.raw.roi.offset,
                ),
            )
        return IFrame(src=self.viewer, width=1800, height=900)

    def start(self):
        self.array_store = create_array_store()
        self.get_datasets()
        self.new_validation_checker()
        return self.start_neuroglancer()

    def open_from_array_identitifier(self, array_identifier):
        if os.path.exists(array_identifier.container / array_identifier.dataset):
            return open_ds(str(array_identifier.container), array_identifier.dataset)
        else:
            return None

    def get_datasets(self):
        for validation_dataset in self.run.datasplit.validate:
            (
                input_raw_array_identifier,
                input_gt_array_identifier,
            ) = self.array_store.validation_input_arrays(
                self.run.name, validation_dataset.name
            )

            self.raw = self.open_from_array_identitifier(input_raw_array_identifier)
            self.gt = self.open_from_array_identitifier(input_gt_array_identifier)
        print(self.raw)

    def update_best_info(self, iteration, validation_dataset_name):
        prediction_array_identifier = self.array_store.validation_prediction_array(
            self.run.name,
            iteration,
            validation_dataset_name,
        )
        self.prediction = self.open_from_array_identitifier(prediction_array_identifier)
        self.most_recent_iteration = iteration

    def update_neuroglancer(self, iteration):
        self.updated_neuroglancer_layer(
            f"prediction at iteration {iteration}", self.prediction
        )
        return None

    def update_best(self, iteration, validation_dataset_name):
        self.update_best_info(iteration, validation_dataset_name)
        self.update_neuroglancer(iteration)

    def new_validation_checker(self):
        self.process = Thread(  # multiprocessing.Process(
            target=self.update_with_new_validation_if_possible
        )
        self.process.daemon = True
        self.process.start()

    def update_with_new_validation_if_possible(self):
        while True:
            time.sleep(3)
            for validation_dataset in self.run.datasplit.validate:
                most_recent_iteration_previous = self.most_recent_iteration
                prediction_array_identifier = (
                    self.array_store.validation_prediction_array(
                        self.run.name,
                        self.most_recent_iteration,
                        validation_dataset.name,
                    )
                )

                container = prediction_array_identifier.container
                if os.path.exists(container):
                    iteration_dirs = [
                        name
                        for name in os.listdir(container)
                        if os.path.isdir(os.path.join(container, name))
                        and name.isnumeric()
                    ]

                    for iteration_dir in iteration_dirs:
                        if int(iteration_dir) > self.most_recent_iteration:
                            inference_dir = os.path.join(
                                container,
                                iteration_dir,
                                "validation_config",
                                "prediction",
                            )
                            if os.path.exists(inference_dir):
                                inference_dir_contents = [
                                    f
                                    for f in os.listdir(inference_dir)
                                    if not f.startswith(".") and not f.endswith(".json")
                                ]
                                if inference_dir_contents:
                                    # then it should have at least a chunk writtent out, assume it has all of it written out
                                    self.most_recent_iteration = int(iteration_dir)
                    if most_recent_iteration_previous != self.most_recent_iteration:
                        self.update_best(
                            self.most_recent_iteration,
                            validation_dataset.name,
                        )

Creating FileConfigStore:
	path: /nrs/cellmap/ackermand/dacapo_learnathon_examples/configs


# set up neuroglancer that tracks progress

In [None]:
nrv = NeuroglancerRunViewer(run)
nrv.start()

# pseudo training, really now just does inference very few seconds

In [1]:
from old_predict import predict
from dacapo.experiments.run import Run
from funlib.geometry import Roi
import shutil
from dacapo.store.create_store import (
    create_config_store,
    create_array_store,
    create_weights_store,
)

config_store = create_config_store()
array_store = create_array_store()
run_name = "example_scratch_example_synthetic_datasplit_config_example_synthetic_distance_task_config_example_synthetic_unet_example_synthetic_trainer_config__0"
run = Run(config_store.retrieve_run_config(run_name))
# create weights store and read weights
weights_store = create_weights_store()
for iteration in range(50, 2500, 50):
    print("hi")
    # time.sleep(3)
    #shutil.rmtree(f"/nrs/cellmap/ackermand/dacapo_learnathon_examples/{run_name}/validation.zarr/{iteration}/validation_config/prediction")
    # weights = weights_store.retrieve_weights(run, iteration)
    # run.model.load_state_dict(weights.model)
    # prediction_array_identifier = array_store.validation_prediction_array(
    #     run.name, iteration, run.datasplit.validate[0].name
    # )
    # # predict(run.model, run.datasplit.validate[0].raw, prediction_array_identifier)
    # # %%

    # predict(
    #     run.model,
    #     run.datasplit.validate[0].raw,
    #     prediction_array_identifier,
    #     output_roi=Roi((0, 0, 0), (864, 864, 864)),
    # )

Creating FileConfigStore:
	path: /nrs/cellmap/ackermand/dacapo_learnathon_examples/configs
Creating local weights store in directory %s /nrs/cellmap/ackermand/dacapo_learnathon_examples
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
hi
