In [1]:
import sys
import os
import glob
sys.path.append('..')
import copy
import cv2
import numpy as np
import torch
from vm2m.utils.metric import build_metric
from vm2m.dataloader import MultiInstVidDataset

In [5]:
# Public OTVM
# pred_dir = "/home/chuongh/OTVM/output/benchmark"

# Finetuned version
pred_dir = "/home/chuongh/OTVM/output/xmem_finetuned"

# Public FTP-VM
# pred_dir = "/home/chuongh/FTP-VM/publicoutput"

# Finetuned version
pred_dir = "/home/chuongh/FTP-VM/output/ft_22k_rgb"

In [6]:
# create dataset and dataloader
def evaluate(subset):
    dataset = MultiInstVidDataset(root_dir="/mnt/localssd/syn/benchmark", split=subset, clip_length=1, overlap=0, padding_inst=10, is_train=False, short_size=576, 
                        crop=[512, 512], flip_p=0.5, bin_alpha_max_k=30,
                        max_step_size=5, random_seed=2023, mask_dir_name='mask', pha_dir='pha', weight_mask_dir='', is_ss_dataset=False)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

    val_error_dict = build_metric(['MAD', 'MSE', 'SAD', 'dtSSD']) 
    val_error_dict["MAD_fg"] = copy.deepcopy(val_error_dict['MAD'])
    val_error_dict["MAD_bg"] = copy.deepcopy(val_error_dict['MAD'])
    val_error_dict["MAD_unk"] = copy.deepcopy(val_error_dict['MAD'])

    # evaluate the model
    video_name = None
    all_preds = []
    all_gts = []
    all_trimap = []
    for i, batch in enumerate(dataloader):
        print(f"{i}/{len(dataloader)}")
        image_names = batch.pop('image_names')
        trimap = batch.pop('trimap').numpy()
        alpha_gt = batch.pop('alpha').numpy()
        
        if image_names[0][0].split('/')[-2] != video_name:
            if len(all_gts) > 0:
                all_preds = np.stack(all_preds, axis=0)
                all_trimap = np.concatenate(all_trimap, axis=0)
                all_gts = np.concatenate(all_gts, axis=0)

                current_metrics = {}
                for k, v in val_error_dict.items():
                    current_trimap = None
                    if k.endswith("_fg"):
                        current_trimap = (all_trimap[None] == 2).astype('float32')
                    elif k.endswith("_bg"):
                        current_trimap = (all_trimap[None] == 0).astype('float32')
                    elif k.endswith("_unk"):
                        current_trimap = (all_trimap[None] == 1).astype('float32')
                    current_metrics[k] = v.update(all_preds[None], all_gts[None], trimap=current_trimap)
                
                log_str = f"{video_name}: "
                for k, v in current_metrics.items():
                    log_str += "{} - {:.4f}, ".format(k, v)
                all_preds = []
                all_gts = []
                all_trimap = []
                print(video_name, log_str)
            video_name = image_names[0][0].split('/')[-2]

        all_gts.append(alpha_gt[0])
        all_trimap.append(trimap[0])
        
        # TODO: load the prediction
        video_name, frame_name = image_names[0][0].split('/')[-2:]
        mask_names = glob.glob(os.path.join(pred_dir, subset, video_name, frame_name.replace(".jpg", "")) + '/*.png')
        mask_names = sorted(mask_names)
        all_masks = []
        for mask_name in mask_names:
            mask = cv2.imread(mask_name, cv2.IMREAD_GRAYSCALE)
            mask = cv2.resize(mask, (alpha_gt.shape[-1], alpha_gt.shape[-2]))
            mask = mask / 255.0
            all_masks.append(mask)
        all_masks = np.stack(all_masks, axis=0)
        all_preds.append(all_masks)

    all_preds = np.stack(all_preds, axis=0)
    all_trimap = np.concatenate(all_trimap, axis=0)
    all_gts = np.concatenate(all_gts, axis=0)

    current_metrics = {}
    for k, v in val_error_dict.items():
        current_trimap = None
        if k.endswith("_fg"):
            current_trimap = (all_trimap[None] == 2).astype('float32')
        elif k.endswith("_bg"):
            current_trimap = (all_trimap[None] == 0).astype('float32')
        elif k.endswith("_unk"):
            current_trimap = (all_trimap[None] == 1).astype('float32')
        current_metrics[k] = v.update(all_preds[None], all_gts[None], trimap=current_trimap)

    log_str = f"{video_name}: "
    for k, v in current_metrics.items():
        log_str += "{} - {:.4f}, ".format(k, v)
    print(video_name, log_str)

    print("Metrics:")
    metric_str = ""
    plain_str = ""
    for k, v in val_error_dict.items():
        metric_str += "{}: {}\n".format(k, v.average())
        plain_str += str(v.average()) + ","
    print(metric_str)
    print(plain_str)

In [7]:
evaluate("comp_easy")

0/465
1/465
2/465
3/465
4/465
5/465
6/465
7/465
8/465
9/465
10/465
11/465
12/465
13/465
14/465
15/465
16/465
17/465
18/465
19/465
20/465
21/465
22/465
23/465
24/465
25/465
26/465
27/465
28/465
29/465
30/465
00001 00001: MAD - 3.0862, MSE - 2.0635, SAD - 6.3996, dtSSD - 0.1183, MAD_fg - 30.2133, MAD_bg - 0.0000, MAD_unk - 49.9587, 
31/465
32/465
33/465
34/465
35/465
36/465
37/465
38/465
39/465
40/465
41/465
42/465
43/465
00002 00002: MAD - 3.0199, MSE - 1.0898, SAD - 6.2622, dtSSD - 0.2206, MAD_fg - 0.2358, MAD_bg - 0.0128, MAD_unk - 62.0097, 
44/465
45/465
46/465
47/465
48/465
49/465
50/465
51/465
52/465
53/465
54/465
55/465
56/465
57/465
58/465
59/465
60/465
61/465
62/465
63/465
64/465
65/465
66/465
67/465
68/465
69/465
70/465
71/465
00005 00005: MAD - 1.4485, MSE - 0.7052, SAD - 3.0037, dtSSD - 0.1115, MAD_fg - 1.0719, MAD_bg - 0.0941, MAD_unk - 46.5243, 
72/465
73/465
74/465
75/465
76/465
77/465
78/465
79/465
80/465
81/465
82/465
83/465
84/465
85/465
86/465
87/465
88/465
89/465
90/4

-- BGR

10.529733940243519,8.887114821314155,20.245137786202328,0.254121503874973,45.79688570206778,3.220213327375712,81.53992698174184,

-- RGB

2.6376528049379617,1.1923663155268645,5.071319443698794,0.15257648995907963,3.9979234912484305,0.06640529503092797,55.2621984315902,


In [8]:
evaluate("comp_medium")

0/478
1/478
2/478
3/478
4/478
5/478
6/478
7/478
8/478
9/478
10/478
11/478
12/478
13/478
14/478
15/478
16/478
17/478
18/478
19/478
20/478
21/478
22/478
23/478
24/478
25/478
26/478
27/478
28/478
29/478
30/478
00029 00029: MAD - 11.2975, MSE - 9.1051, SAD - 19.7661, dtSSD - 0.4121, MAD_fg - 32.0365, MAD_bg - 2.5235, MAD_unk - 125.1053, 
31/478
32/478
33/478
34/478
35/478
36/478
37/478
38/478
39/478
40/478
41/478
42/478
43/478
44/478
45/478
46/478
47/478
48/478
49/478
50/478
51/478
52/478
53/478
54/478
55/478
56/478
57/478
58/478
59/478
60/478
00033 00033: MAD - 5.1585, MSE - 2.7638, SAD - 9.0253, dtSSD - 0.3118, MAD_fg - 5.1336, MAD_bg - 0.0984, MAD_unk - 80.0377, 
61/478
62/478
63/478
64/478
65/478
66/478
67/478
68/478
69/478
70/478
71/478
72/478
73/478
74/478
75/478
76/478
77/478
78/478
79/478
80/478
81/478
82/478
83/478
84/478
85/478
86/478
87/478
88/478
89/478
90/478
00037 00037: MAD - 2.3379, MSE - 0.5688, SAD - 4.0903, dtSSD - 0.1332, MAD_fg - 0.1099, MAD_bg - 0.0067, MAD_unk - 47.5

In [9]:
evaluate("comp_hard")

0/495
1/495
2/495
3/495
4/495
5/495
6/495
7/495
8/495
9/495
10/495
11/495
12/495
13/495
00002 00002: MAD - 3.5496, MSE - 1.7638, SAD - 7.3604, dtSSD - 0.2677, MAD_fg - 3.6665, MAD_bg - 0.0250, MAD_unk - 71.7783, 
14/495
15/495
16/495
17/495
18/495
19/495
20/495
21/495
22/495
23/495
24/495
25/495
26/495
27/495
28/495
29/495
30/495
31/495
32/495
33/495
34/495
35/495
36/495
37/495
38/495
39/495
40/495
41/495
42/495
43/495
00012 00012: MAD - 3.9079, MSE - 2.0206, SAD - 8.1033, dtSSD - 0.1387, MAD_fg - 0.8275, MAD_bg - 0.4181, MAD_unk - 73.8174, 
44/495
45/495
46/495
47/495
48/495
49/495
50/495
51/495
52/495
53/495
54/495
55/495
56/495
57/495
58/495
59/495
60/495
61/495
62/495
63/495
64/495
65/495
66/495
67/495
68/495
69/495
70/495
71/495
72/495
73/495
00015 00015: MAD - 13.8308, MSE - 10.3027, SAD - 25.8713, dtSSD - 0.4720, MAD_fg - 20.3308, MAD_bg - 4.5014, MAD_unk - 98.0010, 
74/495
75/495
76/495
77/495
78/495
79/495
80/495
81/495
82/495
83/495
84/495
85/495
86/495
87/495
88/495
89/495
9

In [10]:
evaluate("real")

0/679
1/679
2/679
3/679
4/679
5/679
6/679
7/679
8/679
9/679
10/679
11/679
12/679
13/679
14/679
15/679
3048929 3048929: MAD - 4.7338, MSE - 2.3680, SAD - 9.8161, dtSSD - 0.4238, MAD_fg - 1.2590, MAD_bg - 0.0467, MAD_unk - 79.9148, 
16/679
17/679
18/679
19/679
20/679
21/679
22/679
23/679
24/679
25/679
26/679
27/679
28/679
Pexels_Videos_2795385 Pexels_Videos_2795385: MAD - 4.1757, MSE - 1.7795, SAD - 8.6588, dtSSD - 0.3786, MAD_fg - 0.0706, MAD_bg - 0.0013, MAD_unk - 85.5926, 
29/679
30/679
31/679
32/679
33/679
34/679
35/679
36/679
37/679
38/679
39/679
40/679
41/679
42/679
43/679
44/679
45/679
46/679
47/679
48/679
49/679
50/679
pexels-artem-podrez-5457692 pexels-artem-podrez-5457692: MAD - 11.2718, MSE - 6.2919, SAD - 23.3732, dtSSD - 0.6817, MAD_fg - 10.8988, MAD_bg - 0.0042, MAD_unk - 88.7010, 
51/679
52/679
53/679
54/679
55/679
56/679
57/679
58/679
59/679
60/679
61/679
62/679
63/679
64/679
65/679
66/679
67/679
68/679
69/679
70/679
71/679
72/679
73/679
74/679
75/679
76/679
77/679
78/679