In [4]:
# -*- coding: utf-8 -*-
"""
Created on Mon Feb  8 17:39:59 2021

Collection of custom evaluation functions for embedding

@author: marathomas
"""

import numpy as np
import pandas as pd
from sklearn.neighbors import NearestNeighbors


def make_nn_stats_dict(calltypes, labels, indices):
    nn_stats_dict = {}
    
    for calltype in calltypes:
        call_indices = np.asarray(np.where(labels==calltype))[0]
        calltype_counts = np.zeros((call_indices.shape[0],len(calltypes)))

        for i,ind in enumerate(call_indices):
            nearest_neighbors = indices[ind]
            for neighbor in nearest_neighbors:
                neighbor_label = labels[neighbor]
                calltype_counts[i,np.where(np.asarray(calltypes)==neighbor_label)[0][0]] += 1 
        nn_stats_dict[calltype] = calltype_counts 
  
    return nn_stats_dict

def get_knn(k,embedding):

    # Find k nearest neighbors
    nbrs = NearestNeighbors(metric='euclidean',n_neighbors=k+1, algorithm='brute').fit(embedding)
    distances, indices = nbrs.kneighbors(embedding)

    # need to remove the first neighbor, because that is the datapoint itself
    indices = indices[:,1:]  
    distances = distances[:,1:]
    
    return indices, distances


def make_statstabs(nn_stats_dict, calltypes, labels,k):
    
    overall = np.zeros((len(calltypes)))
    
    for i,calltype in enumerate(calltypes):
        overall[i] = sum(labels==calltype)
    
    overall = (overall/np.sum(overall))*100

    stats_tab = np.zeros((len(calltypes),len(calltypes)))
    stats_tab_norm = np.zeros((len(calltypes),len(calltypes)))

    for i, calltype in enumerate(calltypes):
        stats = nn_stats_dict[calltype]
        stats_tab[i,:] = (np.mean(stats,axis=0)/k)*100
        stats_tab_norm[i,:] = ((np.mean(stats,axis=0)/k)*100)/overall

    stats_tab = pd.DataFrame(stats_tab)
    stats_tab_norm = pd.DataFrame(stats_tab_norm)

    stats_tab.loc[len(stats_tab)] = overall

    stats_tab.columns = calltypes
    stats_tab.index = calltypes+['overall']

    stats_tab_norm.columns = calltypes
    stats_tab_norm.index = calltypes

    x=stats_tab_norm.replace(0, 0.0001)
    stats_tab_norm = np.log2(x)

    return stats_tab, stats_tab_norm


class nn:
    def __init__(self, embedding, labels, k):
        
        self.embedding = embedding
        self.labels = labels
        self.k = k
        
        
        label_types = sorted(list(set(labels)))        
        indices, distances = get_knn(k,embedding)
        nn_stats_dict = make_nn_stats_dict(label_types, labels, indices)
        stats_tab, stats_tab_norm = make_statstabs(nn_stats_dict, label_types, labels, k)
        
        self.statstab = stats_tab
        self.statstabnorm = stats_tab_norm
        self.k = k
    
    def get_statstab(self):
        return self.statstab
    
    def get_statstabnorm(self):
        return self.statstabnorm
    
    def getS(self):    
        return np.mean(np.diagonal(self.statstab))
    
    def getSnorm(self):
        return np.mean(np.diagonal(self.statstabnorm))
    
    def get_ownclass_S(self):
        return np.diagonal(self.statstab)
    
    def get_ownclass_S(self):
        return np.diagonal(self.statstabnorm)
        
        