In [None]:
%load_ext autoreload
%autoreload 2
import os

from awesome.dataset.awesome_dataset import AwesomeDataset
import os
import torch
from awesome.util.path_tools import get_project_root_path
import matplotlib.pyplot as plt
os.chdir(get_project_root_path()) 

In [None]:
from awesome.measures.miou import MIOU

from awesome.dataset.fbms_sequence_dataset import FBMSSequenceDataset, FBMSSequenceSample

dataset_name = "marple10"

data_path = f"./data/local_datasets/FBMS-59/train/{dataset_name}/"
dataset = FBMSSequenceDataset(
    data_path, 
    weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
    processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed", # Be sure that when this contains files, they corresond to the processing settings
    confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
    do_weak_label_preprocessing=True,
    do_uncertainty_label_flip=True,
    test_weak_label_integrity=True,
    all_frames=True)
dataset.get_ground_truth_indices()

In [None]:
from awesome.run.functions import plot_as_image, prepare_input_eval
data_path = f"./data/local_datasets/FBMS-59/train/{dataset_name}"
awds = AwesomeDataset(**{
            "dataset": dataset,
            "xytype": "edge",
            "feature_dir": f"{data_path}/Feat",
            "dimension": "3d", # 2d for fcnet
            "mode": "model_input",
            "model_input_requires_grad": False,
            "batch_size": 1,
            "split_ratio": 1,
            "shuffle_in_dataloader": False,
            "image_channel_format": "bgr",
            "do_image_blurring": True
        })
image, ground_truth, _input, targets, fg, bg, prior_state = prepare_input_eval(awds, None, 349)
display(plot_as_image(ground_truth))
display(plot_as_image(torch.where(targets == 0, 1, 0), size=10))

In [None]:
def get_dataset(name: str, with_check: bool = True):
    dataset_name = name
    data_path = f"./data/local_datasets/FBMS-59/train/{dataset_name}/"
    dataset = FBMSSequenceDataset(
        data_path, 
        weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based",
        processed_weak_labels_dir = "weak_labels/labels_with_uncertainty_flownet2_based/processed", # Be sure that when this contains files, they corresond to the processing settings
        confidence_dir= "weak_labels/labels_with_uncertainty_flownet2_based/",
        do_weak_label_preprocessing=True,
        do_uncertainty_label_flip=True,
        test_weak_label_integrity=with_check,
        all_frames=True)
    
    return dataset
    


In [None]:
datasets = [
   'bear01',
   'bear02',
   'cars2',
   'cars3',
   'cars6',
   'cars7',
   'cars8',
   'cats04',
   'cats05',
   'horses01',
   'horses03',
   'marple1',
   'marple10',
   'marple11',
   'marple5',
   'meerkats01',
   'people04',
   'rabbits01',
   ]
path = "output/dataset_check/"
os.makedirs(path, exist_ok=True)

for dataset_name in datasets:
    dataset = get_dataset(dataset_name)
    p = os.path.join(path, dataset_name + "_after" + ".png")
    fig = dataset.plot_ground_truth_mask_images(save=True, path=p, override=True)
    plt.close(fig)


In [None]:

dataset = get_dataset("marple10")
fig = dataset.plot_ground_truth_mask_images()
fig


In [None]:
from typing import Literal
from awesome.run.functions import get_mpl_figure, plot_mask
from awesome.run.functions import value_mask_to_channel_masks
#sample = dataset[149]

#sample.trajectory_mask.shape

def add_label_info(_id, sample, mode: Literal["weak", "gt"]) -> str:
    if isinstance(_id, torch.Tensor):
        _id = _id.item()
    if mode == "weak":
        if _id == sample.foreground_weak_label_object_id:
            return str(_id) + " FG"
        elif _id == sample.background_weak_label_object_id:
            return str(_id) + " BG"
        else:
            return str(_id)
    elif mode == "gt":
        gt_fg = sample.weak_label_id_ground_truth_object_id_mapping.get(sample.foreground_weak_label_object_id, None)
        gt_bg = sample.weak_label_id_ground_truth_object_id_mapping.get(sample.background_weak_label_object_id, None)
        
        if _id == gt_fg:
            return str(_id) + " FG"
        elif _id == gt_bg:
            return str(_id) + " BG"
        else:
            return str(_id)

indices = dataset.get_ground_truth_indices()

rows = len(indices)
fig, axs = get_mpl_figure(rows=rows, cols=3, size=5, tight=False, ratio_or_img=dataset[0].image, ax_mode="2d")

for i, index in enumerate(indices):

    sample = dataset[index]

    row_axs = axs[i]
    fig = plot_mask(sample.image, sample.trajectory_mask, ax=row_axs[0], labels=[add_label_info(x, sample, "weak") for x in sample.trajectory_mask_object_ids])
    row_axs[0].set_title("Weak Label: " + str(index))

    try:
        fig = plot_mask(sample.image, sample.ground_truth_mask, ax=row_axs[1], labels=[add_label_info(x, sample, "gt") for x in sample.ground_truth_object_ids])
        row_axs[1].set_title("GT: " + str(index))
    except Exception:
        pass
    
    try:
        fig = plot_mask(sample.image, 1 - sample.label, ax=row_axs[2])
        row_axs[2].set_title("GT Selected: " + str(index))
    except Exception:
        pass

display(fig)
plt.close(fig)
# try:
#     fig = plot_mask(sample.image, value_mask_to_channel_masks(sample.weak_label, ignore_value=2)[0], size=15)
#     display(fig)
#     plt.close(fig)
# except Exception as e:
#     print(e)
#     pass


display(sample.foreground_weak_label_object_id)
display(sample.background_weak_label_object_id)

In [None]:
sample.trajectory_mask

In [None]:
sample.trajectory_mask_object_ids

In [None]:
sample.weak_label_id_ground_truth_object_id_mapping

In [None]:
torch.unique(sample.weak_label)

In [None]:
display(sample.trajectory_mask)

In [None]:
torch.unique(sample.weak_label)

In [None]:
torch.unique(sample.trajectory_mask)

In [None]:
plot_mask(sample.image, sample.trajectory_mask, size=15)

In [None]:
sample.trajectory_mask_object_ids

In [None]:
from awesome.run.functions import value_mask_to_channel_masks

value_mask_to_channel_masks(sample.weak_label, 2)

In [None]:
sample = dataset[97]
sample.weak_label_id_ground_truth_object_id_mapping

In [None]:
for i in range(160, 170):
    sample = dataset[i]
    path = f"temp/{sample['feat_name']}_{i}.png"
    sample.plot_weak_labels(size=10, path=path, save=True, override=True)

In [None]:
from awesome.measures.miou import MIOU


In [None]:
for i in range(0, 3):
    index = i
    sample = dataset[index]
    display(sample.plot())
    display(sample.plot_weak_labels())
    display(sample.plot_selected_weak_labels())
    display(sample.plot_selected())
    

In [None]:
display(sample.plot_weak_labels())

In [None]:
masks = []
for i in range(10):
    masks.append(sample.weak_label)


In [None]:
m = torch.stack(masks, dim=0)

m[m == 2] = 0

m = m.sum(dim=0)

from awesome.run.functions import plot_as_image

plot_as_image(m, size=10, colorbar=True)

In [None]:
plt.figure()



In [None]:

sample = dataset[index]


sample.blur_image = True
display(sample.plot_selected())
display(sample.plot_selected_weak_labels())

In [None]:
compare_path = "../Self-supervised-Sparse-to-Dense-Motion-Segmentation/FBMS59-train-masks-with-confidence-flownet2-based/temp/test_pth/10_cars2.pth"

_cmp = torch.load(compare_path)
compare_input, compare_mask, compare_confidence, compare_image_path = _cmp



In [None]:
compare_mask.shape

In [None]:
from awesome.dataset.awesome_dataset import AwesomeDataset

awesome_dataset = AwesomeDataset(
    dataset, 
    xytype="edge",
    feature_dir = os.path.join(data_path, "Feat"), 
    edge_dir = os.path.join(data_path, "edge"),
    dimension = "3d",
    model_input_requires_grad=False,
    batch_size = 1,
    split_ratio = 1,
    shuffle_in_dataloader = False,
    image_channel_format = "bgr",
    do_image_blurring = True,
)

In [None]:
ret = awesome_dataset[1]
(img, feature_encoding, xy_clean, args), target = ret


In [None]:
compare_bgr, compare_edgemap = compare_input[:3, ...], compare_input[3, ...]

# Get the target mask in our format
cmp_m = compare_mask + 3
cmp_m[compare_mask == 1] = 0
cmp_m[compare_mask == 0] = 1

def amir_to_our_mask_format(compare_mask: torch.Tensor) -> torch.Tensor:
    cmp_m = compare_mask + 3
    cmp_m[compare_mask == 1] = 0
    cmp_m[compare_mask == 0] = 1
    return cmp_m

print("Image equal: ", torch.allclose(img, torch.tensor(compare_bgr)))
print("Edgemap equal: ", torch.allclose(feature_encoding, torch.tensor(compare_edgemap)))
print("Target equal: ", torch.allclose(target, torch.tensor(cmp_m,dtype=torch.float32)))


In [None]:
cmp_m = compare_mask + 3
cmp_m[compare_mask == 1] = 0
cmp_m[compare_mask == 0] = 1

In [None]:
import numpy as np
np.abs(img.numpy() - compare_bgr).sum()
np.abs(target.numpy() - compare_mask).sum()

In [None]:
from awesome.run.functions import plot_as_image

display(plot_as_image(img - compare_bgr, title="img - compare_bgr", open=True, colorbar=True))
display(plot_as_image(target - cmp_m, title="target - cmp_m", size=10, open=True, colorbar=True))

In [None]:
from awesome.run.functions import plot_as_image

display(plot_as_image(img))
display(plot_as_image(feature_encoding, title="feature_encoding"))
display(plot_as_image(target, title="target", open=Tr))

In [None]:
sample.weak_label

In [None]:
index = 2
sample = dataset[index]


u_id = dataset.unique_weak_label_object_ids

sample.plot_weak_labels(all_object_ids=u_id, cmap="tab10")

In [None]:
sample.ground_truth_object_id_weak_label_mapping

In [None]:
# #assert False, "Save the images"
# from tqdm.autonotebook import tqdm

# name = os.path.basename(dataset.dataset_path)
# max_len = 5 * 24 #len(dataset)
# paths = []
# for index in tqdm(range(min(len(dataset), max_len))):
#     sample = dataset[index]
#     path = f"output/gif/Traj_{name}_{index:03d}_.png"
#     paths.append(path)
#     fig = sample.plot_weak_labels(all_object_ids=u_id, size=5, save=True, path=path, override=True, cmap="tab10")
#     plt.close(fig)

# from awesome.util.gif_writer_images import GifWriterImages


# writer = GifWriterImages(f"{name}_traj.gif", paths, "output/gif")
# writer(1)


# Moving multicut tracks in data directory

In [None]:
assert False, "Stop here"
# Code for extracting the trajectories into the dataset folder
tracks_path = "data/local_datasets/FBMS-59/tracks"
dataset_dirs = "data/local_datasets/FBMS-59/test/"

import shutil

for folder in os.listdir(tracks_path):
    inner_path = "MulticutResults/pfldof0.5000004"
    complete_track_path = os.path.join(tracks_path, folder, inner_path)
    tracks_file = list(os.listdir(complete_track_path))[0]
    tracks_file_path = os.path.join(complete_track_path, tracks_file)

    target_path = os.path.join(dataset_dirs, folder, "tracks", "multicut")
    os.makedirs(target_path, exist_ok=True)
    target_file_path = os.path.join(target_path, tracks_file)
    shutil.copy(tracks_file_path, target_file_path)


## Moving hdf5 files

Moves the content of hdf5 files into the directory of the corresponding sequence.


In [None]:
#assert False, "Stop here"
import os
from tqdm.auto import tqdm
import h5py
from awesome.run.functions import save_mask
# Code for extracting the trajectories into the dataset folder
tracks_path = "data/local_datasets/FBMS-59/labels_with_uncertainty_flownet2_based"
dataset_dirs = "data/local_datasets/FBMS-59/train/"
import numpy as np

import shutil
algo_name = "labels_with_uncertainty_flownet2_based_new"

it = tqdm(os.listdir(tracks_path), desc="Processing folders")


for folder in it:
    it.set_description(f"Processing {folder}")
    complete_track_path = os.path.join(tracks_path, folder)
    h5_files = list(os.listdir(complete_track_path))

    target_path = os.path.join(dataset_dirs, folder, "weak_labels", algo_name)
    os.makedirs(target_path, exist_ok=True)
    confidence = None

    h5it = tqdm(h5_files, desc="Processing h5 files")
    for h5_file in h5it:
        if '.h5' not in h5_file:
            continue
        path = os.path.join(complete_track_path, h5_file)
        name = h5_file.split(".")[0]
        with h5py.File(path, "r") as f:
            # 0 = background, 1 = foreground, -1 = no label
            weak_label = np.asarray(f["img"]).T
            confidence = np.asarray(f["confidence"]).T
        
        mask = np.zeros_like(weak_label, dtype=np.uint8)

        # Reset labels indices
        vals = np.unique(weak_label)
        if len(vals) == 3:
            # Single object case
            if (0 in vals) and (1 in vals):

                mask[weak_label == 0] = 255
                mask[weak_label == 1] = 1
            else:
                mask[...] = weak_label[...] + 1

        else:
            mask[...] = weak_label[...] + 1

        save_mask(mask, os.path.join(target_path, f"{name}.png"))
        
        with h5py.File(os.path.join(target_path, f"{name}_confidence.h5"), "w") as f:
            f['confidence'] = confidence


## Unpack model checkpoints and just extract the state_dict

In [None]:
checkpoint_dir = "data/modelsUncertFbms/FBMS59-train-masks-with-confidence-flownet2-based/with-voting/checkpoint"
checkpoint_target_dir = "data/checkpoints/labels_with_uncertainty_flownet2_based"

if not os.path.exists(checkpoint_target_dir):
    os.makedirs(checkpoint_target_dir)

for file in os.listdir(checkpoint_dir):
    if ".pth" not in file:
        continue
    name = file.split(".")[0] + "_unet"
    state_dict = torch.load(os.path.join(checkpoint_dir, file), map_location=torch.device('cpu')).get('state_dict')
    if state_dict is None:
        print(f"Could not load {file}")
        continue
    torch.save(state_dict, os.path.join(checkpoint_target_dir, name + ".pth"))


In [None]:



checkpoint_dir = "data/modelsUncertFbms/FBMS59-train-masks-with-confidence-flownet2-based/with-voting/checkpoint"
checkpoint_target_dir = "data/checkpoints/labels_with_uncertainty_flownet2_based"

if not os.path.exists(checkpoint_target_dir):
    os.makedirs(checkpoint_target_dir)

for file in os.listdir(checkpoint_dir):
    if ".pth" not in file:
        continue
    name = file.split(".")[0] + "_unet"
    state_dict = torch.load(os.path.join(checkpoint_dir, file), map_location=torch.device('cpu')).get('state_dict')
    if state_dict is None:
        print(f"Could not load {file}")
        continue
    torch.save(state_dict, os.path.join(checkpoint_target_dir, name + ".pth"))