In [None]:
from ool.clevrtex_eval import CLEVRTEX, collate_fn
data = CLEVRTEX('~/ool_data/', dataset_variant='full', split='test')

In [None]:
import torch
device= torch.device('cuda:0')

In [None]:
import numpy as np
from pprint import pprint
from matplotlib import pyplot as plt

In [None]:
from ool.clevrtex_eval import CLEVRTEX_Evaluator

In [None]:
import pytorch_lightning as pl
import torch
import sys
sys.path.insert(0, '~/experiments')
from pathlib import Path
import gc

from tqdm.auto import tqdm
import warnings


def array(key, *res):
    return np.array([r.statistic(key) for r in res])

@torch.no_grad()
def evaluate_checkpoint(cls, ckpt, data, bg=True, step=False, key='steps', last=False, half_batch=False):
    ref_path = Path('output')
    ckpt = Path(ckpt)
    model = cls.load_from_checkpoint(checkpoint_path=ckpt)
    log_path = ckpt.relative_to(ref_path).parent
    model.nowatermark=True
    model.ddp = False
    model.gc=-1
    model.cpu_metrics=False
    model.workers=8
    model.drop_last_batch = False
    model.eval()

    global_step = model.global_step
    model.trainer = None
    evl = CLEVRTEX_Evaluator(masks_have_background=bg)
    bs = model.hparams.batch_size
    
    if half_batch:
        bs = bs // 2
    dl = torch.utils.data.DataLoader(data, batch_size=bs, shuffle=False, collate_fn=collate_fn, drop_last=last, num_workers=8)
    m = model.model
    del model
    model = m.eval().to(device)
    gc.collect()
    for batch in tqdm(dl):
        ind, img, mask, meta = batch
        img = img.to(device)
        mask = mask.to(device)
        if step:
            out = model(img, global_step)
        else:
            out = model(img)
        evl.update(out['canvas'], out[key]['mask'], img, mask, meta)
    del model
    return evl

def print_res(res):
    acc = array('acc', *res)
    ari = array('ARI', *res)
    iou = array('mIoU_fg', *res)
    ari_fg = array('ARI_FG', *res)
    mse = array('MSE', *res)
    print('& acc & ari & arr_fg & miou & mse \\\\')
    print(f'& {acc.mean():6.5f} & {ari.mean():6.5f} & {ari_fg.mean():6.5f} & {iou.mean():6.5f} & {mse.mean():6.5f} \\\\')
    if len(res) > 1:
        print(f'& \\tiny \\(\\pm {acc.std():6.5f}\\) & \\tiny \\(\\pm {ari.std():6.5f}\\) & \\tiny \\(\\pm {ari_fg.std():6.5f}\\) & \\tiny \\(\\pm {iou.std():6.5f}\\) & \\tiny \\(\\pm {mse.std():6.5f}\\) \\\\')

### Evaluate several trained models

The res contains a set of CLEVRTEX_Evaluator objects that can be futher inspected to see the breakdown

In [None]:
from gnm import LitGNM
prefix = 'output/'
model = '/checkpoints/last.ckpt'
runs = [
    prefix + 'clt_fullcm/gnm/gnm-on-clevrtex' + model,  # Change this to a checkpoint of interest
]

with warnings.catch_warnings():
    warnings.filterwarnings('ignore',
                            message='The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details.')
    warnings.filterwarnings('ignore',
                            message='Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.')
        
    res = [
        evaluate_checkpoint(LitGNM, r, data, bg=False, step=True, key='steps') for r in runs
    ]
print_res(res)

In [None]:
from space import LitSPACE
runs = []
with warnings.catch_warnings():
    warnings.filterwarnings('ignore',
                            message='The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details.')
    warnings.filterwarnings('ignore',
                            message='Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.')
        
    res = [
        evaluate_checkpoint(LitSPACE, r, data, bg=False, step=True, key='steps') for r in runs
    ]
print_res(res)

In [None]:
from genesisv2 import LitGenesis
runs = []
with warnings.catch_warnings():
    warnings.filterwarnings('ignore',
                            message='The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details.')
    warnings.filterwarnings('ignore',
                            message='Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.')
        
    res = [
        evaluate_checkpoint(LitGenesis, r, data, bg=True, step=False, key='layers') for r in runs
    ]
print_res(res)

In [None]:
from spair import LitSpair

with warnings.catch_warnings():
    warnings.filterwarnings('ignore',
                            message='The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details.')
    warnings.filterwarnings('ignore',
                            message='Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.')
        
    res = [
        evaluate_checkpoint(LitSpair, r, data, bg=False, step=False, key='steps') for r in runs
    ]
    
print_res(res)

In [None]:
from slota import LitSlot
import warnings

with warnings.catch_warnings():
    warnings.filterwarnings('ignore',
                            message='The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details.')
    warnings.filterwarnings('ignore',
                            message='Default grid_sample and affine_grid behavior has changed to align_corners=False since 1.3.0. Please specify align_corners=True if the old behavior is desired. See the documentation of grid_sample for details.')
        
    res = [
        evaluate_checkpoint(LitSlot, r, data, bg=True, step=False, key='layers') for r in runs
    ]
    
print_res(res)