In [1]:
%cd ..

/home/ubuntu/dev/pepper/projects/metric_learning_playground


In [38]:
import os.path as osp
from numbers import Number
from copy import deepcopy

import numpy as np

import matplotlib
import matplotlib.pyplot as plt

In [3]:
%matplotlib inline
# %matplotlib widget

In [25]:
from mmcv import Config
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import load_checkpoint

from mmcls.apis import multi_gpu_test, single_gpu_test
from mmcls.datasets import build_dataset, build_dataloader
from mmcls.models import build_classifier

from src import *
from src.apis import single_gpu_metric_test

In [14]:
# configs
model = 'lenet'
dataset = 'mnist'
ckpt_iter = 6000
ckpt_path = osp.join('work_dirs', f'{model}_{dataset}', f'iter_{ckpt_iter}.pth')
cfg_fp = osp.join('configs', model, f'{model}_{dataset}.py')
assert osp.exists(cfg_fp)
assert osp.exists(ckpt_path)

# setup variables

cfg = Config.fromfile(cfg_fp)
# print(cfg.pretty_text)

In [17]:
dataset = build_dataset(cfg.data.test, default_args=dict(test_mode=True))

loader_cfg = dict(
    samples_per_gpu=128,
    workers_per_gpu=2,
    num_gpus=1,
    dist=False,
    round_up=True,
)
test_loader_cfg = {
    **loader_cfg,
    "shuffle": False,
    "sampler_cfg": None,
    **cfg.data.get("test_dataloader", {}),
}
data_loader = build_dataloader(dataset, **test_loader_cfg)

In [23]:
model = build_classifier(cfg.model)
checkpoint = load_checkpoint(model, ckpt_path, map_location="cpu")

load checkpoint from local path: work_dirs/lenet_mnist/iter_6000.pth


In [24]:
CLASSES = checkpoint["meta"]["CLASSES"]
print(CLASSES)

['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']


In [27]:
model = MMDataParallel(model, device_ids=[0])
model.CLASSES = CLASSES

preds, feats = single_gpu_metric_test(
    model, data_loader,
)

[>>>>>>>                  ] 2999/10000, 2607.4 task/s, elapsed: 1s, ETA:     3s

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[>>>>>>>>>>>>>>>>>>>>>>>  ] 9253/10000, 2899.1 task/s, elapsed: 3s, ETA:     0s

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [33]:
eval_results = dataset.evaluate(
    results=preds,
    metric=["accuracy", "f1_score"],
    metric_options=None,
)
for k, v in eval_results.items():
    if isinstance(v, np.ndarray):
        v = [round(out, 2) for out in v.tolist()]
    elif isinstance(v, Number):
        v = round(v, 2)
    else:
        raise ValueError(f"Unsupport metric type: {type(v)}")
    print(f"\n{k} : {v}")


accuracy_top-1 : 98.9

accuracy_top-5 : 99.87

f1_score : 98.91


In [None]:
num_classes = 10
labes = np.array(dataset.get_gt_labels())
features = np.array(deepcopy(feats))

colors = ["C0", "C1", "C2", "C3", "C4", "C5", "C6", "C7", "C8", "C9"]
for label_idx in range(num_classes):
    plt.scatter(
        features[labels == label_idx, 0],
        features[labels == label_idx, 1],
        c=colors[label_idx],
        s=1,
    )
plt.legend(
    ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], loc="upper right"
)