## This script loads the current model and performs an evaluation of it

### Initialize
First, initialize the model with all parameters


In [1]:
from data_source import DataSource
from visualize import Visualize
from sphere import Sphere
from model import Model
from loss import TripletLoss, ImprovedTripletLoss
from training_set import TrainingSet
from average_meter import AverageMeter
from data_splitter import DataSplitter

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import sys
import time
import math
import numpy as np
import pandas as pd
import open3d as o3d
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

from tqdm.auto import tqdm
from scipy import spatial
%reload_ext autoreload
%autoreload 2

In [2]:
torch.cuda.set_device(1)
torch.backends.cudnn.benchmark = True
net = Model().cuda()
restore = False
optimizer = torch.optim.SGD(net.parameters(), lr=5e-3, momentum=0.9)
batch_size = 12
num_workers = 12
descriptor_size = 128
bandwidth = 100
net_input_size = 2*bandwidth
n_features = 2
cache = 50
criterion = ImprovedTripletLoss(margin=2, alpha=0.5, margin2=0.2)
writer = SummaryWriter()
model_save = '../models/12000_b14_big.pkl'
net.load_state_dict(torch.load(model_save))
#summary(net, input_size=[(2, 200, 200), (2, 200, 200), (2, 200, 200)])

<All keys matched successfully>

Initialize the data source

In [3]:
n_data = 2000
ds = DataSource('/media/scratch/berlukas/spherical/test', n_data, 5)
ds.load(n_data)
n_data = len(ds.anchors)

Loading anchors from:	 /media/scratch/berlukas/spherical/test/anchor/
Loading positives from:	 /media/scratch/berlukas/spherical/test/positive/
Loading negatives from:	 /media/scratch/berlukas/spherical/test/negative/
Done loading dataset.
	Anchors total: 		201
	Positives total: 	201
	Negatives total: 	201


In [4]:
test_set = TrainingSet(ds, restore, bandwidth)
print("Total size: ", len(test_set))
test_loader = torch.utils.data.DataLoader(test_set, batch_size=10, shuffle=False, num_workers=1, pin_memory=True, drop_last=False)


Generating anchor spheres


HBox(children=(FloatProgress(value=0.0, max=201.0), HTML(value='')))


Generating positive spheres


HBox(children=(FloatProgress(value=0.0, max=201.0), HTML(value='')))


Generating negative spheres


HBox(children=(FloatProgress(value=0.0, max=201.0), HTML(value='')))


Generated features
Total size:  201


In [7]:
def accuracy(dista, distb):
    margin = 0
    pred = (dista - distb - margin).cpu().data
    acc = ((pred < 0).sum()).float()/dista.size(0)
    return acc

net.eval()
n_iter = 0
anchor_embeddings = np.empty(1)
positive_embeddings = np.empty(1)
with torch.no_grad():
    test_accs = AverageMeter()
    test_pos_dist = AverageMeter()
    test_neg_dist = AverageMeter()

    for batch_idx, (data1, data2, data3) in enumerate(test_loader):
        embedded_a, embedded_p, embedded_n = net(data1.cuda().float(), data2.cuda().float(), data3.cuda().float())
        dist_to_pos, dist_to_neg, loss, loss_total = criterion(embedded_a, embedded_p, embedded_n)
        writer.add_scalar('Ext_Test/Loss', loss, n_iter)

        acc = accuracy(dist_to_pos, dist_to_neg)
        test_accs.update(acc, data1.size(0))
        test_pos_dist.update(dist_to_pos.cpu().data.numpy().sum())
        test_neg_dist.update(dist_to_neg.cpu().data.numpy().sum())

        writer.add_scalar('Ext_Test/Accuracy', test_accs.avg, n_iter)
        writer.add_scalar('Ext_Test/Distance/Positive', test_pos_dist.avg, n_iter)
        writer.add_scalar('Ext_Test/Distance/Negative', test_neg_dist.avg, n_iter)

        anchor_embeddings = np.append(anchor_embeddings, embedded_a.cpu().data.numpy().reshape([1,-1]))
        positive_embeddings = np.append(positive_embeddings, embedded_p.cpu().data.numpy().reshape([1,-1]))
        n_iter = n_iter + 1

In [8]:
desc_anchors = anchor_embeddings[1:].reshape([n_data, descriptor_size])
desc_positives = positive_embeddings[1:].reshape([n_data, descriptor_size])

sys.setrecursionlimit(50000)
tree = spatial.KDTree(desc_positives)
p_norm = 2
max_pos_dist = 0.05
max_anchor_dist = 1
for n_nearest_neighbors in tqdm(range(1,21)):
    pos_count = 0
    anchor_count = 0
    idx_count = 0
    for idx in range(n_data):
        nn_dists, nn_indices = tree.query(desc_anchors[idx,:], p = p_norm, k = n_nearest_neighbors)
        nn_indices = [nn_indices] if n_nearest_neighbors == 1 else nn_indices

        for nn_i in nn_indices:
            if (nn_i >= n_data):
                break;
            dist = spatial.distance.euclidean(desc_positives[nn_i,:], desc_positives[idx,:])
            if (dist <= max_pos_dist):
                pos_count = pos_count + 1;
                break
        for nn_i in nn_indices:
            if (nn_i >= n_data):
                break;
            dist = spatial.distance.euclidean(desc_positives[nn_i,:], desc_anchors[idx,:])
            if (dist <= max_anchor_dist):
                anchor_count = anchor_count + 1;
                break
        for nn_i in nn_indices:
            if (nn_i == idx):
                idx_count = idx_count + 1;
                break
    pos_precision = (pos_count*1.0) / n_data
    anchor_precision = (anchor_count*1.0) / n_data
    idx_precision = (idx_count*1.0) / n_data
    
    print(f'recall {idx_precision} for {n_nearest_neighbors} neighbors')
    writer.add_scalar('Ext_Test/Precision/Positive_Distance', pos_precision, n_nearest_neighbors)
    writer.add_scalar('Ext_Test/Precision/Anchor_Distance', anchor_precision, n_nearest_neighbors)
    writer.add_scalar('Ext_Test/Precision/Index_Count', idx_precision, n_nearest_neighbors)

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

recall 0.43781094527363185 for 1 neighbors
recall 0.6417910447761194 for 2 neighbors
recall 0.7611940298507462 for 3 neighbors
recall 0.8059701492537313 for 4 neighbors
recall 0.835820895522388 for 5 neighbors
recall 0.8606965174129353 for 6 neighbors
recall 0.8706467661691543 for 7 neighbors
recall 0.8756218905472637 for 8 neighbors
recall 0.8855721393034826 for 9 neighbors
recall 0.9203980099502488 for 10 neighbors
recall 0.9353233830845771 for 11 neighbors
recall 0.9502487562189055 for 12 neighbors
recall 0.9601990049751243 for 13 neighbors
recall 0.9651741293532339 for 14 neighbors
recall 0.9751243781094527 for 15 neighbors
recall 0.9751243781094527 for 16 neighbors
recall 0.9800995024875622 for 17 neighbors
recall 0.9800995024875622 for 18 neighbors
recall 0.9800995024875622 for 19 neighbors
recall 0.9850746268656716 for 20 neighbors

