In [10]:
%load_ext autoreload
%autoreload 2
from awesome.model.unet import UNet
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.util.path_tools import get_project_root_path
import os
from awesome.run.functions import plot_as_image, channel_masks_to_value_mask, transparent_added_listed_colormap, get_mpl_figure, plot_mask
import matplotlib.pyplot as plt
import numpy as np
import torch
os.chdir(get_project_root_path())


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
dataset_kind = "train"
dataset = "ducks01"
data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
fbms_ds = FBMSSequenceDataset(
                        dataset_path=data_path,
                        weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based_new",
                        processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                        confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                        do_weak_label_preprocessing=False,
                        do_uncertainty_label_flip=False,
                        all_frames=True,
                        test_weak_label_integrity=False,
                        label_mode="multiple_objects",
                        segmentation_object_id=[]
                    )


In [12]:
from typing import Literal
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.measures.se import SE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss
from awesome.measures.weighted_loss import WeightedLoss
from awesome.model.batch_size_multi_prior_module import BatchSizeMultiPriorModule
from awesome.model.combined_segmentation_module import CombinedSegmentationModule
from awesome.model.multiple_object_aware_path_connected_net import MultipleObjectsAwarePathConnectedNet
from awesome.model.number_based_multi_prior_module import NumberBasedMultiPriorModule
from awesome.model.zoo import Zoo
from awesome.model.net_factory import real_nvp_path_connected_net
from awesome.run.awesome_config import AwesomeConfig
from awesome.util.reflection import class_name

xytype = "edge"
dataset_kind = "train"
dataset = "ducks01"
all_frames = True
prior_epochs = 4000
prior_refit_epochs = 400
subset = None # 0 #slice(0, 5)
segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"
segmentation_model_state_dict_path = None
    
if segmentation_model_switch == "original":
    segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain_xy":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
else:
    raise ValueError(f"Unknown segmentation_modevl_switch: {segmentation_model_switch}")
image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"
input_channels = 4 if xytype == "edge" else 6

prior_criterion = UnariesConversionLoss(SE(reduction="mean"))
data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"

pretrain_state_path = f"./data/checkpoints/pretrain_states/2024-01-11/model_{dataset}_unet_spatial_realnvp_{prior_epochs}_{prior_refit_epochs}_multi_prior_testing"


real_dataset = FBMSSequenceDataset(
                dataset_path=data_path,
                weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based_new",
                processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                do_weak_label_preprocessing=False,
                do_uncertainty_label_flip=False,
                test_weak_label_integrity=False,
                all_frames=True,
                label_mode="multiple_objects",
                segmentation_object_id=[]
            )
real_dataset.test_weak_label_integrity = True
prior_epochs = 4000
prior_refit_epochs = 400
segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"
number_of_objects = real_dataset.get_number_of_objects()

inner_prior_factory = real_nvp_path_connected_net
inner_prior_factory_model_args=dict(
        channels=2,
        hidden_units=32,
        flow_n_flows=12,
        flow_output_fn="tanh",
        norm="minmax",
        convex_net_hidden_units=130,
        convex_net_hidden_layers=2,
        network_type=MultipleObjectsAwarePathConnectedNet,
    )

prior_type, prior_args = BatchSizeMultiPriorModule.get_type_args(*NumberBasedMultiPriorModule.get_type_args(
    inner_prior_factory, inner_prior_factory_model_args
))

cfg = AwesomeConfig(
    name_experiment=f"UNET+nbatch+multiobject+testing",
    combined_segmentation_module_type=class_name(CombinedSegmentationModule),
    combined_segmentation_module_args=dict(),
    dataset_type=class_name(AwesomeDataset),
    dataset_args={
        "dataset": real_dataset,
        "xytype": xytype,
        "feature_dir": f"{data_path}/Feat",
        "dimension": "3d", # 2d for fcnet
        "mode": "model_input",
        "model_input_requires_grad": False,
        "batch_size": 1,
        "split_ratio": 1,
        "shuffle_in_dataloader": False,
        "image_channel_format": image_channel_format,
        "do_image_blurring": True
    },
    segmentation_model_type=class_name(UNet),
    segmentation_model_args={
        'in_chn': input_channels,
        'out_chn': number_of_objects if number_of_objects == 1 else number_of_objects + 1, # If single object we use BCE, if multiple we use CE and neet backgorund channel
    },
    segmentation_training_mode='multi',
    segmentation_model_state_dict_path=segmentation_model_state_dict_path, # Path to the pretrained model
    use_segmentation_output_inversion=True,
    use_prior_model=True,
    prior_model_args=prior_args,
    prior_model_type=prior_type,
    loss_type=class_name(FBMSJointLoss),
    loss_args={
        "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
        "penalty_criterion": prior_criterion.criterion,
        "alpha": 1,
        "beta": 1,
    },
    use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
    #extra_penalty_after_n_epochs=1,
    #use_reduce_lr_in_extra_penalty_hook=False,
    use_lr_on_plateau_scheduler=False,
    use_binary_classification=True, 
    num_epochs=0,
    device="cuda",
    dtype=str(torch.float32),
    runs_path="./runs/fbms_local/unet/multi_prior",
    optimizer_args={
        "lr": 0.003,
        "betas": (0.9, 0.999),
        "eps": 1e-08,
        "amsgrad": False
    },
    use_progress_bar=False,
    plot_indices_during_training=real_dataset.get_ground_truth_indices(),
    save_images_after_pretraining=True,
    include_unaries_when_saving=True,
    agent_args=dict(
            do_pretraining=True,
            pretrain_only=True, 
            force_pretrain=True,
            pretrain_state_path=pretrain_state_path + ".pth",
            pretrain_args=dict(
                use_pretrain_checkpoints=True,
                do_pretrain_checkpoints=True,
                pretrain_checkpoint_dir=pretrain_state_path,
                lr=0.001,
                use_logger=True,
                use_step_logger=True,
                num_epochs=4000,
                proper_prior_fit_retrys=1,
                reuse_state_epochs=400,
                # Prefit flow net identity => Flow will be identity(-like) at the beginning
                prefit_flow_net_identity=True,
                prefit_flow_net_identity_lr=1e-2,
                prefit_flow_net_identity_weight_decay=1e-5,
                prefit_flow_net_identity_num_epochs=100,
                # Prefit convex net, to start with a convex thing
                prefit_convex_net=True,
                prefit_convex_net_lr=1e-3,
                prefit_convex_net_weight_decay=0,
                prefit_convex_net_num_epochs=200,
                zoo=Zoo()
            )
    )
)
#cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", override=True, no_uuid=True)



In [13]:
from awesome.run.awesome_runner import AwesomeRunner


runner = AwesomeRunner(cfg)
runner.build()
runner.train()

KeyboardInterrupt: 

In [None]:
runner.__runner_context__