## Train CNN to detect Rana sierrae in audio recordings

This script uses the training and validation data generated in 02_prep_training_data.ipynb to train
a CNN using OpenSoundscape. 

Note that training is a stochastic process and will result in slightly different results each time
the script is run. The original model object trained and used in the manuscript is included in 
the subfolder `./resources/rana_seirrae_cnn.model`. 

### This takes a long time
Training a CNN deep learning model is computationaly expensive and slow. It is much faster when a GPU is available (OpenSoundscape will automatically use a GPU if it is available) but even so will take about an hour to run (estimated 20 hours for CPU only machine). 

You can proceed through the rest of the notebooks without re-training the model, instead using the original model object trained and used in the manuscript, which is included in 
the subfolder `./resources/rana_seirrae_cnn.model`. 


This notebook is part of a series of notebooks and scripts in the [repository](https://github.com/kitzeslab/rana-sierrae-cnn):

- `01_explore_annotated_data.ipynb` Explore annotated dataset of Rana sierrae call types

- `02_prep_training_data.ipynb` Prepare annotated files for training a CNN machine learning model

- `03_train_cnn.ipynb` Train a CNN to recognize Rana sierrae vocaliztaions

- `04_cnn_prediction.ipynb` Use the cnn to detect Rana sierrae in audio recordings

- `05_cnn_validation.ipynb` Analyze the accuracy and performance of the CNN

- `06_aggregate_scores.py` Aggregate scores from CNN prediction across dates and times of day

- `07_explore_results.ipynb` Analyze temporal patterns of vocal activity using the CNN detections


imports

In [1]:
import pandas as pd
from opensoundscape.torch.models.cnn import CNN
from opensoundscape.data_selection import resample
import wandb

Load training data

In [2]:
# Load the training and validation datasets prepared in the notebook 02_prep_training_data.ipynb
train_df = pd.read_csv('./resources/training_set.csv').set_index(['file','start_time','end_time'])
train_df['negative']=1-train_df['rana_sierrae']
val_df = pd.read_csv('./resources/validation_set.csv').set_index(['file','start_time','end_time'])
val_df['negative']=1-val_df['rana_sierrae']

# upsample to match the class with the most samples (reuse samples from other classes)
train_df = resample(train_df,upsample=True,n_samples_per_class=train_df.sum().max())

initialize wandb session

In [None]:
# initialize Weights and Biases logging session for tracking model training progress
# requires log in to wandb the first time
# can skip this non-critical step by passing wandb_session=None to train() and commenting
# out these lines
wandb_session = wandb.init(
    entity='kitzeslab', #replace this with your WandB "entity" ie group name
    project="rana_sierrae_notebooks",
    config=dict(
        comment="Description: training resnet18 on A & E classes and excluding unknown class X",
    )
)

create CNN with OpenSoundscape and customize preprocessing

In [4]:
# create opensoundscape.CNN object to train a CNN on audio
model = CNN(architecture='resnet18',classes=train_df.columns,sample_duration=2.0,single_target=True)

#modify preprocessing of the CNN:
#bandpass spectrograms to 300-2000 Hz
model.preprocessor.pipeline.bandpass.set(min_f=300,max_f=2000)
#modify augmentation routine parameters
model.preprocessor.pipeline.frequency_mask.set(max_masks=5,max_width=0.1)
model.preprocessor.pipeline.time_mask.set(max_masks=5,max_width=0.1)
model.preprocessor.pipeline.add_noise.set(std=0.01)

# decrease the learning rate from the default value
model.optimizer_params['lr']=0.002

Train cnn for 20 epochs on Training set, evaluating on Validation set

Trained models are saved to ./resources during the cell above, once every 5 epochs. The model performing best on the validation set is saved as `best.model`

In [None]:
# train CNN for 20 epochs with batch size 128
model.train(
    train_df,
    val_df,
    epochs=20,
    batch_size=128,
    num_workers=12,
    save_path=f'./resources/',
    save_interval=5,
    log_interval=10,
    validation_interval=1,
    wandb_session=wandb_session
)

#let wandb know this run finished successfully
wandb_session.finish()