In [None]:
%load_ext autoreload
%autoreload 2
from awesome.run.awesome_config import AwesomeConfig
from awesome.run.awesome_runner import AwesomeRunner
from awesome.util.reflection import class_name
import os
import torch

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.dataset.convexity_segmentation_dataset import ConvexitySegmentationDataset
from awesome.measures.awesome_loss import AwesomeLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.convex_diffeomorphism_net import ConvexDiffeomorphismNet
from awesome.model.net import Net
import awesome
from awesome.util.path_tools import get_project_root_path


os.chdir(get_project_root_path()) # Beeing in the root directory of the project is important for the relative paths to work consistently

In [None]:
# CNNet

from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet

xytype = "xy"

dataset = "marple2"

cfg = AwesomeConfig(
        name_experiment=f"CNNET_+{dataset}+{xytype}+diffeo+joint",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": FBMSSequenceDataset(
                    dataset_path=f"./data/local_datasets/FBMS-59/test/{dataset}"
                ),
            "xytransform": "xy",
            "xytype": xytype,
            "mode": "scribbles",
            "feature_dir": f"./data/local_datasets/FBMS-59/test/{dataset}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": True, # Can be used for 3d nets
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "subset": 1
        },
        segmentation_model_type=class_name(CNNNet),
        segmentation_model_args={
            'width': 16,
            'depth': 2,
            'kernel_size': 3,
            'input': 'rgbxy',
        },
        segmentation_training_mode='single',
        use_prior_model=True,
        prior_model_args=dict(
            nf_layers=3,
            nf_hidden=70
        ),
        prior_model_type=class_name(ConvexDiffeomorphismNet),
        loss_type=class_name(AwesomeImageLossJoint),
        loss_args={
            "criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": True,
                "noneclass" : 2.,
                "xygrad" : 0.001,
                "rgbgrad" : 0.001,
                "featgrad" : 0.0,
                "xytype" : xytype,}),
            "prior_criterion": GradientPenaltyLoss(**{
                "criterion": torch.nn.BCELoss(),
                "apply_gradient_penalty": False,
                "noneclass" : 2.,}),
            "gamma": 0.5,
            "beta": 0.5,
            },
        use_extra_penalty_hook=True, # Panalty hook for the panalty term that models output should match
        extra_penalty_after_n_epochs=800,
        use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=4000,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local",
        optimizer_args={
            "lr": 0.02,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "weight_decay": 0,
            "amsgrad": False
        },
        use_progress_bar=True,
        semantic_soft_segmentation_code_dir="../siggraph/",
        semantic_soft_segmentation_model_checkpoint_dir="./data/sss_checkpoint/model",
        plot_indices_during_training_nth_epoch=50,
        plot_indices_during_training=[0],
        tf_use_gpu=True,
    )
#cfg.save_to_file("./config/Test_FBMS_CNNET.yaml", override=True)

In [None]:
runner = AwesomeRunner(cfg)
runner.build()

In [None]:
#runner.config.num_epochs = 2000
runner.train()

In [None]:
from awesome.util.temporary_property import TemporaryProperty


with TemporaryProperty(runner.dataloader, mode="sample", return_prior=False):
    fig = runner.dataloader[0]["raw_sample"].plot_weak_labels()
    display(fig)

In [None]:
dataloader = runner.dataloader.__dataset__
import awesome.run.functions as F
import matplotlib.pyplot as plt

index = 0

size = 10
cols = 4
fig, ax = plt.subplots(1, cols, figsize=(size * cols, size))

sample = dataloader[index]

sample.plot(ax=ax[0], labels=sample.ground_truth_object_ids)
ax[0].set_title("Ground truth")

sample.plot_weak_labels(ax=ax[1])
ax[1].set_title("Weak labels")

sample.plot_selected_weak_labels(ax=ax[2])
ax[2].set_title("Selected Weak labels")

sample.plot_selected(ax=ax[3], labels=["Foreground", "Background"])
ax[3].set_title("Selected ground truth")


mapping = sample._get_gt_object_id_weak_label_mapping()
display(mapping)
fig

## T

In [None]:
assert False, "Stop here"
# Code for extracting the trajectories into the dataset folder
tracks_path = "data/local_datasets/FBMS-59/tracks"
dataset_dirs = "data/local_datasets/FBMS-59/test/"

import shutil

for folder in os.listdir(tracks_path):
    inner_path = "MulticutResults/pfldof0.5000004"
    complete_track_path = os.path.join(tracks_path, folder, inner_path)
    tracks_file = list(os.listdir(complete_track_path))[0]
    tracks_file_path = os.path.join(complete_track_path, tracks_file)

    target_path = os.path.join(dataset_dirs, folder, "tracks", "multicut")
    os.makedirs(target_path, exist_ok=True)
    target_file_path = os.path.join(target_path, tracks_file)
    shutil.copy(tracks_file_path, target_file_path)
