In [None]:
import os 
import sys
import json, yaml
from types import SimpleNamespace

import torch
import torchaudio
import torch.nn as nn
import numpy as np
import pandas as pd
import lightning as L
from tqdm import tqdm
from timm import create_model, list_models
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.tuner import Tuner

device = "cuda" if torch.cuda.is_available() else 'cpu'

sys.path.append('../src')

from custom.data import AudioDataset, DataModule
from custom.trainer import TrainModule
from custom.net import SpectrogramCNN
from custom.utils import batch_to_device

%load_ext autoreload
%autoreload 2

full_path = '../'

In [None]:
# Many user-editable parameters are defined in config_insecteffnet.yaml
with open("config_insecteffnet.yaml", "rt") as infp:
    cfg = SimpleNamespace(**yaml.safe_load(infp))

# calculate some derived parameters
cfg.data_path = f'{cfg.data_path_base}/{str(cfg.wav_crop_len).replace(".", "-")}s_crop/'
cfg.window_size = cfg.n_fft                    
cfg.hop_length = int(cfg.n_fft / 2)            
if cfg.minmax_norm:
    cfg.min, cfg.max = get_min_max(cfg, DataModule, SpectrogramCNN)

In [None]:
# Loss Function and class weights
class_weights = np.load(f'{full_path}/class_weights/class_weights_2.npy')
loss_fn = nn.CrossEntropyLoss(weight=torch.from_numpy(class_weights).to(device),
                              label_smoothing=cfg.label_smoothing)

# Data Logic, Loading, Augmentation
dm = DataModule(cfg=cfg)

# Network
model = SpectrogramCNN(cfg)

# Training Logic
tmod = TrainModule(model,
                   loss_fn=loss_fn,
                   optimizer_name='Adam',
                   optimizer_hparams={"lr": cfg.lr, "weight_decay": cfg.weight_decay},
                   cfg=cfg)

In [None]:
trainer = L.Trainer(
    max_epochs=20,
    accelerator="auto",
    devices="auto",
    enable_checkpointing=True,
    reload_dataloaders_every_n_epochs=False)

In [None]:
# Adjust min_lr and max_lr to define lr search space
# num_trainings defines the granularity
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model=tmod, datamodule=dm, min_lr=1e-7, max_lr=1e-2, num_training=100, attr_name="cfg")

In [None]:
# Results can be found in
print(lr_finder.results)

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

# Pick point based on plot, or get suggestion
new_lr = lr_finder.suggestion()