# Dacapo

In [1]:
from dacapo.experiments.architectures import CNNectomeUNetConfig
from dacapo.experiments.trainers import GunpowderTrainerConfig
from dacapo.experiments.trainers.gp_augments import (
    ElasticAugmentConfig,
    IntensityAugmentConfig,
)
from dacapo.experiments.tasks import AffinitiesTaskConfig
from funlib.geometry.coordinate import Coordinate
import math

## Trainer

In [2]:
trainer_config = GunpowderTrainerConfig(
    name="default_v2_no_dataset_predictor_node_lr_5E-5",
    batch_size=2,
    learning_rate=0.00005,
    augments=[
        ElasticAugmentConfig(
            control_point_spacing=(100, 100, 100),
            control_point_displacement_sigma=(10.0, 10.0, 10.0),
            rotation_interval=(0, math.pi / 2.0),
            subsample=8,
            uniform_3d_rotation=True,
        ),
        IntensityAugmentConfig(
            scale=(0.7, 1.3),
            shift=(-0.2, 0.2),
            clip=True,
        ),
    ],
    clip_raw=True,
    num_data_fetchers=20,
    snapshot_interval=10000,
    min_masked=0.05,
    add_predictor_nodes_to_dataset=False,
)

## Task

In [3]:
task_config = AffinitiesTaskConfig(
    name=f"3d_lsdaffs_weight_ratio_0.50",
    neighborhood=[
        (1, 0, 0),
        (0, 1, 0),
        (0, 0, 1),
        (3, 0, 0),
        (0, 3, 0),
        (0, 0, 3),
        (9, 0, 0),
        (0, 9, 0),
        (0, 0, 9),
    ],
    lsds=True,
    lsds_to_affs_weight_ratio=0.50,
)

## Architecture

I had an issue where, by default, I created the rasterization at the same resolution as the raw data. But the default architecture (with the upsampling layer `upsample_factors`) expects it to be at 2x the resolution including mask and validation. This resulted in an error when submitting. Since we don't really care about a higher res (at the moment), we can just comment out the upsampling layer (`constant_upsample` and `upsample_factors`)

In [4]:
architecture_config = CNNectomeUNetConfig(
    name="unet",
    input_shape=Coordinate(216, 216, 216),
    eval_shape_increase=Coordinate(72, 72, 72),
    fmaps_in=1,
    num_fmaps=12,
    fmaps_out=72,
    fmap_inc_factor=6,
    downsample_factors=[(2, 2, 2), (3, 3, 3), (3, 3, 3)],
    # constant_upsample=True,
    # upsample_factors=[(2, 2, 2)],
)

## Datasplit

EVERYTHING MUST BE IN Z,Y,X AND NM!

combined datasplits

In [14]:
from dacapo.store.create_store import create_config_store, MongoConfigStore
from dacapo.experiments.datasplits import TrainValidateDataSplitConfig
from dacapo.experiments.datasplits import DataSplitConfig
from dacapo.options import Options
from dacapo.store.converter import converter

options = Options.instance()

combined_train_configs = []
combined_validate_configs = []
config_store = create_config_store()
annotation_name = "plasmodesmata"
for dataset in ["jrc_22ak351-leaf-3m", "jrc_22ak351-leaf-3r", "jrc_22ak351-leaf-2l"]:
    # arbitrarily use last run, doesn't matter since we really only care about datasplit
    if dataset == "jrc_22ak351-leaf-3m":
        run_name = "finetuned_3d_lsdaffs_weight_ratio_0.50_plasmodesmata_pseudorandom_training_centers_maxshift_18_more_annotations_unet_default_v2_no_dataset_predictor_node_lr_5E-5__0"
    else:
        run_name = f"finetuned_3d_lsdaffs_weight_ratio_0.50_{dataset}_plasmodesmata_pseudorandom_training_centers_unet_default_v2_no_dataset_predictor_node_lr_5E-5__0"
    run_config = config_store.retrieve_run_config(run_name)
    datasplit_config = run_config.datasplit_config

    combined_train_configs.extend(datasplit_config.train_configs)
    combined_validate_configs.extend(datasplit_config.validate_configs)

combined_datasplit_config = TrainValidateDataSplitConfig(
    name=f"combined_{annotation_name}_pseudorandom_training_centers",
    train_configs=combined_train_configs,
    validate_configs=combined_validate_configs,
)
config_store.store_datasplit_config(combined_datasplit_config)

## Run

In [20]:
from dacapo.experiments import RunConfig
from dacapo.experiments.starts import StartConfig
from dacapo.store.create_store import create_config_store

config_store = create_config_store()
start_config = StartConfig(
    "finetuned_3d_lsdaffs_weight_ratio_0.50_plasmodesmata_pseudorandom_training_centers_maxshift_18_more_annotations_unet_default_v2_no_dataset_predictor_node_lr_5E-5__1",
    "140000",
)
iterations = 200000
# make validation interval huge so don't have to deal with validation until after the fact
validation_interval = 5000
repetitions = 3
for i in range(repetitions):
    run_config = RunConfig(
        name=("_").join(
            [
                "scratch" if start_config is None else "finetuned",
                task_config.name,
                combined_datasplit_config.name,
                architecture_config.name,
                trainer_config.name,
            ]
        )
        + f"__{i}",
        task_config=task_config,
        datasplit_config=combined_datasplit_config,
        architecture_config=architecture_config,
        trainer_config=trainer_config,
        num_iterations=iterations,
        validation_interval=validation_interval,
        repetition=i,
        start_config=start_config,
    )
    config_store.store_run_config(run_config)
    # "dacapo run -r {run_config.name}"
    print(
        f"visualize run: python /groups/scicompsoft/home/ackermand/Programming/ml_experiments/scripts/visualize_pipeline.py visualize-pipeline -r {run_config.name}"
    )

visualize run: python /groups/scicompsoft/home/ackermand/Programming/ml_experiments/scripts/visualize_pipeline.py visualize-pipeline -r finetuned_3d_lsdaffs_weight_ratio_0.50_combined_plasmodesmata_pseudorandom_training_centers_unet_default_v2_no_dataset_predictor_node_lr_5E-5__0
visualize run: python /groups/scicompsoft/home/ackermand/Programming/ml_experiments/scripts/visualize_pipeline.py visualize-pipeline -r finetuned_3d_lsdaffs_weight_ratio_0.50_combined_plasmodesmata_pseudorandom_training_centers_unet_default_v2_no_dataset_predictor_node_lr_5E-5__1
visualize run: python /groups/scicompsoft/home/ackermand/Programming/ml_experiments/scripts/visualize_pipeline.py visualize-pipeline -r finetuned_3d_lsdaffs_weight_ratio_0.50_combined_plasmodesmata_pseudorandom_training_centers_unet_default_v2_no_dataset_predictor_node_lr_5E-5__2


In [22]:
print(run_config.name)

finetuned_3d_lsdaffs_weight_ratio_0.50_combined_plasmodesmata_pseudorandom_training_centers_unet_default_v2_no_dataset_predictor_node_lr_5E-5__2


# Prediction Mask

In [7]:
from funlib.persistence import open_ds, prepare_ds
from funlib.geometry import Roi, Coordinate
from scipy.ndimage import binary_dilation, distance_transform_edt
import numpy as np

for iterations in range(1, 4):
    ds = open_ds(
        f"/nrs/cellmap/ackermand/cellmap/leaf-gall/{dataset}.n5",
        "plasmodesmata_column_cells",
    )
    voxel_size = ds.voxel_size
    data = ds.to_ndarray() > 0
    ds = open_ds(
        f"/nrs/cellmap/ackermand/cellmap/leaf-gall/{dataset}.n5",
        "plasmodesmata_column_target_cells",
    )
    data += ds.to_ndarray() > 0
    data = 1 - (data > 0)
    data = binary_dilation(data, iterations=iterations)

    output_ds = prepare_ds(
        "/nrs/cellmap/ackermand/cellmap/leaf-gall/prediction_masks.zarr",
        f"dilation_iterations_{iterations}_{dataset}",
        total_roi=ds.roi,
        voxel_size=voxel_size,
        dtype=np.uint8,
        write_size=Coordinate(np.array([64, 64, 64]) * 256),
        delete=True,
        # force_exact_write_size=True
    )
    output_ds[ds.roi] = data