In [1]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
import torchvision.transforms as tt
from time import time
import matplotlib.pyplot as plt

import data_preprocessing as dp
import utils
from models.ResNet import ResNet
import train_zsl

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.cuda.get_device_name(0))
torch.manual_seed(0)

Google word embeddings loaded
cuda:0
Quadro GV100


<torch._C.Generator at 0x7fbcac1825f0>

In [2]:
from sklearn.neighbors import KDTree
import pandas as pd

In [3]:
data_dir = '../TINY_IMGNET/zsl_dataset'

In [4]:
train_ds, valid_ds, zsl_ds = dp.create_datasets()

Train set images: 75000
Validation set images: 7500
ZSL set images: 25000


In [5]:
label_vecs, target_labels, zsl_label_vecs, zsl_target_labels, train_target_vectors_norm = dp.preprocess_labels(train_ds, zsl_ds)

Categories split into seen and unseen
Labels transformed into average labels
Label vectors preprocessed
Target vectors normalized


In [6]:
batch_size = 512*4

In [7]:
net = ResNet(3, 150)
net = nn.DataParallel(net, device_ids=['cuda:0', 'cuda:1', 'cuda:2', 'cuda:3'])
net = net.to(device)

In [8]:
checkpoint = torch.load('./model_weights/CE_2_15_ces_0_55.pth')
net.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [9]:
zsl_target_labels_list = list(zsl_target_labels.values())
zsl_class_vecs_list = list(zsl_label_vecs.values())
zsl_target_vectors = torch.cat(zsl_class_vecs_list)

In [10]:
zsl_loader = DataLoader(zsl_ds, batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [11]:
zsl_ids = {v: k for k, v in zsl_ds.class_to_idx.items()}

In [12]:
tree = KDTree(zsl_target_vectors)

In [13]:
target_labels_list = list(target_labels.values())

In [14]:
GOOGLE_VECS = dp.load_vectors()
train_cat, zsl_cat = dp.split_classes()
labels = pd.read_csv('./words.txt', sep='\t', header=None)
train_labels_df = labels[labels[0].isin(train_cat)]
zsl_labels_df = labels[labels[0].isin(zsl_cat)]

zsl_labels_df['average_label'] = zsl_labels_df[1].transform(dp.average_label)
zsl_labels_df['average_vector'] = zsl_labels_df['average_label'].transform(GOOGLE_VECS.get_vector)

Google word embeddings loaded
Categories split into seen and unseen


In [15]:
def evaluate_zsl_seen(model):
    model.eval()
    top_5 = []
    top_1 = []
    start_time = time()

    for i, (zsl_img, zsl_target) in enumerate(zsl_loader):
        zsl_img = zsl_img.to(device)
        zsl_target = zsl_target.to(device)
        zsl_x = zsl_img
        zsl_target = zsl_target.tolist()

        class_id_batch = [zsl_ids[class_num] for class_num in zsl_target] # target class ids from target batch
        labels_batch = [zsl_labels_df[zsl_labels_df[0] == class_id]['average_label'].item() for class_id in class_id_batch]

        pred_emb = model(zsl_x)[1]
        pred_emb = pred_emb.to('cpu')
        emb_batch = torch.Tensor.cpu(pred_emb.detach()).numpy().squeeze()

        vec_batch = [np.expand_dims(emb, axis=0) for emb in emb_batch]

        index_5_batch = [tree.query(vec, k=5, return_distance=False) for vec in vec_batch]
        index_5_batch = [arr.squeeze() for arr in index_5_batch]

        pred_ids_5 = [[zsl_target_labels_list[index] for index in array] for array in index_5_batch]
        pred_labels_5 = [[zsl_labels_df[zsl_labels_df[0] == class_id]['average_label'].item() for class_id in array] for array in pred_ids_5]

        index_1_batch = [tree.query(vec, k=1, return_distance=False) for vec in vec_batch]
        index_1_batch = [arr.squeeze() for arr in index_1_batch]

        pred_ids_1 = [zsl_target_labels_list[index] for index in index_1_batch]
        pred_labels_1 = [zsl_labels_df[zsl_labels_df[0] == class_id]['average_label'].item() for class_id in pred_ids_1]

        pairs_1 = list(zip(labels_batch, pred_labels_1))
        pairs_5 = list(zip(labels_batch, pred_labels_5))

        top1 = sum([x[0] == x[1] for x in pairs_1])
        top_1.append(top1)
        
        top5 = 0
        for pair in pairs_5:
            top = pair[0] in pair[1]
            top5 += top
        top_5.append(top5)
#         print(top1, top5, batch_size)
              
    top_1_mean = sum(top_1) / len(top_1)
    top_5_mean = sum(top_5) / len(top_5)
    
#     print(top_1_mean, top_5_mean)
    
    print(f'Top-1 accuracy: {round(top_1_mean / float(batch_size), 4)}')
    print(f'Top-5 accuracy: {round(top_5_mean / float(batch_size), 4)}')

    compute_time = round((time()-start_time), 2)
    print(f'Time: {compute_time} sec \n')
    
    return pairs_1, pairs_5

In [16]:
zsl_pairs = evaluate_zsl_seen(net)

Top-1 accuracy: 0.0778
Top-5 accuracy: 0.2585
Time: 65.11 sec 

