In [1]:
import os
import pickle
import sys
sys.path.append("../") # adding root folder to the path

import torch 
import torchvision
from torchvision import transforms
from torchvision.models import *
from torchvision import datasets
from tqdm import tqdm

from MRL import *
from imagenetv2_pytorch import ImageNetV2Dataset
from argparse import ArgumentParser
from utils import *

# nesting list is by default from 8 to 2048 in powers of 2, can be modified from here.
BATCH_SIZE = 1024
IMG_SIZE = 256
CENTER_CROP_SIZE = 224
NESTING_LIST=[2**i for i in range(3, 12)]
ROOT="/local/xiangyu/CSC2233/train/" # path to validation datasets
model_weight_path = "/home/ericliu/csc2233/MRL/train/trainlogs/bfb14b69-f5c6-4754-958c-c7c522fe44be/final_weights.pt"
output_dir = "/local/eric/ft-vect-train"

In [2]:
model = resnet50(False, weights=None)
model = load_from_old_ckpt(model, False, NESTING_LIST, extract_ft=True)
apply_blurpool(model)	
model.load_state_dict(get_ckpt(model_weight_path)) # Since our models have a torch DDP wrapper, we modify keys to exclude first 7 chars. 
model = model.cuda()
model.eval()

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
test_transform = transforms.Compose([
				transforms.Resize(IMG_SIZE),
				transforms.CenterCrop(CENTER_CROP_SIZE),
				transforms.ToTensor(),
				normalize])

dataset = torchvision.datasets.ImageFolder(ROOT+'val/', transform=test_transform)
idx_to_class = {v: k for k, v in dataset.class_to_idx.items()} 
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=8, shuffle=False)



In [3]:
ft_to_vecs = {ft : [] for ft in NESTING_LIST}
ft_to_label_to_vecs = {ft : {k : [] for k in dataset.class_to_idx} for ft in NESTING_LIST}

with torch.no_grad():
    for img_inputs, labels in tqdm(dataloader):
        logits = model(img_inputs.cuda())
        label_names = [idx_to_class[label.item()] for label in labels]

        for i, num_feat in enumerate(NESTING_LIST):
            for ft_vec, label in zip(logits[i], label_names):
                ft_vec = ft_vec.cpu()
                ft_to_vecs[num_feat].append(ft_vec)
                ft_to_label_to_vecs[num_feat][label].append(ft_vec)

with open(f"{output_dir}/ft_to_vecs.pkl", "wb") as file:
    pickle.dump(ft_to_vecs, file)

with open(f"{output_dir}/ft_to_label_to_vecs.pkl", "wb") as file:
    pickle.dump(ft_to_label_to_vecs, file)

for num_feat, vecs in ft_to_vecs.items():
    ft_dir = f"{output_dir}/ft_size_{num_feat}"
    os.makedirs(ft_dir, exist_ok = True)
    filepath = f"{ft_dir}/ft_{num_feat}.pack"
    save_fvecs(filepath, vecs)

for num_feat, label_to_vecs in ft_to_label_to_vecs.items():
    ft_dir = f"{output_dir}/ft_size_{num_feat}"
    os.makedirs(ft_dir, exist_ok = True)
    for label, vecs in label_to_vecs.items():
        out_dir = f"{ft_dir}/{label}"
        os.makedirs(out_dir, exist_ok = True)
        filepath = f"{out_dir}/{label}.pack"
        save_fvecs(filepath, vecs)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 49/49 [00:58<00:00,  1.19s/it]


In [None]:
all_vecs = []
all_labels = []
for label, vecs in tqdm(ft_to_label_to_vecs[2048].items()):
    all_vecs = all_vecs + vecs
    all_labels = all_labels + [label] * len(vecs)
all_vecs = torch.stack(all_vecs)
all_vecs = all_vecs.numpy()

with open(f"{output_dir}/all_vecs.pkl", "wb") as file:
    pickle.dump(all_vecs, file)
with open(f"{output_dir}/all_labels.pkl", "wb") as file:
    pickle.dump(all_labels, file)

In [4]:
from collections import Counter
from sklearn.neighbors import NearestNeighbors
import numpy as np

with open(f"{output_dir}/all_vecs.pkl", "rb") as file:
    all_vecs = pickle.load(file)
with open(f"{output_dir}/all_labels.pkl", "rb") as file:
    all_labels = pickle.load(file)

vec_per_label = dict(Counter(all_labels))

In [None]:
batch_size = 10
x = np.array(all_vecs)
k = vec_per_label[all_labels[0]]
nn = NearestNeighbors(n_neighbors=k, algorithm='brute', metric='euclidean').fit(x)

for start in tqdm(range(0, all_vecs.shape[0], batch_size)):
    end = min(start + batch_size, all_vecs.shape[0])
    query_batch = all_vecs[start:end]
    distances, indices = nn.kneighbors(all_vecs[start].reshape(1, -1))

or set the environment variable OPENBLAS_NUM_THREADS to 64 or lower
  0%|                                                                                       | 76/128117 [01:16<33:55:00,  1.05it/s]

In [3]:
all_vecs = []
all_labels = []

with open("/local/eric/ft-vect-validation/ft_to_label_to_vecs.pkl", "rb") as file:
    ft_to_label_to_vecs = pickle.load(file)

for label, vecs in tqdm(ft_to_label_to_vecs[2048].items()):
    all_vecs = all_vecs + vecs
    all_labels = all_labels + [label] * len(vecs)
all_vecs = torch.stack(all_vecs)
all_vecs = all_vecs.numpy()

with open(f"{output_dir}/all_vecs.pkl", "wb") as file:
    pickle.dump(all_vecs, file)
with open(f"{output_dir}/all_labels.pkl", "wb") as file:
    pickle.dump(all_labels, file)

100%|████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 4656.64it/s]
