In [None]:
%cd "~/Projects/Segmentation/TreeSeg"
import json
import math
import os
from pathlib import Path

import hashlib
import matplotlib.pyplot as plt
import networkx as nx
import nvdiffrast.torch as dr
import torch
from torch import Tensor, nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import yaml
import cv2
from rich.console import Console
from rich.tree import Tree
import torch_geometric as pyg

import tree_segmentation as ts
import tree_segmentation.extension as ext
from tree_segmentation.extension import ops_3d, Mesh, utils
from semantic_sam import SemanticSAM, semantic_sam_l, semantic_sam_t
from segment_anything import build_sam
from tree_segmentation import  TreePredictor, TreeSegmentMetric, Tree2D, MaskData
from tree_segmentation.util import show_masks, show_all_levels, get_hash_name
from evaluation.batch_eval_PartNet import get_mesh_and_gt_tree, get_images
import pycocotools.mask as mask_util

In [None]:
print(torch.__version__)
%load_ext autoreload
%autoreload 2
%matplotlib inline
torch.set_grad_enabled(False)
console = Console()
device = torch.device("cuda")
# device = torch.device("cpu")
utils.set_printoptions(linewidth=120)

In [None]:
data_root = Path('/data5/SA-1B/')
data_part_1 = data_root.joinpath(f"{0:06d}")
images_paths = list(data_part_1.glob('*.jpg'))
print(f"There are {len(images_paths)} images")
print(len(list(data_part_1.glob('*.json'))))

In [None]:
def read_annotations(json_path: Path):
    masks = []
    scores = []
    with open(json_path, 'r') as f:
        data = json.load(f)
    annotations = data['annotations']
    for ann in annotations:
        scores.append(ann['predicted_iou'])
        masks.append(mask_util.decode(ann['segmentation']))
    scores = np.stack(scores)
    masks = np.stack(masks)
    tree2d = Tree2D(MaskData(masks=torch.from_numpy(masks), iou_preds=torch.from_numpy(scores)))
    tree2d.update_tree()
    return tree2d


# Debug

In [None]:
image = utils.load_image(images_paths[0])
plt.imshow(image)

In [None]:
tree2d = read_annotations(images_paths[0].with_suffix('.json'))
tree2d.print_tree()
show_all_levels(image, tree2d, alpha=0.8)


In [None]:
# load SAM
assert torch.cuda.is_available()
# model = build_sam(Path('./weights/sam_vit_h_4b8939.pth').expanduser())
model = semantic_sam_l(Path("./weights/swinl_only_sam_many2many.pth").expanduser())
# model = semantic_sam_t( Path("./weights/swint_only_sam_many2many.pth").expanduser())
model = model.eval().to(device)
tree_seg = TreePredictor(model, box_nms_thresh=0.9)

In [None]:
print(image.shape)
H, W, _ = image.shape
scale = min(1024 / H, 1024 / W)
image_resized = cv2.resize(image, (int(scale * W), int(scale * H)), interpolation=cv2.INTER_AREA)
print(image_resized.shape)
plt.imshow(image_resized)
prediction = tree_seg.generate(image_resized)

In [None]:
prediction.print_tree()
show_all_levels(image_resized, prediction)

In [None]:
from tree_segmentation.metric import TreeSegmentMetric
metric = TreeSegmentMetric(is_resize_2d_as_gt=True)
metric.update(prediction.to(device), tree2d.to(device))
for k, v in metric.summarize().items():
    print(k, v)

In [None]:
mask_a = prediction.masks[142 - 1]
mask_b = prediction.masks[150 - 1]
print(mask_a.sum(), mask_b.sum())
inter = (mask_a * mask_b).sum()
print(inter)
print(inter / (mask_a.sum() + mask_b.sum() - inter))
show_masks(None, mask_a, mask_b)

In [None]:
print(metric.calc_tree_structure_score(prediction))

# Evalutate Semantic-SAM-L

In [None]:
# choose eval images
part_idx = 110 # < 110
num_eval = 10

images_paths =sorted( list(data_root.joinpath(f"{part_idx:06d}").glob('*.jpg')))
np.random.seed(42)
eval_image_paths = np.random.choice(images_paths, num_eval)
print(f"Try To evaluate {len(eval_image_paths)} image")

In [None]:
# load model
model = semantic_sam_l(Path("./weights/swinl_only_sam_many2many.pth").expanduser())
# model = semantic_sam_t( Path("./weights/swint_only_sam_many2many.pth").expanduser())
model = model.eval().to(device)
tree_seg = TreePredictor(model)
# init metric
metirc = TreeSegmentMetric()

In [None]:
from tree_segmentation.metric import TreeSegmentMetric
timer = utils.TimeEstimator(num_eval)
time_avg = utils.TimeWatcher()
timer.start()
time_avg.start()
for i, image_path in enumerate(eval_image_paths, 1):
    image = utils.load_image(image_path)
    H, W, _ = image.shape
    scale = min(1024 / H, 1024 / W)
    image = cv2.resize(image, (int(scale * W), int(scale * H)), interpolation=cv2.INTER_AREA)
    time_avg.log('image')
    gt = read_annotations(image_path.with_suffix('.json'))
    time_avg.log('gt')
    prediction = tree_seg.generate(image, device=device)
    time_avg.log('tree2d')
    metirc.update(prediction, gt.to(device), return_match=False)
    time_avg.log('metric')
    timer.step()
    if i % 2 == 0:
        print(f'Process [{i+1}/{num_eval}], time: {timer.progress}',
              ', '.join(f'{k}: {utils.float2str(v)}' for k, v in metirc.summarize().items()))

print('Complete Evalution')
print('Time:', time_avg)
for k, v in metirc.summarize().items():
    print(f"{k:5s}: {v}")