In [1]:
!pip install keras_applications
!pip install transformers

Collecting keras_applications
  Downloading Keras_Applications-1.0.8-py3-none-any.whl (50 kB)
     |████████████████████████████████| 50 kB 729 kB/s             
Installing collected packages: keras-applications
Successfully installed keras-applications-1.0.8


In [2]:
import numpy as np 
import pandas as pd 
import os
import re
from sklearn.utils import shuffle
from tensorflow.keras.preprocessing import image
import time
from nltk.corpus import wordnet as wn
from PIL import Image
import requests
import torch
from transformers import BeitFeatureExtractor, BeitForImageClassification
from transformers import AutoFeatureExtractor, ViTForImageClassification

import cv2
import matplotlib.pyplot as plt

In [3]:
in_path = '../input/dataset-crawler-hyponyms/'
data_dict = {}
for path, dirs, files in os.walk(in_path):
    for i in dirs:
        for path2, dirs2, files2 in os.walk('../input/dataset-crawler-hyponyms/'+i):
            urls = [in_path+i+'/'+j for j in files2]
            data_dict[str(i)] = urls
print(len(data_dict))

10


In [4]:
data_dict.keys()

dict_keys(['tabby cat', 'angora cat', 'lynx cat', 'siamese cat', 'tiger cat', 'persian cat', 'cougar cat', 'leopard cat', 'egyptian cat', 'cat'])

# Classifiers

In [5]:
def evaluate(pred_class, actual_class, acc, TP, FN, sp_list, lch_list, wups_list, siblings, common_misconceptions):
    # find wordnet synsets for predicted class 
    wn_preds = wn.synsets(pred_class, pos=wn.NOUN)
    wn_pred=wn_preds[0]
    
    # find synset of ground truth class
    try:            
        actual_class=actual_class.replace(' ', '_')
        wn_label = wn.synsets(actual_class, pos=wn.NOUN)[0]
    except:   # we have added the word 'cat' for disambiguation (e.g. tiger cat vs tiger)
        actual_class = actual_class.replace('_cat', '')
        wn_label = wn.synsets(actual_class, pos=wn.NOUN)[0]
    #print('Actual:', wn_label, actual_class)
    #print('Pred:', wn_preds, pred_class)
    
    # Find hypernyms of ground truth class and predicted class 
    try:
        label_hyper_actual = wn_label.hypernyms()[0]
    except: # if there are no hypernyms, get entity as the global hypernym
        label_hyper_actual = wn.synsets('entity')[0]
    try:
        label_hyper_pred = wn_pred.hypernyms()[0]
    except:
        label_hyper_pred = wn.synsets('entity')[0]
    #print('Hypernyms:', label_hyper_pred, label_hyper_actual)
    
    # check if predicted class is a sub-class of actual label 
    # first, find all hyponyms (in any taxonomy depth) of the ground truth class
    label_hypo = list(set([w for s in wn_label.closure(lambda s:s.hyponyms()) for w in s.lemma_names()]))
    
    # then, find all hyponym synsets
    synset_hypo = [wn.synsets(word, pos=wn.NOUN)[0] for word in label_hypo]
    
    # if predicted is subset, then prediction is correct
    hypo_intersection=set(wn_preds).intersection(set(synset_hypo))
    same=False
    
    # if there is an overlap in actual/pred synsets or label names
    if hypo_intersection or (wn_label in wn_preds) or (actual_class.lower()==pred_class.lower()):
        #print('Same')
        acc+=1
        TP+=1
        same=True
        
    # if actual/pred synsets have a common immediate parent node (or entity as the common parent)
    elif (label_hyper_pred==label_hyper_actual) and not(label_hyper_pred==wn.synset('entity.n.01') and label_hyper_actual==wn.synset('entity.n.01')):
        #print('Siblings. Hypernyms:', label_hyper_pred)
        FN+=1
        siblings+=1
        common_misconceptions.append((actual_class, pred_class))
        
    # otherwise, ground truth and pred classes have more than 1-level difference
    else:
        #print('Not immediately related')
        FN+=1
        common_misconceptions.append((actual_class, pred_class))
     
    if not same:   # find WordNet-based metrics for different ground truth/pred synsets
        min_hypernym = wn_pred.lowest_common_hypernyms(wn_label)
        #print('Lowest common hypernym: ', min_hypernym)
    
        shortest_path = wn_pred.path_similarity(wn_label)
        sp_list.append(shortest_path)
        #print('Shortest path similarity between {} and {}: {}'.format(pred_class, actual_class, shortest_path))
    
        # hortest path and the maximum depth of the taxonomy
        lch = wn_pred.lch_similarity(wn_label)
        lch_list.append(lch)
        #print('Leacock-Chodorow similarity between {} and {}: {}'.format(pred_class, actual_class, lch))
    
        # depth of the two synsets in the taxonomy and that of their Least Common Subsumer (most specific ancestor node)
        wups = wn_pred.wup_similarity(wn_label)
        wups_list.append(wups)
        #print('Wu-Palmer similarity between {} and {}: {}'.format(pred_class, actual_class, wups))
    
    return acc, TP, FN, sp_list, lch_list, wups_list, siblings, common_misconceptions

In [6]:
def predictions(image, model, feature_extractor, all_time, preds):
    start = time.time()    
    inputs = feature_extractor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits
    elapsed_time = time.time()-start
    probs = logits.softmax(dim=1)
    idx = torch.topk(probs, 3, largest=True, sorted=True, out=None)
    top3pos = idx[1][0]
    pred_classes = {}
    top_counter = 1
    for item in top3pos:
        predicted_class_idx = item.item()
        pred_classes[top_counter] = model.config.id2label[predicted_class_idx]
        top_counter+=1
    
    # model predicts one of the 1000 ImageNet classes
    preds.append(pred_classes)
    print(pred_classes)
    all_time.append(elapsed_time)
    
    return all_time, preds, pred_classes

In [7]:
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
model_name = 'BeiT'

Downloading:   0%|          | 0.00/276 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.59M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/395M [00:00<?, ?B/s]

In [8]:
results={}
all_misconceptions = {}
for data_name, data in data_dict.items():
    all_time = []
    preds = []
    acc=0
    TP=0
    FN=0
    siblings=0
    common_misconceptions = []
    sp_list, lch_list, wups_list = [], [], []

    for url in data:
        image = cv2.imread(url)
        #try:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        all_time, preds, top = predictions(image, model, feature_extractor, all_time, preds)
        top_1_class=list(top.values())[0]
        if ',' in top_1_class:
            splt = top_1_class.split(',')
            top_1_class = splt[0]
            top_1_class=top_1_class.rstrip().lstrip()
        if ' ' in top_1_class:
            top_1_class=top_1_class.replace(' ', '_')
        acc, TP, FN, sp_list, lch_list, wups_list, siblings, common_misconceptions = evaluate(top_1_class, data_name, acc, 
                                                             TP, FN, sp_list, lch_list, wups_list, siblings, common_misconceptions)

        #except Exception as e:
        #    print(str(e))
        mean_acc = acc/(TP+FN)
        #print(mean_acc, TP, FN, np.mean(sp_list), np.mean(lch_list), np.mean(wups_list), np.sum(all_time), siblings)
        run_id = data_name+'_'+model_name
        results[run_id] = (data_name, model_name, mean_acc, TP, FN, np.mean(sp_list), 
                               np.mean(lch_list), np.mean(wups_list), np.sum(all_time), siblings)
        all_misconceptions[run_id] = common_misconceptions

{1: 'wildcat', 2: 'European_wildcat, catamountain, Felis_silvestris', 3: 'stray'}
{1: 'pet', 2: 'tabby, tabby_cat', 3: 'domestic_animal, domesticated_animal'}
{1: 'tabby, tabby_cat', 2: 'feline, felid', 3: 'Manx, Manx_cat'}
{1: 'domestic_cat, house_cat, Felis_domesticus, Felis_catus', 2: 'feline, felid', 3: 'cat, true_cat'}
{1: 'tabby, tabby_cat', 2: 'feline, felid', 3: 'tiger_cat'}
{1: 'kitten, kitty', 2: 'margay, margay_cat, Felis_wiedi', 3: 'tiger_cat'}
{1: 'caterer', 2: 'bobcat, bay_lynx, Lynx_rufus', 3: 'tiger_cat'}
{1: 'tabby, tabby_cat', 2: 'tabby, queen', 3: 'tiger_cat'}
{1: 'cat, true_cat', 2: 'tom, tomcat', 3: 'Manx, Manx_cat'}
{1: 'alpine_gold, alpine_hulsea, Hulsea_algida', 2: 'young_mammal', 3: 'feline, felid'}
{1: 'Egyptian_cat', 2: 'tabby, tabby_cat', 3: 'kitty, kitty-cat, puss, pussy, pussycat'}
{1: 'tabby, tabby_cat', 2: 'pet', 3: 'animal, animate_being, beast, brute, creature, fauna'}
{1: 'tabby, tabby_cat', 2: 'tom, tomcat', 3: 'cat, true_cat'}
{1: 'tabby, tabby_cat'

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


{1: 'feline, felid', 2: 'lynx, catamount', 3: 'kitty, kitty-cat, puss, pussy, pussycat'}
{1: 'blue_point_Siamese', 2: 'feline, felid', 3: 'lynx, catamount'}
{1: 'Siamese_cat, Siamese', 2: 'blue_point_Siamese', 3: 'feline, felid'}
{1: 'Siamese_cat, Siamese', 2: 'blue_point_Siamese', 3: 'Egyptian_cat'}
{1: 'Siamese_cat, Siamese', 2: 'blue_point_Siamese', 3: 'cat, true_cat'}
{1: 'Siamese_cat, Siamese', 2: 'blue_point_Siamese', 3: 'kitten, kitty'}
{1: 'Siamese_cat, Siamese', 2: 'domestic_cat, house_cat, Felis_domesticus, Felis_catus', 3: 'cat, true_cat'}
{1: 'Siamese_cat, Siamese', 2: 'blue_point_Siamese', 3: 'tom, tomcat'}
{1: 'blue_point_Siamese', 2: 'Siamese_cat, Siamese', 3: 'Burmese_cat'}
{1: 'blue_point_Siamese', 2: 'Siamese_cat, Siamese', 3: 'tom, tomcat'}
{1: 'kitten, kitty', 2: 'feline, felid', 3: 'young, offspring'}
{1: 'Siamese_cat, Siamese', 2: 'blue_point_Siamese', 3: 'wildcat'}
{1: 'muffler', 2: 'sand_cat', 3: 'male'}
{1: 'Siamese_cat, Siamese', 2: 'blue_point_Siamese', 3: 'i

In [9]:
all_df = pd.DataFrame(results.values(), columns=['dataset', 'model', 'accuracy', 'TP', 'FN', 'path_similarity', 'LCH', 'WUPS', 'predict_time', 'siblings'])
all_df

Unnamed: 0,dataset,model,accuracy,TP,FN,path_similarity,LCH,WUPS,predict_time,siblings
0,tabby cat,BeiT,0.34,17,33,0.242394,2.039635,0.765372,23.800472,1
1,angora cat,BeiT,0.64,32,18,0.302075,2.216198,0.717113,23.023868,9
2,lynx cat,BeiT,0.34,17,33,0.03768,0.349971,0.07766,22.89855,0
3,siamese cat,BeiT,0.88,44,6,0.15873,1.698549,0.672535,24.327689,0
4,tiger cat,BeiT,0.083333,5,55,0.205819,1.857868,0.751651,29.662709,15
5,persian cat,BeiT,0.8,40,10,0.227763,1.812894,0.611691,24.458804,2
6,cougar cat,BeiT,0.9,45,5,0.193333,1.897683,0.79569,24.012208,0
7,leopard cat,BeiT,0.06,3,47,0.308017,2.426305,0.911461,24.299089,37
8,egyptian cat,BeiT,0.26,13,37,0.172819,1.681919,0.613177,23.44292,10
9,cat,BeiT,0.62963,17,10,0.336869,2.253426,0.782431,13.361513,0


In [10]:
all_df.to_pickle('./cats_imagenet_BeiT.pkl')

In [11]:
import pickle 

with open('./BeiT_all_common_misconceptions.pkl', 'wb') as f:
    pickle.dump(all_misconceptions, f)
        
with open('./BeiT_all_common_misconceptions.pkl', 'rb') as f:
    loaded_dict = pickle.load(f)