In [None]:
%matplotlib notebook
import torch
import numpy as np
from typing import NamedTuple
import tqdm
import glob

import matplotlib.pyplot as plt
import matplotlib.style
matplotlib.style.use('ggplot')

import string

from experimental.beacon_dist import model, utils, train, generate_letter_dataset, generate_ycb_dataset
import experimental.beacon_dist.multiview_dataset as mvd

In [None]:
eval_path = glob.glob('/home/erick/scratch/beacon_dist/ycb_100k_scenes_4_view_strafe.part_0000[0-1].npz')
model_path = '/home/erick/scratch/ycb_10k_strafe_test/model_000000128.pt'
torch.manual_seed(1234)

ENVIRONMENTS_PER_BATCH = 90
QUERIES_PER_ENVIRONMENT = 16

In [None]:
a = torch.load(model_path)

In [None]:
def load_model(model_path):
    m = model.ConfigurationModel(
            model.ConfigurationModelParams(
            descriptor_size=256,
            descriptor_embedding_size=32,
            position_encoding_factor=10000,
            num_encoder_heads=2,
            num_encoder_layers=2,
            num_decoder_heads=2,
            num_decoder_layers=2,
        )
    )
    state_dict = torch.load(model_path)
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    m.load_state_dict(state_dict)
    m.eval()
    return m.to('cuda')

m = load_model(model_path)

In [None]:
def load_eval_dataset(dataset_paths):
    return mvd.MultiviewDataset(mvd.DatasetInputs(file_paths=dataset_paths, index_path=None, data_tables=None))
    
dataset = load_eval_dataset(eval_path)

In [None]:


def evaluate(m, dataset):
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=ENVIRONMENTS_PER_BATCH,
        collate_fn=train.make_collator_fn(QUERIES_PER_ENVIRONMENT),
        num_workers=10,
    )
    
    model_output = []
    query_labels = []
    context_image_ids = []
    query_image_ids = []
    queries = []
    
    # Generate queries
    with torch.no_grad():
        for i, (batch, query) in enumerate(tqdm.tqdm(data_loader)):
            batch = batch.to("cuda")
            query = query.to("cuda")
        
            model_out = torch.sigmoid(m(batch, query))
            # Compute labels
            labels = utils.is_valid_configuration(batch.query.class_label, query)
            
            model_output.append(model_out.to('cpu'))
            query_labels.append(labels.to('cpu'))
            context_image_ids.append(batch.context.image_id.to('cpu'))
            query_image_ids.append(batch.query.image_id.to('cpu'))
            queries.append(query.to('cpu'))
            
            
    return (
        model_output,
        query_labels,
        context_image_ids,
        query_image_ids,
        queries
    )
    
output_list, labels_list, context_image_ids_list, query_image_ids_list, queries_list = evaluate(m, dataset)


In [None]:
output = torch.concatenate(output_list)
labels = torch.concatenate(labels_list)
context_image_ids = torch.concatenate(context_image_ids_list)
query_image_ids = torch.concatenate(query_image_ids_list)
# queries = torch.concatenate(queries_list)

In [None]:
class ErrorRates(NamedTuple):
    threshold: float
    true_positive: float
    true_negative: float
    false_positive: float
    false_negative: float

def compute_error_rates(outputs: torch.tensor, labels: torch.tensor, threshold: float):
    thresholded_outputs = outputs > threshold
    
    true_positive_count = torch.sum(torch.logical_and(thresholded_outputs, labels))
    false_positive_count = torch.sum(torch.logical_and(thresholded_outputs, np.logical_not(labels)))
    true_negative_count = torch.sum(torch.logical_and(torch.logical_not(thresholded_outputs), torch.logical_not(labels)))
    false_negative_count = torch.sum(torch.logical_and(torch.logical_not(thresholded_outputs), labels))
    
    return ErrorRates(
        threshold=threshold,
        true_positive=true_positive_count,
        true_negative=true_negative_count,
        false_positive=false_positive_count,
        false_negative=false_negative_count
    )


In [None]:
def plot_results(outputs, labels):
    thresholds = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99]
    error_rates = [compute_error_rates(outputs, labels, threshold) for threshold in thresholds]
    
    xs = [er.false_positive / (er.false_positive + er.true_negative) for er in error_rates]
    ys = [er.true_positive / (er.true_positive + er.false_negative) for er in error_rates]
    
    plt.figure()
    plt.plot(xs, ys, 'o-')
    plt.xlabel('$FP/(FP+TN)$')
    plt.ylabel('$TP/(TP+FN)$')
    for x, y, label in zip(xs, ys, thresholds):
        plt.text(x, y, label)
    
    plt.title('ROC Curve')
    
    plt.tight_layout()

plot_results(output, labels)

In [None]:
eval_path

In [None]:
batch_idx = np.argmax(output - labels)
image_id = image_ids[batch_idx].item()
letters = image_descriptors[image_descriptors['image_id'] == image_id]
query = queries[batch_idx]
letter_set = {x['char']: generate_data.LetterPosition(x=x['x'], y=x['y'], angle=x['theta']) for x in letters}

In [None]:
image = generate_data.image_from_letter_set(letter_set, width=1280, height=720)
kps = dataset[image_id]
plt.figure()
plt.imshow(image)
plt.scatter(kps.x, kps.y, c=kps.class_label)
plt.colorbar()

In [None]:
plt.figure()
plt.imshow(image)
plt.scatter(kps.x[query], kps.y[query], c=kps.class_label[query])
plt.colorbar()