# Training a CNN 

Make a copy of './configs/default.yml' and edit parameters as desired. Point to your config file in second cell.

Run this notebook to train a Convolutional Neural Network using OpenSoundscape's CNN class. For documentation and tutorials, visit [opensoundscape.org](https://opensoundscape.org). 

Import packages:

In [50]:
import opensoundscape as opso 
from load_cfg import cnn_from_cfg
import yaml
import pandas as pd
import wandb
from pathlib import Path
from glob import glob

Load your config file: change the first line to the location of your config file.

In [37]:
config_file = "./configs/default.yml"

with open(config_file, "r") as f:
    cfg = yaml.safe_load(f)

Generate an opensoundscape.CNN object based on the config file parameters

In [15]:
cnn = cnn_from_cfg(config_file)

Load table of annotated clips

In [16]:
df = pd.read_csv('/Users/SML161/labeled_datasets/rana_sierrae_2022/labels_2s.csv')

# point to the location of the audio files
data_root = Path('/Users/SML161/labeled_datasets/rana_sierrae_2022/mp3')
df['file']=[data_root/f for f in df['file']]

# set indices according to OpenSoundscape's expected format
df=df.set_index(['file','start_time','end_time'])

#subset table to desired classes (columns)
df=df[cfg['class_list']]

Split labeled data into training and validation sets

In [22]:
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(df,test_size=0.2,random_state=0)
val_df.to_csv('./validation_set.csv')

Inspect class imbalance

In [49]:
train_df.sum()

A    2136
B      88
C     404
D     288
E     584
dtype: int64

Resample training set for even representation of each class, with 300 of each call type

In [51]:
balanced_train = opso.data_selection.resample(train_df,300,upsample=True,downsample=True)

Train the CNN according to the parameters in the config file. 

Saves model checkpoints and the config to the folder specified in `cfg["train"]["save_path"]` 

In [41]:
#Set up weights and biases logging, and a subfolder for saved checkpoints & config of this run

save_dir = Path(cfg["train"]["save_path"])
if cfg["wandb_init"]["project"] is not None:
    wandb.login()  # only needed once per machine (find API key at wandb.ai/authorize)
    wandb_session = wandb.init(**cfg["wandb_init"])
    run_name = wandb_session.name
else:
    wandb_session = None
    run_number = len(glob(f"{save_dir}/run_*"))+1
    run_name = f"run_{run_number}"

# define and create sub-folder for this run's checkpoints
run_dir = save_dir / run_name
run_dir.mkdir(exist_ok=False)

# update the save path to the newly created subfolder
cfg["train"]["save_path"] = str(run_dir)

# save the config file to the subfolder as well
with open(run_dir / "config.yml",'w') as f:
    yaml.safe_dump(cfg,f)

# train CNN using several parameters defined in the config
cnn.train(
    train_df=balanced_train,
    validation_df=val_df,
    wandb_session=wandb_session,
    **cfg["train"]
)
 
wandb.finish() # notifies wandb that the run completed successfully


Training Epoch 0
Epoch: 0 [batch 0/143, 0.00%] 
	DistLoss: 34.335
Metrics:
Epoch: 0 [batch 10/143, 6.99%] 
	DistLoss: 17.449


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 0 [batch 20/143, 13.99%] 
	DistLoss: 16.614
Metrics:
Epoch: 0 [batch 30/143, 20.98%] 
	DistLoss: 15.297


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 0 [batch 40/143, 27.97%] 
	DistLoss: 14.507


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 0 [batch 50/143, 34.97%] 
	DistLoss: 13.930


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 0 [batch 60/143, 41.96%] 
	DistLoss: 13.554


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 0 [batch 70/143, 48.95%] 
	DistLoss: 13.321


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 0 [batch 80/143, 55.94%] 
	DistLoss: 13.162


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 0 [batch 90/143, 62.94%] 
	DistLoss: 13.139


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 0 [batch 100/143, 69.93%] 
	DistLoss: 12.999
Metrics:
Epoch: 0 [batch 110/143, 76.92%] 
	DistLoss: 12.842
Metrics:
Epoch: 0 [batch 120/143, 83.92%] 
	DistLoss: 12.745
Metrics:
Epoch: 0 [batch 130/143, 90.91%] 
	DistLoss: 12.728


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 0 [batch 140/143, 97.90%] 
	DistLoss: 12.675


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:




Metrics:
	MAP: 0.075

Validation.




Metrics:
	MAP: 0.091

Training Epoch 1
Epoch: 1 [batch 0/143, 0.00%] 
	DistLoss: 11.522


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 10/143, 6.99%] 
	DistLoss: 10.957


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 20/143, 13.99%] 
	DistLoss: 11.299


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 30/143, 20.98%] 
	DistLoss: 11.505
Metrics:
Epoch: 1 [batch 40/143, 27.97%] 
	DistLoss: 11.507
Metrics:
Epoch: 1 [batch 50/143, 34.97%] 
	DistLoss: 11.441


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 60/143, 41.96%] 
	DistLoss: 11.584


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 70/143, 48.95%] 
	DistLoss: 11.709


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 80/143, 55.94%] 
	DistLoss: 11.658


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 90/143, 62.94%] 
	DistLoss: 11.746


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 100/143, 69.93%] 
	DistLoss: 11.740


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 110/143, 76.92%] 
	DistLoss: 11.738


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 120/143, 83.92%] 
	DistLoss: 11.758


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 130/143, 90.91%] 
	DistLoss: 11.772


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 1 [batch 140/143, 97.90%] 
	DistLoss: 11.754


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:




Metrics:
	MAP: 0.075

Validation.




Metrics:
	MAP: 0.092

Training Epoch 2
Epoch: 2 [batch 0/143, 0.00%] 
	DistLoss: 11.091


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 2 [batch 10/143, 6.99%] 
	DistLoss: 11.233
Metrics:
Epoch: 2 [batch 20/143, 13.99%] 
	DistLoss: 11.157


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 2 [batch 30/143, 20.98%] 
	DistLoss: 11.372


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 2 [batch 40/143, 27.97%] 
	DistLoss: 11.522
Metrics:
Epoch: 2 [batch 50/143, 34.97%] 
	DistLoss: 11.494


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 2 [batch 60/143, 41.96%] 
	DistLoss: 11.467


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 2 [batch 70/143, 48.95%] 
	DistLoss: 11.532
Metrics:
Epoch: 2 [batch 80/143, 55.94%] 
	DistLoss: 11.661
Metrics:
Epoch: 2 [batch 90/143, 62.94%] 
	DistLoss: 11.683


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 2 [batch 100/143, 69.93%] 
	DistLoss: 11.653
Metrics:
Epoch: 2 [batch 110/143, 76.92%] 
	DistLoss: 11.617


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 2 [batch 120/143, 83.92%] 
	DistLoss: 11.708


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 2 [batch 130/143, 90.91%] 
	DistLoss: 11.712
Metrics:
Epoch: 2 [batch 140/143, 97.90%] 
	DistLoss: 11.632
Metrics:




Metrics:
	MAP: 0.077

Validation.




Metrics:
	MAP: 0.085

Training Epoch 3
Epoch: 3 [batch 0/143, 0.00%] 
	DistLoss: 13.337


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 10/143, 6.99%] 
	DistLoss: 11.184


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 20/143, 13.99%] 
	DistLoss: 11.318


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 30/143, 20.98%] 
	DistLoss: 11.703


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 40/143, 27.97%] 
	DistLoss: 11.710


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 50/143, 34.97%] 
	DistLoss: 11.683


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 60/143, 41.96%] 
	DistLoss: 11.711
Metrics:
Epoch: 3 [batch 70/143, 48.95%] 
	DistLoss: 11.649
Metrics:
Epoch: 3 [batch 80/143, 55.94%] 
	DistLoss: 11.564


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 90/143, 62.94%] 
	DistLoss: 11.574


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 100/143, 69.93%] 
	DistLoss: 11.537
Metrics:
Epoch: 3 [batch 110/143, 76.92%] 
	DistLoss: 11.596


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 120/143, 83.92%] 
	DistLoss: 11.559


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 130/143, 90.91%] 
	DistLoss: 11.608


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 3 [batch 140/143, 97.90%] 
	DistLoss: 11.641
Metrics:




Metrics:
	MAP: 0.076

Validation.




Metrics:
	MAP: 0.088

Training Epoch 4
Epoch: 4 [batch 0/143, 0.00%] 
	DistLoss: 12.655
Metrics:
Epoch: 4 [batch 10/143, 6.99%] 
	DistLoss: 11.807
Metrics:
Epoch: 4 [batch 20/143, 13.99%] 
	DistLoss: 11.586


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 4 [batch 30/143, 20.98%] 
	DistLoss: 11.462
Metrics:
Epoch: 4 [batch 40/143, 27.97%] 
	DistLoss: 11.678
Metrics:
Epoch: 4 [batch 50/143, 34.97%] 
	DistLoss: 11.579


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 4 [batch 60/143, 41.96%] 
	DistLoss: 11.545
Metrics:
Epoch: 4 [batch 70/143, 48.95%] 
	DistLoss: 11.599
Metrics:
Epoch: 4 [batch 80/143, 55.94%] 
	DistLoss: 11.645


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 4 [batch 90/143, 62.94%] 
	DistLoss: 11.664


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 4 [batch 100/143, 69.93%] 
	DistLoss: 11.656


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 4 [batch 110/143, 76.92%] 
	DistLoss: 11.561


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 4 [batch 120/143, 83.92%] 
	DistLoss: 11.663
Metrics:
Epoch: 4 [batch 130/143, 90.91%] 
	DistLoss: 11.632
Metrics:
Epoch: 4 [batch 140/143, 97.90%] 
	DistLoss: 11.637


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:




Metrics:
	MAP: 0.077

Validation.




Metrics:
	MAP: 0.086

Training Epoch 5
Epoch: 5 [batch 0/143, 0.00%] 
	DistLoss: 10.340


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 5 [batch 10/143, 6.99%] 
	DistLoss: 11.020


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 5 [batch 20/143, 13.99%] 
	DistLoss: 11.231


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 5 [batch 30/143, 20.98%] 
	DistLoss: 11.419


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 5 [batch 40/143, 27.97%] 
	DistLoss: 11.494


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 5 [batch 50/143, 34.97%] 
	DistLoss: 11.565


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 5 [batch 60/143, 41.96%] 
	DistLoss: 11.487


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 5 [batch 70/143, 48.95%] 
	DistLoss: 11.626
Metrics:
Epoch: 5 [batch 80/143, 55.94%] 
	DistLoss: 11.690


  _warn_prf(average, modifier, msg_start, len(result))


Metrics:
Epoch: 5 [batch 90/143, 62.94%] 
	DistLoss: 11.639
Metrics:
Epoch: 5 [batch 100/143, 69.93%] 
	DistLoss: 11.633
Metrics:
Epoch: 5 [batch 110/143, 76.92%] 
	DistLoss: 11.569
Metrics:
Epoch: 5 [batch 120/143, 83.92%] 
	DistLoss: 11.585
Metrics:
Epoch: 5 [batch 130/143, 90.91%] 
	DistLoss: 11.570
Metrics:
Epoch: 5 [batch 140/143, 97.90%] 
	DistLoss: 11.610
Metrics:




Metrics:
	MAP: 0.078

Validation.


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/SML161/miniconda3/envs/opso_dev/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/SML161/miniconda3/envs/opso_dev/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
  File "/Users/SML161/miniconda3/envs/opso_dev/lib/python3.9/site-packages/torch/__init__.py", line 1239, in <module>
    from torch import onnx as onnx
  File "/Users/SML161/miniconda3/envs/opso_dev/lib/python3.9/site-packages/torch/onnx/__init__.py", line 12, in <module>
    from . import (  # usort:skip. Keep the order instead of sorting lexicographically
  File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
  File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
  File "<frozen importlib._bootstrap_externa

KeyboardInterrupt: 

Save a model checkpoint

In [46]:
cnn.save_weights(run_dir / 'latest.pt')