## Evaluate slick association results on dataset

This expects an input of a GeoJSON file from the slick explorer. This can be generated clicking the `Run All` button.

In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
import matplotlib.pyplot as plt
import os

### Load results file along with truth

In [None]:
results_file = 'results_20230413.geojson'
results = gpd.read_file(results_file)

In [None]:
truth_file = '/home/k3blu3/datasets/cerulean/slick_truth_year1.csv'
truth = pd.read_csv(truth_file)

In [None]:
# just grab the columns we need
truth = truth[['PID', 'HITL MMSI', 'HITL Confidence', 'Algo MMSI', 'Algo Score']]

In [None]:
# merge the results and truth on PID
results = pd.merge(results, truth, on='PID')

In [None]:
# explode if there are multiple HITL MMSI
results = results.assign(truth=results['HITL MMSI'].str.split(',')).explode('truth')

In [None]:
# check number of unique samples being evaluated
len(pd.unique(results['PID']))

In [None]:
results.head(5)

In [None]:
# clean up dataframe types
results['truth'] = results['truth'].str.strip()
results.temporal_score = results.temporal_score.astype(np.float32)
results.overlap_score = results.overlap_score.astype(np.float32)
results.frechet_dist = results.frechet_dist.astype(np.float32)
results.total_score = results.total_score.astype(np.float32)

### Evaluate results against truth

In [None]:
comparisons = list()
for pid, pid_group in results.groupby('PID'):
    # treat each slick independently
    for sidx, slick_group in pid_group.groupby('slick_index'):
        # we'll take the highest total score
        slick_group = slick_group.sort_values('total_score', ascending=False)

        entry = dict()
        entry['PID'] = slick_group.iloc[0].PID
        entry['temporal_score'] = slick_group.iloc[0].temporal_score
        entry['overlap_score'] = slick_group.iloc[0].overlap_score
        entry['frechet_dist'] = slick_group.iloc[0].frechet_dist
        entry['Krishna MMSI'] = slick_group.iloc[0].traj_id
        entry['Krishna Score'] = slick_group.iloc[0].total_score
        entry['Truth MMSI'] = slick_group.iloc[0]['truth']
        entry['Algo MMSI'] = slick_group.iloc[0]['Algo MMSI']

        comparisons.append(entry)

comparisons = pd.DataFrame(comparisons)

In [None]:
pid = list()
krishna_score = list()
krishna_mmsi = list()
algo_mmsi = list()
truth_mmsi = list()

for p, group in comparisons.groupby('PID'):
    for truth in pd.unique(group['Truth MMSI']):
        if truth != 'DARK':
            pid.append(p)
            algo_mmsi.append(group['Algo MMSI'].iloc[0].astype(str))
            truth_mmsi.append(truth)
            if (group['Krishna MMSI'].astype(str) == truth).any():
                row = group[group['Krishna MMSI'].astype(str) == truth]
                krishna_mmsi.append(row.iloc[0]['Krishna MMSI'])
                krishna_score.append(np.float32(row.iloc[0]['Krishna Score']))
            else:
                krishna_idx = group['Krishna Score'].astype(np.float32).idxmax()
                krishna_mmsi.append(group.loc[krishna_idx]['Krishna MMSI'])
                krishna_score.append(np.float32(group.loc[krishna_idx]['Krishna Score']))

In [None]:
res = pd.DataFrame({'pid': pid, 
                    'krishna_score': krishna_score, 
                    'krishna_mmsi': krishna_mmsi, 
                    'algo_mmsi': algo_mmsi, 
                    'truth_mmsi': truth_mmsi})

In [None]:
krishna_correct = res[res['krishna_mmsi'] == res['truth_mmsi']]
krishna_incorrect = res[res['krishna_mmsi'] != res['truth_mmsi']]
algo_correct = res[res['algo_mmsi'] == res['truth_mmsi']]
algo_incorrect = res[res['algo_mmsi'] != res['truth_mmsi']]

In [None]:
krishna_pct = 100 * len(krishna_correct) / len(res)
algo_pct = 100 * len(algo_correct) / len(res)

In [None]:
print(krishna_pct)
print(algo_pct)

### Quick plot of scores

In [None]:
plt.figure(dpi=200, figsize=(10, 5))
plt.style.use('ggplot')
plt.subplot(1, 2, 1)
plt.hist(krishna_correct.krishna_score, color='red', alpha=0.7)
plt.title('Correct Matches')
plt.xlabel('Score')
plt.ylabel('Counts')
plt.subplot(1, 2, 2)
plt.hist(krishna_incorrect.krishna_score, color='blue', alpha=0.7)
plt.title('Incorrect Matches')
plt.xlabel('Score')
plt.ylabel('Counts')