In [1]:
import pytorch_lightning as pl

In [2]:
from dataset import *
import random
import numpy as np
from tqdm import tqdm
import multiprocessing
from scipy.stats import wasserstein_distance
from statistics import mean
def chunks(a, b, n):
    assert a.shape[0] == b.shape[0], "Unequal array sizes"
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(a), n):
        yield a[i:i + n], b[i:i + n]
def cos_sim(a,b):
    return np.dot(a, b.T) / (np.linalg.norm(a, axis=1) * np.linalg.norm(b))


In [3]:
K=12
m=24

In [4]:
ds = IrisVerificationDatasetPseudo('../../Datasets/', [
    'train_iris_casia_v4', 'train_iris_nd_0405', 
    'train_iris_nd_crosssensor_2013', 'train_iris_utris_v1'
])

In [5]:
image_ids= np.arange(len(ds.annotations), dtype=np.int32)
image_classes = np.zeros(len(ds.annotations), dtype=np.int32)
image_ids_string = np.array([a['__image_id'] for a in ds.annotations])

for image_idx, annotation in enumerate(ds.annotations):
    image_classes[image_idx] = annotation['__class_number']

In [6]:
permute_dict = {} #image id: (pairs: samples of pairs, impostros: samples of impostors)
for image_idx in tqdm(image_ids, "Generating random pairs and impostors:"):
    image_cls = image_classes[image_idx]
    image_id_str = image_ids_string[image_idx]

    image_ids_pair = image_ids[np.logical_and(image_classes == image_cls, image_ids != image_idx)]
    image_ids_impostor = image_ids[image_classes != image_cls]

    permute_dict[image_id_str] = (
        tuple(image_ids_string[np.random.choice(image_ids_pair, m*K)].tolist()),
        tuple(image_ids_string[np.random.choice(image_ids_impostor, m*K)].tolist()),
    )

Generating random pairs and impostors:: 100%|██████████| 170280/170280 [02:58<00:00, 952.87it/s] 


In [7]:
rand_vec = {}
for image_id in image_ids_string:
    rand_vec[image_id] = (np.random.rand(512)*2-1)


In [49]:
vector_batch = np.zeros((len(rand_vec), 512))
idx_to_id = {}
id_to_idx = {}

for idx, img in enumerate(rand_vec):
    id_to_idx[img] = idx
    idx_to_id[idx] = img
    vector_batch[idx] = rand_vec[img]

vector_len = np.linalg.norm(vector_batch, axis=1)
one_over_vector_len = 1/vector_len
distances_v = np.zeros((len(permute_dict), K*m*2, 1))

In [50]:
q = {}
batch_size=512
for start_id in tqdm(range(0,len(permute_dict),batch_size), "Calculating cosine distances"):
    root_vectors = vector_batch[start_id:start_id+batch_size, :, np.newaxis]
    root_one_over = one_over_vector_len[start_id:start_id+batch_size, np.newaxis, np.newaxis]
    batch_vectors = np.zeros((batch_size, K*m*2, 512))
    batch_one_over = np.zeros((batch_size, K*m*2, 1))

    for i in range(start_id, start_id+batch_size):
        pair, impostor = permute_dict[image_id]
        idxs = np.array(
            [id_to_idx[p] for p in pair ] + [id_to_idx[i] for i in impostor], 
            dtype=np.int32
        )
        batch_vectors[i-start_id, :, :] = vector_batch[idxs]
        batch_one_over[i-start_id, :, :] = one_over_vector_len[idxs, np.newaxis]
        
        
    distances_v[start_id:start_id+batch_size, :] = ((batch_vectors @ root_vectors) * batch_one_over * root_one_over)
    #w_distances = np.zeros(K)
    #for i in range(0, K*m, m):
    #    w_distances[i//m] = wasserstein_distance(v_dist[i:i+m], v_dist[pair_len+i:pair_len+i+m])
    #q[image_id] = np.mean(w_distances)

Calculating Wasserstein distances::  68%|██████▊   | 226/333 [05:27<02:40,  1.50s/it]

In [None]:
for i, dv  in enumerate(distances_v):
    
    w_distances = np.zeros(K)
    for j in range(0, K*m, m):
        w_distances[i//m] = wasserstein_distance(dv[i:i+m], dv[K*m+i:K*m+i+m])
    q[image_id] = np.mean(w_distances)

    

In [39]:
root_vectors.shape

(128, 512, 1)

In [40]:
((batch_vectors @ root_vectors) * batch_one_over * root_one_over).shape

(128, 576, 1)

In [26]:
batch_vectors.shape

(128, 576, 512)

In [8]:
import pickle

In [9]:
with open('../../Datasets/pseudo_permutations.pt' ,'wb') as f:
    pickle.dump(permute_dict, f)

In [39]:
529213234/1024/1024

504.6970691680908