In [None]:
%load_ext autoreload
%autoreload 2
from awesome.run.awesome_config import AwesomeConfig
from awesome.run.awesome_runner import AwesomeRunner
from awesome.util.reflection import class_name
import os
import torch

from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.dataset.convexity_segmentation_dataset import ConvexitySegmentationDataset
from awesome.measures.awesome_loss import AwesomeLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.convex_diffeomorphism_net import ConvexDiffeomorphismNet
from awesome.model.net import Net
import awesome
from awesome.util.path_tools import get_project_root_path
from awesome.util.logging import basic_config
import matplotlib.pyplot as plt
from awesome.analytics.result_model import ResultModel
from awesome.run.functions import get_result, split_model_result, plot_image_scribbles, plot_mask_labels
from awesome.util.temporary_property import TemporaryProperty
from awesome.run.functions import get_result, split_model_result,register_alpha_map, plot_image_scribbles, plot_mask_labels, plot_mask
import numpy as np
from matplotlib.colors import to_hex, to_rgb
import matplotlib as mpl
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
from awesome.run.functions import get_mpl_figure
from typing import Literal
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.dataset.sisbosi_dataset import SISBOSIDataset, ConvexityDataset as SISBOSIConvexityDataset
from awesome.measures.awesome_image_loss_joint import AwesomeImageLossJoint
from awesome.measures.awesome_image_loss import AwesomeImageLoss
from awesome.measures.gradient_penalty_loss import GradientPenaltyLoss
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.measures.regularizer_loss import RegularizerLoss
from awesome.model.cnn_net import CNNNet
from awesome.measures.tv import TV
from awesome.model.convex_net import ConvexNet
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.se import SE
from awesome.measures.ae import AE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss
from awesome.model.wrapper_module import WrapperModule
#load_ext matplotlib
#matplotlib tk
import normflows as nf
basic_config()

os.chdir(get_project_root_path()) # Beeing in the root directory of the project is important for the relative paths to work consistently

In [None]:
from awesome.model.zoo import Zoo
from awesome.model.net_factory import real_nvp_path_connected_net

xytype = "edge"
dataset_kind = "train"
dataset = "cars3"
all_frames = True
segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"


segmentation_model_state_dict_path = None
if segmentation_model_switch == "original":
    segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain_xy":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
else:
    raise ValueError(f"Unknown segmentation_model_switch: {segmentation_model_switch}")
image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"
input_channels = 4 if xytype == "edge" else 6
prior_criterion = UnariesConversionLoss(SE(reduction="mean"))
channels = 3
data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"

real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=True,
                    all_frames=True,
                )
data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
pretrain_checkpoint_dir = f"./data/checkpoints/pretrain_states/model_{dataset}_unet_spatial_{all_frames}_realnvp_spatio_temporal"

batch_size = 2
prior_epochs = 1000
prior_reuse_state_epochs = 400
prefit_flow_grid_epochs = 30
prefit_convex_net_epochs = 400

cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+spatio-temporal+realnvp",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": real_dataset,
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": batch_size,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": image_channel_format,
            "do_image_blurring": True,
            "spatio_temporal": True,
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': input_channels,
        },
        segmentation_training_mode='multi',
        segmentation_model_state_dict_path=segmentation_model_state_dict_path, # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            channels=channels,
            hidden_units=32,
            flow_n_flows=18,
            flow_output_fn="tanh",
            norm="minmax",
            convex_net_hidden_units=130,
            convex_net_hidden_layers=2,
        ),
        prior_model_type=class_name(real_nvp_path_connected_net),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "alpha": 1,
            "beta": 1,
        },
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=1,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=0,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/spatio_temporal",
        optimizer_args={
            "lr": 0.003,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=True,
        plot_indices_during_training_nth_epoch=20,
        plot_indices_during_training=real_dataset.get_ground_truth_indices(),
        save_images_after_pretraining=True,
        include_unaries_when_saving=True,
        agent_args=dict(
             do_pretraining=True,
             pretrain_only=True, 
             force_pretrain=True,
             pretrain_state_path=pretrain_checkpoint_dir + ".pth",
             pretrain_args=dict(
                 use_pretrain_checkpoints=True,
                 do_pretrain_checkpoints=True,
                 pretrain_checkpoint_dir=pretrain_checkpoint_dir,
                 lr=0.001,
                 use_logger=True,
                 use_step_logger=True,
                 num_epochs=prior_epochs,
                 proper_prior_fit_retrys=1,
                 reuse_state_epochs=prior_reuse_state_epochs,
                 # Prefit flow net identity => Flow will be identity(-like) at the beginning
                 prefit_flow_net_identity=True,
                 prefit_flow_net_identity_lr=1e-2,
                 prefit_flow_net_identity_weight_decay=1e-5,
                 prefit_flow_net_identity_num_epochs=prefit_flow_grid_epochs,
                 # Prefit convex net, to start with a convex thing
                 prefit_convex_net=True,
                 prefit_convex_net_lr=1e-3,
                 prefit_convex_net_weight_decay=0,
                 prefit_convex_net_num_epochs=prefit_convex_net_epochs,
                 batch_size=batch_size,
                 zoo=Zoo()
             )
        ),
        #output_folder="./runs/fbms_local/unet/TestUnet/",
    )
path = f"./config/fbms_spatio_temporal/2024_01_16/{cfg.name_experiment}.yaml"
os.makedirs(os.path.dirname(path), exist_ok=True)
cfg.save_to_file(path, override=True, no_uuid=True)

In [None]:
runner = AwesomeRunner(cfg)
runner.build()
runner.store_config()

In [None]:
runner.train()

In [None]:
%matplotlib tk

In [None]:
import numpy as np
from awesome.run.functions import get_mpl_figure, get_result, split_model_result
from awesome.model.path_connected_net import PathConnectedNet
from typing import Any
plt.close("all")

grid_shapes = dict()
model = runner.agent._get_model()
dataloader = runner.dataloader


index = list(range(0, len(dataloader), 1))

t_n = len(index)
t_max = len(dataloader) - 1

images = []
segmentations = []
priors = []

for i in index:
    res, ground_truth, img, _, _ = get_result(model, dataloader, i, False)
    res = split_model_result(res, model, dataloader, img)

    res_prior = res.get("prior", None)
    res_pred = res["segmentation"]

    images.append(img)
    segmentations.append(res_pred)
    priors.append(res_prior)

images = torch.stack(images)
segmentations = torch.stack(segmentations)
priors = torch.stack(priors)

shp = priors.shape[-2:]
if shp not in grid_shapes:
    grid_shapes[shp] = PathConnectedNet.create_normalized_grid(shp).cpu().numpy()
grid = grid_shapes[res_prior.shape[-2:]]




# Stack time 
pred = priors # B x C x H x W
# Spatio temporal grid
t_grid = torch.stack([torch.cat([torch.tensor(grid[0]), torch.full((1, *pred.shape[-2:]), t / t_max)], dim=0) for t in index])



def plot_spatio_temporal_object(grid: Any, unaries: Any, size: float = 5):
    
    if isinstance(grid, torch.Tensor):
        grid = grid.cpu().numpy()
    if isinstance(unaries, torch.Tensor):
        unaries = unaries.cpu().numpy()
    
    if len(grid.shape) < 4:
        grid = grid[None]
    if len(unaries.shape) < 4:
        unaries = unaries[None]
    

    fig, ax = get_mpl_figure(subplot_kw=dict(projection='3d'))

    for i in range(grid.shape[0]):
        g = grid[i]
        u = unaries[i][0]

        z = u
        y = g[1]
        x = g[0]
        offset = g[2].max() # Offset is the time
        ax.contour(x, y, z, levels=[0.5], colors="red", offset=offset, linewidths=2)

    x_left, x_right = ax.get_xlim()
    y_low, y_high = ax.get_ylim()

    zoom= 1
    elevation = 130
    azimuth = 90
    roll = 0

    ax.set_box_aspect(aspect=((x_right-x_left)/(y_low-y_high), 1, 1), zoom=zoom)
    ax.view_init(elev=elevation, azim=azimuth, roll=roll)

    ax.set_axis_off()
    return fig

fig = plot_spatio_temporal_object(t_grid, pred)
fig


In [None]:
import cv2 as cv
from awesome.run.functions import plot_as_image


p = pred[0][0]

all_contours = []
all_hierarchy = []

times = [t / t_max for t in index]

for i in range(0, len(pred)):
    p = pred[i][0]
    t = times[i]
    ret, thresh = cv.threshold(((1 - p.numpy()) * 255).astype(np.uint8), 123, 1, cv.THRESH_BINARY)
    contours, hierarchy = cv.findContours(thresh, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)

    # We are loosing some information here, eg. holes in the object are not represented
    # This is because we are only using the contour of the object

    local_contours = [np.concatenate([c[:, 0, :] / (torch.tensor(p.shape[-2:]).numpy() - 1), np.full((c.shape[0], 1), t)], axis=1) for c in contours]
    all_contours.extend(local_contours)
    all_hierarchy.extend(hierarchy[0])


In [None]:
import numpy as np
import pyvista as pv

points = np.concatenate(all_contours)
cloud = pv.PolyData(points)
cloud.plot()

volume = cloud.delaunay_3d(alpha=1)
shell = volume.extract_geometry()

axes = pv.Axes()
display(axes.show_actor())

shell.plot(show_axes=False)

In [None]:
%pip install trame-vuetify

In [None]:
fig, ax = get_mpl_figure()

x = contours[1].squeeze()[:, 0]
y = contours[1].squeeze()[:, 1]
ax.plot(x, y, color="red", linewidth=2)
fig

In [None]:
contours[0].squeeze()[:, 0]

In [None]:
t = torch.linspace(0, 1, 5)
torch.cat([res, torch.full(t)], dim=0) for t in 

In [None]:
grid.shape

## Old Experiments with glow

In [None]:
xytype = "edge"
dataset_kind = "train"
dataset = "bear01"
all_frames = True
subset = None #slice(0, 5)
segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"


segmentation_model_state_dict_path = None
if segmentation_model_switch == "original":
    segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain_xy":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
else:
    raise ValueError(f"Unknown segmentation_model_switch: {segmentation_model_switch}")
image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"

prior_criterion = UnariesConversionLoss(SE(reduction="mean"))
data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"

real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=False,
                    all_frames=True,
                )
dataset = AwesomeDataset(
    **{
        "dataset": real_dataset,
        "xytype": xytype,
        "feature_dir": f"{data_path}/Feat",
        "dimension": "3d", # 2d for fcnet
        "mode": "model_input",
        "model_input_requires_grad": False,
        "batch_size": 1,
        "split_ratio": 1,
        "shuffle_in_dataloader": False,
        "image_channel_format": image_channel_format,
        "do_image_blurring": True,
        "model_input_requires_grad": True,
        "subset": subset,
        "spatio_temporal": False,
    }
)


segmentation_model = UNet(4, 1)
segmentation_model.load_state_dict(torch.load(segmentation_model_state_dict_path))


def init_glow(channels: int, 
              hidden_channels: int,
              n_flows: int,
              height: int, 
              width: int,
              scale: bool = True,
              scale_map: Literal["sigmoid", "exp"] = "sigmoid",
              ) -> nf.NormalizingFlow:
    # Define flows

    input_shape = (channels, height, width)

    # Set up flows, distributions and merge operations
    q0 = nf.distributions.base.Uniform(input_shape, 0, 1)
    flows = []
    
    for j in range(n_flows):
        flows += [nf.flows.GlowBlock(channels, hidden_channels,
                                    split_mode='channel', 
                                    scale_map=scale_map, leaky=0.01,
                                    scale=scale, net_actnorm=False)]

    # Construct flow model with the multiscale architecture
    model = nf.NormalizingFlow(q0, 
                               flows, 
                               q0)

    return model


In [None]:
import logging
from awesome.agent.torch_agent import TorchAgent
from awesome.dataset.prior_dataset import PriorManager
from awesome.measures.unaries_weighted_loss import UnariesWeightedLoss
from awesome.model.wrapper_module import WrapperModule
from awesome.model.unet import UNet
from awesome.model.path_connected_net import PathConnectedNet
from awesome.model.convex_net import ConvexNextNet
from normflows import NormalizingFlow
from normflows.flows import GlowBlock
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from awesome.util.torch import TensorUtil



    
channels = 2

image_shape = dataset[0][0][0].shape[1:]

flow_model = init_glow(channels=channels, hidden_channels=256, n_flows=3, 
                       height=image_shape[0], 
                       width=image_shape[1],
                       scale=True)
convex_model = ConvexNextNet(n_hidden=130, 
                             n_hidden_layers=2,
                             in_features=channels)

path_connected_model = PathConnectedNet(convex_model, flow_model)

wrapper_module = WrapperModule(
    segmentation_module=segmentation_model,
    prior_module=path_connected_model,
    prior_arg_mode="param_clean_grid"
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

lr = 1e-3
weight_decay = 1e-7

previous_state = None
previous_center_of_mass = None

prior_module = wrapper_module.prior_module


use_prior_sigmoid = True
use_logger = False
use_step_logger = False
batch_progress_bar = None

criterion = UnariesConversionLoss(SE(reduction="mean"))

TensorUtil.to(wrapper_module, device=device)

In [None]:
inputs, labels, indices, prior_state = TorchAgent.decompose_training_item(dataset[0], training_dataset=dataset)

max_iter = 1000

loss_hist = np.array([])

grid = inputs[2]

prior_module = wrapper_module.prior_module
model = prior_module.flow_net

optimizer = torch.optim.Adamax(model.parameters(), lr=1e-3, weight_decay=1e-5)

inputs, labels, indices, prior_state = TorchAgent.decompose_training_item(dataset[0], training_dataset=dataset)

grid = inputs[2]

grid = grid.to(device)
model = model.to(device)

grid = grid[None,...]

model.train()

for i in tqdm(range(max_iter)):
    
    x, y = grid, grid
    
    optimizer.zero_grad()
    loss = model.forward_kld(x.to(device))
        
    if ~(torch.isnan(loss) | torch.isinf(loss)):
        loss.backward()
        optimizer.step()

    loss_hist = np.append(loss_hist, loss.detach().to('cpu').numpy())
    del(x, y, loss)


In [None]:

num_epochs = 1000

data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
it = data_loader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
use_progress_bar = True

if use_progress_bar:
    it = tqdm(it, desc="Pretraining images")

for i, item in enumerate(it):
    inputs, labels, indices, prior_state = TorchAgent.decompose_training_item(item, training_dataset=dataset)
    device_inputs: torch.Tensor = TensorUtil.to(
        inputs, device=device)
    # device_labels: torch.Tensor = TensorUtil.to(labels, device=device)

    # Evaluate model to get unaries
    # Switch prior weights if needed, using context manager
    with PriorManager(wrapper_module,
                        prior_state=prior_state,
                        prior_cache=dataset.__prior_cache__,
                        model_device=device,
                        training=True
                        ):
        
        unaries = None
        has_proper_prior_fit = False
        loaded_current_from_checkpoint = False

        # Get the unaries
        # Disable prior evaluation to just get the unaries
        with torch.no_grad(), TemporaryProperty(wrapper_module, evaluate_prior=False):
            if isinstance(device_inputs, list):
                unaries = wrapper_module(*device_inputs)
            else:
                unaries = wrapper_module(device_inputs)


        # Getting inputs for prior
        prior_args, prior_kwargs = wrapper_module.get_prior_args(device_inputs[0],
                                                                    *device_inputs[1:],
                                                                    segm=unaries[0, ...],
                                                                    )
        _input = prior_args[0]
        actual_input = _input.detach().clone()

        _unique_vals = torch.unique(unaries >= 0.5)
        # Check if unaries output contains at least some foreground
        if len(_unique_vals) == 1:
            # No foreground / background predicted. Skip this image
            # We will keep the state of the prior if reuse_state is True
            # If there was a pre existing state, we will use it again
            logging.warning(f"Unaries of segmentation model contain no foreground. Skipping image. {i}")
            continue
        
        # Determine number of epochs
        epochs = num_epochs
        
        # Train n iterations
        it = range(epochs)
        if use_progress_bar:
            desc = f'Image {i + 1}: Pretraining'
            if batch_progress_bar is None:
                batch_progress_bar = tqdm(
                    total=epochs,
                    desc=desc,
                    leave=True)
            else:
                batch_progress_bar.reset(total=epochs)
                batch_progress_bar.set_description(desc)

        groups = []
        groups += [dict(params=prior_module.flow_net.parameters(), weight_decay=weight_decay)]
        groups += [dict(params=prior_module.convex_net.parameters())]
        
        optimizer = torch.optim.Adam(groups, lr=lr)

        device_prior_output = None

        with torch.set_grad_enabled(True):
            # Train n iterations
            for step in it:
                optimizer.zero_grad()
                # Forward pass
                device_prior_output = prior_module(actual_input, *prior_args[1:], **prior_kwargs)
                device_prior_output = wrapper_module.process_prior_output(
                    device_prior_output, use_sigmoid=use_prior_sigmoid)[None, ...]  # Add batch dim again

                loss: torch.Tensor = criterion(
                    device_prior_output, unaries)

                if ~(torch.isnan(loss) | torch.isinf(loss)):
                    loss.backward()
                    optimizer.step()
                else:
                    logging.warning(
                        f"Loss is nan or inf. Skipping step {step} of image {i}")
                    break

                if use_logger and use_step_logger:
                    logger.log_value(
                        loss.item(), f"PretrainingLoss/Image_{i}", step=step)

                prior_module.enforce_convexity()
                if batch_progress_bar is not None:
                    batch_progress_bar.set_postfix(
                        loss=loss.item(), refresh=False)
                    batch_progress_bar.update()
                        



In [None]:

grid = inputs[2]

prior_module = wrapper_module.prior_module
norm_flow = prior_module.flow_net

with torch.no_grad():
    norm_flow.eval()
    grid = grid.to(device)
    grid = grid[None, ...]
    out_grid = norm_flow(grid)


out_grid.min()


In [None]:
res, ground_truth, img, fg, bg = get_result(wrapper_module, dataset, 0, model_gets_targets=False)
res = split_model_result(res, wrapper_module, dataset, img, compute_crf=False)
res_prior = res.get("prior", None)
res_pred = res["segmentation"]
fig = plot_image_scribbles(img, res_pred, fg, bg, res_prior, save=True, size=5, tight_layout=True, title="Epoch: " + str(step),
                                        legend=False)
fig                                    

In [None]:
prior_module