<a href="https://colab.research.google.com/github/mgozon/DLG-UROP/blob/main/dlg_stats.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DLG Statistics
This notebook provides utilities to analyze DLG.
- assign_guess(guess, gt_dataset, n, verbose = False): guess_perm
- def assign_best(guess, gt_dataset, n, verbose = False): best_match
- compute_stats(guess_perm, gt_data, recovered_threshold = 0.25): rel_errors, recovered_rate, cos_angles

In [None]:
import torch
from scipy.optimize import linear_sum_assignment
from math import sqrt

# Matchings / Assignments

In [3]:
# find best linear sum assignment using relative error
def assign_guess(guess, gt_dataset, n, verbose = False):
    cost_matrix = [[sqrt(torch.sum((guess[i]-gt_dataset[j])**2).item()) / sqrt(torch.sum(gt_dataset[j]**2).item()) for j in range(n)] for i in range(n)]
    row_ind, col_ind = linear_sum_assignment(cost_matrix)
    best_MSE = sum([cost_matrix[row_ind[i]][col_ind[i]] for i in range(n)]) / n

    guess_perm = torch.zeros(gt_dataset.shape)
    for i in range(n):
        guess_perm[col_ind[i]] = guess[i]
    
    if (verbose):
        print('gt data vs guess perm (linear sum assignment): ')
        for i in range(n):
            print(gt_dataset[i], guess_perm[i], 'RE (AE / TN): ', sqrt(torch.sum((gt_dataset[i]-guess_perm[i])**2).item()) / sqrt(torch.sum(gt_dataset[i]**2).item()))

    return guess_perm

print('defined: assign_guess(guess, gt_dataset, n, verbose = False): guess_perm')

defined: assign_guess(guess, gt_dataset, n, verbose = False): guess_perm


In [4]:
# NOTE: this returns an array with the best corresponding gt_data since matching is not one to one
# --> this is NOT the same format as assign_guess()
def assign_best(guess, gt_dataset, n, verbose = False):
    cost_matrix = torch.tensor([[sqrt(torch.sum((guess[i]-gt_dataset[j])**2).item()) / sqrt(torch.sum(gt_dataset[j]**2).item()) for j in range(n)] for i in range(n)])
    match_idx = torch.argmin(cost_matrix, dim=1)
    best_match = gt_dataset[match_idx]

    if (verbose):
        # print('relative error matrix:')
        # print(cost_matrix)
        print('closest match: assignment and relative error (%):')
        for i in range(n):
            RE = sqrt(torch.sum((best_match[i]-guess[i])**2).item()) / sqrt(torch.sum(best_match[i]**2).item())
            print(guess[i], best_match[i], 100*RE)
    
    return best_match

print('defined: assign_best(guess, gt_dataset, n, verbose = False): best_match')

defined: assign_best(guess, gt_dataset, n, verbose = False): best_match


# DLG Statistics

In [None]:
#@title Old Statistics (SE, MSE, NE, cos distance) (not updated)
def compute_stats_old(n_elts, guess_perm, gt_data):
    MSE = 0
    SEs = []
    n_errors = []
    cos_angles = []
    for i in range(n_elts):
        SE = torch.sum((guess_perm[i]-gt_data[i])**2).item()
        SEs.append(SE); MSE += SE

        n_error = (torch.sum((guess_perm[i]-gt_data[i])**2) / (torch.linalg.norm(gt_data[i])**2)).item()
        n_errors.append(n_error)

        cos_angle = (torch.sum(guess_perm[i]*gt_data[i]).item() / (torch.linalg.norm(gt_data[i]) * torch.linalg.norm(guess_perm[i]))).item()
        cos_angles.append(cos_angle)

    MSE /= n_elts

    return SEs, MSE, n_errors, cos_angles

In [5]:
def compute_stats(guess_perm, gt_data, recovered_threshold = 0.25):
    assert guess_perm.shape[0] == gt_data.shape[0]
    n_elts = guess_perm.shape[0]

    rel_errors = []
    recovered_rate = 0
    cos_angles = []
    for i in range(n_elts):
        rel_error = sqrt(torch.sum((gt_data[i]-guess_perm[i])**2).item()) / sqrt(torch.sum(gt_data[i]**2).item())
        rel_errors.append(rel_error)

        if (rel_error <= recovered_threshold):
            recovered_rate += 1

        cos_angle = (torch.sum(guess_perm[i]*gt_data[i]).item() / (torch.linalg.norm(gt_data[i]) * torch.linalg.norm(guess_perm[i]))).item()
        cos_angles.append(cos_angle)

    recovered_rate /= n_elts

    return rel_errors, recovered_rate, cos_angles
  
print('defined: compute_stats(guess_perm, gt_data, recovered_threshold = 0.25): rel_errors, recovered_rate, cos_angles')

defined: compute_stats(guess_perm, gt_data, recovered_threshold = 0.25): rel_errors, recovered_rate, cos_angles
