In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf

import random, glob, os
import numpy as np

from pydub import AudioSegment
from pydub import effects
from utils.refactored_common import *
# from utils.refactored_common import unision_shuffled_copies
from tqdm.notebook import tqdm
import pydub
import librosa
try :
    from keras.utils import Sequence #   sequence =  keras.utils.Sequence
except:
    from keras.utils.all_utils import Sequence


# import tensorflow_io as tfio

import soundfile as sf
import audioflux
from scipy import signal

import matplotlib.pyplot as plt

In [26]:
from generators import base_generator_audio as BASE
from  curricula import selection
from audio_models import base_cnn, transformer_classifier

In [53]:
class rho_generator_audio(BASE.BaseClassificationGenerator):
    def __init__(self, param, base_dir : str, batch_size : int = 16, shuffle : bool = True, gentype : str = 'train', return_spec : bool = False, return_fft : bool = False, ext : str = 'flac', 
                 selector = None, irred_model : tf.keras.Model = None, target_model_path : str = '', epoch_cutoff : int = 3, minibatch_size : float = 0.6, loss : tf.keras.losses.Loss = tf.keras.losses.categorical_crossentropy):
        super().__init__(param, base_dir, batch_size, shuffle, gentype, return_spec, return_fft, ext)
        self.selector = selector
        self.irred_model = irred_model
        self.target_model_path = target_model_path
        self.epoch_cutoff = epoch_cutoff
        self.minibatch_size = minibatch_size
        self.loss = loss

        self.select_func = None;
        self.target_model = None

        self.cache = None

        #start rho_selection after epoch_cutoff
        #! selector takes in pretrained model, target model, and returns a list of indices


    def on_epoch_end(self):
        self.epoch_cutoff = self.epoch_cutoff - 1 if self.epoch_cutoff > 0 else 0
        if self.selector is not None and self.epoch_cutoff <= 0:
            self.target_model = tf.keras.models.load_model(self.target_model_path)
            self.target_model.compile(optimizer = 'adam', loss = self.loss, metrics = ['accuracy'])

            self.select_func = self.selector(self.irred_model, self.target_model, self.minibatch_size)
        print(self.cache)

    def __getitem__(self, index):
        a, b = super().__getitem__(index)
        
        if self.selector is not None and self.epoch_cutoff <= 0:
            indices = self.select_func(a, b)
            a = a[indices]
            b = b[indices]
        self.cache = (a.shape, b.shape)
        return a, b

### RHO Training

In [54]:
run_name = "audio_mnist__CNN__Transformer__rho_selection"
irred_chkpt = "checkpoints/audio_mnist_losses_spectrogram.keras"
target_chkpt = "checkpoints/audio_mnist_transformer.keras"
width = 25
height = 128
num_classes = 10

param = yaml_load("cfg.yaml")


In [55]:
model = transformer_classifier.BaseTransformerClassifier(width, height, num_classes)
checkpoint_filepath = target_chkpt
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    monitor='accuracy',
    mode='max',
    save_best_only=True)

import datetime
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=f"logs/{target_chkpt.split('.')[0]}_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}", histogram_freq=1)

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

In [59]:
import json

def to_json(obj):
    return json.dumps(obj, default=lambda obj: obj.__dict__)

In [65]:
import jsonpickle
json_string = jsonpickle.encode(rho_generator.__dict__)

In [67]:
rho_generator.__dict__

{'ext': 'wav',
 'base_dir': 'data/mnist/',
 'batch_size': 32,
 'shuffle': True,
 'gentype': 'train',
 'classes': ['5_dgt',
  '6_dgt',
  '1_dgt',
  '2_dgt',
  '3_dgt',
  '7_dgt',
  '4_dgt',
  '0_dgt',
  '8_dgt',
  '9_dgt'],
 'class_dict': {'5_dgt': 0,
  '6_dgt': 1,
  '1_dgt': 2,
  '2_dgt': 3,
  '3_dgt': 4,
  '7_dgt': 5,
  '4_dgt': 6,
  '0_dgt': 7,
  '8_dgt': 8,
  '9_dgt': 9},
 'files': [('5_dgt', 'data/mnist//train/5_dgt/5_08_23.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_40_12.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_53_10.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_24_25.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_35_48.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_35_11.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_27_20.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_28_18.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_28_36.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_50_15.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_57_3.wav'),
  ('5_dgt', 'data/mnist//train/5_dgt/5_03_

In [63]:
json.dumps()

TypeError: Object of type type is not JSON serializable

In [56]:
rho_generator = rho_generator_audio(
    param, "data/mnist/", batch_size=32, return_spec=True, ext = 'wav', selector=selection.irreducible_loss_selector, irred_model = tf.keras.models.load_model(irred_chkpt), target_model_path = target_chkpt, epoch_cutoff = 3, minibatch_size = 0.6, loss = 'categorical_crossentropy'
)

In [57]:
a, b = rho_generator.__getitem__(0)

In [58]:
model.fit(rho_generator, epochs=10, callbacks=[model_checkpoint_callback])

Epoch 1/10


  self._warn_if_super_not_called()


[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 135ms/step - accuracy: 0.1920 - loss: 2.1427((32, 25, 128), (32, 10))
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m68s[0m 135ms/step - accuracy: 0.1922 - loss: 2.1422
Epoch 2/10
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 137ms/step - accuracy: 0.5498 - loss: 1.1648((32, 25, 128), (32, 10))
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m65s[0m 137ms/step - accuracy: 0.5499 - loss: 1.1646
Epoch 3/10
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 131ms/step - accuracy: 0.7214 - loss: 0.7063((32, 25, 128), (32, 10))
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 131ms/step - accuracy: 0.7215 - loss: 0.7062
Epoch 4/10
[1m467/468[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 175ms/step - accuracy: 0.7701 - loss: 0.6133((19, 25, 128), (19, 10))
[1m468/468[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 176ms/step - accurac

<keras.src.callbacks.history.History at 0x7fec44976130>