In [None]:
"""
Created on Mon Apr 29 16:15:46 2024

@author: Michaela ALksne

Script to train a resnet-18 CNN to classify A and B calls in 30 second spectrograms
sets model and spectrogram parameters and connects to wandB so user can monitor training progress

Model parameters: 
    - multi-target model: 3 labels per sample
    - classification with ResampleLoss function
    - weights pretrained on ImageNet
    - learning rate = 0.001
    - cooling factor = 0.3 (decreases learning rate by multiplying 0.001*3 every ten epochs)
    - epochs = 12 
    - batch_size = 12

Spectrogram parameters:
    - 30 second windows
    - 3200 Hz(samples/second) sampling rate 
    - 3200 point-FFT which results in 1 Hz bins
    - 90 % overlap (or 1400 samples), resulting in 100 ms bins
    - 1600 Hamming window samples. A Hamming window is used to smooth the signal and reduce spectral leakage/artifacts for the FFT. 
    - minimum frequency: 10 Hz
    - maximum frequency: 150 Hz
    
Spectrogram augmentations: 
    - frequency_mask: adds random horizontal bars over image
    - time_mask: adds random vertical bars over the image
    - add_noise: adds random Gaussian noise to image 
    
Notes for user:
batch_size – number of training files to load/process before re-calculating the loss function and backpropagation
num_workers – parallelization (ie, cores or cpus)
log_interval – interval in epochs to evaluate model with validation dataset and print metrics to the log

"""

In [None]:
import opensoundscape
import glob
import os
import pandas as pd
import numpy as np
import sklearn
import librosa
import torch
import wandb
import random

In [None]:
 # read in train and validation dataframes
train_clips = pd.read_csv('../../data/processed/train.csv', index_col=[0,1,2]) 
val_clips = pd.read_csv('../../data/processed/validation.csv', index_col=[0,1,2]) 
print(train_clips.sum()) 
print(val_clips.sum())

In [None]:
# modify relative filepaths 

