In [4]:
import json
import pickle
import torch

import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import f1_score,precision_score,recall_score
import pandas as pd
from tqdm import tqdm
import statistics as stats

In [2]:
def read_pkl(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data

def read_json(path):
    with open(path) as f:
        data = json.load(f)
    return data

img_path = 'data/image_features_all.pkl'
label_path = 'data/label_features_all.pkl'
label_ppt_path = 'data/label_features_prompt_all.pkl'
ann_path = '../data/validation.json'
label_name_path = '../data/iMat_fashion_2018_label_map_228.csv'
one_hot_path = '../data/target_labels.json'

image_fts = read_pkl(img_path)
label_fts_ = read_pkl(label_path)
label_ppt_fts = read_pkl(label_ppt_path)
label_map = pd.read_csv(label_name_path)
label_map = dict(zip(label_map.labelId, label_map.labelName))
one_hot = read_json(one_hot_path)

anns = read_json(ann_path)
anns = anns['annotations']

def get_pred(idx,prompt=False):
    img_ft = image_fts[idx]
    img_ft = torch.from_numpy(img_ft)
    img_ft /= img_ft.norm(dim=-1, keepdim=True)

    if prompt:
        curr_label_fts = label_ppt_fts
    else:
        curr_label_fts = label_fts_

    label_fts = torch.from_numpy(curr_label_fts)
    label_fts /= label_fts.norm(dim=-1, keepdim=True)
    pred = img_ft @ label_fts.T
    return pred

def get_label_names(labels):
    label_name = {}
    for label in labels:
        label_name[label] = label_map[int(label)]
    return label_name

In [26]:
def probe_img(i=10, prompt=False, verbose=False):
    ann = anns[i]

    img_id = ann['imageId']
    if verbose:
        print(f'Prediction for image id : {img_id}')
    target = ann['labelId']
    tr_map = get_label_names(target)
    if verbose:
        print(f'targets : {tr_map}')

    pred = get_pred(int(img_id),prompt)
    a,ids = torch.topk(pred, 8)
    ids = [x+1 for x in list(ids.detach().numpy()[0])]
    pr_map = get_label_names(ids)
    if verbose:
        print(f'prediction : {pr_map}')
    return target, ids

def list_to_oneHot(label_list,num_tar=228,indexed=False):
    if not indexed:
        label_list = [int(l)-1 for l in label_list]
    one_hot = [0]*228
    for l in label_list:
        one_hot[l] = 1
    return one_hot

# def f1_score_old(pred_oh, tar_oh):
#     p = precision_score(tar_oh, pred_oh, average='micro')
#     r = recall_score(tar_oh, pred_oh, average='micro')
#     f = f1_score(tar_oh, pred_oh, average='micro')
#     print(f'prec: {p}, re:{r}, f1 :{f}')

def f1_score_(pred, tar):
    tar = [int(x) for x in tar]
    def prec(pred,tar):
        dem = len(pred)
        num = 0
        for i in pred:
            if i in tar:
                num += 1
        return num/dem

    def rec(pred,tar):
        dem = len(tar)
        num = 0
        for i in pred:
            if i in tar:
                num += 1
        return num/dem

    def f1(r,p):
        # return 2*p*r/(p+r)
        return (2*p*r)/(p+r+np.finfo(float).eps)

    p = prec(pred,tar)
    r = rec(pred,tar)
    f = round(f1(r,p),2)
    return p,r,f


In [29]:
precs = []
recs = []
f1s = []
for id in tqdm(range(len(anns))):

    try:
        tar,pred = probe_img(int(id))
        p,r,f = f1_score_(pred,tar)
        precs.append(p)
        recs.append(r)
        f1s.append(f)
    except Exception as e:
        print(f'{id}:{e}')


def get_mean(l):
    return np.mean(np.array(l))

# print(f'{get_mean(precs)}, {get_mean(recs)},{get_mean(f1s)}')

100%|██████████| 9897/9897 [00:01<00:00, 5848.95it/s]


In [30]:
print(f'{get_mean(precs)}, {get_mean(recs)},{get_mean(f1s)}')

0.1800040416287764, 0.18944891873715883,0.1808103465696676


In [22]:
f1s
# stats.mean(f1s)

[0.4,
 0.24,
 0.5,
 0.25,
 0.12,
 0.25,
 0.22,
 0.33,
 0.13,
 0.14,
 0.33,
 0.11,
 0.12,
 0.4,
 0.15,
 0.22,
 0.27,
 0.38,
 0.25,
 0.24,
 0.38,
 0.4,
 0.11,
 0.22,
 0.1,
 0.25,
 0.14,
 0.1,
 0.13,
 0.27,
 0.11,
 0.25,
 0.15,
 0.13,
 0.24,
 1.0,
 0.13,
 0.29,
 0.33,
 0.15,
 0.25,
 0.27,
 0.14,
 1.0,
 1.0,
 0.14,
 0.11,
 0.17,
 0.13,
 0.12,
 1.0,
 0.29,
 0.14,
 0.25,
 0.27,
 0.11,
 0.1,
 0.22,
 0.13,
 0.15,
 0.13,
 0.11,
 0.24,
 0.21,
 0.11,
 0.13,
 0.14,
 0.13,
 0.15,
 0.14,
 0.11,
 0.13,
 0.25,
 0.12,
 0.29,
 0.15,
 0.27,
 0.25,
 1.0,
 0.33,
 0.11,
 0.25,
 0.12,
 0.4,
 0.11,
 1.0,
 0.12,
 0.21,
 0.13,
 0.29,
 0.22,
 0.53,
 0.11,
 0.14,
 0.25,
 1.0,
 0.13,
 1.0,
 0.14,
 0.29,
 0.3,
 0.31,
 0.15,
 1.0,
 0.11,
 0.13,
 0.12,
 0.12,
 1.0,
 0.12,
 0.13,
 0.27,
 0.17,
 0.13,
 0.12,
 0.4,
 0.12,
 0.13,
 0.38,
 0.22,
 0.29,
 0.13,
 0.24,
 0.43,
 0.22,
 0.25,
 0.4,
 0.13,
 0.12,
 0.29,
 0.33,
 0.33,
 0.11,
 0.29,
 0.29,
 0.31,
 1.0,
 0.25,
 0.15,
 0.17,
 0.13,
 0.12,
 0.12,
 0.25,
 0.4,
 0.22,
 