### Trains audio models, stores benchmarks

In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf

import random, glob, os
import numpy as np

from pydub import AudioSegment
from pydub import effects
from utils.refactored_common import *
# from utils.refactored_common import unision_shuffled_copies
from tqdm.notebook import tqdm
import pydub
import librosa
try :
    from keras.utils import Sequence #   sequence =  keras.utils.Sequence
except:
    from keras.utils.all_utils import Sequence


# import tensorflow_io as tfio

import soundfile as sf
import audioflux
from scipy import signal

import matplotlib.pyplot as plt

In [4]:
from generators import base_generator_audio as BASE
from  curricula import selection
from models import base_cnn, transformer_classifier

#### Run Params

Probably can move these to a cfg file, but ehhh

In [5]:
def return_checkpoints(target_path, log_path):
    checkpoint_filepath = target_path
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath,
        monitor='accuracy',
        mode='max',
        save_best_only=True)

    import datetime
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=f"{log_path}_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}", histogram_freq=1)
    os.makedirs(f"{log_path}", exist_ok=True)
    
    csv_callback = tf.keras.callbacks.CSVLogger(f"{log_path}_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}.csv", append=True)

    return [model_checkpoint_callback, csv_callback]

In [6]:
run_name = "audio_mnist__CNN__CNN__rho_selection"   #! Convention: "dataset__irred_model__target_model__curriculum"
irred_chkpt = f"results/{run_name}/checkpoints/mnist_spec_CNN_small.keras"
control_chkpt = f"results/{run_name}/checkpoints/mnist_spec_CNN_large.keras"
target_chkpt = f"results/{run_name}/checkpoints/mnist_spec_CNN_large_rho.keras"

irred_log = f"results/{run_name}/logs/mnist_spec_CNN_small"
control_log = f"results/{run_name}/logs/mnist_spec_CNN_large"
target_log = f"results/{run_name}/logs/mnist_spec_CNN_large_rho"

width = 25
height = 128
num_classes = 10
epochs = 10
cfg = "cfg.yaml"
base_dir = "data/mnist/"
minibatch_size = 0.5
batch_size = 32
ext = 'wav'
return_spec = True
return_fft = False



params = yaml_load(cfg)
run_params = {
    "run_name": run_name,
    "irred_chkpt": irred_chkpt,
    "control_chkpt": control_chkpt,
    "target_chkpt": target_chkpt,
    "irred_log": irred_log,
    "control_log": control_log,
    "target_log": target_log,
    "width": width,
    "height": height,
    "num_classes": num_classes,
    "epochs": epochs,
    "cfg": cfg,
    "base_dir": base_dir,
    "minibatch_size": minibatch_size,
    "batch_size": batch_size,
    "ext": ext,
    "return_spec": return_spec,
}



#### Loading Dataloaders

In [7]:
#! Basic dataloaders
train_gen = BASE.BaseClassificationGenerator(params, base_dir, batch_size, gentype='train', return_spec=return_spec, return_fft=return_fft, ext=ext)

holdout_gen = BASE.BaseClassificationGenerator(params, base_dir, batch_size, gentype='val', return_spec=return_spec, return_fft=return_fft, ext=ext)

test_gen = BASE.BaseClassificationGenerator(params, base_dir, batch_size, gentype='test', return_spec=return_spec, return_fft=return_fft, ext=ext)

base_class_params = train_gen.toJSON()

run_params["base_dataloader_params"] = base_class_params



#### Calculating Baseline

In [8]:
control_model = transformer_classifier.BaseTransformerClassifier(width, height, num_classes)

control_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

control_model.summary()

callbacks = return_checkpoints(control_chkpt, control_log)

In [9]:
control_model.fit(train_gen, validation_data=test_gen, epochs=10, callbacks=callbacks)

Epoch 1/10


  self._warn_if_super_not_called()


[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 192ms/step - accuracy: 0.2361 - loss: 2.0610 - val_accuracy: 0.1863 - val_loss: 2.3632
Epoch 2/10
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m91s[0m 192ms/step - accuracy: 0.5560 - loss: 1.0574 - val_accuracy: 0.2411 - val_loss: 2.8628
Epoch 3/10
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 139ms/step - accuracy: 0.7118 - loss: 0.7097 - val_accuracy: 0.2507 - val_loss: 2.5860
Epoch 4/10
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m66s[0m 139ms/step - accuracy: 0.7975 - loss: 0.5314 - val_accuracy: 0.3252 - val_loss: 2.5972
Epoch 5/10
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m81s[0m 137ms/step - accuracy: 0.8448 - loss: 0.4020 - val_accuracy: 0.3324 - val_loss: 2.6063
Epoch 6/10
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m87s[0m 184ms/step - accuracy: 0.8746 - loss: 0.3403 - val_accuracy: 0.2895 - val_loss: 2.9670
Epoch 7/10
[1m468/46

<keras.src.callbacks.history.History at 0x7f25f50f7670>

#### Irreducibe Model Training

In [10]:
irred_model = base_cnn.BaseCNN(width, height, num_classes) #! change this according to your base model. also can probably look at changing base_cnn params

irred_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

irred_model.summary()

callbacks = return_checkpoints(irred_chkpt, irred_log)



  super().__init__(


In [11]:
irred_model.fit(holdout_gen, validation_data=test_gen, epochs=epochs*2, callbacks=callbacks) #! lightweight

Epoch 1/20


[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 128ms/step - accuracy: 0.0972 - loss: 2.5628 - val_accuracy: 0.1003 - val_loss: 2.3196
Epoch 2/20
[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 130ms/step - accuracy: 0.1766 - loss: 2.2454 - val_accuracy: 0.1707 - val_loss: 2.2366
Epoch 3/20
[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 129ms/step - accuracy: 0.2080 - loss: 2.1183 - val_accuracy: 0.2064 - val_loss: 2.1489
Epoch 4/20
[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 130ms/step - accuracy: 0.2572 - loss: 1.9445 - val_accuracy: 0.2195 - val_loss: 2.2603
Epoch 5/20
[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 131ms/step - accuracy: 0.3283 - loss: 1.8041 - val_accuracy: 0.2373 - val_loss: 2.0888
Epoch 6/20
[1m187/187[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 127ms/step - accuracy: 0.3646 - loss: 1.7096 - val_accuracy: 0.2557 - val_loss: 2.0344
Epoch 7/20
[1m187/18

<keras.src.callbacks.history.History at 0x7f25800e20d0>

#### RHO-LOSS Training

In [12]:
target_model = transformer_classifier.BaseTransformerClassifier(width, height, num_classes)

target_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

target_model.summary()

callbacks = return_checkpoints(target_chkpt, target_log)

In [13]:
irred_model = tf.keras.models.load_model(irred_chkpt)

train_rho_gen = BASE.rho_generator_audio(params, base_dir, batch_size, gentype='train', return_spec=return_spec, return_fft=return_fft, ext=ext, selector=selection.irreducible_loss_selector, irred_model=irred_model, target_model=target_model, epoch_cutoff=0)

In [14]:
target_model.fit(train_rho_gen, validation_data=test_gen, epochs=epochs, callbacks=callbacks)

Epoch 1/10
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 154ms/step - accuracy: 0.2126 - loss: 2.1764((19, 25, 128), (19, 10))


[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m88s[0m 178ms/step - accuracy: 0.2128 - loss: 2.1757 - val_accuracy: 0.3166 - val_loss: 2.0141
Epoch 2/10
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 159ms/step - accuracy: 0.4999 - loss: 1.2286((19, 25, 128), (19, 10))


[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m149s[0m 194ms/step - accuracy: 0.4999 - loss: 1.2284 - val_accuracy: 0.2527 - val_loss: 2.3225
Epoch 3/10
[1m467/468[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 163ms/step - accuracy: 0.6706 - loss: 0.8340((19, 25, 128), (19, 10))


[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m89s[0m 188ms/step - accuracy: 0.6707 - loss: 0.8337 - val_accuracy: 0.2479 - val_loss: 2.7805
Epoch 4/10
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 153ms/step - 

<keras.src.callbacks.history.History at 0x7f2576044670>

#### Saving Params

In [15]:
import json
with open(f"results/{run_name}/runparams.json", 'w+') as f:
    json.dump(run_params, f, indent=4)