In [None]:
%matplotlib notebook
import torch
import numpy as np
from typing import NamedTuple
import tqdm

import matplotlib.pyplot as plt
import matplotlib.style
matplotlib.style.use('ggplot')

from experimental.beacon_dist import model, utils, train

In [None]:
eval_path = '/home/erick/scratch/beacon_dist/ABC_10k_images.npy'
model_path = '/home/erick/scratch/initial_classification_test/model_000001000.pt'
torch.manual_seed(1234)

ENVIRONMENTS_PER_BATCH = 64
QUERIES_PER_ENVIRONMENT = 16

In [None]:
def load_model(model_path):
    m = model.ConfigurationModel(
            model.ConfigurationModelParams(
            descriptor_size=256,
            descriptor_embedding_size=32,
            position_encoding_factor=10000,
            num_encoder_heads=2,
            num_encoder_layers=2,
            num_decoder_heads=2,
            num_decoder_layers=2,
        )
    )
    m.load_state_dict(torch.load(model_path))
    m.eval()
    return m.to('cuda')

m = load_model(model_path)

In [None]:
def load_eval_dataset(dataset_path):
    return utils.Dataset(data=np.load(dataset_path))
    
dataset = load_eval_dataset(eval_path)

In [None]:


def evaluate(m, dataset):
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=ENVIRONMENTS_PER_BATCH,
        collate_fn=train.make_collator_fn(QUERIES_PER_ENVIRONMENT),
    )
    
    model_output = []
    query_labels = []
    
    # Generate queries
    with torch.no_grad():
        for i, (batch, queries) in enumerate(tqdm.tqdm(data_loader)):
            batch = batch.to("cuda")
            queries = queries.to("cuda")
        
            model_out = torch.sigmoid(m(batch, queries))
            # Compute labels
            labels = utils.is_valid_configuration(batch.class_label, queries)
            
            model_output.append(model_out.to('cpu'))
            query_labels.append(labels.to('cpu'))
            
            
    return torch.concat(model_output), torch.concat(query_labels)
    
output, labels = evaluate(m, dataset)


In [None]:
output.shape

In [None]:
labels.shape

In [None]:
class ErrorRates(NamedTuple):
    threshold: float
    true_positive: float
    true_negative: float
    false_positive: float
    false_negative: float

def compute_error_rates(outputs: torch.tensor, labels: torch.tensor, threshold: float):
    thresholded_outputs = outputs > threshold
    
    true_positive_count = torch.sum(torch.logical_and(thresholded_outputs, labels))
    false_positive_count = torch.sum(torch.logical_and(thresholded_outputs, np.logical_not(labels)))
    true_negative_count = torch.sum(torch.logical_and(torch.logical_not(thresholded_outputs), torch.logical_not(labels)))
    false_negative_count = torch.sum(torch.logical_and(torch.logical_not(thresholded_outputs), labels))
    
    return ErrorRates(
        threshold=threshold,
        true_positive=true_positive_count,
        true_negative=true_negative_count,
        false_positive=false_positive_count,
        false_negative=false_negative_count
    )


In [None]:
def plot_results(outputs, labels):
    thresholds = [0.01, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.99]
    error_rates = [compute_error_rates(outputs, labels, threshold) for threshold in thresholds]
    
    xs = [er.false_positive / (er.false_positive + er.true_negative) for er in error_rates]
    ys = [er.true_positive / (er.true_positive + er.false_negative) for er in error_rates]
    
    plt.figure()
    plt.plot(xs, ys, 'o-')
    plt.xlabel('$FP/(FP+TN)$')
    plt.ylabel('$TP/(TP+FN)$')
    for x, y, label in zip(xs, ys, thresholds):
        plt.text(x, y, label)
    
    plt.title('ROC Curve')
    
    plt.tight_layout()

plot_results(output, labels)