In [1]:
%matplotlib inline
import matplotlib.pyplot as plt

import sys
import numpy as np
import torch

sys.path.insert(0, '..')
from isegm.utils import vis, exp

from isegm.inference import utils
from isegm.inference.evaluation import evaluate_dataset, evaluate_sample

device = torch.device('cuda:0')
cfg = exp.load_config_file('../config.yml', return_edict=True)

### Init dataset

In [2]:
# Possible choices: 'GrabCut', 'Berkeley', 'DAVIS', 'COCO_MVal', 'SBD'
DATASET = 'Avalanche3'
dataset = utils.get_dataset(DATASET, cfg)

In [3]:
from isegm.inference.predictors import get_predictor
from pathlib import Path
EVAL_MAX_CLICKS = 20
MODEL_THRESH = 0.49

checkpoint_path = utils.find_checkpoint(cfg.INTERACTIVE_MODELS_PATH, 'coco_lvis_h18_itermask.pth')
model = utils.load_is_model(checkpoint_path, device)

# Possible choices: 'NoBRS', 'f-BRS-A', 'f-BRS-B', 'f-BRS-C', 'RGB-BRS', 'DistMap-BRS'
brs_mode = 'NoBRS'
predictor = get_predictor(model, brs_mode, device, prob_thresh=MODEL_THRESH)

### Dataset evaluation

In [4]:
TARGET_IOU = 0.9

all_ious, elapsed_time = evaluate_dataset(dataset, predictor, pred_thr=MODEL_THRESH, 
                                          max_iou_thr=TARGET_IOU, max_clicks=EVAL_MAX_CLICKS)
mean_spc, mean_spi = utils.get_time_metrics(all_ious, elapsed_time)
noc_list, over_max_list = utils.compute_noc_metric(all_ious,
                                                   iou_thrs=[0.8, 0.85, 0.9],
                                                   max_clicks=EVAL_MAX_CLICKS)

header, table_row = utils.get_results_table(noc_list, over_max_list, brs_mode, DATASET,
                                            mean_spc, elapsed_time, EVAL_MAX_CLICKS)
print(header)
print(table_row)

  0%|          | 0/36 [00:00<?, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


-----------------------------------------------------------------------------------------------
|  BRS Type   |  Dataset  | NoC@80% | NoC@85% | NoC@90% |>=20@85% |>=20@90% | SPC,s |  Time   |
-----------------------------------------------------------------------------------------------
|    NoBRS    | Avalanche |  5.47   |  6.36   |  9.47   |    1    |    6    | 0.196 | 0:01:06 |


### Single sample eval

In [5]:
import cv2

for sample_id in range(0,len(dataset)):

    TARGET_IOU = 0.95

    sample = dataset.get_sample(sample_id)
    gt_mask = sample.gt_mask

    sample.image = cv2.resize(sample.image, dsize=(3000,2000))
    gt_mask = cv2.resize(gt_mask, (3000,2000), interpolation =cv2.INTER_NEAREST)

    clicks_list, ious_arr, pred = evaluate_sample(sample.image, gt_mask, predictor, 
                                                pred_thr=MODEL_THRESH, 
                                                max_iou_thr=TARGET_IOU, max_clicks=EVAL_MAX_CLICKS)

    pred_mask = pred > MODEL_THRESH
    draw = vis.draw_with_blend_and_clicks(sample.image, mask=pred_mask, clicks_list=clicks_list)
    draw = np.concatenate((draw,
        255 * pred_mask[:, :, np.newaxis].repeat(3, axis=2),
        255 * (gt_mask > 0)[:, :, np.newaxis].repeat(3, axis=2)
    ), axis=1)

    print(ious_arr)

    plt.figure(figsize=(40, 60))
    plt.imshow(draw)
    plt.show()

[0.07218324 0.48172757 0.6717496  0.8051948  0.82116246 0.87198067
 0.8568147  0.86165047 0.87363493 0.8716577  0.8767661  0.9022674
 0.9026275  0.91043615 0.91341656 0.91844815 0.919586   0.921963
 0.9208401  0.91775244]


KeyboardInterrupt: 