In [2]:
%load_ext autoreload
%autoreload 2
from awesome.run.awesome_config import AwesomeConfig
from awesome.run.awesome_runner import AwesomeRunner
from awesome.util.reflection import class_name
import os
import torch
from typing import List, Tuple, Union, Dict, Literal, Optional
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.dataset.convexity_segmentation_dataset import ConvexitySegmentationDataset
from awesome.measures.awesome_loss import AwesomeLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.convex_diffeomorphism_net import ConvexDiffeomorphismNet
from awesome.model.net import Net
import awesome
from awesome.util.path_tools import get_project_root_path
import copy

os.chdir(get_project_root_path()) # Beeing in the root directory of the project is important for the relative paths to work consistently

## Building Run Configurations for the different scenarios and models

This Notebook is used to generate all the configs in the [config](../config) folder which are used to run the different scenarios and models.

### Models from Hannah Dröge's Paper 
Dröge, Hannah; Moeller, Michael (2021): Learning or Modelling? An Analysis of Single Image Segmentation Based on Scribble Information. In : 2021 IEEE International Conference on Image Processing (ICIP). IEEE, pp. 2274–2278.

#### 1. FCNET

##### FCNET with serveral feat combinations


In [None]:
# 1. FCNET xy

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.fc_net import FCNet
from awesome.measures.tv import TV

xytypes = ["xy", "featxy", "feat"]

for xytype in xytypes:
    cfg = AwesomeConfig(
        name_experiment=f"FCNET_benchmark+{xytype}",
        dataset_type=class_name(SISBOSIDataset),
        dataset_args={
            "dataset": SISBOSIConvexityDataset(
                    dataset_path="./data/datasets/convexity_dataset",
                    transform=False, # Using augmentations
                    semantic=False
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": "./data/datasets/convexity_dataset/Feat",
            "bs" : None,
            "dimension": "2d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
        },
        segmentation_model_type=class_name(FCNet),
        segmentation_model_args={
            'width': 16,
            'depth': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=False,
        prior_model_args=None,
        prior_model_type=None,
        loss_type=class_name(RegularizerLoss),
        loss_args={
            "criterion": torch.nn.BCELoss(),
            "tau": 0.,
            "regularizer": TV(),
        },
        use_binary_classification=True, 
        num_epochs=3000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs",
        use_progress_bar=False,
    )
    cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", no_uuid=True, override=True)

##### FCNET with ConvexNet

In [None]:
# 1. FCNET

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_loss import AwesomeLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.convex_net import ConvexNet
from awesome.model.fc_net import FCNet
from awesome.measures.tv import TV

xytypes = ["xy", "featxy", "feat"]
seeds = [47, 131]

for xytype in xytypes:
    cfg = AwesomeConfig(
        name_experiment=f"FCNET_benchmark+{xytype}+convex",
        dataset_type=class_name(SISBOSIDataset),
        dataset_args={
            "dataset": SISBOSIConvexityDataset(
                    dataset_path="./data/datasets/convexity_dataset",
                    transform=False, # Using augmentations
                    semantic=False
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": "./data/datasets/convexity_dataset/Feat",
            "bs" : None,
            "dimension": "2d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
        },
        segmentation_model_type=class_name(FCNet),
        segmentation_model_args={
            'width': 16,
            'depth': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=True,
        prior_model_args=dict(),
        prior_model_type=class_name(ConvexNet),
        loss_type=class_name(AwesomeLoss),
        loss_args={
                "criterion": torch.nn.BCELoss(),
                "tau": 0.,
                "regularizer": TV(),
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        use_binary_classification=True, 
        num_epochs=3000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs",
        use_progress_bar=False,
    )
    cfg.save_to_file(f"./config/rerun_unireps/{cfg.name_experiment}.yaml", no_uuid=True, override=True)
    for seed in seeds:
        c = copy.deepcopy(cfg)
        c.seed = seed
        c.name_experiment = cfg.name_experiment + f"+seed{seed}"
        c.save_to_file(f"./config/rerun_unireps/{c.name_experiment}.yaml", no_uuid=True, override=True)

#### FCNET with ConvexNet Joint Training

In [None]:
# 1. FCNET

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_loss_joint import AwesomeLossJoint
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.convex_net import ConvexNet
from awesome.model.fc_net import FCNet
from awesome.measures.tv import TV

xytypes = ["xy", "featxy", "feat"]
seeds = [47, 131]
for xytype in xytypes:
    cfg = AwesomeConfig(
        name_experiment=f"FCNET_benchmark+{xytype}+convex+joint",
        dataset_type=class_name(SISBOSIDataset),
        dataset_args={
            "dataset": SISBOSIConvexityDataset(
                    dataset_path="./data/datasets/convexity_dataset",
                    transform=False, # Using augmentations
                    semantic=False
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": "./data/datasets/convexity_dataset/Feat",
            "bs" : None,
            "dimension": "2d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
        },
        segmentation_model_type=class_name(FCNet),
        segmentation_model_args={
            'width': 16,
            'depth': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=True,
        prior_model_args=dict(),
        prior_model_type=class_name(ConvexNet),
        loss_type=class_name(AwesomeLossJoint),
        loss_args={
                "criterion": torch.nn.BCELoss(),
                "tau": 0.,
                "regularizer": TV(),
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        use_reduce_lr_in_extra_penalty_hook=True,
        use_binary_classification=True, 
        num_epochs=3000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs",
        use_progress_bar=False,
    )
    cfg.save_to_file(f"./config/benchmarks/{cfg.name_experiment}.yaml", no_uuid=True, override=True)
    for seed in seeds:
        c = copy.deepcopy(cfg)
        c.seed = seed
        c.name_experiment = cfg.name_experiment + f"+seed{seed}"
        c.save_to_file(f"./config/benchmarks_seeds/{c.name_experiment}.yaml", no_uuid=True, override=True)

##### FCNET with Diffeomophism

In [None]:
# 1. FCNET Using xy and diffeo

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_loss import AwesomeLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.convex_diffeomorphism_net import ConvexDiffeomorphismNet
from awesome.model.fc_net import FCNet
from awesome.measures.tv import TV

cfg = AwesomeConfig(
    name_experiment="FCNET_benchmark+xy+diffeo",
    dataset_type=class_name(SISBOSIDataset),
    dataset_args={
        "dataset": SISBOSIConvexityDataset(
                dataset_path="./data/datasets/convexity_dataset",
                transform=False, # Using augmentations
                semantic=False
            ),
        "xytransform": "xy",
        "xytype": "xy",
        "mode": "scribbles",
        "feature_dir": "./data/datasets/convexity_dataset/Feat",
        "bs" : None,
        "dimension": "2d", # 2d for fcnet
        "mode": "model_input",
        "model_input_requires_grad": False, # Can be used for 3d nets
        "batch_size": 1,
        "split_ratio": 1,
        "shuffle_in_dataloader": False,
    },
    segmentation_model_type=class_name(FCNet),
    segmentation_model_args={
        'width': 16,
        'depth': 3,
        'input': 'rgbxy',
    },
    segmentation_training_mode='single',
    use_prior_model=True,
    prior_model_args=dict(),
    prior_model_type=class_name(ConvexDiffeomorphismNet),
    loss_type=class_name(AwesomeLoss),
    loss_args={
            "criterion": torch.nn.BCELoss(),
            "tau": 0.,
            "regularizer": TV(),
            "name": "BCE",
        },
    use_binary_classification=True, 
    num_epochs=1000,
    device="cuda",
    dtype=str(torch.float32),
    runs_path="./runs",
    use_progress_bar=False,
)
cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", no_uuid=True, override=True)

In [None]:
# 1. FCNET Using featxy and diffeo

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_loss import AwesomeLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.convex_diffeomorphism_net import ConvexDiffeomorphismNet
from awesome.model.fc_net import FCNet
from awesome.measures.tv import TV

cfg = AwesomeConfig(
    name_experiment="FCNET_benchmark+featxy+diffeo",
    dataset_type=class_name(SISBOSIDataset),
    dataset_args={
        "dataset": SISBOSIConvexityDataset(
                dataset_path="./data/datasets/convexity_dataset",
                transform=False, # Using augmentations
                semantic=False
            ),
        "xytransform": "xy",
        "xytype": "featxy",
        "mode": "scribbles",
        "feature_dir": "./data/datasets/convexity_dataset/Feat",
        "bs" : None,
        "dimension": "2d", # 2d for fcnet
        "mode": "model_input",
        "model_input_requires_grad": False, # Can be used for 3d nets
        "batch_size": 1,
        "split_ratio": 1,
        "shuffle_in_dataloader": False,
    },
    segmentation_model_type=class_name(FCNet),
    segmentation_model_args={
        'width': 16,
        'depth': 3,
        'input': 'rgbxy',
    },
    segmentation_training_mode='single',
    use_prior_model=True,
    prior_model_args=dict(),
    prior_model_type=class_name(ConvexDiffeomorphismNet),
    loss_type=class_name(AwesomeLoss),
    loss_args={
            "criterion": torch.nn.BCELoss(),
            "tau": 0.,
            "regularizer": TV(),
            "name": "BCE",
        },
    use_binary_classification=True, 
    num_epochs=1000,
    device="cuda",
    dtype=str(torch.float32),
    runs_path="./runs",
    use_progress_bar=False,
)
cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", no_uuid=True, override=True)

In [None]:
# 1. FCNET Using feat and diffeo

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_loss import AwesomeLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.convex_diffeomorphism_net import ConvexDiffeomorphismNet
from awesome.model.fc_net import FCNet
from awesome.measures.tv import TV

cfg = AwesomeConfig(
    name_experiment="FCNET_benchmark+feat+diffeo",
    dataset_type=class_name(SISBOSIDataset),
    dataset_args={
        "dataset": SISBOSIConvexityDataset(
                dataset_path="./data/datasets/convexity_dataset",
                transform=False, # Using augmentations
                semantic=False
            ),
        "xytransform": "xy",
        "xytype": "feat",
        "mode": "scribbles",
        "feature_dir": "./data/datasets/convexity_dataset/Feat",
        "bs" : None,
        "dimension": "2d", # 2d for fcnet
        "mode": "model_input",
        "model_input_requires_grad": False, # Can be used for 3d nets
        "batch_size": 1,
        "split_ratio": 1,
        "shuffle_in_dataloader": False,
    },
    segmentation_model_type=class_name(FCNet),
    segmentation_model_args={
        'width': 16,
        'depth': 3,
        'input': 'rgbxy',
    },
    segmentation_training_mode='single',
    use_prior_model=True,
    prior_model_args=dict(),
    prior_model_type=class_name(ConvexDiffeomorphismNet),
    loss_type=class_name(AwesomeLoss),
    loss_args={
            "criterion": torch.nn.BCELoss(),
            "tau": 0.,
            "regularizer": TV(),
            "name": "BCE",
        },
    use_binary_classification=True, 
    num_epochs=1000,
    device="cuda",
    dtype=str(torch.float32),
    runs_path="./runs",
    use_progress_bar=False,
)
cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", no_uuid=True, override=True)

#### 2. CNNet with ConvexNet

In [None]:
# CNNet

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet

xytypes = ["xy", "featxy", "feat"]
seeds = [47, 131]

for xytype in xytypes:
    cfg = AwesomeConfig(
        name_experiment=f"CNNET_benchmark+{xytype}+convex",
        dataset_type=class_name(SISBOSIDataset),
        dataset_args={
            "dataset": SISBOSIConvexityDataset(
                    dataset_path="./data/datasets/convexity_dataset",
                    transform=False, # Using augmentations
                    semantic=False
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": "./data/datasets/convexity_dataset/Feat",
            "bs" : None,
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": True, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
        },
        segmentation_model_type=class_name(CNNNet),
        segmentation_model_args={
            'width': 16,
            'depth': 2,
            'kernel_size': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=True,
        prior_model_args=dict(),
        prior_model_type=class_name(ConvexNet),
        loss_type=class_name(AwesomeImageLoss),
        loss_args={
            "criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": True,
                "noneclass" : 2.,
                "xygrad" : 0.01,
                "rgbgrad" : 0.01,
                "featgrad" : 0.0,
                "xytype" : xytype,})
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        use_binary_classification=True, 
        num_epochs=3000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs",
        use_progress_bar=False,
    )
    cfg.save_to_file(f"./config/rerun_unireps/{cfg.name_experiment}.yaml", no_uuid=True, override=True)
    for seed in seeds:
        c = copy.deepcopy(cfg)
        c.seed = seed
        c.name_experiment = cfg.name_experiment + f"+seed{seed}"
        c.save_to_file(f"./config/rerun_unireps/{c.name_experiment}.yaml", no_uuid=True, override=True)

In [None]:
# CNNet Joint Training

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet

xytypes = ["xy", "featxy", "feat"]
seeds = [47, 131]
for xytype in xytypes:
    cfg = AwesomeConfig(
        name_experiment=f"CNNET_benchmark+{xytype}+convex+joint",
        dataset_type=class_name(SISBOSIDataset),
        dataset_args={
            "dataset": SISBOSIConvexityDataset(
                    dataset_path="./data/datasets/convexity_dataset",
                    transform=False, # Using augmentations
                    semantic=False
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": "./data/datasets/convexity_dataset/Feat",
            "bs" : None,
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": True, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
        },
        segmentation_model_type=class_name(CNNNet),
        segmentation_model_args={
            'width': 16,
            'depth': 2,
            'kernel_size': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=True,
        prior_model_args=dict(),
        prior_model_type=class_name(ConvexNet),
        loss_type=class_name(AwesomeImageLossJoint),
        loss_args={
            "criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": True,
                "noneclass" : 2.,
                "xygrad" : 0.01,
                "rgbgrad" : 0.01,
                "featgrad" : 0,
                "xytype" : xytype,
                })
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        use_reduce_lr_in_extra_penalty_hook=True,
        use_binary_classification=True, 
        num_epochs=3000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs",
        use_progress_bar=False,
    )
    cfg.save_to_file(f"./config/benchmarks/{cfg.name_experiment}.yaml", no_uuid=True, override=True)
    for seed in seeds:
        c = copy.deepcopy(cfg)
        c.seed = seed
        c.name_experiment = cfg.name_experiment + f"+seed{seed}"
        c.save_to_file(f"./config/benchmarks_seeds/{c.name_experiment}.yaml", no_uuid=True, override=True)

##### CNNET with Diffeomophism

In [None]:
# CNNet Using feat and diffeo net

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet

cfg = AwesomeConfig(
    name_experiment="CNNET_benchmark+feat+diffeo",
    dataset_type=class_name(SISBOSIDataset),
    dataset_args={
        "dataset": SISBOSIConvexityDataset(
                dataset_path="./data/datasets/convexity_dataset",
                transform=False, # Using augmentations
                semantic=False
            ),
        "xytransform": "xy",
        "xytype": "feat",
        "mode": "scribbles",
        "feature_dir": "./data/datasets/convexity_dataset/Feat",
        "bs" : None,
        "dimension": "3d", # 2d for fcnet
        "mode": "model_input",
        "model_input_requires_grad": True, # Can be used for 3d nets
        "batch_size": 1,
        "split_ratio": 1,
        "shuffle_in_dataloader": False,
    },
    segmentation_model_type=class_name(CNNNet),
    segmentation_model_args={
        'width': 16,
        'depth': 2,
        'kernel_size': 3,
        'input': 'rgbxy',
    },
    segmentation_training_mode='single',
    use_prior_model=True,
    prior_model_args=dict(),
    prior_model_type=class_name(ConvexDiffeomorphismNet),
    loss_type=class_name(AwesomeImageLoss),
    loss_args={
        "criterion": GradientPenaltyLoss(**{
            "criterion": torch.nn.BCELoss(),
            "apply_gradient_penalty": True,
            "noneclass" : 2.,
            "xygrad" : 0.01,
            "rgbgrad" : 0.01,
            "name": "BCE+GradientPenalty",})
        },
    use_binary_classification=True, 
    num_epochs=1000,
    device="cuda",
    dtype=str(torch.float32),
    runs_path="./runs",
    use_progress_bar=False,
)
cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", no_uuid=True, override=True)

In [None]:
# CNNet Using featxy and diffeo net

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet

cfg = AwesomeConfig(
    name_experiment="CNNET_benchmark+featxy+diffeo",
    dataset_type=class_name(SISBOSIDataset),
    dataset_args={
        "dataset": SISBOSIConvexityDataset(
                dataset_path="./data/datasets/convexity_dataset",
                transform=False, # Using augmentations
                semantic=False
            ),
        "xytransform": "xy",
        "xytype": "featxy",
        "mode": "scribbles",
        "feature_dir": "./data/datasets/convexity_dataset/Feat",
        "bs" : None,
        "dimension": "3d", # 2d for fcnet
        "mode": "model_input",
        "model_input_requires_grad": True, # Can be used for 3d nets
        "batch_size": 1,
        "split_ratio": 1,
        "shuffle_in_dataloader": False,
    },
    segmentation_model_type=class_name(CNNNet),
    segmentation_model_args={
        'width': 16,
        'depth': 2,
        'kernel_size': 3,
        'input': 'rgbxy',
    },
    segmentation_training_mode='single',
    use_prior_model=True,
    prior_model_args=dict(),
    prior_model_type=class_name(ConvexDiffeomorphismNet),
    loss_type=class_name(AwesomeImageLoss),
    loss_args={
        "criterion": GradientPenaltyLoss(**{
            "criterion": torch.nn.BCELoss(),
            "apply_gradient_penalty": True,
            "noneclass" : 2.,
            "xygrad" : 0.01,
            "rgbgrad" : 0.01,
            "name": "BCE+GradientPenalty",})
        },
    use_binary_classification=True, 
    num_epochs=1000,
    device="cuda",
    dtype=str(torch.float32),
    runs_path="./runs",
    use_progress_bar=False,
)
cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", no_uuid=True, override=True)

In [None]:
# CNNet Using xy and diffeo net

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet

cfg = AwesomeConfig(
    name_experiment="CNNET_benchmark+xy+diffeo",
    dataset_type=class_name(SISBOSIDataset),
    dataset_args={
        "dataset": SISBOSIConvexityDataset(
                dataset_path="./data/datasets/convexity_dataset",
                transform=False, # Using augmentations
                semantic=False
            ),
        "xytransform": "xy",
        "xytype": "xy",
        "mode": "scribbles",
        "feature_dir": "./data/datasets/convexity_dataset/Feat",
        "bs" : None,
        "dimension": "3d", # 2d for fcnet
        "mode": "model_input",
        "model_input_requires_grad": True, # Can be used for 3d nets
        "batch_size": 1,
        "split_ratio": 1,
        "shuffle_in_dataloader": False,
    },
    segmentation_model_type=class_name(CNNNet),
    segmentation_model_args={
        'width': 16,
        'depth': 2,
        'kernel_size': 3,
        'input': 'rgbxy',
    },
    segmentation_training_mode='single',
    use_prior_model=True,
    prior_model_args=dict(),
    prior_model_type=class_name(ConvexDiffeomorphismNet),
    loss_type=class_name(AwesomeImageLoss),
    loss_args={
        "criterion": GradientPenaltyLoss(**{
            "criterion": torch.nn.BCELoss(),
            "apply_gradient_penalty": True,
            "noneclass" : 2.,
            "xygrad" : 0.01,
            "rgbgrad" : 0.01,
            "name": "BCE+GradientPenalty",})
        },
    use_binary_classification=True, 
    num_epochs=1000,
    device="cuda",
    dtype=str(torch.float32),
    runs_path="./runs",
    use_progress_bar=False,
)
cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", no_uuid=True, override=True)

### Our Pixel-wise Model

In [None]:
# Net without Prior

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.dataset.convexity_segmentation_dataset import ConvexitySegmentationDataset
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.fc_net import FCNet
from awesome.model.net import Net

cfg = AwesomeConfig(
    name_experiment="NET_benchmark",
    dataset_type=class_name(ConvexitySegmentationDataset),
    dataset_args={
        "dataset_path": "./data/datasets/convexity_dataset",
        "batch_size": 1,
        "split_ratio": 1.,
        "shuffle_in_dataloader": False,
    },
    scribble_percentage=1.,
    segmentation_model_type=class_name(Net),
    segmentation_model_args=dict(),
    segmentation_training_mode='single',
    use_prior_model=False,
    prior_model_args=None,
    prior_model_type=None,
    loss_type=class_name(torch.nn.BCELoss),
    loss_args={},
    use_binary_classification=True, 
    num_epochs=1000,
    device="cuda",
    dtype=str(torch.float32),
    runs_path="./runs",
    use_progress_bar=False,
)
cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", no_uuid=True, override=True)

In [None]:
# Net with convex prior

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.dataset.convexity_segmentation_dataset import ConvexitySegmentationDataset
from awesome.measures.awesome_loss import AwesomeLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.fc_net import FCNet
from awesome.model.net import Net

cfg = AwesomeConfig(
    name_experiment="NET_benchmark+convex",
    dataset_type=class_name(ConvexitySegmentationDataset),
    dataset_args={
        "dataset_path": "./data/datasets/convexity_dataset",
        "batch_size": 1,
        "split_ratio": 1.,
        "shuffle_in_dataloader": False,
    },
    scribble_percentage=0.8,
    segmentation_model_type=class_name(Net),
    segmentation_model_args=dict(n_hidden=130),
    segmentation_training_mode='single',
    use_prior_model=True,
    prior_model_args=dict(n_hidden=130),
    prior_model_type=class_name(ConvexNet),
    loss_type=class_name(AwesomeLoss),
    loss_args={
        "criterion": torch.nn.BCELoss(),
        "alpha": 1.,
        "name": "BCE",
    },
    use_binary_classification=True, 
    num_epochs=1000,
    device="cuda",
    dtype=str(torch.float32),
    runs_path="./runs",
    use_progress_bar=False,
)
cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", no_uuid=True, override=True)

In [None]:
# Net with convex prior

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.dataset.convexity_segmentation_dataset import ConvexitySegmentationDataset
from awesome.measures.awesome_loss import AwesomeLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.fc_net import FCNet
from awesome.model.net import Net

cfg = AwesomeConfig(
    name_experiment="NET_benchmark+diffeo",
    dataset_type=class_name(ConvexitySegmentationDataset),
    dataset_args={
        "dataset_path": "./data/datasets/convexity_dataset",
        "batch_size": 1,
        "split_ratio": 1.,
        "shuffle_in_dataloader": False,
    },
    scribble_percentage=0.8,
    segmentation_model_type=class_name(Net),
    segmentation_model_args=dict(n_hidden=130),
    segmentation_training_mode='single',
    use_prior_model=True,
    prior_model_args=dict(n_hidden=130),
    prior_model_type=class_name(ConvexDiffeomorphismNet),
    loss_type=class_name(AwesomeLoss),
    loss_args={
        "criterion": torch.nn.BCELoss(),
        "alpha": 1.,
        "name": "BCE",
    },
    use_binary_classification=True, 
    num_epochs=1000,
    device="cuda",
    dtype=str(torch.float32),
    runs_path="./runs",
    use_progress_bar=False,
)
cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", no_uuid=True, override=True)

## FBMS Configs

### FBMS CNNet Convex Sequential

In [None]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from itertools import product


xytypes = ["featxy", "feat"]
datasets = [
'camel01',
 'cars1',
 'cars10',
 'cars4',
 'cars5',
 'cats01',
 'cats03',
 'cats06',
 'dogs01',
 'dogs02',
 'farm01',
 'giraffes01',
 'goats01',
 'horses02',
 'horses04',
 'horses05',
 'lion01',
 'marple12',
 'marple2',
 'marple4',
 'marple6',
 'marple7',
 'marple9',
 'people03',
 'people1',
 'people2',
 'rabbits02',
 'rabbits03',
 'rabbits04',
 'tennis']

it = product(xytypes, datasets)

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    print(xytype, dataset)
    cfg = AwesomeConfig(
        name_experiment=f"CNNET_+{dataset}+{xytype}+convex",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=f"./data/local_datasets/FBMS-59/test/{dataset}"
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": f"./data/local_datasets/FBMS-59/test/{dataset}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": True, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
        },
        segmentation_model_type=class_name(CNNNet),
        segmentation_model_args={
            'width': 16,
            'depth': 2,
            'kernel_size': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=True,
        prior_model_args=dict(),
        prior_model_type=class_name(ConvexNet),
        loss_type=class_name(AwesomeImageLoss),
        loss_args={
            "criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": True,
                "noneclass" : 2.,
                "xygrad" : 0.001,
                "rgbgrad" : 0.001,
                "featgrad" : 0.0,
                "xytype" : xytype,}),
            "prior_criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": False,
                "noneclass" : 2.,}),
            "gamma": 1.0,
            "beta": 1.0,
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        use_binary_classification=True, 
        num_epochs=3000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms",
        use_progress_bar=True,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=[0, 1, 2, 3]
    )
    path = f"./config/fbms/{dataset}/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)

### FBMS CNNNet Joint Convex

In [None]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from itertools import product


# xytypes = ["featxy", "feat"]
# datasets = [
# 'camel01',
#  'cars1',
#  'cars10',
#  'cars4',
#  'cars5',
#  'cats01',
#  'cats03',
#  'cats06',
#  'dogs01',
#  'dogs02',
#  'farm01',
#  'giraffes01',
#  'goats01',
#  'horses02',
#  'horses04',
#  'horses05',
#  'lion01',
#  'marple12',
#  'marple2',
#  'marple4',
#  'marple6',
#  'marple7',
#  'marple9',
#  'people03',
#  'people1',
#  'people2',
#  'rabbits02',
#  'rabbits03',
#  'rabbits04',
#  'tennis']

# it = product(xytypes, datasets)

datasets = ['cars1', 'cars1', 'marple9', 'cars5', 'people2', 'marple2', 'tennis', 'marple4', 'cars10']
features = ['feat', 'featxy', 'feat', 'featxy', 'feat', 'feat', 'feat', 'featxy', 'featxy']

it = zip(features, datasets)

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    print(xytype, dataset)
    cfg = AwesomeConfig(
        name_experiment=f"CNNET_+{dataset}+{xytype}+convex+joint",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=f"./data/local_datasets/FBMS-59/test/{dataset}"
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": f"./data/local_datasets/FBMS-59/test/{dataset}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": True, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
        },
        segmentation_model_type=class_name(CNNNet),
        segmentation_model_args={
            'width': 16,
            'depth': 2,
            'kernel_size': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=True,
        prior_model_args=dict(),
        prior_model_type=class_name(ConvexNet),
        loss_type=class_name(AwesomeImageLossJoint),
        loss_args={
            "criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": True,
                "noneclass" : 2.,
                "xygrad" : 0.001,
                "rgbgrad" : 0.001,
                "featgrad" : 0.0,
                "xytype" : xytype,}),
            "prior_criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": False,
                "noneclass" : 2.,}),
            "gamma": 1.0,
            "beta": 1.0,
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        use_reduce_lr_in_extra_penalty_hook=True,
        use_binary_classification=True, 
        num_epochs=3000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms",
        use_progress_bar=False,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=[0, 1, 2, 3]
    )
    path = f"./config/fbms/joint/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)

### FBMS CNNet Diffeo Sequential

In [None]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from itertools import product

xytypes = ["xy"]#, "featxy", "feat"]
datasets = [
'camel01',
 'cars1',
 'cars10',
 'cars4',
 'cars5',
 'cats01',
 'cats03',
 'cats06',
 'dogs01',
 'dogs02',
 'farm01',
 'giraffes01',
 'goats01',
 'horses02',
 'horses04',
 'horses05',
 'lion01',
 'marple12',
 'marple2',
 'marple4',
 'marple6',
 'marple7',
 'marple9',
 'people03',
 'people1',
 'people2',
 'rabbits02',
 'rabbits03',
 'rabbits04',
 'tennis']

it = product(xytypes, datasets)

# datasets = ['cars1', 'cars1', 'marple9', 'cars5', 'people2', 'marple2', 'tennis', 'marple4', 'cars10']
# features = ['feat', 'featxy', 'feat', 'featxy', 'feat', 'feat', 'feat', 'featxy', 'featxy']

# it = zip(features, datasets)

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    print(xytype, dataset)
    cfg = AwesomeConfig(
        name_experiment=f"CNNET_+{dataset}+{xytype}+diffeo",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=f"./data/local_datasets/FBMS-59/test/{dataset}"
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": f"./data/local_datasets/FBMS-59/test/{dataset}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": True, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
        },
        segmentation_model_type=class_name(CNNNet),
        segmentation_model_args={
            'width': 16,
            'depth': 2,
            'kernel_size': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=True,
        prior_model_args=dict(
            nf_layers=3,
            nf_hidden=70
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(AwesomeImageLoss),
        loss_args={
            "criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": True,
                "noneclass" : 2.,
                "xygrad" : 0.001,
                "rgbgrad" : 0.001,
                "featgrad" : 0.0,
                "xytype" : xytype,}),
            "prior_criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": False,
                "noneclass" : 2.,}),
            "gamma": 1.0,
            "beta": 1.0,
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        use_binary_classification=True, 
        num_epochs=4000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/slurm_monitor/fbms",
        use_progress_bar=False,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=[0, 1, 2, 3]
    )
    path = f"./config/fbms_diffeo/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)

### FBMS CNNNet Joint Diffeo

In [None]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from itertools import product
from awesome.model.convex_diffeomorphism_net import ConvexDiffeomorphismNet

# xytypes = ["featxy", "feat"]
# datasets = [
# 'camel01',
#  'cars1',
#  'cars10',
#  'cars4',
#  'cars5',
#  'cats01',
#  'cats03',
#  'cats06',
#  'dogs01',
#  'dogs02',
#  'farm01',
#  'giraffes01',
#  'goats01',
#  'horses02',
#  'horses04',
#  'horses05',
#  'lion01',
#  'marple12',
#  'marple2',
#  'marple4',
#  'marple6',
#  'marple7',
#  'marple9',
#  'people03',
#  'people1',
#  'people2',
#  'rabbits02',
#  'rabbits03',
#  'rabbits04',
#  'tennis']

# it = product(xytypes, datasets)

datasets = ['cars1', 'cars1', 'marple9', 'cars5', 'people2', 'marple2', 'tennis', 'marple4', 'cars10']
features = ['feat', 'featxy', 'feat', 'featxy', 'feat', 'feat', 'feat', 'featxy', 'featxy']

it = zip(features, datasets)

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    print(xytype, dataset)
    cfg = AwesomeConfig(
        name_experiment=f"CNNET_+{dataset}+{xytype}+diffeo+joint",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=f"./data/local_datasets/FBMS-59/test/{dataset}"
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": f"./data/local_datasets/FBMS-59/test/{dataset}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": True, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
        },
        segmentation_model_type=class_name(CNNNet),
        segmentation_model_args={
            'width': 16,
            'depth': 2,
            'kernel_size': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=True,
        prior_model_args=dict(
            nf_layers=3,
            nf_hidden=70
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(AwesomeImageLossJoint),
        loss_args={
            "criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": True,
                "noneclass" : 2.,
                "xygrad" : 0.001,
                "rgbgrad" : 0.001,
                "featgrad" : 0.0,
                "xytype" : xytype,}),
            "prior_criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": False,
                "noneclass" : 2.,}),
            "gamma": 1.0,
            "beta": 1.0,
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        use_reduce_lr_in_extra_penalty_hook=True,
        use_binary_classification=True, 
        num_epochs=4000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/slurm_monitor/fbms",
        use_progress_bar=False,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=[0, 1, 2, 3]
    )
    path = f"./config/fbms_diffeo_joint/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)

## Unet Amir with Weighting Sequentialy Training => Equivalent what amir does in his paper

In [None]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet

xytypes = ["edge"]
datasets = ['bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   'cars9',
   'cats02',
   'cats04',
   'cats05',
   'cats07',
   'ducks01',
   'horses01',
   'horses03',
   'horses06',
   'lion02',
   'marple1',
   'marple10',
   'marple11',
   'marple13',
   'marple3',
   'marple5',
   'marple8',
   'meerkats01',
   'people04',
   'people05',
   'rabbits01',
   'rabbits05']


it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]

    
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    cfg = AwesomeConfig(
            name_experiment=f"UNET+{dataset}+{xytype}+diffeo",
            dataset_type=class_name(AwesomeDataset),
            dataset_args={
                "dataset": FBMSSequenceDataset(
                        dataset_path=data_path,
                        weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                        processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                        confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                        do_weak_label_preprocessing=True,
                        do_uncertainty_label_flip=True,
                        all_frames=False
                    ),
                "xytype": xytype,
                "feature_dir": f"{data_path}/Feat",
                "dimension": "3d", # 2d for fcnet
                "mode": "model_input",
                "model_input_requires_grad": False,
                "batch_size": 1,
                "split_ratio": 1,
                "shuffle_in_dataloader": False,
                "image_channel_format": "bgr",
                "do_image_blurring": True
            },
            segmentation_model_type=class_name(UNet),
            segmentation_model_args={
                'in_chn': 4,
            },
            segmentation_training_mode='single',
            segmentation_model_state_dict_path=f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth", # Path to the pretrained model
            use_segmentation_output_inversion=True,
            use_prior_model=True,
            prior_model_args=dict(
                nf_layers=3,
                nf_hidden=70
            ),
            prior_model_type=class_name(ConvexDiffeomorphismNet),
            loss_type=class_name(AwesomeImageLoss),
            loss_args={
            "criterion": GradientPenaltyLoss(**{
                    "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms"),
                    "apply_gradient_penalty": False,
                    "noneclass" : 2.,
                    "xygrad" : 0.0,
                    "rgbgrad" : 0.0,
                    "featgrad" : 0.0,
                    "xytype" : xytype}),
                "prior_criterion": GradientPenaltyLoss(**{
                    "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms"),
                    "apply_gradient_penalty": False,
                    "noneclass" : 2.,}),
                "gamma": 0.5,
                "beta": 0.5,
                },
            use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
            extra_penalty_after_n_epochs=800,
            use_reduce_lr_in_extra_penalty_hook=False,
            use_lr_on_plateau_scheduler=False,
            use_binary_classification=True, 
            num_epochs=4000,
            device="cuda",
            dtype=str(torch.float32),
            runs_path="./runs/fbms_local/unet/",
            optimizer_args={
                "lr": 0.01,
                "betas": (0.9, 0.999),
                "eps": 1e-08,
                "weight_decay": 0,
                "amsgrad": False
            },
            use_progress_bar=True,
            semantic_soft_segmentation_code_dir="../siggraph/",
            semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
            plot_indices_during_training_nth_epoch=20,
            plot_indices_during_training=[0, 1, 2, 3],
        )
    path = f"./config/fbms_unet_diffeo/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


# Unet Amir with Weighting Jointly Training => Equivalent what amir does in his paper

In [None]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss

xytypes = ["edge"]
datasets = ['bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   'cars9',
   'cats02',
   'cats04',
   'cats05',
   'cats07',
   'ducks01',
   'horses01',
   'horses03',
   'horses06',
   'lion02',
   'marple1',
   'marple10',
   'marple11',
   'marple13',
   'marple3',
   'marple5',
   'marple8',
   'meerkats01',
   'people04',
   'people05',
   'rabbits01',
   'rabbits05']


it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]

    
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+joint",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    all_frames=False
                ),
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": "bgr",
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': 4,
        },
        segmentation_training_mode='single',
        segmentation_model_state_dict_path=f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth", # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            nf_layers=3,
            nf_hidden=70
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(AwesomeImageLossJoint),
        loss_args={
           "criterion": GradientPenaltyLoss(**{
                "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms"),
                "apply_gradient_penalty": False,
                "noneclass" : 2.,
                "xygrad" : 0.0,
                "rgbgrad" : 0.0,
                "featgrad" : 0.0,
                "xytype" : xytype}),
            "prior_criterion": GradientPenaltyLoss(**{
                "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms"),
                "apply_gradient_penalty": False,
                "noneclass" : 2.,}),
            "gamma": 0.5,
            "beta": 0.5,
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        extra_penalty_after_n_epochs=800,
        use_reduce_lr_in_extra_penalty_hook=True,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=4000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/",
        optimizer_args={
            "lr": 0.01,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "weight_decay": 0,
            "amsgrad": False
        },
        use_progress_bar=True,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=[0, 1, 2, 3],
    )
    path = f"./config/fbms_unet_diffeo/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


## Just Fitting Prior to amirs model without changing the model

In [3]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss

xytypes = ["edge"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]

    
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+only_prior",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    all_frames=False
                ),
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": "bgr",
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': 4,
        },
        segmentation_training_mode='none',
        segmentation_model_state_dict_path=f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth", # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            n_hidden=130,
            n_hidden_layers=1,
            nf_layers=4,
            nf_hidden=130,
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(AwesomeImageLoss),
        loss_args={
           "criterion": GradientPenaltyLoss(**{
                "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms"),
                "apply_gradient_penalty": False,
                "noneclass" : 2.,
                "xygrad" : 0.0,
                "rgbgrad" : 0.0,
                "featgrad" : 0.0,
                "xytype" : xytype}),
            "prior_criterion": GradientPenaltyLoss(**{
                "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms"),
                "apply_gradient_penalty": False,
                "noneclass" : 2.,}),
            "gamma": 0.5,
            "beta": 0.5,
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        extra_penalty_after_n_epochs=800,
        use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=4000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/",
        optimizer_args={
            "lr": 0.01,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "weight_decay": 0,
            "amsgrad": False
        },
        use_progress_bar=True,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=[0, 1, 2, 3],
        agent_args=dict(
            do_pretraining=True, 
            pretrain_only=True, # Do Fitting in pretrain mode only
            force_pretrain=True, 
            pretrain_args=dict(
                lr=0.003,
                use_logger=True,
                reuse_state=True,
                num_epochs=2000,
                reuse_state_epochs=200
                )
        )
    )
    path = f"./config/fbms_unet_diffeo_only_prior/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


Loading frames...:   0%|          | 0/6 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/24 [00:00<?, ?it/s]

  bg_flip_coords = flip_probability[torch.argwhere(bg_flip_mask).squeeze(), :2].int().T


Loading frames...:   0%|          | 0/4 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/3 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/4 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/3 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/2 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/6 [00:00<?, ?it/s]

 Could not find a foreground weak label object id for sample cats04:79!
 Could not find a foreground weak label object id for sample cats04:97!


Loading frames...:   0%|          | 0/6 [00:00<?, ?it/s]

 Could not find a foreground weak label object id for sample cats05:86!


Loading frames...:   0%|          | 0/11 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/26 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/13 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/22 [00:00<?, ?it/s]

 Could not find a foreground weak label object id for sample lion02:139!
 Could not find a foreground weak label object id for sample lion02:159!
 Could not find a foreground weak label object id for sample lion02:179!
 Could not find a foreground weak label object id for sample lion02:199!
 Could not find a foreground weak label object id for sample lion02:219!
 Could not find a foreground weak label object id for sample lion02:239!
 Could not find a foreground weak label object id for sample lion02:259!
 Could not find a foreground weak label object id for sample lion02:415!


Loading frames...:   0%|          | 0/12 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/15 [00:00<?, ?it/s]

 Could not find a foreground weak label object id for sample marple10:399!
 Could not find a foreground weak label object id for sample marple10:449!
 Could not find a foreground weak label object id for sample marple10:459!


Loading frames...:   0%|          | 0/9 [00:00<?, ?it/s]

 Could not find a foreground weak label object id for sample marple11:0!
 Could not find a foreground weak label object id for sample marple11:9!


Loading frames...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/7 [00:00<?, ?it/s]

 Could not find a foreground weak label object id for sample marple8:39!
 Could not find a foreground weak label object id for sample marple8:49!
 Could not find a foreground weak label object id for sample marple8:71!


Loading frames...:   0%|          | 0/12 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/17 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/16 [00:00<?, ?it/s]

## Fitting Only Prior to amirs model, and retrained + xy model

In [7]:
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.se import SE
from awesome.measures.ae import AE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss

xytypes = [("edge", "original"), ("edge", "retrain"), ("edgexy", "retrain_xy")]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


it = []
for dataset in datasets:
    for xytype in xytypes:
        it.append((xytype, dataset, 'train'))

prior_criterion = UnariesConversionLoss(SE(reduction="mean"))

for vals in it:
    xytype, segmentation_model_switch = vals[0]
    dataset = vals[1]
    dataset_kind = "train"


    segmentation_model_state_dict_path = None
    if segmentation_model_switch == "original":
        segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain_xy":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    else:
        raise ValueError(f"Unknown segmentation_model_switch: {segmentation_model_switch}")
    image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"
    input_channels = 4 if xytype == "edge" else 6

    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    all_frames=True,
                    test_weak_label_integrity=False
                )
    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+only_prior{'+REFIT' if segmentation_model_switch != 'original' else ''}",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": real_dataset,
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": image_channel_format,
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': input_channels,
        },
        segmentation_training_mode='none',
        segmentation_model_state_dict_path=segmentation_model_state_dict_path,
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            n_hidden=130,
            n_hidden_layers=2,
            diffeo_args=dict(
                num_coupling=6,
                width=130,
                backbone="normal_block"
            ),
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "alpha": 1,
            "beta": 1,
        },
           
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=800,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=200,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/prior_only/",
        optimizer_args={
            "lr": 0.01,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "weight_decay": 0,
            "amsgrad": False
        },
        use_progress_bar=False,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=real_dataset.get_ground_truth_indices(),
        agent_args=dict(
            do_pretraining=True, 
            pretrain_only=True, # Do Fitting in pretrain mode only
            force_pretrain=True, 
            pretrain_state_path=f"./data/checkpoints/pretrain_states/23-11-13/model_{dataset}_unet_{xytype}_{segmentation_model_switch}_{prior_criterion.get_name()}.pth",
            pretrain_args=dict(
                pretrain_checkpoint_dir=f"./data/checkpoints/pretrain_states/23-11-13/model_{dataset}_unet_{xytype}_{segmentation_model_switch}_{prior_criterion.get_name()}",
                lr=0.001,
                use_logger=True,
                use_step_logger=False,
                num_epochs=2000,
                proper_prior_fit_retrys=1,
                criterion=prior_criterion,
                do_pretrain_checkpoints=True,
                use_pretrain_checkpoints=True,
        
            )
        )
    )

    path = f"./config/fbms_unet_diffeo_only_prior/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


KeyboardInterrupt: 

## FineTune Prior Only Including all Frames

In [4]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss

xytypes = ["edge"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]

    
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+only_prior+all_frames+deeper",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    all_frames=True,
                    _no_indexing=True
                ),
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": "bgr",
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': 4,
        },
        segmentation_training_mode='none',
        segmentation_model_state_dict_path=f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth", # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            n_hidden=130,
            n_hidden_layers=2,
            diffeo_args=dict(
                num_coupling=6,
                width=130,
                backbone="residual_block"
            ),
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "alpha": 1,
            "beta": 1,
        },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        extra_penalty_after_n_epochs=800,
        use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=4000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/",
        optimizer_args={
            "lr": 0.01,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "weight_decay": 0,
            "amsgrad": False
        },
        use_progress_bar=True,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=[0, 1, 2, 3],
        agent_args=dict(
            do_pretraining=True, 
            pretrain_only=True, # Do Fitting in pretrain mode only
            force_pretrain=True, 
            pretrain_args=dict(
                lr=0.003,
                use_logger=True,
                reuse_state=True,
                num_epochs=2000,
                reuse_state_epochs=200
                )
        )
    )
    path = f"./config/fbms_unet_diffeo_only_prior_all_frames_deeper/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


# Fit the Prior to the models 
1. Version with realnvp (11.01)

In [3]:
from typing import Literal
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.net_factory import real_nvp_path_connected_net
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.measures.se import SE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss
from awesome.model.zoo import Zoo

xytypes = ["edge"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"

it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

prior_epochs = 4000
prior_refit_epochs = 400

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]


    segmentation_model_state_dict_path = None
    if segmentation_model_switch == "original":
        segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain_xy":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    else:
        raise ValueError(f"Unknown segmentation_modevl_switch: {segmentation_model_switch}")
    image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"
    input_channels = 4 if xytype == "edge" else 6

    prior_criterion = UnariesConversionLoss(SE(reduction="mean"))
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"

    pretrain_state_path = f"./data/checkpoints/pretrain_states/2024-01-11/model_{dataset}_unet_spatial_realnvp_{prior_epochs}_{prior_refit_epochs}"


    real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=True,
                    all_frames=True,
                    _no_indexing=True
                )
    real_dataset.test_weak_label_integrity = True
    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+{segmentation_model_switch}+ep{prior_epochs}+refit{prior_refit_epochs}+realnvp",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": real_dataset,
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": image_channel_format,
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': input_channels,
        },
        segmentation_training_mode='multi',
        segmentation_model_state_dict_path=segmentation_model_state_dict_path, # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            channels=2,
            hidden_units=32,
            flow_n_flows=12,
            flow_output_fn="tanh",
            norm="minmax",
            convex_net_hidden_units=130,
            convex_net_hidden_layers=2,
        ),
        prior_model_type=class_name(real_nvp_path_connected_net),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "penalty_criterion": prior_criterion.criterion,
            "alpha": 1,
            "beta": 1,
        },
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=1,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=0,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/prior_training",
        optimizer_args={
            "lr": 0.003,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=False,
        plot_indices_during_training=real_dataset.get_ground_truth_indices(),
        save_images_after_pretraining=True,
        include_unaries_when_saving=True,
        agent_args=dict(
             do_pretraining=True,
             pretrain_only=True, 
             force_pretrain=True,
             pretrain_state_path=pretrain_state_path + ".pth",
             pretrain_args=dict(
                 use_pretrain_checkpoints=True,
                 do_pretrain_checkpoints=True,
                 pretrain_checkpoint_dir=pretrain_state_path,
                 lr=0.001,
                 use_logger=True,
                 use_step_logger=True,
                 num_epochs=4000,
                 proper_prior_fit_retrys=1,
                 reuse_state_epochs=400,
                 # Prefit flow net identity => Flow will be identity(-like) at the beginning
                 prefit_flow_net_identity=True,
                 prefit_flow_net_identity_lr=1e-2,
                 prefit_flow_net_identity_weight_decay=1e-5,
                 prefit_flow_net_identity_num_epochs=100,
                 # Prefit convex net, to start with a convex thing
                 prefit_convex_net=True,
                 prefit_convex_net_lr=1e-3,
                 prefit_convex_net_weight_decay=0,
                 prefit_convex_net_num_epochs=200,
                 zoo=Zoo()
             )
        ),
    )

    path = f"./config/fbms_unet_prior_realnvp/2024_01_11/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


2. Version with realnvp (11.01) and longer training for refit images (15.01)

In [4]:
from typing import Literal
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.net_factory import real_nvp_path_connected_net
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.measures.se import SE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss
from awesome.model.zoo import Zoo

xytypes = ["edge"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"

it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

prior_epochs = 4000
prior_refit_epochs = 1000

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]


    segmentation_model_state_dict_path = None
    if segmentation_model_switch == "original":
        segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain_xy":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    else:
        raise ValueError(f"Unknown segmentation_modevl_switch: {segmentation_model_switch}")
    image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"
    input_channels = 4 if xytype == "edge" else 6

    prior_criterion = UnariesConversionLoss(SE(reduction="mean"))
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"

    pretrain_state_path = f"./data/checkpoints/pretrain_states/2024-01-15/model_{dataset}_unet_spatial_realnvp_{prior_epochs}_{prior_refit_epochs}"


    real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=True,
                    all_frames=True,
                    _no_indexing=True
                )
    real_dataset.test_weak_label_integrity = True
    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+{segmentation_model_switch}+ep{prior_epochs}+refit{prior_refit_epochs}+realnvp",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": real_dataset,
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": image_channel_format,
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': input_channels,
        },
        segmentation_training_mode='multi',
        segmentation_model_state_dict_path=segmentation_model_state_dict_path, # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            channels=2,
            hidden_units=32,
            flow_n_flows=12,
            flow_output_fn="tanh",
            norm="minmax",
            convex_net_hidden_units=130,
            convex_net_hidden_layers=2,
        ),
        prior_model_type=class_name(real_nvp_path_connected_net),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "penalty_criterion": prior_criterion.criterion,
            "alpha": 1,
            "beta": 1,
        },
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=1,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=0,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/prior_training",
        optimizer_args={
            "lr": 0.003,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=False,
        plot_indices_during_training=real_dataset.get_ground_truth_indices(),
        save_images_after_pretraining=True,
        include_unaries_when_saving=True,
        agent_args=dict(
             do_pretraining=True,
             pretrain_only=True, 
             force_pretrain=True,
             pretrain_state_path=pretrain_state_path + ".pth",
             pretrain_args=dict(
                 use_pretrain_checkpoints=True,
                 do_pretrain_checkpoints=True,
                 pretrain_checkpoint_dir=pretrain_state_path,
                 lr=0.001,
                 use_logger=True,
                 use_step_logger=True,
                 num_epochs=prior_epochs,
                 proper_prior_fit_retrys=1,
                 reuse_state_epochs=prior_refit_epochs,
                 # Prefit flow net identity => Flow will be identity(-like) at the beginning
                 prefit_flow_net_identity=True,
                 prefit_flow_net_identity_lr=1e-2,
                 prefit_flow_net_identity_weight_decay=1e-5,
                 prefit_flow_net_identity_num_epochs=100,
                 # Prefit convex net, to start with a convex thing
                 prefit_convex_net=True,
                 prefit_convex_net_lr=1e-3,
                 prefit_convex_net_weight_decay=0,
                 prefit_convex_net_num_epochs=200,
                 zoo=Zoo()
             )
        ),
    )

    path = f"./config/fbms_unet_prior_realnvp/2024_01_15/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


## Jointly Train Unet With prior

In [5]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss

xytypes = ["edge"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]

    
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+joint",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    all_frames=True,
                    _no_indexing=True
                ),
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": "bgr",
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': 4,
        },
        segmentation_training_mode='multi',
        segmentation_model_state_dict_path=f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth", # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            n_hidden=130,
            n_hidden_layers=1,
            nf_layers=4,
            nf_hidden=130,
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "alpha": 1,
            "beta": 1,
        },
            
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=1,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=200,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/",
        optimizer_args={
            "lr": 0.003,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=False,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=[0, 1, 2, 3],
        agent_args=dict(
             do_pretraining=True, 
             pretrain_state_path=f"./data/checkpoints/pretrain_states/model_{dataset}_joint_unet.pth",
             pretrain_args=dict(
                 lr=0.003,
                 use_logger=False
             )
        ),
        weight_decay_on_weight_norm_modules=5e-5,
    )
    path = f"./config/fbms_unet_diffeo_joint/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


### Deeper!

In [None]:
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss

xytypes = ["edge"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]

    
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+joint+deeper",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    all_frames=True,
                    _no_indexing=True
                ),
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": "bgr",
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': 4,
        },
        segmentation_training_mode='multi',
        segmentation_model_state_dict_path=f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth", # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            n_hidden=130,
            n_hidden_layers=2,
            diffeo_args=dict(
                num_coupling=6,
                width=130,
                backbone="residual_block"
            ),
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "alpha": 1,
            "beta": 1,
        },
            
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=1,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=200,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/",
        optimizer_args={
            "lr": 0.003,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=False,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=[0, 1, 2, 3],
        agent_args=dict(
             do_pretraining=True, 
             pretrain_state_path=f"./data/checkpoints/pretrain_states/model_{dataset}_joint_unet_deeper.pth",
             pretrain_args=dict(
                 lr=0.003,
                 use_logger=False
             )
        ),
        weight_decay_on_weight_norm_modules=5e-5,
    )
    path = f"./config/fbms_unet_diffeo_joint_deeper/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


### Jointly train with spatial information

In [8]:
from typing import Literal
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.measures.se import SE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss

xytypes = ["edgexy"]
datasets = [
   'bear01',
   #'bear02',
   'cars2',
   'cars3',
   'cars6',
   #'cars7',
   #'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   ##'cats04',
   ##'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   ##'horses01',
   ##'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   ##'marple1',
   ##'marple10',
   ##'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   #'marple5',
   #'marple8', # Multi output Model
   ##'meerkats01',
   'people04',
   #'people05', # Multi output Model
   ##'rabbits01',
   #'rabbits05' # Multi output Model
   ]


segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "retrain_xy"

it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]


    segmentation_model_state_dict_path = None
    if segmentation_model_switch == "original":
        segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain_xy":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    else:
        raise ValueError(f"Unknown segmentation_model_switch: {segmentation_model_switch}")
    image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"


    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"

    real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=False,
                    all_frames=True,
                )

    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+joint+deeper+REFIT",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": real_dataset,
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": image_channel_format,
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': 6,
        },
        segmentation_training_mode='multi',
        segmentation_model_state_dict_path=segmentation_model_state_dict_path, # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            n_hidden=130,
            n_hidden_layers=2,
            diffeo_args=dict(
                num_coupling=6,
                width=130,
                backbone="residual_block"
            ),
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "alpha": 1,
            "beta": 1,
        },
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=1,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=15,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/",
        optimizer_args={
            "lr": 0.0001,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=True,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=5,
        plot_indices_during_training=real_dataset.get_ground_truth_indices(),
        compute_metrics_during_training_nth_epoch=5,
        agent_args=dict(
             do_pretraining=True, 
             pretrain_state_path=f"./data/checkpoints/pretrain_states/model_{dataset}_unet_{xytype}_{segmentation_model_switch}.pth",
             pretrain_args=dict(
                 lr=0.003,
                 use_logger=True,
                 proper_prior_fit_retrys=1,
                 criterion=UnariesConversionLoss(SE(reduction="mean"))
             )
        ),
        weight_decay_on_weight_norm_modules=5e-5,
    )
    path = f"./config/fbms_unet_quick/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


Loading frames...:   0%|          | 0/100 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/30 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/19 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/30 [00:00<?, ?it/s]

Loading frames...:   0%|          | 0/320 [00:00<?, ?it/s]

# Joint Training of Unet and Prior with CRF (Used Config)
Version 11.01.2024

In [8]:
from typing import Literal
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.measures.se import SE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss
from awesome.model.net_factory import real_nvp_path_connected_net
from awesome.model.zoo import Zoo

xytypes = ["edge"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"

it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

prior_epochs = 4000
prior_refit_epochs = 400

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]


    segmentation_model_state_dict_path = None
    if segmentation_model_switch == "original":
        segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain_xy":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    else:
        raise ValueError(f"Unknown segmentation_model_switch: {segmentation_model_switch}")
    image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"
    input_channels = 4 if xytype == "edge" else 6

    prior_criterion = UnariesConversionLoss(SE(reduction="mean"))
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    pretrain_state_path = f"./data/checkpoints/pretrain_states/2024-01-11/model_{dataset}_unet_spatial_realnvp_{prior_epochs}_{prior_refit_epochs}"

    real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=True,
                    all_frames=True,
                    _no_indexing=True,
                )

    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+{segmentation_model_switch}+joint",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": real_dataset,
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": image_channel_format,
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': input_channels,
        },
        segmentation_training_mode='multi',
        segmentation_model_state_dict_path=segmentation_model_state_dict_path, # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            channels=2,
            hidden_units=32,
            flow_n_flows=12,
            flow_output_fn="tanh",
            norm="minmax",
            convex_net_hidden_units=130,
            convex_net_hidden_layers=2,
        ),
        prior_model_type=class_name(real_nvp_path_connected_net),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "penalty_criterion": prior_criterion.criterion,
            "alpha": 1,
            "beta": 1,
        },
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=1,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=15,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/",
        optimizer_args={
            "lr": 0.0001,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=False,

        plot_indices_during_training_nth_epoch=5,
        compute_metrics_during_training_nth_epoch=5,

        compute_crf_with_metrics=True,
        compute_crf_after_training=True,
        compute_crf_after_pretraining=True,
        save_images_after_pretraining=True,
        include_unaries_when_saving=True,

        plot_indices_during_training=real_dataset.get_ground_truth_indices(),
        agent_args=dict(
             do_pretraining=True,
             pretrain_only=False, 
             force_pretrain=False,
             pretrain_state_path=pretrain_state_path + ".pth",
             pretrain_args=dict(
                 use_pretrain_checkpoints=True,
                 do_pretrain_checkpoints=True,
                 pretrain_checkpoint_dir=pretrain_state_path,
                 lr=0.001,
                 use_logger=True,
                 use_step_logger=True,
                 num_epochs=prior_epochs,
                 proper_prior_fit_retrys=1,
                 reuse_state_epochs=prior_refit_epochs,
                 # Prefit flow net identity => Flow will be identity(-like) at the beginning
                 prefit_flow_net_identity=True,
                 prefit_flow_net_identity_lr=1e-2,
                 prefit_flow_net_identity_weight_decay=1e-5,
                 prefit_flow_net_identity_num_epochs=100,
                 # Prefit convex net, to start with a convex thing
                 prefit_convex_net=True,
                 prefit_convex_net_lr=1e-3,
                 prefit_convex_net_weight_decay=0,
                 prefit_convex_net_num_epochs=200,
                 zoo=Zoo()
             )
        ),
        weight_decay_on_weight_norm_modules=0,
    )
    path = f"./config/fbms_unet_diffeo_joint/2024_01_11/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


# Joint Training of Unet and Prior with CRF (Used Config) and Seeds

In [3]:
import itertools
from typing import Literal
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.measures.se import SE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss
from awesome.model.net_factory import real_nvp_path_connected_net
from awesome.model.zoo import Zoo

xytypes = ["edge"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]

seeds = [47, 131]

segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"
it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

prior_epochs = 4000
prior_refit_epochs = 400

it = itertools.product(seeds, it)

for seed, vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]


    segmentation_model_state_dict_path = None
    if segmentation_model_switch == "original":
        segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain_xy":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    else:
        raise ValueError(f"Unknown segmentation_model_switch: {segmentation_model_switch}")
    image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"
    input_channels = 4 if xytype == "edge" else 6

    prior_criterion = UnariesConversionLoss(SE(reduction="mean"))
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    pretrain_state_path = f"./data/checkpoints/pretrain_states/2024-01-11/model_{dataset}_unet_spatial_realnvp_{prior_epochs}_{prior_refit_epochs}_seed{seed}"

    real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=True,
                    all_frames=True,
                    _no_indexing=True,
                )

    cfg = AwesomeConfig(
        seed=seed,
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+{segmentation_model_switch}+joint+seed{seed}",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": real_dataset,
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": image_channel_format,
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': input_channels,
        },
        segmentation_training_mode='multi',
        segmentation_model_state_dict_path=segmentation_model_state_dict_path, # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            channels=2,
            hidden_units=32,
            flow_n_flows=12,
            flow_output_fn="tanh",
            norm="minmax",
            convex_net_hidden_units=130,
            convex_net_hidden_layers=2,
        ),
        prior_model_type=class_name(real_nvp_path_connected_net),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "penalty_criterion": prior_criterion.criterion,
            "alpha": 1,
            "beta": 1,
        },
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=1,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=15,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/",
        optimizer_args={
            "lr": 0.0001,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=False,

        plot_indices_during_training_nth_epoch=5,
        compute_metrics_during_training_nth_epoch=5,

        compute_crf_with_metrics=True,
        compute_crf_after_training=True,
        compute_crf_after_pretraining=True,
        save_images_after_pretraining=True,
        include_unaries_when_saving=True,

        plot_indices_during_training=real_dataset.get_ground_truth_indices(),
        agent_args=dict(
             do_pretraining=True,
             pretrain_only=False, 
             force_pretrain=False,
             pretrain_state_path=pretrain_state_path + ".pth",
             pretrain_args=dict(
                 use_pretrain_checkpoints=True,
                 do_pretrain_checkpoints=True,
                 pretrain_checkpoint_dir=pretrain_state_path,
                 lr=0.001,
                 use_logger=True,
                 use_step_logger=True,
                 num_epochs=prior_epochs,
                 proper_prior_fit_retrys=1,
                 reuse_state_epochs=prior_refit_epochs,
                 # Prefit flow net identity => Flow will be identity(-like) at the beginning
                 prefit_flow_net_identity=True,
                 prefit_flow_net_identity_lr=1e-2,
                 prefit_flow_net_identity_weight_decay=1e-5,
                 prefit_flow_net_identity_num_epochs=100,
                 # Prefit convex net, to start with a convex thing
                 prefit_convex_net=True,
                 prefit_convex_net_lr=1e-3,
                 prefit_convex_net_weight_decay=0,
                 prefit_convex_net_num_epochs=200,
                 zoo=Zoo()
             )
        ),
        weight_decay_on_weight_norm_modules=0,
    )
    path = f"./config/fbms_unet_diffeo_joint/2024_01_11/seed/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


  bg_flip_coords = flip_probability[torch.argwhere(bg_flip_mask).squeeze(), :2].int().T


2. Version 15.01.2024 realnvp longer fitting 

In [6]:
from typing import Literal
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.measures.se import SE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss
from awesome.model.net_factory import real_nvp_path_connected_net
from awesome.model.zoo import Zoo

xytypes = ["edge"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"

it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

prior_epochs = 4000
prior_refit_epochs = 1000

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]


    segmentation_model_state_dict_path = None
    if segmentation_model_switch == "original":
        segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    elif segmentation_model_switch == "retrain_xy":
        segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
    else:
        raise ValueError(f"Unknown segmentation_model_switch: {segmentation_model_switch}")
    image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"
    input_channels = 4 if xytype == "edge" else 6

    prior_criterion = UnariesConversionLoss(SE(reduction="mean"))
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    pretrain_state_path = f"./data/checkpoints/pretrain_states/2024-01-15/model_{dataset}_unet_spatial_realnvp_{prior_epochs}_{prior_refit_epochs}"

    real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=False,
                    all_frames=True,
                )

    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+{segmentation_model_switch}+joint",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": real_dataset,
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": image_channel_format,
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': input_channels,
        },
        segmentation_training_mode='multi',
        segmentation_model_state_dict_path=segmentation_model_state_dict_path, # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            channels=2,
            hidden_units=32,
            flow_n_flows=12,
            flow_output_fn="tanh",
            norm="minmax",
            convex_net_hidden_units=130,
            convex_net_hidden_layers=2,
        ),
        prior_model_type=class_name(real_nvp_path_connected_net),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "penalty_criterion": prior_criterion.criterion,
            "alpha": 1,
            "beta": 1,
        },
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=1,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=15,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/",
        optimizer_args={
            "lr": 0.0001,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=False,

        plot_indices_during_training_nth_epoch=5,
        compute_metrics_during_training_nth_epoch=5,

        compute_crf_with_metrics=True,
        compute_crf_after_training=True,
        compute_crf_after_pretraining=True,
        save_images_after_pretraining=True,
        include_unaries_when_saving=True,

        plot_indices_during_training=real_dataset.get_ground_truth_indices(),
        agent_args=dict(
             do_pretraining=True,
             pretrain_only=False, 
             force_pretrain=False,
             pretrain_state_path=pretrain_state_path + ".pth",
             pretrain_args=dict(
                 use_pretrain_checkpoints=True,
                 do_pretrain_checkpoints=True,
                 pretrain_checkpoint_dir=pretrain_state_path,
                 lr=0.001,
                 use_logger=True,
                 use_step_logger=True,
                 num_epochs=prior_epochs,
                 proper_prior_fit_retrys=1,
                 reuse_state_epochs=prior_refit_epochs,
                 # Prefit flow net identity => Flow will be identity(-like) at the beginning
                 prefit_flow_net_identity=True,
                 prefit_flow_net_identity_lr=1e-2,
                 prefit_flow_net_identity_weight_decay=1e-5,
                 prefit_flow_net_identity_num_epochs=100,
                 # Prefit convex net, to start with a convex thing
                 prefit_convex_net=True,
                 prefit_convex_net_lr=1e-3,
                 prefit_convex_net_weight_decay=0,
                 prefit_convex_net_num_epochs=200,
                 zoo=Zoo()
             )
        ),
        weight_decay_on_weight_norm_modules=0,
    )
    path = f"./config/fbms_unet_diffeo_joint/2024_01_15/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


## Spatio Temporal with Noise

In [8]:
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.model.noisy_path_connected_net import NoisyPathConnectedNet
from awesome.model.zoo import Zoo
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss
from awesome.measures.se import SE
from awesome.model.net_factory import real_nvp_path_connected_net
import itertools

xytype = "edge"
dataset_kind = "train"
dataset = "cars3"
all_frames = True
segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"


segmentation_model_state_dict_path = None
if segmentation_model_switch == "original":
    segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain_xy":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
else:
    raise ValueError(f"Unknown segmentation_model_switch: {segmentation_model_switch}")
image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"
input_channels = 4 if xytype == "edge" else 6
prior_criterion = UnariesConversionLoss(SE(reduction="mean"))

data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"

real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=False,
                    all_frames=True,
                )
data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"

batch_size = 2
prior_epochs = 1000
prior_reuse_state_epochs = 400
prefit_flow_grid_epochs = 30
prefit_convex_net_epochs = 400
noisy_percentages = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
seeds = [42, 47, 131]
ops = itertools.product(noisy_percentages, seeds)

for noisy_percentage, seed in ops:
    cfg = AwesomeConfig(
            seed=seed,
            name_experiment=f"UNET+{dataset}+{xytype}+diffeo+only_prior+realnvp+spatio-temporal+noisy+seed{seed}+np{str(noisy_percentage).replace('.', '_')}",
            dataset_type=class_name(AwesomeDataset),
            dataset_args={
                "dataset": real_dataset,
                "xytype": xytype,
                "feature_dir": f"{data_path}/Feat",
                "dimension": "3d", # 2d for fcnet
                "mode": "model_input",
                "model_input_requires_grad": False,
                "batch_size": 1,
                "split_ratio": 1,
                "shuffle_in_dataloader": False,
                "image_channel_format": image_channel_format,
                "do_image_blurring": True,
                "spatio_temporal": True
            },
            segmentation_model_type=class_name(UNet),
            segmentation_model_args={
                'in_chn': input_channels,
            },
            segmentation_training_mode='multi',
            segmentation_model_state_dict_path=segmentation_model_state_dict_path, # Path to the pretrained model
            use_segmentation_output_inversion=True,
            use_prior_model=True,
            prior_model_args=dict(
                channels=3,
                hidden_units=32,
                flow_n_flows=12,
                flow_output_fn="tanh",
                norm="minmax",
                convex_net_hidden_units=130,
                convex_net_hidden_layers=2,
                network_type=NoisyPathConnectedNet,
                
            ),
            prior_model_type=class_name(real_nvp_path_connected_net),
            loss_type=class_name(FBMSJointLoss),
            loss_args={
                "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
                "alpha": 1,
                "beta": 1,
            },
            use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
            #extra_penalty_after_n_epochs=1,
            #use_reduce_lr_in_extra_penalty_hook=False,
            use_lr_on_plateau_scheduler=False,
            use_binary_classification=True, 
            num_epochs=100,
            device="cuda",
            dtype=str(torch.float32),
            runs_path="./runs/fbms_local/unet/noisy_spatio_temporal",
            optimizer_args={
                "lr": 0.003,
                "betas": (0.9, 0.999),
                "eps": 1e-08,
                "amsgrad": False
            },
            use_progress_bar=True,
            plot_indices_during_training_nth_epoch=5,
            plot_indices_during_training=real_dataset.get_ground_truth_indices(),
            save_images_after_pretraining=True,
            include_unaries_when_saving=True,
            agent_args=dict(
                do_pretraining=True,
                pretrain_only=True, 
                force_pretrain=True,
                pretrain_args=dict(
                    use_pretrain_checkpoints=False,
                    do_pretrain_checkpoints=False,
                    lr=0.001,
                    use_logger=True,
                    use_step_logger=True,
                    num_epochs=prior_epochs,
                    proper_prior_fit_retrys=1,
                    reuse_state_epochs=prior_reuse_state_epochs,
                    # Prefit flow net identity => Flow will be identity(-like) at the beginning
                    prefit_flow_net_identity=True,
                    prefit_flow_net_identity_lr=1e-2,
                    prefit_flow_net_identity_weight_decay=1e-5,
                    prefit_flow_net_identity_num_epochs=prefit_flow_grid_epochs,
                    # Prefit convex net, to start with a convex thing
                    prefit_convex_net=True,
                    prefit_convex_net_lr=1e-3,
                    prefit_convex_net_weight_decay=0,
                    prefit_convex_net_num_epochs=prefit_convex_net_epochs,
                    batch_size=batch_size,
                    noisy_percentage=noisy_percentage,
                    zoo=Zoo()
                )
            ),
            #output_folder="./runs/fbms_local/unet/TestUnet/",
        )
    cfg.save_to_file(f"./config/fbms_noisy_spatio_temporal/{cfg.name_experiment}.yaml", 
                    override=True, 
                    make_dirs=True,
                    no_uuid=True)
    

  fg_flip_coords = flip_probability[torch.argwhere(fg_flip_mask).squeeze(), :2].int().T


## Refitting amirs models with my code to test if they are approximately the same

In [4]:
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss

xytypes = ["edge"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]

    
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+REFIT",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    all_frames=True,
                    _no_indexing=True
                ),
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": True,
            "image_channel_format": "rgb",
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': 4,
        },
        segmentation_training_mode='multi',
        use_segmentation_output_inversion=True,
        use_prior_model=False,
        loss_type=class_name(WeightedLoss),
        loss_args={
            "criterion": torch.nn.BCELoss(),
            "mode": "sssdms",
            "noneclass": 2,
        },
        use_step_lr_scheduler=True,
        step_lr_scheduler_args={
            "gamma": 0.1,
            "step_size": 5,
        },
        use_binary_classification=True, 
        num_epochs=15,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/refit/",
        optimizer_args={
            "lr": 0.01,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=True,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=5,
        compute_metrics_during_training_nth_epoch=5,
        plot_indices_during_training=[0, 1, 2, 3],
        agent_args=dict(
             do_pretraining=False,
        ),
        weight_decay_on_weight_norm_modules=0,
    )
    path = f"./config/fbms_unet_refit/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


## Refitting with Additional spacial Information

In [6]:
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss

xytypes = ["edgexy"]
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   #'cars9', # Multi output Model
   #'cats02', # Multi output Model
   'cats04',
   'cats05',
   #'cats07',  # Multi output Model
   #'ducks01', # Multi output Model
   'horses01',
   'horses03',
   #'horses06', # Multi output Model
   #'lion02',
   'marple1',
   'marple10',
   'marple11',
   #'marple13', # Multi output Model
   #'marple3', # Multi output Model
   'marple5',
   #'marple8', # Multi output Model
   'meerkats01',
   'people04',
   #'people05', # Multi output Model
   'rabbits01',
   #'rabbits05' # Multi output Model
   ]


it = zip(xytypes * len(datasets), datasets, ['train'] * len(datasets))

for vals in it:
    xytype = vals[0]
    dataset = vals[1]
    dataset_kind = vals[2]

    
    data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
    cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+REFIT",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    all_frames=True,
                    _no_indexing=True
                ),
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": True,
            "image_channel_format": "rgb",
            "do_image_blurring": True
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': 6,
        },
        segmentation_training_mode='multi',
        use_segmentation_output_inversion=True,
        use_prior_model=False,
        loss_type=class_name(WeightedLoss),
        loss_args={
            "criterion": torch.nn.BCELoss(),
            "mode": "sssdms",
            "noneclass": 2,
        },
        use_step_lr_scheduler=True,
        step_lr_scheduler_args={
            "gamma": 0.1,
            "step_size": 5,
        },
        use_binary_classification=True, 
        num_epochs=15,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/refit/",
        optimizer_args={
            "lr": 0.01,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=True,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=5,
        compute_metrics_during_training_nth_epoch=5,
        plot_indices_during_training=[0, 1, 2, 3],
        agent_args=dict(
             do_pretraining=False,
        ),
        weight_decay_on_weight_norm_modules=0,
    )
    path = f"./config/fbms_unet_refit_spatial/{cfg.name_experiment}.yaml"
    os.makedirs(os.path.dirname(path), exist_ok=True)
    cfg.save_to_file(path, override=True, no_uuid=True)


### Util: Run all configs


In [None]:
from awesome.run.multi_config_config import MultiConfigConfig

cfg = MultiConfigConfig(
    runner_type=class_name(AwesomeRunner),
    create_job_file=True,
    mode="scan_dir",
    scan_config_directory="./config/benchmarks",
    runs_path="./runs",
    runner_script_path="./scripts/run.py",
    config_directory="./scripts/slurm/config/",
    job_file_path="./scripts/slurm/job_files/",
    name_experiment="Run_Benchmark_Configs",
)
os.makedirs(cfg.runs_path, exist_ok=True)
os.makedirs(cfg.job_file_path, exist_ok=True)
os.makedirs(cfg.config_directory, exist_ok=True)

cfg.save_to_file(f"./config/multi_config/{cfg.name_experiment}.yaml", override=True)

In [None]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet

xytype = "xy"

dataset = "marple2"

cfg = AwesomeConfig(
        name_experiment=f"CNNET_+{dataset}+{xytype}+diffeo+joint",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=f"./data/local_datasets/FBMS-59/test/{dataset}"
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": f"./data/local_datasets/FBMS-59/test/{dataset}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": True, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "subset": 1
        },
        segmentation_model_type=class_name(CNNNet),
        segmentation_model_args={
            'width': 16,
            'depth': 2,
            'kernel_size': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=True,
        prior_model_args=dict(
            nf_layers=3,
            nf_hidden=70
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(AwesomeImageLossJoint),
        loss_args={
            "criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": True,
                "noneclass" : 2.,
                "xygrad" : 0.001,
                "rgbgrad" : 0.001,
                "featgrad" : 0.0,
                "xytype" : xytype,}),
            "prior_criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": False,
                "noneclass" : 2.,}),
            "gamma": 0.5,
            "beta": 0.5,
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        extra_penalty_after_n_epochs=800,
        use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=4000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local",
        optimizer_args={
            "lr": 0.02,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "weight_decay": 0,
            "amsgrad": False
        },
        use_progress_bar=True,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=50,
        plot_indices_during_training=[0],
        tf_use_gpu=True,
    )
#cfg.save_to_file("./config/Test_FBMS_CNNET.yaml", override=True)

In [None]:
nodes = [ "gpu-node" + "{:03d}".format(x) for x in range(1, 11)]
user = "js267086"
address = "omni.zimt.uni-siegen.de"

for node in nodes:
    cfg = f"""
Host {node}
    User {user}
    ProxyCommand ssh -t -W %h:%p -q {user}@{address}
"""
    print(cfg)