In [None]:
import glob

from potoo.plot import *
from potoo.util import *
import sklearn

from cache import *
from constants import *
from datasets import *
from features import *
from load import *
from sp14.model import *
from util import *
from viz import *

figsize('inline_short');

In [None]:
load = Load()
recs = load.recs(
    # limit=30,  # XXX Faster dev
    datasets=[
        'peterson-field-guide',
        'recordings',
    ],
)
display(df_summary(recs), recs[:5])

# Fit search

In [None]:
# Add .feat (for eval functions below)
cache_control(refresh=True)  # XXX
projection = Projection.load('peterson-v0-26bae1c')
recs = projection.transform(recs
    .drop(columns=['feat'])  # XXX
)

In [None]:
recs_eval = (recs
    [lambda df: df.dataset == 'peterson-field-guide']
    # [:10]  # Faster dev
    .reset_index(drop=True)
)
train_n, test_n = (len(recs_eval)+1)//2, len(recs_eval)//2
recs_train, recs_test = (recs_eval
    .pipe(sklearn.utils.shuffle, random_state=0)
    .sample(train_n + test_n, random_state=0)
    .pipe(lambda df: (
        df[:train_n],
        df[train_n : train_n + test_n],
    ))
)
log('params', **{
    'recs_eval': len(recs_eval),
    'recs_train': len(recs_train),
    'recs_test': len(recs_test),
})

In [None]:
from sp14.model import *
search = Search(
    # n_neighbors=3,
    n_neighbors=5,
    # n_neighbors=10,
    projection=projection,
)
search.fit(recs_train)

# Eval search

In [None]:
(search.coverage_error(recs_test, by='species')
    .pipe(lambda df: df_transform_cat(df, lambda _: df.sort_values('coverage_error').species, 'species'))
    .pipe(ggplot, aes(x='species', y='coverage_error'))
    + geom_point()
    + coord_flip()
    + geom_hline(yintercept=len(set(search.fit_classes_)), color='grey')
    + theme_figsize('inline')
    + ggtitle('Coverage error')
)

In [None]:
%%time
with with_figsize('full'):
    search.plot_confusion_matrix(recs_test)

In [None]:
# More usage examples
# search.species(recs_test[:5])
# search.species_probs(recs_test[:5]).T[:search.knn_.n_neighbors + 1].T
# search.similar_recs(recs_test[:5], 10)