In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
import pandas as pd
from PIL import Image
import numpy as np
from collections import Counter

import kuzushiji.viz
from kuzushiji.data_utils import get_encoded_classes, load_train_df

In [None]:
df_train = load_train_df()

In [None]:
df = pd.read_csv('_runs/clf-resnet152/errors.csv.gz')
df.head()

In [None]:
(df['pred'] == df['true']).mean()

In [None]:
(df['true'] == 'seg_fp').mean()

In [None]:
(df['pred'] == 'seg_fp').mean()

In [None]:
(df[df['true'] == 'seg_fp']['pred'] == 'seg_fp').mean()

In [None]:
(df[df['pred'] == 'seg_fp']['true'] == 'seg_fp').mean()

In [None]:
df_true_no_seg_fp = df[df['true'] != 'seg_fp']
(df_true_no_seg_fp['true'] == df_true_no_seg_fp['pred']).mean()

In [None]:
err_items = df[df['pred'] != df['true']]
err_items_chars = err_items[(err_items['true'] != 'seg_fp') & (err_items['pred'] != 'seg_fp')]
err_items_seg_fp_fn = err_items[err_items['true'] == 'seg_fp']
err_items_seg_fp_fp = err_items[err_items['pred'] == 'seg_fp']
print('Error kinds:')
print(f'{len(err_items_chars) / len(df):.2%} true char != pred char')
print(f'{len(err_items_seg_fp_fn + err_items_seg_fp_fp) / len(df):.2%} with segfp')
print(f'{len(err_items_seg_fp_fn) / len(df):.2%} true segfp != pred char')
print(f'{len(err_items_seg_fp_fp) / len(df):.2%} true char != pred segfp')

In [None]:
image_error_count = df[df['pred'] != df['true']].groupby('image_id')['pred'].count().sort_values(ascending=False)
image_error_count

In [None]:
freqs = Counter()
for label in df_train['labels']:
    for i, c in enumerate(label.split()):
        if i % 5 == 0:
            freqs[c] += 1

In [None]:
cls_by_id = {id: cls for cls, id in get_encoded_classes().items()}

In [None]:
def analyze_errors(df, only_errors=False, n=50):
    if only_errors:
        df = df[df['true'] != df['pred']]
    df = df.sample(n=n, random_state=42)
    data = []
    for item in df.itertuples():
        top_logits_list = list(map(float, item.top_k_logits.split()))
        top_logits = dict(zip([cls_by_id[int(id)] for id in item.top_k_classes.split()],
                              top_logits_list))
        places = {cls: idx + 1 for idx, (cls, _) in enumerate(sorted(top_logits.items(), key=lambda x: -x[1]))}
        data.append({
            'true': item.true,
            'pred': item.pred,
            'true_score': top_logits.get(item.true, 0),
            'pred_score': top_logits[item.pred],
            'second_score': top_logits_list[1],
            'true_place': places.get(item.true, len(places)),
            'true_freq': freqs.get(item.true, 0),
            'pred_freq': freqs.get(item.pred, 0),
        })
    # TODO:
    # - true/pred example crops
    return pd.DataFrame(data)

#analyze_errors(df[df['image_id'] == '200003076_00149_1'], only_errors=True)
analyze_errors(df, only_errors=True)

In [None]:
analyze_errors(df[df['pred'] == 'seg_fp'], only_errors=True)

In [None]:
analyze_errors(df[df['pred'] == 'seg_fp'], only_errors=False)

In [None]:
def viz_errors(image_id, with_true_boxes=True):
    image, title = kuzushiji.viz.visualize_clf_errors(image_id, df)
    if with_true_boxes:
        true_boxes = (
            np.array(df_train[df_train['image_id'] == image_id].iloc[0].labels.split())
            .reshape(-1, 5)[:, 1:].astype(int))
        image = kuzushiji.viz.visualize_boxes(image, true_boxes, thickness=2, color=(0, 0, 0))
    print(title)
    plt.figure(figsize=(20, 20))
    plt.title(title)
    plt.imshow(image)
    
viz_errors('200003076_00149_1', with_true_boxes=True);

In [None]:
for image_id in image_error_count[:10].index:
    viz_errors(image_id)

In [None]:
rng = np.random.RandomState(42)
for image_id in rng.choice(sorted(set(df['image_id'].values)), 10, replace=False):
    viz_errors(image_id)