In [1]:
import os
import torch
from torch.utils.data import DataLoader

from PR.utils import load_yaml
from PR.config.class_group import categories
from PR.NNCLR.model import NNCLRHead
from PR.metrics.metrics_utils import group_classes

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\lopez\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define Paths and Load Data

In [3]:
data_path = r"C:\Users\lopez\Projects\jku-pr\features_mae_large"
class_to_classid_path = os.path.join(data_path, "in1k_class_to_classid.yaml")
train_path = os.path.join(data_path, "mae_l23_cls_train_centercrop.th")
train_labels_path = os.path.join(data_path, "train-labels.th")
test_path = os.path.join(data_path, "mae_l23_cls_test_centercrop.th")
test_labels_path = os.path.join(data_path, "test-labels.th")

class_to_classid = load_yaml(class_to_classid_path)
train_data = torch.load(train_path, map_location=device)
train_labels = torch.load(train_labels_path, map_location=device)
test_data = torch.load(test_path, map_location=device)
test_labels = torch.load(test_labels_path, map_location=device)

train_dataloader = DataLoader(train_data, batch_size=1, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=False)

# Load Model

In [4]:
model_state_dict = r"C:\Users\lopez\Projects\jku-pr\models\tensors_v1_seed_0_tensors_v2_seed_0_0_1-0_baseline.pth"
state_dict = torch.load(model_state_dict, map_location=device)

model = NNCLRHead()
model.load_state_dict(state_dict)

<All keys matched successfully>

# Group Data

In [5]:
category_names, category_labels = group_classes(test_labels, categories, class_to_classid)

In [1]:
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet as wn

def check_synset_exists(synset_name):
    try:
        # attempt to get synset using provided name
        syn = wn.synset(synset_name)
        return True, syn.definition() # TODO: do we need the 2nd output?
    except nltk.corpus.reader.wordnet.WordNetError:
        return False, None # TODO: do we need the 2nd output
    
main_5_categories = {
    'animals': 'animal.n.01',   
    'vehicles': 'vehicle.n.01', 
    'household': 'artifact.n.01',
    'food': 'food.n.01',
    'nature': 'natural_object.n.01'
}

dog_15_categories = {
    'maltese dog': 'maltese_dog.n.01',
    'blenheim spaniel': 'blenheim_spaniel.n.01',
    'basset': 'basset.n.01',
    'norwegian elkhound': 'norwegian_elkhound.n.01',
    'giant schnauzer': 'giant_schnauzer.n.01',
    'golden retriever': 'golden_retriever.n.01',
    'brittany spaniel': 'brittany_spaniel.n.01',
    'clumber': 'clumber.n.01',
    'welsh springer spaniel': 'welsh_springer_spaniel.n.01',
    'groenendael': 'groenendael.n.01',
    'kelpie': 'kelpie.n.01',
    'shetland sheepdog': 'shetland_sheepdog.n.01',
    'doberman': 'doberman.n.01',
    'pug': 'pug.n.01',
    'chow': 'chow.n.01',
}

wen_10_categories = {
    'bird': 'bird.n.01',
    'boat': 'boat.n.01',
    'car': 'car.n.01',
    'cat': 'cat.n.01',
    'dog': 'dog.n.01',
    'fruit': 'fruit.n.01',
    'fungus': 'fungus.n.01',
    'insect': 'insect.n.01',
    'monkey': 'monkey.n.01',
    'truck': 'truck.n.01',
}

[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\lopez\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [5]:
for synset_name in wen_10_categories.values():
    print(check_synset_exists(synset_name)[0])

True
True
True
True
True
True
True
True
True
True
