In [1]:
import os
import numpy as np 
import pandas as pd
import librosa
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import pandas as pd
import altair as alt
alt.data_transformers.disable_max_rows()
transformer_output_path = "transformer_inference_csv/"

# Before evaluation

## Installing the enviroment

If you haven't already run the following line / cell to install the enviroment

In [None]:
!conda env create -f ../environment.yml

## Running inference

Before evaluating you must run the pipeline on your dataset with the following cell / line

In [None]:
!python candidate_revision_inference.py

# Click detection

## Phase 1 click evaluation

### Define evaluation functions

In [2]:
from scipy.ndimage import label as get_connected_components

def supress_multiple_pred(data_df,sending_th):
    model_conf = data_df["model_confidence"]

    index_range = np.arange(len(model_conf))
    above_th = np.ones(len(model_conf))*(model_conf>sending_th)

    components, num_components = get_connected_components(above_th)
    pred_center = [(model_conf[components == i+1]).idxmax() for i in range(num_components)]
    best_prob = [np.max(model_conf[components == i+1]) for i in range(num_components)]

    adjusted_pred = np.zeros_like(index_range).astype(float)
    
    for pred_loc, pred_prob in zip(pred_center,best_prob):
        adjusted_pred[pred_loc] = pred_prob
    return adjusted_pred


def get_TP_FP_FN(pred_list, gt_list, time_diff_threshold = 0.0045351):
    """
    Returns the number of true positives, false positives and false negatives
    """
    TP = 0
    FP = 0
    FN = 0



    # Here we keep track of: The id of the closest prediction to each ground truth coda, the difference in start times and, the different in end times
    gt_closest_pred_start_dist_end_dist = -np.ones((len(gt_list),2))

    # Here we keep track of: The id of the closest ground truth to each prediction coda, the difference in start times and, the different in end times
    pred_closest_gt_start_dist_end_dist = -np.ones((len(pred_list),2))

    # In this loop we calculate these values
    # An ID of -1 just means it hasn't been initialized yet
    gt_id = 0
    is_good = np.zeros_like(pred_list)
    is_caught = np.zeros_like(gt_list)

    for gt_click_time in gt_list:

        
        pred_id = 0
        for pred_click_time in pred_list:
            time_diff = np.abs(gt_click_time-pred_click_time)

            if  gt_closest_pred_start_dist_end_dist[gt_id][0] == -1 or time_diff < gt_closest_pred_start_dist_end_dist[gt_id][1]:
                gt_closest_pred_start_dist_end_dist[gt_id][0] = pred_id
                gt_closest_pred_start_dist_end_dist[gt_id][1] = time_diff

            if pred_closest_gt_start_dist_end_dist[pred_id][0] == -1 or time_diff < pred_closest_gt_start_dist_end_dist[pred_id][1]:
                pred_closest_gt_start_dist_end_dist[pred_id][0] = gt_id
                pred_closest_gt_start_dist_end_dist[pred_id][1] = time_diff

            pred_id += 1

        gt_id += 1


    # For every value in the ground truth if we found a match that's a true positive otherwise it's a false negative.
    for gt_id in range(len(gt_list)):
        if gt_closest_pred_start_dist_end_dist[gt_id][0] == -1:
            #print(gt_list[gt_id])
            FN += 1
            continue


        if pred_closest_gt_start_dist_end_dist[int(gt_closest_pred_start_dist_end_dist[gt_id][0])][0] == gt_id and gt_closest_pred_start_dist_end_dist[gt_id][1] < time_diff_threshold:
            is_good[int(gt_closest_pred_start_dist_end_dist[gt_id][0])] = 1
            is_caught[gt_id] = 1
            TP += 1
        else:
            #print(gt_list[gt_id])
            FN += 1

    # For every value in the prediction if we found a match that's a true positive otherwise it's a false positive.
    for pred_id in range(len(pred_list)):
        if pred_closest_gt_start_dist_end_dist[pred_id][0] == -1:
            FP += 1
            continue

        if gt_closest_pred_start_dist_end_dist[int(pred_closest_gt_start_dist_end_dist[pred_id][0])][0] == pred_id and pred_closest_gt_start_dist_end_dist[pred_id][1] < time_diff_threshold:
            continue
            # True positives have already been counted in the previous loop, don't count them again.
        else:
            FP += 1
    

    return is_good, is_caught, TP, FP, FN

In [3]:
def calc_metrics(TP,FP,FN):
    recall = TP/(TP+FN)
    if TP+FP == 0:
        precision = 0
    else:
        precision = TP/(TP+FP)
    f_score = (2*TP)/(2*TP+FP+FN)

    print("TP",TP)
    print("FP",FP)
    print("FN",FN)

    print('recall',recall)
    print('precision',precision)
    print('f_score',f_score)

    return recall, precision, f_score

In [4]:
sending_th = 0.3
th_list = [0.3,0.35,0.4,0.45,0.5,0.55,0.6,0.65,0.7,0.75,0.8,0.85,0.9,0.91,0.92,0.93,0.94,0.95,0.96,0.975,0.98,0.99,0.991,0.992,0.993,0.994,0.995,0.996,0.997,0.998,0.999,0.9995,0.9999,1]

In [5]:
def old_model_graphs(data_path):
    th_TP = {}
    th_FP = {}
    th_FN = {}

    for pred_th in th_list:
        th_TP[pred_th] = 0
        th_FP[pred_th] = 0
        th_FN[pred_th] = 0

    for k, data_name in enumerate(os.listdir(data_path)):
        print("Processing",data_name)
        data_df = pd.read_csv(data_path+data_name)
        model_conf = data_df["model_confidence"]
        

        preds = supress_multiple_pred(data_df,sending_th)
        labels = data_df["label"]

        for pred_th in th_list:
            pred_times = np.array(data_df["window_center_seconds"])[preds >= pred_th]
            gt_times = np.array(data_df["window_center_seconds"])[labels == 1]

            is_good, is_caught, TP, FP, FN = get_TP_FP_FN(pred_times,gt_times)

            th_TP[pred_th] += TP
            th_FP[pred_th] += FP
            th_FN[pred_th] += FN

    all_TP = [0]
    all_FP = [0]
    all_FN = [0]
    all_recall = [1]
    all_precision = [0]
    all_f1 = [0]
    all_th = [0]+th_list

    for th in th_list:
        TP = th_TP[th]
        FP = th_FP[th]
        FN = th_FN[th] 

        all_TP.append(TP)
        all_FP.append(FP)
        all_FN.append(FN)

        all_recall.append(TP/(TP+FN))
        if TP+FP == 0:
            all_precision.append(1)
        else:
            all_precision.append(TP/(TP+FP))

        all_f1.append((2*TP)/(2*TP+FN+FP))

    #all_f1.append(0)
    #all_th.append(1)
    #all_recall.append(0)
    #all_precision.append(1)
    results_df = pd.DataFrame()
    results_df["threshold"] = all_th
    results_df["recall"] = all_recall
    results_df["precision"] = all_precision
    results_df["f1"] = all_f1

    return alt.Chart(results_df).mark_line(color="#994F00").encode(
                alt.X("threshold", scale=alt.Scale(domain=[0,1])),
                alt.Y("f1", scale=alt.Scale(domain=[0,1]))
            ).properties(title="F1 graph"),alt.Chart(results_df).mark_line(color="#994F00").encode(
                alt.X("recall", scale=alt.Scale(domain=[0,1])),
                alt.Y("precision", scale=alt.Scale(domain=[0,1]))
            ).properties(title="Precision-Recall graph"),alt.Chart(results_df).mark_line(color="#994F00").encode(
                alt.X("threshold", scale=alt.Scale(domain=[0,1])),
                alt.Y("precision", scale=alt.Scale(domain=[0,1]))
            ).properties(title="Precision graph"),alt.Chart(results_df).mark_line(color="#994F00").encode(
                alt.X("threshold", scale=alt.Scale(domain=[0,1])),
                alt.Y("recall", scale=alt.Scale(domain=[0,1]))
            ).properties(title="Recall graph")

### Run evaluation functions

In [6]:
val_data_path = "transformer_dataset/val/"
val_click_f1, val_click_prec_recall, val_click_precision, val_click_recall = old_model_graphs(val_data_path)
old_prec_recall = val_click_prec_recall
alt.vconcat(alt.hconcat(val_click_f1,val_click_prec_recall),alt.hconcat(val_click_precision,val_click_recall))

Processing sw085a003_1.csv
Processing sw090a005_1.csv
Processing sw100b002_0.csv
Processing sw061b002_1.csv
Processing sw090a005_0.csv
Processing sw061b002_0.csv
Processing sw091b004_1.csv
Processing sw091b004_0.csv
Processing sw085a003_0.csv
Processing sw061b002_2.csv
Processing sw090b003_0.csv


In [7]:
test_data_path = "transformer_dataset/test/"
test_click_f1, test_click_prec_recall, test_click_precision, test_click_recall = old_model_graphs(test_data_path)

Processing sw100b003_0.csv
Processing sw090a004_1.csv
Processing sw091a002_1.csv
Processing sw091a003_0.csv
Processing sw061b001_1.csv
Processing sw106a002_1.csv
Processing sw090a004_0.csv
Processing sw078b001_0.csv
Processing sw100b003_1.csv
Processing sw078b001_1.csv
Processing sw106a002_0.csv
Processing sw091a002_2.csv
Processing sw091a003_1.csv
Processing sw091a003_2.csv
Processing sw106a002_2.csv
Processing sw091a002_0.csv
Processing sw061b001_2.csv
Processing sw061b001_0.csv


In [8]:
alt.vconcat(alt.hconcat(test_click_f1,test_click_prec_recall),alt.hconcat(test_click_precision,test_click_recall))

## Calculate number of clicks dropped in the first phase

In [9]:
val_data_path = "transformer_dataset/val/"
test_data_path = "transformer_dataset/test/"
sending_th = 0.7

In [10]:
all_TP, all_FP, all_FN = (0,0,0)

for k, data_name in enumerate(os.listdir(val_data_path)):
    #print(k, data_name)
    data_df = pd.read_csv(val_data_path+data_name)
    model_conf = data_df["model_confidence"]


    index_range = np.arange(len(model_conf))
    above_th = np.ones(len(model_conf))*(model_conf>=sending_th)
    components, num_components = get_connected_components(sending_th)

    preds = supress_multiple_pred(data_df,sending_th)
    labels = data_df["label"]

    pred_times = np.array(data_df["window_center_seconds"])[preds >= sending_th]
    gt_times = np.array(data_df["window_center_seconds"])[labels == 1]

    is_good, is_caught, TP, FP, FN = get_TP_FP_FN(pred_times,gt_times)
    all_TP += TP
    all_FP += FP
    all_FN += FN

val_forced_FN = all_FN
all_TP, all_FP, all_FN    


(2854, 27486, 51)

In [11]:
all_TP, all_FP, all_FN = (0,0,0)

for k, data_name in enumerate(os.listdir(test_data_path)):
    #print(k, data_name)
    data_df = pd.read_csv(test_data_path+data_name)
    model_conf = data_df["model_confidence"]


    index_range = np.arange(len(model_conf))
    above_th = np.ones(len(model_conf))*(model_conf>sending_th)
    components, num_components = get_connected_components(sending_th)

    preds = supress_multiple_pred(data_df,sending_th)
    labels = data_df["label"]

    pred_times = np.array(data_df["window_center_seconds"])[preds >= sending_th]
    gt_times = np.array(data_df["window_center_seconds"])[labels == 1]

    is_good, is_caught, TP, FP, FN = get_TP_FP_FN(pred_times,gt_times)
    all_TP += TP
    all_FP += FP
    all_FN += FN

test_forced_FN = all_FN
all_TP, all_FP, all_FN

(2816, 29024, 44)

## Transformer metrics

### Define evaluation functions

In [12]:
def transformer_metrics(data_path,forced_FN):
    all_conf = []
    all_label = []

    for data_name in os.listdir(data_path):
        data_df = pd.read_csv(data_path+data_name)

        all_conf += list(data_df["transformer_confidence"])
        all_label += list(data_df["is_correct"])

    all_conf = np.array(all_conf)
    all_label = np.array(all_label)

    all_th = []
    all_TP = []
    all_FP = []
    all_FN = []
    all_recall = []
    all_precision = []
    all_f1 = []

    TP = 0
    FP = 0
    FN = np.sum(all_label)+forced_FN

    all_TP.append(TP)
    all_FP.append(FP)
    all_FN.append(FN)

    all_th.append(1)
    all_recall.append(0)
    all_precision.append(1)
    all_f1.append(0)

    for conf in np.unique(all_conf):
        TP = np.sum(all_label[all_conf >= conf])
        FP = np.sum((1-all_label)[all_conf >= conf])
        FN = np.sum(all_label[all_conf < conf])+forced_FN
        #print(TP,FP,FN)

        all_TP.append(TP)
        all_FP.append(FP)
        all_FN.append(FN)

        all_th.append(conf)
        all_recall.append(TP/(TP+FN))
        all_precision.append(TP/(TP+FP))
        all_f1.append((2*TP)/(2*TP+FN+FP))

    results_df = pd.DataFrame()
    results_df["threshold"] = all_th
    results_df["recall"] = all_recall
    results_df["precision"] = all_precision
    results_df["f1"] = all_f1

    return alt.Chart(results_df).mark_line(color="#006CD1",interpolate="step").encode(
                    alt.X("threshold", scale=alt.Scale(domain=[0,1])),
                    alt.Y("f1", scale=alt.Scale(domain=[0,1]))
                ).properties(title="F1 graph"),alt.Chart(results_df).mark_line(color="#006CD1",interpolate="step").encode(
                    alt.X("recall", scale=alt.Scale(domain=[0,1])),
                    alt.Y("precision", scale=alt.Scale(domain=[0,1]))
                ).properties(title="Precision-Recall graph"),alt.Chart(results_df).mark_line(color="#006CD1",interpolate="step").encode(
                    alt.X("threshold", scale=alt.Scale(domain=[0,1])),
                    alt.Y("precision", scale=alt.Scale(domain=[0,1]))
                ).properties(title="Precision graph"),alt.Chart(results_df).mark_line(color="#006CD1",interpolate="step").encode(
                    alt.X("threshold", scale=alt.Scale(domain=[0,1])),
                    alt.Y("recall", scale=alt.Scale(domain=[0,1]))
                ).properties(title="Recall graph")

### Run evaluation functions

In [13]:
val_data_path = "transformer_inference_csv/val/"

val_transformer_f1, val_transformer_prec_recall, val_transformer_precision, val_transformer_recall = transformer_metrics(val_data_path,val_forced_FN)
alt.vconcat(alt.hconcat(val_transformer_f1,val_transformer_prec_recall),alt.hconcat(val_transformer_precision,val_transformer_recall))

In [14]:
test_data_path = "transformer_inference_csv/test/"

test_transformer_f1, test_transformer_prec_recall, test_transformer_precision, test_transformer_recall = transformer_metrics(test_data_path,test_forced_FN)
alt.vconcat(alt.hconcat(test_transformer_f1,test_transformer_prec_recall),alt.hconcat(test_transformer_precision,test_transformer_recall))

# Coda and whale clustering evaluation

### Define evaluation functions

In [15]:
import json
import pandas as pd
import numpy as np
import os
import altair as alt
alt.data_transformers.disable_max_rows()
import networkx as nx
import scipy.optimize as opt
from scipy.stats import mode
import scipy
from sklearn.cluster import SpectralClustering

In [16]:
def eval_partition(partition_vector,prob_matrix):
    same_matrix = np.matmul(partition_vector.reshape(-1,1),partition_vector.reshape(1,-1)) + np.matmul((1-partition_vector).reshape(-1,1),(1-partition_vector).reshape(1,-1))
    partition_probability_matrix = same_matrix*prob_matrix+(1-same_matrix)*(1-prob_matrix)
    partition_probability_matrix = np.maximum(partition_probability_matrix,1e-15)
    return np.sum(np.log(partition_probability_matrix))

In [17]:
# Cluster coda probabilities using eigenvectors
def eigen_clustering(prob_matrix,label_matrix):
    
    final_output = []
    prob_matrix = (prob_matrix+prob_matrix.transpose())/2

    # Two clicks will not be assigned same coda if there is not path of at least 0.5 probability that connects them
    # Use this to seperate the problem into subcases before we even start 
    G = nx.from_numpy_matrix(prob_matrix > 0.5)
    G = G.to_undirected()
    S = [G.subgraph(c).copy() for c in nx.connected_components(G)]

    to_divide = [] # Tracks same coda probability matrix, same label matrix and click ids for all subcases

    for subgraph in S:
        selected = np.array(subgraph.nodes)
        sub_prob_matrix = prob_matrix[selected,:][:,selected]
        sub_label_matrix = label_matrix[selected,:][:,selected]
        to_divide.append((sub_prob_matrix,sub_label_matrix,selected))
        
    k = 0
    while k < len(to_divide):
        sub_prob_matrix,sub_label_matrix, original_ids = to_divide[k]
        k += 1

        n = sub_prob_matrix.shape[0] # Number of clicks in case
        if n == 1:
            final_output.append(original_ids)
            continue
        
        # Find approximate best 2-cluster clustering
        clustering = SpectralClustering(n_clusters=2,affinity="precomputed").fit(sub_prob_matrix)    
        best_cut_partition_vector = clustering.labels_
        best_cut_score = eval_partition(best_cut_partition_vector,sub_prob_matrix)

        # Compare with all clicks in 1 cluster
        no_split_score = eval_partition(np.zeros(n),sub_prob_matrix)
        if no_split_score > best_cut_score: # If all clicks in 1 cluster is better, resolve subcase with all clicks in 1 coda
            final_output.append(original_ids)
        else: # If the 2 clustering is better split recursivly into 2 subcases
            mask_one = best_cut_partition_vector > 0.5
            mask_two = best_cut_partition_vector <= 0.5
            to_divide.append((sub_prob_matrix[mask_one,:][:,mask_one],sub_label_matrix[mask_one,:][:,mask_one],original_ids[mask_one]))
            to_divide.append((sub_prob_matrix[mask_two,:][:,mask_two],sub_label_matrix[mask_two,:][:,mask_two],original_ids[mask_two]))

    
    return final_output
        

In [18]:
# Calculate weight between two codas
# The lower it is the more we want to assign same whale
def compare_clusters(cluster_a,cluster_b,whale_prob_matrix):
    probabilities = []
    for a in cluster_a:
        for b in cluster_b:
            if (a,b) in whale_prob_matrix.keys():
                probabilities.append(whale_prob_matrix[(a,b)])

    
    if len(probabilities) == 0:
        return 0
    
    probabilities = np.array(probabilities)
    logs = -np.log(probabilities)+np.log(1-probabilities)
    return np.sum(logs)

In [19]:
# Obtain groups from same / different pairs
def find_groups(n,same_different,id_map):
    click_id_to_cluster_id = np.zeros(n)
    all_clusters = {}
    
    last_cluster_id = 0
    for i in range(n):
        click_id_to_cluster_id[i] = last_cluster_id
        all_clusters[last_cluster_id] = [i]
        last_cluster_id += 1

    for i in range(n):
        for j in range(i+1,n,1):
            if same_different[id_map[(i,j)]] > 0.5 and click_id_to_cluster_id[i] != click_id_to_cluster_id[j]:
                # Move all the nodes from j's cluster to i's cluester
                new_cluster_id = click_id_to_cluster_id[i]

                old_cluster_id = click_id_to_cluster_id[j]
                old_cluster = all_clusters[old_cluster_id]
                
                for old_cluster_node in old_cluster:
                    click_id_to_cluster_id[old_cluster_node] = new_cluster_id

                all_clusters[new_cluster_id] += old_cluster
                del all_clusters[old_cluster_id]
    
    return all_clusters

In [20]:
# Finds optimal whale clusterings
def solve_problem(weight_matrix):
    n = weight_matrix.shape[0]

    num_var = int((n*(n-1))/2)
    num_res = int((n*(n-1)*(n-2))/2)

    id_num = 0
    id_map = {}

    c = np.zeros(num_var)

    for i in range(n):
        for j in range(i+1,n,1):
            id_map[(i,j)] = id_num
            c[id_num] = weight_matrix[i][j]
            id_num += 1

    A = np.zeros((num_res,num_var))
    cond_num = 0
    for i in range(n):
        for j in range(i+1,n,1):
            for k in range(j+1,n,1):
                A[cond_num,id_map[(i,j)]] = 2 
                A[cond_num,id_map[(j,k)]] = 2 
                A[cond_num,id_map[(i,k)]] = -2 
                cond_num += 1

                A[cond_num,id_map[(i,j)]] = 2 
                A[cond_num,id_map[(j,k)]] = -2 
                A[cond_num,id_map[(i,k)]] = 2 
                cond_num += 1

                A[cond_num,id_map[(i,j)]] = -2 
                A[cond_num,id_map[(j,k)]] = 2 
                A[cond_num,id_map[(i,k)]] = 2 
                cond_num += 1

    constraint = opt.LinearConstraint(A,ub=3)
    bounds = opt.Bounds(lb=-0.5,ub=1.5)
    ans = opt.milp(c=c,integrality=1,bounds=bounds,constraints=constraint)
    #print(ans["message"])
    return ans, id_map


In [21]:
def eval_clusters(folder_path):
    coda_tp = 0
    coda_fp = 0
    coda_fn = 0

    whale_tp = 0
    whale_fp = 0
    whale_fn = 0

    pre_split_th = 0.5
    num_perfect = 0
    num_total = 0
    sub_case_sizes = []
    final_whale_times = {}
    file_whales = {}

    for file_name in os.listdir(folder_path):
        #print(file_name)

        final_whale_times[file_name] = []
        final_output = []

        f = open(folder_path+file_name)
        data = json.load(f)
        
        all_times = [float(i) for i in list(data.keys())]
        all_predicted_times  = [i for i in all_times if data[str(i)]["click_probability"] > 0.5]

        n = len(all_predicted_times)
        coda_prob_matrix = np.zeros((n,n))
        coda_label_matrix = np.zeros((n,n))

        # Calculate same coda probability and label matrix
        for i in range(len(all_predicted_times)):
            i_time = all_predicted_times[i]
            for j in range(len(all_predicted_times)):
                if i == j:
                    coda_prob_matrix[i][j] = 1
                    coda_label_matrix[i][j] = 1
                else:
                    j_time = all_predicted_times[j]
                    if j_time in data[str(i_time)]["same_different_times"] and i_time in data[str(j_time)]["same_different_times"]:
                        j_pos = np.where(np.array(data[str(i_time)]["same_different_times"]) == j_time)[0][0]
                        i_pos = np.where(np.array(data[str(j_time)]["same_different_times"]) == i_time)[0][0]

                        coda_prob_matrix[i][j] = data[str(i_time)]["same_different_coda_probability"][j_pos]
                        coda_prob_matrix[j][i] = data[str(j_time)]["same_different_coda_probability"][i_pos]

                        coda_label_matrix[i][j] = data[str(i_time)]["same_different_coda_targets"][j_pos]
                        coda_label_matrix[j][i] = data[str(j_time)]["same_different_coda_targets"][i_pos]

        
        # Run coda clustering algorithm
        coda_prob_matrix = (coda_prob_matrix+coda_prob_matrix.transpose())/2
        coda_clusterings = eigen_clustering(coda_prob_matrix,coda_label_matrix)

        # Evaluate coda clustering
        time_coda = []
        for coda_id in range(len(coda_clusterings)):
            for click_id in coda_clusterings[coda_id]:
                if data[str(all_predicted_times[click_id])]["click_label"] > 0.5:
                    time_coda.append((all_predicted_times[click_id],coda_id))
        
        for i in range(len(time_coda)):
            for j in range(i+1,len(time_coda)):
                i_time, c1 = time_coda[i]
                j_time, c2 = time_coda[j]
                if j_time in data[str(i_time)]["same_different_times"] and i_time in data[str(j_time)]["same_different_times"]:
                        j_pos = np.where(np.array(data[str(i_time)]["same_different_times"]) == j_time)[0][0]
                        i_pos = np.where(np.array(data[str(j_time)]["same_different_times"]) == i_time)[0][0]

                        label = data[str(i_time)]["same_different_coda_targets"][j_pos]
                        pred = (c1 == c2)*1

                        coda_tp += label*pred
                        coda_fp += pred*(1-label)
                        coda_fn += (1-pred)*label



        # Calculate whale weigths
        solo_whale_prob_matrix = {}
        for i in range(len(all_predicted_times)):
            i_time = all_predicted_times[i]
            for j in range(len(all_predicted_times)):
                if i == j:
                    solo_whale_prob_matrix[(i,j)] = 1
                else:
                    j_time = all_predicted_times[j]
                    if j_time in data[str(i_time)]["same_different_times"] and i_time in data[str(j_time)]["same_different_times"]:
                        j_pos = np.where(np.array(data[str(i_time)]["same_different_times"]) == j_time)[0][0]
                        i_pos = np.where(np.array(data[str(j_time)]["same_different_times"]) == i_time)[0][0]

                        solo_whale_prob_matrix[(i,j)] = (data[str(i_time)]["same_different_whale_probability"][j_pos]+data[str(j_time)]["same_different_whale_probability"][i_pos])/2
                        solo_whale_prob_matrix[(j,i)] = (data[str(j_time)]["same_different_whale_probability"][i_pos]+data[str(i_time)]["same_different_whale_probability"][j_pos])/2

        m = len(coda_clusterings)
        whale_weights = np.zeros((m,m))
        for i in range(m):
            for j in range(i+1,m):
                whale_weights[i][j] = compare_clusters(coda_clusterings[i],coda_clusterings[j],solo_whale_prob_matrix)
                whale_weights[j][i] = whale_weights[i][j]
        
        # Calculate optimal whale clustering

        # Two codas will not be assigned same whale if there is not path of positive weights that connects them
        # Use this to seperate the problem into subcases before we even start 
        G = nx.from_numpy_matrix(whale_weights < 0)
        G = G.to_undirected()
        S = [G.subgraph(c).copy() for c in nx.connected_components(G)]
        all_predicted_times = np.array(all_predicted_times)
        to_divide = []
        whale_output = []

        for subgraph in S:
            selected = np.array(subgraph.nodes)
            sub_whale_matrix = whale_weights[selected,:][:,selected]
            to_divide.append((sub_whale_matrix,selected))
            
        # Obtain best whale clusterings with linear programming
        for sub_whale_matrix, original_ids in to_divide:
            n = sub_whale_matrix.shape[0]
            #print(n)

            if n > 1:
                ans, id_map = solve_problem(sub_whale_matrix)
                clusters = find_groups(n,ans["x"],id_map)
            else:
                clusters = {}
                clusters[0] = [0]

            for k in clusters.keys():
                mask = np.array([i in clusters[k] for i in range(n)])
                whale_output.append((original_ids[mask]))


        coda_whale = []
        current_whale_id = 0
        for coda_ids in whale_output:
            for coda_id in coda_ids:
                coda_whale.append(([all_predicted_times[click_id] for click_id in coda_clusterings[coda_id]],current_whale_id))
            current_whale_id += 1

        file_whales[file_name] = []
        for coda, whale in coda_whale:
            file_whales[file_name].append(coda)

        
        time_whale = []
        for coda, whale in coda_whale:
            time_whale += [(t,whale) for t in coda if data[str(t)]["click_label"] > 0.5]

        for i in range(len(time_whale)):
            for j in range(i+1,len(time_whale)):
                i_time, w1 = time_whale[i]
                j_time, w2 = time_whale[j]
                if j_time in data[str(i_time)]["same_different_times"] and i_time in data[str(j_time)]["same_different_times"]:
                        j_pos = np.where(np.array(data[str(i_time)]["same_different_times"]) == j_time)[0][0]
                        i_pos = np.where(np.array(data[str(j_time)]["same_different_times"]) == i_time)[0][0]

                        label = data[str(i_time)]["same_different_whale_targets"][j_pos]
                        pred = (w1 == w2)*1

                        whale_tp += label*pred
                        whale_fp += pred*(1-label)
                        whale_fn += (1-pred)*label

    print("coda f1",(2*coda_tp)/(2*coda_tp+coda_fn+coda_fp))
    print("coda tp",coda_tp)
    print("coda fp",coda_fp)
    print("coda fn",coda_fn)


    print("whale f1",(2*whale_tp)/(2*whale_tp+whale_fn+whale_fp))
    print("whale tp",whale_tp)
    print("whale fp",whale_fp)
    print("whale fn",whale_fn)


### Run evaluation metrics

In [22]:
eval_clusters("transformer_inference_json/val/")



coda f1 0.9143471508780313
coda tp 7654
coda fp 468
coda fn 966
whale f1 0.9116984985263744
whale tp 39286
whale fp 3724
whale fn 3886


In [23]:
eval_clusters("transformer_inference_json/test/")



coda f1 0.9251982423354715
coda tp 9159
coda fp 997
coda fn 484
whale f1 0.8430764155903061
whale tp 40893
whale fp 8073
whale fn 7150
