In [17]:
import pickle
with open("data.pickle", "rb") as data_file:
    data = pickle.load(data_file)
with open("multi_species.pickle", "rb") as multi_file:
    multi_data = pickle.load(multi_file)

In [3]:
import pandas as pd
species_df = pd.read_csv("species.csv", names=["species", "id_0"])
species_lookup={}
species_names=[]
# Make a species lookup based on 1-based index
for idx,row in species_df.iterrows():
    species_lookup[row.id_0+1] = row.species
    species_names.append(row.species)

In [4]:
import plotly.graph_objects as go

In [6]:
import numpy as np

fig = go.Figure()
for species_id in data:
    species_name = species_lookup[species_id]
    matrix = np.array(data[species_id])
    keep_t = matrix[:,0]
    precision = matrix[:,1]
    recall = matrix[:,2]
    doubles = matrix[:,3]
    fig.add_trace(go.Scatter(x=recall, y=precision, mode='lines', name=species_name))
    
fig.update_layout(title='Single Species Detector',
                  xaxis_title="Recall",
                  yaxis_title="Precision")
fig.show()

fig = go.Figure()
for species_id in multi_data:
    try:
        species_name = species_lookup[species_id]
        matrix = np.array(multi_data[species_id])
        keep_t = matrix[:,0]
        precision = matrix[:,1]
        recall = matrix[:,2]
        doubles = matrix[:,3]
        fig.add_trace(go.Scatter(x=recall, y=precision, mode='lines', name=species_name))
    except:
        pass
    
fig.update_layout(title='Multi-Species Detector',
                  xaxis_title="Recall",
                  yaxis_title="Precision")
fig.show()

In [44]:
# Generate a confusion matrix at a given recall
threshold=0.00
trim_background=True
confusion_matrices=multi_data['CONFUSION_MATRICES']
keep_ts=np.asarray(list(confusion_matrices.keys()))
closest_idx = (np.abs(keep_ts-threshold)).argmin()
num_species = len(species_lookup)
confusion_matrix = confusion_matrices[keep_ts[closest_idx]]
#Trim background for visualization purposes
if trim_background:
    confusion_matrix = confusion_matrix[1:,1:].copy()
    cols=confusion_matrix.shape[1]
    #Renormalize columns
    for idx in range(cols):
        col_sum=np.sum(confusion_matrix[:,idx])
        if col_sum > 0:
            confusion_matrix[:,idx] /= col_sum
    labels=species_names
else:
    labels=['Background']
    labels.extend(species_names)
    
# For visualization purposes, set the non-zero average to the midpoint
non_zero_elements=np.count_nonzero(confusion_matrix)
sum_of_matrix=np.sum(confusion_matrix)
average_non_zero=sum_of_matrix/non_zero_elements

fig = go.Figure(go.Heatmap(z=confusion_matrix,
                           x=species_names,
                           y=species_names,
                           colorscale=[[0,'rgb(255,255,255)'],
                                       [average_non_zero,'rgb(128,128,128)'],
                                       [1,'rgb(0,0,0)']]))
fig.update_layout(title=f"Confusion Matrix @ {keep_ts[closest_idx]}",
                  yaxis_title='Truth',
                  xaxis_title='Predicted',
                  yaxis_scaleanchor='x',
                  xaxis_gridcolor='rgba(0,0,0,0)',
                  yaxis_gridcolor='rgba(0,0,0,0)',
                  autosize=False,
                  width=1200,
                  height=1200)
fig.show()