In [1]:
from nltk.corpus import wordnet as wn
import nltk
import csv
import time
import random

### 1. Download wordnet and files with class names of ImageNet-(2)1k 

In [2]:
_ = nltk.download('wordnet')

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\mathe\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


Read ids of classes in ImageNet-21k and ImageNet-1k.

In [3]:
# Read synset ids and names of classes in ImageNet-21k
synset_ids_21k = []
class_names_21k = []
with open('../classes_imagenet/classes_in_imagenet_21k.csv', newline='') as csvfile:
    csv_reader = csv.reader(csvfile)
    next(csv_reader) # skip first line
    for row in csv_reader:
        synset_ids_21k.append(row[0])  # e.g. 'n00004475'
        class_names_21k.append(row[1].replace(' ', '_')) # e.g. 'organism'

In [4]:
# Read synset ids and names of classes in ImageNet-1k
synset_ids_1k = []
class_names_1k = []
with open('../classes_imagenet/classes_in_imagenet_1k.csv', newline='') as csvfile:
    csv_reader = csv.reader(csvfile)
    i = 0
    for row in csv_reader:
        class_id = row[0].split(':')[0]
        synset_ids_1k.append(class_id)
        class_name = row[0].split(':')[1].split(',')[0].replace(' ', '_')[1:]
        class_names_1k.append(class_name)

### 2. Generate synset_ids for out-of-distribution classes 

In [5]:
def is_far_from_1k(category, class_names_1k, thresh=0.1):
    #print('Checking if {} is far from ImageNet-1k'.format(category))
    obj1 = wn.synsets(category)[0]
    for some_class in class_names_1k:
        if some_class:
            obj2 = wn.synsets(some_class)[0]
            #if obj1.path_similarity(obj2) > thresh: 
            if obj1.lch_similarity(obj2) > thresh:
                #print('Nope')
                return False
    #print('Yes!')
    return True

Generate random synset_ids from synset_ids_21k which are not in synset_ids_1k and with minimum path distance to all 1k classes in ImageNet-1k.

In [6]:
nr_ood_classes = len(synset_ids_21k) - len(synset_ids_1k) # 200 # int(sys.argv[1])
ood_synset_ids_str = ''
nr_ood = 0

i = 0
start = time.time()
while i < nr_ood_classes:
    # random_idx = random.randint(0, max_id)
    # random_synset_id = synset_ids_21k[random_idx]
    # random_synset_name = class_names_21k[random_idx]
    ood_synset_id = synset_ids_21k[i]
    ood_class_name = class_names_21k[i]
    is_far_from_color = wn.synsets(ood_class_name)[0].lch_similarity(wn.synsets('color')[0]) < 1.5
    is_not_in_1k = ood_synset_id not in synset_ids_1k
    if is_not_in_1k and is_far_from_color and is_far_from_1k(ood_class_name, class_names_1k, thresh=1.35):
        ood_synset_ids_str += ood_synset_id + ' '
        nr_ood += 1
    i += 1
    if i % 2000 == 0:
        print(i)

end = time.time()
print('Done generating {} OOD classes after {:.2f} seconds!'.format(nr_ood, end-start))

2000
4000
6000
8000
10000
12000
14000
16000
18000
20000
Done generating 110 OOD classes after 235.51 seconds!


In [56]:
nr_ood

58

In [45]:
# Example of how path similarity works.
# 1
# obj1 = wn.synsets('color')[0]
# obj2 = wn.synsets('french_polish')[0]
# print('The similarity between {} and {} is: {}'.format('color', 'greenish_blue', obj1.lch_similarity(obj2)))
wn.synsets(ood_class_name)[0].lch_similarity(wn.synsets('color')[0]) < 1.5

True

In [38]:
# Example of how path similarity works.
# 1
obj1 = wn.synsets('ambulance')[0]
obj2 = wn.synsets('dog')[0]
print('The similarity between {} and {} is: {}'.format('ambulance', 'dog', obj1.lch_similarity(obj2)))
# 2
obj1 = wn.synsets('cat')[0]
obj2 = wn.synsets('kitty')[0]
print('The similarity between {} and {} is: {}'.format('cat', 'kitty', obj1.lch_similarity(obj2)))
# 3
obj1 = wn.synsets('cat')[0]
obj2 = wn.synsets('human')[0]
print('The similarity between {} and {} is: {}'.format('cat', 'human', obj1.lch_similarity(obj2)))
# 4
obj1 = wn.synsets('cat_food')[0]
obj2 = wn.synsets('dog_food')[0]
print('The similarity between {} and {} is: {}'.format('cat_food', 'dog_food', obj1.lch_similarity(obj2)))
# 5
obj1 = wn.synsets('white_rice')[0]
obj2 = wn.synsets('rice')[0]
print('The similarity between {} and {} is: {}'.format('white_rice', 'rice', obj1.lch_similarity(obj2)))

The similarity between ambulance and dog is: 0.9985288301111273
The similarity between cat and kitty is: 0.5020919437972361
The similarity between cat and human is: 1.6916760106710724
The similarity between cat_food and dog_food is: 2.538973871058276
The similarity between white_rice and rice is: 2.9444389791664407


In [8]:
print(f'This {nr_ood} classes below have lch similarity of max 1.35:')
ood_synset_ids_str

This 110 classes below have lch similarity of max 1.35:


'n00443803 n00447540 n00454493 n01314388 n01334217 n01338685 n01339336 n01461315 n01481331 n01557185 n01682435 n01803893 n01888411 n01910252 n01912152 n01956344 n01997119 n02103181 n02112497 n02122298 n02456962 n02577041 n02662239 n02663849 n02741367 n02761206 n02801450 n02829596 n02878222 n02917521 n02926591 n02982599 n03011741 n03042829 n03042984 n03145522 n03297735 n03334912 n03349296 n03392648 n03410022 n03410147 n03410303 n03524976 n03573848 n03582508 n03644073 n03683457 n03684611 n03695452 n03733644 n03856335 n03901750 n03902482 n03963294 n04023249 n04072960 n04082562 n04088797 n04100519 n04170933 n04269822 n04314914 n04327204 n04363991 n04368109 n04394031 n04476831 n04476972 n04486934 n04526520 n04562496 n05453412 n05453815 n05578095 n06209940 n07596362 n07596566 n07616590 n07643981 n07728053 n07728181 n07728284 n07728585 n07729142 n07729828 n07757753 n07905296 n07914128 n07914887 n07930554 n08511017 n08663703 n09495962 n09779280 n10027246 n10097477 n10104888 n10107303 n10117851

In [22]:
from collections import Counter
test = ood_synset_ids_str.split(' ')


Counter(test).keys() # equals to list(set(words))
Counter(test).values()

dict_values([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])

In [5]:
import numpy as np
import torch

In [34]:
a=torch.tensor([[1,2,3,4], [5,3,7,3]])
print(a)
maxe = torch.where(a == torch.max(a, dim=0).values, 1, 0)
maxe

tensor([[1, 2, 3, 4],
        [5, 3, 7, 3]])


tensor([[0, 0, 0, 1],
        [1, 1, 1, 0]])