In [1]:
%load_ext autoreload
%autoreload 2
from awesome.run.awesome_config import AwesomeConfig
from awesome.run.awesome_runner import AwesomeRunner
from awesome.util.reflection import class_name
import os
import torch

from awesome.model.convex_diffeomorphism_net import ConvexDiffeomorphismNet
from awesome.util.path_tools import get_project_root_path
from awesome.util.logging import basic_config
from awesome.dataset.awesome_dataset import AwesomeDataset
from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset
from awesome.measures.fbms_joint_loss import FBMSJointLoss
from awesome.model.unet import UNet
from awesome.measures.weighted_loss import WeightedLoss
from awesome.measures.se import SE
from awesome.measures.unaries_conversion_loss import UnariesConversionLoss
from typing import Literal

basic_config()

os.chdir(get_project_root_path()) # Beeing in the root directory of the project is important for the relative paths to work consistently

  from tqdm.autonotebook import tqdm


In [2]:
from awesome.model.zoo import Zoo
from awesome.model.net_factory import real_nvp_path_connected_net

xytype = "edge"
dataset_kind = "train"
dataset = "bear01"
all_frames = True
subset = None # 0 #slice(0, 5)
segmentation_model_switch: Literal["original", "retrain", "retrain_xy"] = "original"


segmentation_model_state_dict_path = None
if segmentation_model_switch == "original":
    segmentation_model_state_dict_path = f"./data/checkpoints/labels_with_uncertainty_flownet2_based/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
elif segmentation_model_switch == "retrain_xy":
    segmentation_model_state_dict_path = f"./data/checkpoints/refit_spatial_unet_uncertainty/23_11_13/model_{dataset}_unet.pth"
else:
    raise ValueError(f"Unknown segmentation_model_switch: {segmentation_model_switch}")
image_channel_format = "bgr" if segmentation_model_switch == "original" else "rgb"
input_channels = 4 if xytype == "edge" else 6
prior_criterion = UnariesConversionLoss(SE(reduction="mean"))

data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"

real_dataset = FBMSSequenceDataset(
                    dataset_path=data_path,
                    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
                    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed",
                    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
                    do_weak_label_preprocessing=True,
                    do_uncertainty_label_flip=True,
                    test_weak_label_integrity=False,
                    all_frames=True,
                )
data_path = f"./data/local_datasets/FBMS-59/{dataset_kind}/{dataset}"
cfg = AwesomeConfig(
        name_experiment=f"UNET+{dataset}+{xytype}+diffeo+only_prior+realnvp",
        dataset_type=class_name(AwesomeDataset),
        dataset_args={
            "dataset": real_dataset,
            "xytype": xytype,
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": image_channel_format,
            "do_image_blurring": True,
            "subset": subset
        },
        segmentation_model_type=class_name(UNet),
        segmentation_model_args={
            'in_chn': input_channels,
        },
        segmentation_training_mode='multi',
        segmentation_model_state_dict_path=segmentation_model_state_dict_path, # Path to the pretrained model
        use_segmentation_output_inversion=True,
        use_prior_model=True,
        prior_model_args=dict(
            channels=2,
            hidden_units=32,
            flow_n_flows=12,
            flow_output_fn="tanh",
            norm="minmax",
            convex_net_hidden_units=130,
            convex_net_hidden_layers=2,
        ),
        prior_model_type=class_name(real_nvp_path_connected_net),
        loss_type=class_name(FBMSJointLoss),
        loss_args={
            "criterion": WeightedLoss(torch.nn.BCELoss(), mode="sssdms", noneclass=2),
            "alpha": 1,
            "beta": 1,
        },
        use_extra_penalty_hook=False, # Panalty hook for the panalty term that models output should match
        #extra_penalty_after_n_epochs=1,
        #use_reduce_lr_in_extra_penalty_hook=False,
        use_lr_on_plateau_scheduler=False,
        use_binary_classification=True, 
        num_epochs=100,
        device="cuda",
        dtype=str(torch.float32),
        runs_path="./runs/fbms_local/unet/comparison_path_nets",
        optimizer_args={
            "lr": 0.003,
            "betas": (0.9, 0.999),
            "eps": 1e-08,
            "amsgrad": False
        },
        use_progress_bar=True,
        plot_indices_during_training_nth_epoch=5,
        plot_indices_during_training=real_dataset.get_ground_truth_indices(),
        save_images_after_pretraining=True,
        include_unaries_when_saving=True,
        agent_args=dict(
             do_pretraining=True,
             pretrain_only=True, 
             force_pretrain=True,
             pretrain_state_path=f"./data/checkpoints/pretrain_states/model_{dataset}_unet_spatial_{all_frames}_{subset}_realnvp.pth",
             pretrain_args=dict(
                 use_pretrain_checkpoints=True,
                 do_pretrain_checkpoints=True,
                 pretrain_checkpoint_dir=f"./data/checkpoints/pretrain_states/model_{dataset}_unet_spatial_{all_frames}_{subset}_realnvp",
                 lr=0.001,
                 use_logger=True,
                 use_step_logger=True,
                 num_epochs=4000,
                 proper_prior_fit_retrys=1,
                 reuse_state_epochs=400,
                 # Prefit flow net identity => Flow will be identity(-like) at the beginning
                 prefit_flow_net_identity=True,
                 prefit_flow_net_identity_lr=1e-2,
                 prefit_flow_net_identity_weight_decay=1e-5,
                 prefit_flow_net_identity_num_epochs=100,
                 # Prefit convex net, to start with a convex thing
                 prefit_convex_net=True,
                 prefit_convex_net_lr=1e-3,
                 prefit_convex_net_weight_decay=0,
                 prefit_convex_net_num_epochs=200,
                 zoo=Zoo()
             )
        ),
        #output_folder="./runs/fbms_local/unet/TestUnet/",
    )
cfg.save_to_file(f"./config/{cfg.name_experiment}.yaml", override=True, no_uuid=True)

  bg_flip_coords = flip_probability[torch.argwhere(bg_flip_mask).squeeze(), :2].int().T


'./config/UNET+bear01+edge+diffeo+only_prior+realnvp.yaml'

In [3]:
runner = AwesomeRunner(cfg)
runner.build()
runner.store_config()

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
2024-01-11:12:40:17.526 INFO     [tensorboard.py:55] Tensorboard logger created at: runs\fbms_local\unet\comparison_path_nets\UNET+bear01+edge+diffeo+only_prior+realnvp_24_01_11_12_40_17


'./runs/fbms_local/unet/comparison_path_nets\\UNET+bear01+edge+diffeo+only_prior+realnvp_24_01_11_12_40_17\\init_cfg_awesome_config.yaml'

In [6]:
runner.agent._get_model().prior_module.reset_parameters()


In [None]:
#runner.config.num_epochs = 2000
runner.train()

In [None]:
from awesome.run.functions import get_result, split_model_result, plot_image_scribbles
import matplotlib.pyplot as plt

from awesome.run.functions import get_mpl_figure, plot_mask, prepare_input_eval
from awesome.util.matplotlib import saveable
import normflows as nf
from typing import Optional, Tuple
from awesome.model.path_connected_net import PathConnectedNet, minmax
from matplotlib.axes import Axes

def coordinate_grid(image_shape):
    x = torch.arange(image_shape[1]).float()
    y = torch.arange(image_shape[0]).float()
    yy, xx = torch.meshgrid(y, x)
    return torch.stack([yy, xx])

def create_circle(image_shape: Tuple[int, int], radius: float, center: Tuple[float, float]):
    grid = coordinate_grid(image_shape)
    yy, xx = grid
    circle = (yy - center[0])**2 + (xx - center[1])**2 <= radius**2
    return circle.float()[None, ...]

def subsample_mask(x,
                   subsample: int = 25):
    image_shape = x.shape[-2:]

    ones_grid = torch.ones(x[0].shape)
    subsampled_grid = torch.zeros(x[0].shape)
    coords = (torch.argwhere(ones_grid) % subsample) == 0
    coords_mask = coords.all(dim=-1).reshape((image_shape))
    subsampled_grid[coords_mask] = 1
    return subsampled_grid.bool()

@saveable()
def plot_output(img, 
                output, 
                target, 
                grid: torch.Tensor, 
                subsample:int = 25, 
                **kwargs):
    image_shape = grid.shape[-2:]
    def denorm_grid(grid):
        image_shape = grid.shape[-2:]
        grid_y = minmax(grid[0], grid[0].min(), grid[0].max(), 0, image_shape[0])
        grid_x = minmax(grid[1], grid[1].min(), grid[1].max(), 0, image_shape[1])
        grid_dnorm = torch.cat([grid_y[None, ...], grid_x[None, ...]], dim=0).detach().cpu().numpy()
        return grid_dnorm

    fig = plot_match(img, output, target, size=5, tight=True, subsample=subsample)
    ax = fig.axes[0]
    dnorm_grid_pt = denorm_grid(grid)

    dnorm_grid_pt = torch.clamp(torch.tensor(dnorm_grid_pt), min=torch.tensor([[[0]], [[0]]]), max=torch.tensor([[[image_shape[0] - 1]], [[image_shape[1] - 1]]])).numpy()
    
    msk = subsample_mask(dnorm_grid_pt, subsample=subsample)

    return plot_grid(dnorm_grid_pt, msk, ax, 'g', origin="lower")
    
@saveable()
def plot_grid(grid: torch.Tensor, 
              mask: torch.Tensor, 
              ax: Optional[Axes] = None, 
              tight: Optional[bool] = False,
              size: Optional[float] = 5,
              color: str = "b",
              dense: bool = True,
              origin: Literal['lower', 'upper'] = "upper"
              ):
    if ax is None:
        fig, ax = get_mpl_figure(1, 1, tight=tight, size=size, ratio_or_img=grid)
    else:
        fig = ax.figure

    dots = torch.argwhere(mask)

    col_idx = torch.unique(dots[:, 0])
    row_idx = torch.unique(dots[:, 1])

    cols = grid[:, col_idx]
    rows = grid[:, :, row_idx]

    for idx in range(cols.shape[1]):
        col = cols[:, idx]
        if dense:
            x = col[1]
            y = col[0]
        else:
            x = col[1, row_idx]
            y = col[0, row_idx]
        ax.plot(x, y, color=color)

    for idx in range(rows.shape[2]):
        row = rows[:, :, idx]
        if dense:
            x = row[1]
            y = row[0]
        else:
            x = row[1, col_idx]
            y = row[0, col_idx]
        ax.plot(x, y, color=color)

    if origin == "upper":
        ax.invert_yaxis()
    return fig


def plot_match(img, 
               output, 
               target, 
               subsample:int = 25, 
               grid: torch.Tensor = None,
               **kwargs):

    image_shape = img.shape[-2:]
    if grid is None:
        grid = coordinate_grid(image_shape)

    subsampled_grid = torch.zeros_like(grid[0])

    coords_mask = subsample_mask(grid, subsample=subsample)
    subsampled_grid[coords_mask] = 1

    add = []

    if target is not None:
        add.append(target)
    if output is not None:
        add.append(output)
    
    add.append(subsampled_grid.float()[None, ...])

    stack_plot = torch.cat(add, dim=0)

    fig = plot_mask(img, stack_plot, **kwargs)
    return fig

index = 0

model = runner.agent._get_model()
dataloader = runner.agent.training_dataset
model_gets_targets = runner.agent.model_gets_targets
p = os.path.join(runner.agent.agent_folder, "pretrain_priors")
os.makedirs(p, exist_ok=True)

#indices = [0, 19] #len(dataloader)
for i in range(len(dataloader)):
    res, ground_truth, img, fg, bg = get_result(model, dataloader, i, model_gets_targets=model_gets_targets)
    res = split_model_result(res, model, dataloader, img)
    res_prior = res.get("prior", None)
    res_pred = res["segmentation"]
    boxes = res.get("boxes", None)
    labels = res.get("labels", None)
    
    iterations = 2000
    fig = plot_image_scribbles(image=img,
                        inference_result=res_pred,
                        foreground_mask=fg,
                        background_mask=bg,
                        prior_result=res_prior,
                        save=True,
                        path=os.path.join(p, f"prior_{i}_{iterations}.png"),
                        size=10,
                        title=f"Prior Epoch: {iterations}", open=True)
    #display(fig)
    plt.close(fig)

    
    with torch.no_grad():
        model.eval()
        model.to(torch.device("cpu"))
        image, ground_truth, _input, targets, fg, bg, prior_state = prepare_input_eval(dataloader, model, index)
        grid = model.prior_module.get_deformation(_input[2][None, ...])[0]

    fig = plot_output(img, 1 - res_prior, 1 - res_pred, grid=grid, size=30, subsample=10,
                    save=True,
                        path=os.path.join(p, f"deform_grid_{i}_{iterations}.png"),
                    )

    #display(fig)
    plt.close(fig)



In [None]:
from awesome.run.functions import get_mpl_figure, plot_mask, prepare_input_eval
from awesome.util.matplotlib import saveable
import normflows as nf
from typing import Optional, Tuple
from awesome.model.path_connected_net import PathConnectedNet, minmax
from matplotlib.axes import Axes

def coordinate_grid(image_shape):
    x = torch.arange(image_shape[1]).float()
    y = torch.arange(image_shape[0]).float()
    yy, xx = torch.meshgrid(y, x)
    return torch.stack([yy, xx])

def create_circle(image_shape: Tuple[int, int], radius: float, center: Tuple[float, float]):
    grid = coordinate_grid(image_shape)
    yy, xx = grid
    circle = (yy - center[0])**2 + (xx - center[1])**2 <= radius**2
    return circle.float()[None, ...]

def subsample_mask(x,
                   subsample: int = 25):
    image_shape = x.shape[-2:]

    ones_grid = torch.ones(x[0].shape)
    subsampled_grid = torch.zeros(x[0].shape)
    coords = (torch.argwhere(ones_grid) % subsample) == 0
    coords_mask = coords.all(dim=-1).reshape((image_shape))
    subsampled_grid[coords_mask] = 1
    return subsampled_grid.bool()

@saveable()
def plot_output(img, 
                output, 
                target, 
                grid: torch.Tensor, 
                subsample:int = 25, 
                **kwargs):
    image_shape = grid.shape[-2:]
    def denorm_grid(grid):
        image_shape = grid.shape[-2:]
        grid_y = minmax(grid[0], grid[0].min(), grid[0].max(), 0, image_shape[0])
        grid_x = minmax(grid[1], grid[1].min(), grid[1].max(), 0, image_shape[1])
        grid_dnorm = torch.cat([grid_y[None, ...], grid_x[None, ...]], dim=0).detach().cpu().numpy()
        return grid_dnorm

    fig = plot_match(img, output, target, size=5, tight=True, subsample=subsample)
    ax = fig.axes[0]
    dnorm_grid_pt = denorm_grid(grid)

    dnorm_grid_pt = torch.clamp(torch.tensor(dnorm_grid_pt), min=torch.tensor([[[0]], [[0]]]), max=torch.tensor([[[image_shape[0] - 1]], [[image_shape[1] - 1]]])).numpy()
    
    msk = subsample_mask(dnorm_grid_pt, subsample=subsample)

    return plot_grid(dnorm_grid_pt, msk, ax, 'g', origin="lower")
    
@saveable()
def plot_grid(grid: torch.Tensor, 
              mask: torch.Tensor, 
              ax: Optional[Axes] = None, 
              tight: Optional[bool] = False,
              size: Optional[float] = 5,
              color: str = "b",
              dense: bool = True,
              origin: Literal['lower', 'upper'] = "upper"
              ):
    if ax is None:
        fig, ax = get_mpl_figure(1, 1, tight=tight, size=size, ratio_or_img=grid)
    else:
        fig = ax.figure

    dots = torch.argwhere(mask)

    col_idx = torch.unique(dots[:, 0])
    row_idx = torch.unique(dots[:, 1])

    cols = grid[:, col_idx]
    rows = grid[:, :, row_idx]

    for idx in range(cols.shape[1]):
        col = cols[:, idx]
        if dense:
            x = col[1]
            y = col[0]
        else:
            x = col[1, row_idx]
            y = col[0, row_idx]
        ax.plot(x, y, color=color)

    for idx in range(rows.shape[2]):
        row = rows[:, :, idx]
        if dense:
            x = row[1]
            y = row[0]
        else:
            x = row[1, col_idx]
            y = row[0, col_idx]
        ax.plot(x, y, color=color)

    if origin == "upper":
        ax.invert_yaxis()
    return fig


def plot_match(img, 
               output, 
               target, 
               subsample:int = 25, 
               grid: torch.Tensor = None,
               **kwargs):

    image_shape = img.shape[-2:]
    if grid is None:
        grid = coordinate_grid(image_shape)

    subsampled_grid = torch.zeros_like(grid[0])

    coords_mask = subsample_mask(grid, subsample=subsample)
    subsampled_grid[coords_mask] = 1

    add = []

    if target is not None:
        add.append(target)
    if output is not None:
        add.append(output)
    
    add.append(subsampled_grid.float()[None, ...])

    stack_plot = torch.cat(add, dim=0)

    fig = plot_mask(img, stack_plot, **kwargs)
    return fig


