In [None]:
import shutil

import matplotlib.pyplot as plt
import pandas as pd
import tqdm

from validate_detector import *

cil_model_path='weights/CIL_1000_250_2993-WA-mem50-resnet34-pretrained-drop0.5-augmented-adam.pt'

In [None]:
metadata = pd.read_pickle(Path(DATASET_PATH) / LOGODET_3K_NORMAL_PATH / METADATA_CROPPED_IMAGE_PATH)
    
# Test images full format
with open(Path(DATASET_PATH) / 'LogoDet-3K' / 'test.txt') as f:
    test_instances = [Path(x.strip()).name for x in f.readlines()]

# Configure
cil_model, cil_idx2class, cil_class2idx, cil_class_remap = load_cil_model(cil_model_path, None)
cil_model.eval()

In [None]:
metadata['cropped_image_path'] = metadata['cropped_image_path'].apply(lambda x:str(x))
df_test = metadata[metadata['cropped_image_path'].isin(test_instances)]
print(len(df_test))
df_test.head()

In [None]:
common_trsf = iLogoDet3K_trsf['common']
test_trsf = iLogoDet3K_trsf['test']
all_trsf = transforms.Compose([*test_trsf, *common_trsf])

In [None]:
from PIL import Image
import pandas as pd

cropped_path = 'dataset/LogoDet-3K/cropped'

COMPUTATION = False
if COMPUTATION:
    def predict_image(df_row):
        cropped_image_path = df_row['cropped_image_path']
        label = df_row['brand']
        # Read image
        im_trsf = df_row['img']
        # CIL prediction
        cil_prediction = cil_model(im_trsf.expand(1, *im_trsf.shape))
        cil_class = cil_prediction['logits'].argmax().int().item()
        resolved_label = cil_idx2class[cil_class_remap[cil_class]]
        # Result
        res_dict = {
            'image': cropped_image_path,
            'label': label,
            'prediction': resolved_label,
            'label_id': cil_class2idx[label],
            'prediction_id': cil_class2idx[resolved_label]
        }
        return res_dict

    imgs = []
    for im in tqdm(df_test['cropped_image_path'].values, total=len(df_test)):
        im_read = Image.open(cropped_path+'/'+im)
        imgs.append(all_trsf(im_read))
        im_read.close()
    df_test['img'] = imgs
    df_test.head()


    res = []
    for _, row in tqdm(df_test.iterrows(), total=len(df_test)):
        x = predict_image(row)
        res.append(x)
    print(len(res))

    df_res = pd.DataFrame(res)
    df_res.to_pickle('./cm/predictions.pkl')
else:
    df_res = pd.read_pickle('./cm/predictions.pkl')

df_res.head()

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib

cm = confusion_matrix(df_res['label_id'], df_res['prediction_id'])

In [None]:
#matplotlib.image.imsave('cm/cm_raw.png', cm)
matplotlib.image.imsave('cm/cm_norm.png', cm/cm.sum())

In [None]:
out_diag = {}
k = 1
COMPUTATION_CM = True

if COMPUTATION_CM:
    while True:
        cm_k = cm>=k
        res = []
        for gt in range(cm_k.shape[0]):
            for pred in range(cm_k.shape[1]):
                if gt != pred and cm_k[gt, pred] > 0:
                    res.append((gt, pred))

        print(f'{k}: {len(res)}\r', end='')

        out_diag.update([(k, res)])
        if len(res) > 0:
            matplotlib.image.imsave(f'cm/cm_th{k}.png', cm_k)
        else:
            break
        k += 1
    with open('cm/out_diag.pickle', 'wb') as handle:
        pickle.dump(out_diag, handle, protocol=pickle.HIGHEST_PROTOCOL)
else:
    with open('cm/out_diag.pickle', 'rb') as handle:
        out_diag = pickle.load(handle)

out_diag

In [None]:
thd = 0
out_diag_filtered = {k:[(cil_idx2class[x], cil_idx2class[y]) for x,y in res if (x-y)>thd] for k, res in out_diag.items()}
out_diag_filtered = dict(filter(lambda x: x[1] != [], out_diag_filtered.items()))

with open('cm/cm_entry_GT_vs_PRED.txt', 'w') as f:
    f.write('Vengono qui riportati tutte le entry della matrice di confusione al variare delle soglie.\n')
    f.write('\n')
    f.write('Per ogni soglia, viene riportata una tupla nel formato (Ground truth, Prediction).\n')
    f.write('\n')
    f.write('\n')
    for k, res in out_diag_filtered.items():
        print(k, len(res))
        f.write(f'> ConfusionMatrix_ij >= {k}\n')
        for t in res:
            f.write(f'\t{t[0] , t[1]}\n')
        f.write('\n')

In [None]:
import plotly.express as px

x, y = list(zip(*[(k, len(v)) for k, v in out_diag_filtered.items()]))
x, y = list(x), list(y)

x.append(40)
y.append(0)

print(x)
print(y)

fig = px.line(x=x, y=y, log_y=True)
fig.update_xaxes(title_text='Threshold')
fig.update_yaxes(title_text='Number of entry (log)')
fig.update_layout(
    title_text = '# of entry CM_ij in the confusion matrix where CM_ij >= threshold (where i!=j)', title_x = 0.5,
    xaxis = dict(
        tickmode = 'linear',
        tick0 = 0,
        dtick = 1
    )
)
fig.show()
fig.write_image('cm/cm_entry_plot.png')