# 0. Setup (Don't unfold!)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
os.chdir('../')
os.environ['CUDA_VISIBLE_DEVICES'] = '2'

import sys
sys.path.append("Marigold")

import torch
import numpy as np
import jhutil; jhutil.color_log(1111, )

In [None]:
import sys

vggt = True
if vggt:
    sys.argv = [
        "ludvig_uplift.py", 
        "--colmap_dir", "./dataset/llff_data/fern/vggt/", 
        "--gs_source", "./dataset/llff_data/fern/vggt/point_cloud/iteration_0/point_cloud.ply", 
        "--config", "configs/dif_NVOS.yaml", 
        "--height", "1199", 
        "--width", "1600", 
        "--tag", "fern"
    ]
else:
    sys.argv = [
        "ludvig_uplift.py", 
        "--colmap_dir", "./dataset/llff_data/fern/", 
        "--gs_source", "./dataset/llff_data/fern/gs/point_cloud/iteration_30000/point_cloud.ply", 
        "--config", "configs/dif_NVOS.yaml", 
        "--height", "1199", 
        "--width", "1600", 
        "--tag", "fern"
    ]

In [None]:
from ludvig_uplift import *

args = parse_args()
reproducibility(0)
model = LUDVIGUplift(args)

# 1. Load dataset

In [None]:
t0 = time()
print("Uplifting features...")
directory = model.config['feature'].pop(
    'directory',
    os.path.join(model.colmap_dir, 'images')
)
dataset = config_to_instance(
    directory=directory,
    gaussian=model.gaussian,
    cameras=model.colmap_cameras,
    render_fn=model.render,
    scene=model.scene,
    height=model.img_height,
    width=model.img_width,
    **model.config.feature,
)
loader = iter(dataset)

# 2. Uplift

In [None]:
features, _ = uplifting(
    loader,
    model.gaussian,
    prune_gaussians=model.config.get("prune_gaussians", None),
)
if model.config.get('normalize', False):
    print("l2-normalizing uplifted features.")
    features /= features.norm(dim=1, keepdim=True) + 1e-6
print(
    f"Total time for preprocessing + uplifting {len(model.colmap_cameras)} images: {round(time()-t0)}s"
)
model.features = features

# 3. HPO

In [None]:
########################################################
# load eval_fn
########################################################

from evaluation.spin_nvos.diffusion import SegmentationDiffusionNVOS

os.makedirs(model.logdir, exist_ok=True)
cfg_path = os.path.join(model.logdir, "config.yaml")
yaml.dump(model.config, open(cfg_path, "w"))
eval_kwargs = model.config.get("evaluation", dict())

eval_fn: SegmentationDiffusionNVOS = config_to_instance(
    gaussian=model.gaussian,
    features=model.features,
    render_fn=model.render,
    render_rgb=model.render_rgb,
    logdir=model.logdir,
    image_dir=model.colmap_dir,
    colmap_cameras=model.colmap_cameras,
    scene=model.scene,
    height=model.img_height,
    width=model.img_width,
    **model.config.evaluation,
)

In [None]:
# args = eval_fn.hyperparameter_search()

In [None]:
from diffusion.segmentation import GraphDiffusionSeg
from utils.graph import energy_fn

def graph_call(graph:GraphDiffusionSeg, features):
    features = graph.normalize_features(features)

    if graph.initial_features is None:
        graph.compute_initial_features()
        graph.mask = graph.initial_features.squeeze() > 0
        graph.precompute_similarities(features)
    
    # similarities = graph.compute_similarities()
    similarities = energy_fn(
        graph.similarities, graph.feature_bandwidth, graph.mask
    )
    graph.compute_regularizer(features)
    similarities *= torch.sqrt(
        graph.reg_similarities[graph.knn_neighbor_indices] * graph.reg_similarities[:, None]
    )
    
    diffused_features = graph.run_diffusion(similarities, binarize=1e-5)
    diffused_features = (diffused_features>0) * graph.reg_similarities[:,None].type(torch.float32)
    
    return diffused_features, graph.reg_similarities



In [None]:
# np.linspace(1, 4, 5)
# np.linspace(1, 4, 10)

In [None]:
from tqdm import tqdm
from itertools import product

frange = np.arange(1, 5)
grange = np.arange(1, 5)
k_best = 0
f_best = 2
g_best = 2
results = []
best_iou = 0
param_combinations = list(product(frange, grange))
with tqdm(total=len(param_combinations)) as pbar:
    for f, g in param_combinations:
        eval_fn.graph.feature_bandwidth = 2.0**f
        eval_fn.graph.reg_bandwidth = 2.0**g
        eval_fn.manifold_features, _ = graph_call(eval_fn.graph, eval_fn.features)
        cur_iou, k_iou = eval_fn.segment_and_evaluate(
            eval_fn.manifold_features,
            save=False,
            use_sam=eval_fn.sam_model is not None,
        )
        if cur_iou > best_iou:
            best_iou = cur_iou
            k_best = k_iou
            f_best = 2.0**f
            g_best = 2.0**g
        pbar.update(1)
        results.append((f, g, cur_iou))
eval_fn.graph.feature_bandwidth = f_best
eval_fn.graph.reg_bandwidth = g_best
eval_fn.graph.trace_name = eval_fn.trace_name
eval_fn.manifold_features, eval_fn.reg_similarities = eval_fn.graph.__call__(eval_fn.features)

# 4. Render mask 

In [None]:
# eval_fn.evaluate(k_best, f_best, g_best)

In [None]:
from evaluation.spin_nvos.base import *

features = eval_fn.manifold_features
ev_name = eval_fn.ev_name
img_name = ev_name.split("/")[-1]
camera = next(
    cam
    for cam in eval_fn.colmap_cameras
    if cam.image_name == eval_fn.mask_to_img[img_name]
)

gt_path = eval_fn.gtpath_from_name(ev_name)
gt_img = Image.open(gt_path)

anchor = eval_fn.render_fn(features.repeat(1, 3), camera)[:1]

anchor = viz_normalization(anchor, dim=range(len(anchor.shape)))
_img_up = resize(anchor, (gt_img.size[1], gt_img.size[0])).squeeze()


In [None]:
_img_up.chans

# 5. Get IOU

In [None]:
# from utils.evaluation import segmentation_loop

# best_iou, mask_best, _, _ = segmentation_loop(
#     _img_up, gt_img, k_best, metric="iou"
# )
# print("IoU:", round(best_iou, 3))

In [None]:
from utils.evaluation import to_pil, iou

mask_2d = _img_up > (1 - k_best / 100)
mask_best_iou = to_pil(mask_2d)
img_arr, gt_img_arr = np.array(mask_best_iou), np.array(gt_img)
img_arr = img_arr // max(img_arr.max(), 1)
gt_img_arr = gt_img_arr // gt_img_arr.max()
best_iou = iou(img_arr, gt_img_arr, class_label=1)

print("IoU:", round(best_iou, 3))

In [None]:
mask_2d.chans