## Data Loading

In [2]:
import h5py
import numpy as np
import pickle

data = h5py.File('fashion-mnist-784-euclidean.hdf5', 'r')
dataset = np.array(data['train'])
queries = np.array(data['test'])
with open('clusters_fashion_mnist_784_euclidean.pkl', 'rb') as f:
    clusters = pickle.load(f)
with open('ground_truth_fashion_mnist_784_normalized_euclidean_0_0_0_5.pkl', 'rb') as f:
    ground_truth_total = pickle.load(f)

In [3]:
ground_truth_total_level = [[[] for _ in range(10000)] for _ in range(100)]
for clus in range(100):
    for t in ground_truth_total[clus]:
        ground_truth_total_level[t[0]][t[1]].append(t)

In [4]:
centroids = []
for cluster in clusters:
    centroids.append(np.mean(cluster))

## Prepare Inputs

In [61]:
def euclidean_dist_normalized(x1, x2=None, eps=1e-8):
    if np.isnan(x2):
        return 1.0
    left = x1 / 255.0
    right = x2 / 255.0
    return np.sqrt(((left - right) ** 2).mean())

train_features = []
train_targets = []
slot = 0.01
for query_id in range(8000):
    cardinality = [0 for _ in range(100)]
    distances2centroids = []
    for cc in centroids:
        distances2centroids.append(euclidean_dist_normalized(queries[query_id], cc))
    for threshold_id, threshold in enumerate(np.arange(0.0, 0.5, slot)):
        indicator = []
        for cluster_id in range(100):
            cardinality[cluster_id] += ground_truth_total_level[cluster_id][query_id][threshold_id][-1]
            if cardinality[cluster_id] > 0:
                indicator.append(1)
            else:
                indicator.append(0)
        feature = np.concatenate((queries[query_id] / 255.0, [threshold+slot]))
        train_features.append(feature)
        train_targets.append(indicator)
                
test_features = []
test_targets = []
slot = 0.01
for query_id in range(8000,10000):
    cardinality = [0 for _ in range(100)]
    distances2centroids = []
    for cc in centroids:
        distances2centroids.append(euclidean_dist_normalized(queries[query_id], cc))
    for threshold_id, threshold in enumerate(np.arange(0.0, 0.5, slot)):
        indicator = []
        for cluster_id in range(100):
            cardinality[cluster_id] += ground_truth_total_level[cluster_id][query_id][threshold_id][-1]
            if cardinality[cluster_id] > 0:
                indicator.append(1)
            else:
                indicator.append(0)
        feature = np.concatenate((queries[query_id] / 255.0, [threshold+slot]))
        test_features.append(feature)
        test_targets.append(indicator)
        
        

In [62]:
len(train_features[0])

785

In [63]:
import torch
batch_size = 128
train_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.FloatTensor(train_features), torch.FloatTensor(train_targets)), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
        torch.utils.data.TensorDataset(torch.FloatTensor(test_features), torch.FloatTensor(test_targets)), batch_size=batch_size, shuffle=True)


## Multi-label Networks

In [64]:
from __future__ import print_function
import argparse
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

input_dimension = 785
hidden_num = 256
output_num = 100

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.nn1 = nn.Linear(input_dimension, hidden_num)
        self.nn2 = nn.Linear(hidden_num, hidden_num)
        self.nn3 = nn.Linear(hidden_num, hidden_num)
        self.nn4 = nn.Linear(hidden_num, output_num)

    def forward(self, x):
        out1 = F.relu(self.nn1(x))
        out2 = F.relu(self.nn2(out1))
        out3 = F.relu(self.nn3(out2))
        out4 = F.sigmoid(self.nn4(out3))
        return out4

def loss_fn(estimates, targets):
    return F.mse_loss(estimates, targets)

def print_loss(estimates, targets):
    true_positive = 0.0
    true_negative = 0.0
    false_positive = 0.0
    false_negative = 0.0
    num_elements = estimates.shape[1]
    for est, tar in zip(estimates, targets):
        for i in range(num_elements):
            if est[i] < 0.5 and tar[i] == 0:
                true_negative += 1
            elif est[i] < 0.5 and tar[i] == 1:
                false_positive += 1
            elif est[i] >= 0.5 and tar[i] == 0:
                false_negative += 1
            else:
                true_positive += 1
    precision = true_positive / (true_positive + false_positive)
    recall = true_positive / (true_positive + false_negative)
    return precision, recall

In [None]:
model = Model()
opt = optim.Adam(model.parameters(), lr=0.001)
for e in range(30):
    model.train()
    for batch_idx, (features, targets) in enumerate(train_loader):
        x = Variable(features)
        y = Variable(targets.unsqueeze(1))
        opt.zero_grad()
        estimates = model(x)
        loss = loss_fn(estimates, targets)
        if batch_idx % 100 == 0:
            print('Training: Iteration {0}, Batch {1}, Loss {2}'.format(e, batch_idx, loss.item()))
        loss.backward()
        opt.step()

    model.eval()    
    test_loss = 0.0
    precision = 0.0
    recall = 0.0
    for batch_idx, (features, targets) in enumerate(test_loader):
        x = Variable(features)
        y = Variable(targets.unsqueeze(1))
        estimates = model(x)
        loss = loss_fn(estimates, targets)
        test_loss += loss.item()
        prec, rec = print_loss(estimates, targets)
        precision += prec
        recall += rec
        if batch_idx % 100 == 0:
            print ('Testing: Iteration {0}, Batch {1}, Loss {2}, Precision {3}, Recall {4}'.format(e, batch_idx, loss.item(), prec, rec))
    test_loss /= len(test_loader)
    precision /= len(test_loader)
    recall /= len(test_loader)
    print ('Testing: Loss {0}, Precision {1}, Recall {2}'.format(test_loss, precision, recall))
    
    

Training: Iteration 0, Batch 0, Loss 0.2496676743030548
Training: Iteration 0, Batch 100, Loss 0.2118474394083023
Training: Iteration 0, Batch 200, Loss 0.19117672741413116
Training: Iteration 0, Batch 300, Loss 0.13398313522338867
Training: Iteration 0, Batch 400, Loss 0.06092129275202751
Training: Iteration 0, Batch 500, Loss 0.03990978002548218
Training: Iteration 0, Batch 600, Loss 0.03648688271641731
Training: Iteration 0, Batch 700, Loss 0.034134093672037125
Training: Iteration 0, Batch 800, Loss 0.058855827897787094
Training: Iteration 0, Batch 900, Loss 0.03753291815519333
Training: Iteration 0, Batch 1000, Loss 0.03673548623919487
Training: Iteration 0, Batch 1100, Loss 0.030431687831878662
Training: Iteration 0, Batch 1200, Loss 0.03044726699590683
Training: Iteration 0, Batch 1300, Loss 0.02446487359702587
Training: Iteration 0, Batch 1400, Loss 0.02101174183189869
Training: Iteration 0, Batch 1500, Loss 0.021719638258218765
Training: Iteration 0, Batch 1600, Loss 0.02661372

In [20]:
torch.save(model.state_dict(), 'global_fashion_mnist_784_euclidean_binary_query_threshold.model')

## Model Usage

In [70]:
model = Model()
model.load_state_dict(torch.load('global_fashion_mnist_784_euclidean_binary_query_threshold.model'))
model.eval()

Model(
  (nn1): Linear(in_features=785, out_features=256, bias=True)
  (nn2): Linear(in_features=256, out_features=256, bias=True)
  (nn3): Linear(in_features=256, out_features=256, bias=True)
  (nn4): Linear(in_features=256, out_features=100, bias=True)
)

In [71]:
for batch_idx, (features, targets) in enumerate(test_loader):
    x = Variable(features)
    y = Variable(targets.unsqueeze(1))
    estimates = model(x)
    loss = loss_fn(estimates, targets)
    print (estimates[0]>=0.5)
    print (targets[0])

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.,

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.,

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False,  True, False, False, False, False, False, False, False])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

tensor([False,  True,  True,  True,  True,  True,  True, False,  True, False,
         True,  True, False,  True, False,  True, False, False,  True,  True,
         True,  True,  True,  True, False,  True,  True, False,  True,  True,
         True,  True, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True, False,  True,  True,  True, False,
        False,  True,  True,  True, False, False, False,  True,  True,  True,
         True,  True, False,  True,  True,  True,  True, False, False,  True,
        False, False,  True,  True,  True,  True,  True,  True, False,  True,
        False,  True,  True,  True, False,  True, False,  True, False,  True,
        False,  True,  True, False,  True, False, False,  True, False, False])
tensor([0., 1., 1., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0.,
        0., 1., 1., 1., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1.,
        1., 1., 0., 0., 1., 1., 0., 0., 1., 0., 0., 0., 1.,

tensor([False, False, False, False, False,  True, False, False, False, False,
        False, False, False,  True, False, False, False,  True, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False,  True, False,  True, False,
        False, False, False, False, False, False, False, False,  True, False,
        False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False,  True, False,  True, False, False,  True, False, False, False,
        False,  True,  True, False, False, False, False, False, False, False,
        False,  True,  True, False, False, False,  True, False, False, False])
tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,

tensor([0., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 0., 1.,
        1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 0., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,
        0., 1., 1., 0., 1., 0., 1., 1., 0., 0.])
tensor([ True, False, False, False,  True, False, False,  True,  True,  True,
        False, False,  True, False,  True, False,  True, False, False, False,
        False, False,  True, False,  True, False, False,  True, False,  True,
        False,  True,  True, False, False,  True, False,  True, False, False,
        False, False,  True, False, False,  True,  True, False, False,  True,
         True, False, False,  True,  True, False, False, False, False,  True,
        False,  True,  True, False, False, False, False,  True, False, False,
     

tensor([False,  True, False, False,  True,  True,  True, False, False, False,
        False, False, False,  True, False,  True, False, False,  True,  True,
         True,  True,  True, False, False, False, False, False,  True,  True,
        False,  True, False, False,  True,  True,  True,  True,  True, False,
         True, False, False, False, False, False,  True, False,  True, False,
        False,  True,  True,  True, False,  True, False, False, False,  True,
        False,  True, False,  True,  True, False, False, False, False,  True,
        False,  True,  True,  True,  True, False,  True, False, False,  True,
        False,  True,  True,  True, False, False, False, False, False,  True,
        False,  True,  True, False, False,  True,  True, False, False, False])
tensor([0., 1., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 0.,
        1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0., 1., 0., 0., 1., 1.,
        1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 0., 1.,

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

tensor([False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False,  True, False,  True, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False,  True, False,
        False, False, False, False, False, False, False, False,  True, False,
        False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False,  True, False,  True, False, False,  True, False, False, False,
        False,  True,  True, False, False, False, False, False, False, False,
        False,  True,  True, False, False, False,  True, False, False, False])
tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.,

tensor([False,  True, False, False, False,  True, False, False, False, False,
        False, False, False,  True, False,  True, False,  True,  True,  True,
         True,  True,  True, False, False, False, False, False, False, False,
        False, False, False, False,  True, False,  True,  True,  True,  True,
         True, False, False, False, False, False, False, False,  True, False,
        False,  True,  True,  True, False,  True, False, False, False, False,
        False,  True, False, False,  True, False,  True, False, False,  True,
        False,  True,  True,  True,  True, False,  True, False, False,  True,
        False,  True,  True,  True, False, False, False,  True, False, False,
        False,  True,  True, False, False, False,  True,  True, False, False])
tensor([0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1.,
        1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
        1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1.,

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.,

tensor([ True,  True, False, False,  True,  True, False, False, False,  True,
        False, False,  True,  True, False,  True, False, False, False, False,
        False,  True,  True, False,  True, False, False, False, False,  True,
        False,  True,  True, False, False,  True, False,  True, False,  True,
        False, False, False, False, False, False,  True, False,  True, False,
         True, False, False, False,  True,  True, False, False, False,  True,
        False,  True,  True, False,  True, False, False,  True,  True, False,
        False, False,  True,  True, False, False, False, False,  True, False,
        False, False,  True,  True,  True, False,  True,  True,  True,  True,
         True, False,  True,  True, False,  True, False, False,  True, False])
tensor([1., 1., 0., 0., 1., 1., 0., 0., 0., 1., 0., 0., 1., 1., 0., 1., 0., 0.,
        0., 0., 0., 1., 1., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 1.,
        0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 1.,

tensor([0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.,
        1., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0.,
        1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([False, False,  True,  True, False, False,  True, False,  True, False,
         True,  True, False, False, False, False, False, False,  True,  True,
         True, False,  True,  True, False,  True,  True, False,  True, False,
         True, False, False,  True,  True, False,  True,  True, False, False,
         True,  True, False,  True,  True,  True, False,  True, False, False,
        False,  True,  True,  True, False, False,  True,  True,  True,  True,
         True,  True, False,  True,  True,  True,  True, False, False,  True,
     

tensor([1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 0., 1., 0., 1., 0.,
        0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1.,
        1., 1., 0., 1., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1.,
        1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 0.,
        1., 0., 1., 1., 0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 0., 0., 1., 1., 0., 0., 1., 1., 0.])
tensor([False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False, False, False,  True, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
     

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([ True, False, False, False, False, False, False,  True, False, False,
        False, False, False, False,  True, False,  True, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
     

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False,  True, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

tensor([False, False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False,  True, False, False, False, False, False, False, False])
tensor([0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False,  True, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False, False,
        False, False,  True, False, False, False, False, False, False, False])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

tensor([False,  True, False, False,  True, False, False, False, False, False,
        False, False, False,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,  True,
        False, False,  True, False, False, False, False, False, False, False,
        False, False, False, False, False, False,  True, False, False, False,
        False, False, False, False,  True, False, False, False, False,  True,
        False, False, False, False, False, False, False, False, False, False,
        False, False,  True, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False,  True,  True,  True,
         True, False, False, False, False,  True, False, False,  True, False])
tensor([0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [None]:
use_cuda = torch.cuda.is_available()

#     torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")
train_dataset = np.array(f['train'])
test_dataset = np.array(f['test'])
train_lefts, train_rights, test_lefts, test_rights = prepare_dataset(train_dataset, test_dataset, train_num, test_num)

train_loader = torch.utils.data.DataLoader(
    (train_lefts, train_rights), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    (test_lefts, test_rights), batch_size=batch_size, shuffle=True)


In [None]:
# hash_distances, input_distances = test(model, device, train_loader)
hash_distances, input_distances = test(model, device, test_loader)

In [None]:
lefts = torch.FloatTensor([f['train'][0] for x in range(999)])
rights = torch.FloatTensor(f['train'][1:1000])
inputdistance = angular_distance(lefts, rights).detach().numpy()
hashdistance = l1_distance(model(lefts), model(rights)).detach().numpy()


In [None]:
# for xx in zip(inputdistance, hashdistance):
#     print (xx[0], xx[1])
index_1 = np.argsort(hashdistance, 0)
index_2 = np.argsort(inputdistance, 0)
# np.random.shuffle(index_2)

input_index = {}
for pos, idx in enumerate(index_2):
    input_index[idx] = pos
sum = 0.0
for pos, idx in enumerate(index_1):
    sum += np.abs(pos - input_index[idx])
sum / len(index_1)

In [None]:
xxx = np.sort(inputdistance, 0)
plt.plot(xxx)
plt.show()

In [None]:
import math
distances = []
for i in index_1:
    distances.append(math.floor(inputdistance[i].item()* 40))

In [None]:
import matplotlib.pyplot as plt

plt.plot(distances)
plt.show()

In [None]:
F.cosine_similarity(torch.FloatTensor(f['train'][0]).unsqueeze(0), torch.FloatTensor(f['train'][6]).unsqueeze(0), dim=1, eps=1e-8)

In [None]:
for x, y in zip(hash_distances[0][0:30], input_distances[0][0:30]):
    print (x, y)

In [None]:
dataset_vector = model(torch.FloatTensor(f['train']))

In [None]:
query_vector = model(torch.FloatTensor(f['test']))

In [None]:
def binarization(vector):
    query_codes = []
    for v in vector:
        binary_code = []
        for e in v:
            if e < 0.5:
                binary_code.append(0)
            else:
                binary_code.append(1)
        query_codes.append(binary_code)
    return np.array(query_codes)
dataset_binary = binarization(dataset_vector.detach().numpy())
query_binary = binarization(query_vector.detach().numpy())

In [None]:
len(dataset_binary)

In [None]:
len(query_binary)

In [None]:
import math
hash_table = {}
for idx, point in enumerate(dataset_binary):
    pos = 0
    key = 0
    for d in point:
        key += d * math.pow(2, pos)
        pos += 1
    if key in hash_table:
        hash_table[key].append(idx)
    else:
        hash_table[key] = [idx]

In [None]:
f['neighbors'][:]

In [None]:
def find_candidate_distance(vector, hash_table, candidate_num):
    candidate = []
    for point in query_binary:
        cand = []
        dis = 0
        while len(cand) < 100:
            pos = 0
            key = 0
            for d in point:
                key += d * math.pow(2, pos)
                pos += 1
            if key in hash_table:
                candidate.append(hash_table[key])
    return candidate
find_candidate_0_distance(query_binary, hash_table)

In [None]:
class Node(object):
    def __init__(hash_code, data_index_set):
        self.hash_code = hash_code
        self.data_index_set = data_index_set
        self.children = []
        
    def isLeaf():
        return len(self.children) == 0
    
    def train(dataset):
        train_data = dataset[self.data_index_set]
        self.model = train(dataset)
        
    def partition():
        points = dataset[self.data_index_set]
        hash_table = {}
        codes = self.model(points)
        for idx, code in enumerate(codes):
            if code in hash_table:
                hash_table[code].append(self.data_index_set[idx])
            else:
                hash_table[code] = [self.data_index_set[idx]]
        for key,value in d.items():
            self.children.append(Node(key, value))
    
    def search(query, dataset):
        if self.isLeaf():
            return validate(dataset[self.data_index_set])
        else:
            children_idxes = select_children(query)
            result = []
            for idx in children_idxes:
                result += self.children[idx].search(query, dataset)
            return result
    
    
    

def index_construction(dataset):
    model = train(dataset)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import manifold, datasets

data = np.array(f['train'])

tsne = manifold.TSNE(n_components=2, init='pca', random_state=501)
X_tsne = tsne.fit_transform(data[np.random.choice(data.shape[0], 100000, replace=False)])