# Generating predictions using the CNN

This assumes you've used the `train_form_cfg.ipynb` notebook to train a CNN and save weights in './run_1/latest.pt', and also to create the file `./validation_set.csv` listing file, start, and end time of clips, and their associated labels

Make a copy of './configs/default.yml' and edit parameters as desired. Point to your config file in second cell.

Run this notebook to load a trained CNN and use it to generate predictions on audio data. For documentation and tutorials, visit [opensoundscape.org](https://opensoundscape.org). 

Import packages:

In [27]:
import opensoundscape

In [25]:
from opensoundscape import CNN
from load_cfg import cnn_from_cfg
import yaml
import pandas as pd
import wandb

Load config file: 
change the file path to the location of your config file

In [58]:
config_file = "./configs/default.yml"

with open(config_file, "r") as f:
    cfg = yaml.safe_load(f)

In [66]:
cnn = cnn_from_cfg(config_file)
cnn.load_weights('./run_1/latest.pt')

We demonstrate prediction for two use cases:

1. The user has created a dataframe with the exact start and end times of each clip

2. The user passes a list of audio files, which are automatically split into appropriate length clips during .predict()

## Predict on a pre-defined table of clips

In [67]:
# load the validation dataset into a dataframe, with ('file','start_time','end_time') as the index
samples = pd.read_csv('./validation_set.csv',index_col=[0,1,2]).sample(200)

# generate predictions for each clip using settings from the config file
preds = cnn.predict(samples,**cfg['predict'])

The returned df looks very similar to the original, but has cnn output scores for each sample and class (column). 

In [68]:
preds.head(2)

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,A,B,C,D,E
file,start_time,end_time,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
/Users/SML161/labeled_datasets/rana_sierrae_2022/mp3/sine2022a_MSD-0558_20220622_034500_0-10s.mp3,4.0,6.0,-6.000661,-23.155195,-14.586831,-17.094234,-12.850677
/Users/SML161/labeled_datasets/rana_sierrae_2022/mp3/sine2022a_MSD-0558_20220622_103000_0-10s.mp3,3.0,5.0,-6.004264,-23.164827,-14.594211,-17.101292,-12.855599


Let's evaluate with a few metrics. 

We expect terrible performance since we didn't actually train the model for long

In [71]:
_, metrics = cnn.eval(samples.values,preds.values)
print(f"Validation set mean avg precision: {metrics['map']:0.2f}")
for c in preds.columns:
    print(f"Validation set {c} class avg precision: {metrics[c]['avg_precision']:0.2f}")

Metrics:
	MAP: 0.089
Validation set mean avg precision: 0.09
Validation set A class avg precision: 0.27
Validation set B class avg precision: 0.01
Validation set C class avg precision: 0.04
Validation set D class avg precision: 0.03
Validation set E class avg precision: 0.09


## Predict on list of audio files

If a list of audio files is passed to `predict`, each file is automatically preprocessed into the correct-length clips. Overlap between clips is defined in the config by `cfg['predict']['overlap_fraction']`.

In [41]:
files = list(preds.reset_index()['file'][0:2])
files

['/Users/SML161/labeled_datasets/rana_sierrae_2022/mp3/sine2022a_MSD-0558_20220621_043000_0-10s.mp3',
 '/Users/SML161/labeled_datasets/rana_sierrae_2022/mp3/sine2022a_MSD-0558_20220623_133000_0-10s.mp3']

In [64]:
preds2 = cnn.predict(files,**cfg['predict'])
preds2.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,A,B,C,D,E
file,start_time,end_time,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
/Users/SML161/labeled_datasets/rana_sierrae_2022/mp3/sine2022a_MSD-0558_20220621_043000_0-10s.mp3,0.0,3.0,-6.481035,-22.489037,-14.08933,-15.926941,-13.430865
/Users/SML161/labeled_datasets/rana_sierrae_2022/mp3/sine2022a_MSD-0558_20220621_043000_0-10s.mp3,3.0,6.0,-6.480886,-22.487982,-14.089136,-15.92626,-13.43055
/Users/SML161/labeled_datasets/rana_sierrae_2022/mp3/sine2022a_MSD-0558_20220621_043000_0-10s.mp3,6.0,9.0,-6.480846,-22.487713,-14.089093,-15.926157,-13.430593
/Users/SML161/labeled_datasets/rana_sierrae_2022/mp3/sine2022a_MSD-0558_20220623_133000_0-10s.mp3,0.0,3.0,-6.4815,-22.491108,-14.090951,-15.928777,-13.432294
/Users/SML161/labeled_datasets/rana_sierrae_2022/mp3/sine2022a_MSD-0558_20220623_133000_0-10s.mp3,3.0,6.0,-6.481314,-22.490112,-14.090863,-15.928102,-13.43205
