# Prepare notebook 

In [None]:
!pip uninstall -y orbax flax dopamine-rl
!pip install "numba<=0.56.0"
!pip install librosa==0.9.2 timit-utils==0.9.0 torchaudio

In [None]:
!pip install --upgrade tensorflow-addons==0.18.0
!pip install --upgrade tensorflow-probability==0.17.0
!pip install --upgrade tensorflow-io==0.26.0
!pip install --upgrade tensorflow==2.9.1
!pip install matplotlib==3.4
!pip install wandb==0.16

!pip install -q opencv-python-headless librosa wandb scikit-learn

In [None]:
!pip install audiomentations

## Import needed libraries

In [None]:
print("Hello")

In [None]:
import os
import pandas as pd
pd.options.mode.chained_assignment = None # avoids assignment warning
import numpy as np
import random
from glob import glob
from tqdm import tqdm 
tqdm.pandas()  # enable progress bars in pandas operations
import gc

import librosa
import sklearn
import json

# Import for visualization
import matplotlib as mpl
import matplotlib.pyplot as plt
import librosa.display as lid 
import IPython.display as ipd

# Import tensorflow
import tensorflow as tf
# Set logging level to avoid unnecessary messages
tf.get_logger().setLevel('ERROR') 
# Set autograph verbosity to avoid unnecessary messages
tf.autograph.set_verbosity(0) 
# Enable xla for speed up
tf.config.optimizer.set_jit(True)

# Import required tensorflow modules
import tensorflow_io as tfio
import tensorflow_addons as tfa 
import tensorflow_probability as tfp 
import tensorflow.keras.backend as K 

print("Packages loaded")

## Create CFG class to store all hyperparameters

In [None]:
class CFG:
    
    # Plot training history
    training_plot = True
    
    # Notebook link
    notebook_link = 'kaggle datasets download -d andradaolteanu/gtzan-dataset-music-genre-classification'
    
    # Verbosity level
    verbose = 2
    
    # Device and random seed
    device = 'GPU' 
    seed = 21 # 42
    
    # Spectrogram size and batch size
    img_size = [64, 1292]
    batch_size = 32
    
    # Drop remainder - dropping last batch if size < batch size
    drop_remainder = True
    
    # TRAINING SETTINGS
    # Number of epochs, and number of folds
    epochs = 50
    fsr = True # reduce stride of stem block
    num_fold = 5
    
    # Selected folds for training and evaluation
    selected_folds = [0]

    # Learning rate, optimizer, and scheduler
    lr = 1e-4 # 1e-3
    scheduler = 'cos'
    optimizer = 'Adam' # AdamW, Adam
    
    # Loss function and label smoothing
    loss = 'CCE' # BCE, CCE
    label_smoothing = 0.05 # label smoothing
    
    # AUDIO FILES CONFIGURATION
    duration = 30 # second
    sample_rate = 22050
    audio_len = duration*sample_rate
    
    # STFT parameters 
    # taken from: https://ieeexplore-1ieee-1org-10000470f00e3.wbg2.bg.agh.edu.pl/stamp/stamp.jsp?tp=&arnumber=9778067
    window = 1024
    hop_length = 512
    nfft = 4096
    n_mels=64
    fmin = 0
    fmax = 8000
    normalize = True
    
    # AUGMENTATION CONFIG
    augment=True
    
    
    
    # SPECTROGRAM AUGMENTATION
    spec_augment_prob = 1 # IMPORTANT: SHOULD BE 1
    
    # Mix-Up augmentation
    mixup_prob = 0.2 # 0.4
    mixup_alpha = 0.5
    
    # Cut-Mix augmentation
    cutmix_prob = 0 # 0.3
    cutmix_alpha = 1
    
    # Frequency and Time masking
    mask_prob = 0.4 # 0.4
    freq_mask = 30
    time_mask = 30


    
    # AUGIO AUGMENTATION
    audio_augment_prob = 1 # IMPORTANT: SHOULD BE 1
    
    # Time shift
    timeshift_prob = 0.5 # 0.5
    
    # Gaussian Noise
    gn_prob = 0 # 0.3 

    # Pitch Shift
    pitch_shift_prob = 1
    
    # Speed Change
    speed_change_prob = 0
    
    
    # Data Preprocessing Settings
    labelsMapping = {0: 'blues',
                 1: 'classical',
                 2: 'country',
                 3: 'disco',
                 4: 'hiphop',
                 5: 'jazz',
                 6: 'metal',
                 7: 'pop',
                 8: 'reggae',
                 9: 'rock'}

    class_names = list(labelsMapping.values())
    num_classes = len(class_names)
    class_labels = list(range(num_classes))
    label2name = dict(zip(class_labels, class_names))
    name2label = {v:k for k,v in label2name.items()}
    
    # Paths
    DESTINATION_PATH = 'working\\' # ON KAGGLE: "/kaggle/working/Music-Genres-Classification/"
    DATA_DIRECTORY = 'gtzan\\Data\\' # ON KAGGLE: '/kaggle/input/gtzan-1000/Data/'

    NETWORK_NAME = "ResNet"
    
# Set seeding

def seeding(SEED):
    np.random.seed(SEED)
    random.seed(SEED)
    os.environ['PYTHONHASHSEED'] = str(SEED)
    tf.random.set_seed(SEED)
    print('seeding done!!!')
    
seeding(CFG.seed)

# Load data

## Unzip

In [None]:
os.listdir(CFG.DATA_DIRECTORY)

## Set seeding

## Prepare filepath | genre dataframe

In [None]:
genres = sorted(os.listdir(CFG.DATA_DIRECTORY + "genres_original"))

df = pd.DataFrame(columns=["filepath", "genre"])

idx = 0
for genre in genres:
    for audio in sorted(os.listdir(CFG.DATA_DIRECTORY + "genres_original/" + genre)):
        new_row = pd.DataFrame({"filepath": CFG.DATA_DIRECTORY + "genres_original/" + genre + "/" + audio,
                                "genre": genre}, 
                               index=[idx])
        idx += 1
        df = pd.concat([df, new_row], ignore_index=True)

df = df[df.filepath != CFG.DATA_DIRECTORY + "genres_original/" + "jazz/jazz.00054.wav"]
df['target'] = df.genre.map(CFG.name2label)

## Check if the path exists

In [None]:
tf.io.gfile.exists(df.filepath.iloc[0])

# Login to WANDB to log trainings

In [None]:
import wandb

wandb.login(key="ed6c2fc334f7ae297c94626b3056901c86359321")

wandb_config={
    "architecture": CFG.NETWORK_NAME,
    "input_shape": (64, 1292, 3),
    "epochs": CFG.epochs,
    "batch_size": CFG.batch_size,
    "seed": CFG.seed,
    "use_small_sample": False, 
}

wandb.init(
    # set the wandb project where this run will be logged
    project="mgc-augmentation",

    # track hyperparameters and run metadata with wandb.config
    config=wandb_config
)
    
wandb.log({
    "Augment": CFG.augment,
    "Timeshift Prob": CFG.timeshift_prob,
    "Gaussian Noise Prob": CFG.gn_prob,
    "PitchShift Prob": CFG.pitch_shift_prob,
    "SpeedChange Prob": CFG.speed_change_prob,
    "Mixup Prob": CFG.mixup_prob,
    "MixUp Alpha": CFG.mixup_alpha,
    "CutMix Prob": CFG.cutmix_prob,
    "CutMix Alpha": CFG.cutmix_alpha,
    "Masking Prob": CFG.mask_prob,
    "Mask Width": CFG.freq_mask
})

# EDA

## Utils to load and display sample audio, waveform and spectrogram

In [None]:
import cv2

def load_audio(filepath):
    audio, sr = librosa.load(filepath)
    return audio, sr

def get_spectrogram(audio):
    spec = librosa.feature.melspectrogram(y=audio, 
                                          sr=CFG.sample_rate, 
                                          n_mels=CFG.n_mels,
                                          n_fft=CFG.nfft,
                                          hop_length=CFG.hop_length,
                                          # fmax=CFG.fmax,
                                          # fmin=CFG.fmin,
                                   )
    spec = librosa.power_to_db(spec, ref=np.max)
    return spec

def display_audio(row):
    # Caption for viz
    caption = f'Id: {row.filepath} | Genre: {row.genre} | Target: {row.target}'

    # Read audio file
    audio, sr = load_audio(row.filepath)
    # Keep fixed length audio
    audio = audio[:CFG.audio_len]
    # Spectrogram from audio
    spec = get_spectrogram(audio)
    # Display audio
    print("# AUDIO:")
    display(ipd.Audio(audio, rate=CFG.sample_rate))

    print(caption)
    print('# VISUALIZATION:')
    
    fig, ax = plt.subplots(2, 1, figsize=(12, 6), sharex=True, tight_layout=True)

    # Waveplot
    lid.waveshow(audio,
                 sr=CFG.sample_rate,
                 ax=ax[0])
    # Specplot
    lid.specshow(spec, 
                 sr = CFG.sample_rate,
                 hop_length = CFG.hop_length,
                 n_fft=CFG.nfft,
                #  fmin=CFG.fmin,
                #  fmax=CFG.fmax,
                 x_axis = 'time', 
                 y_axis = 'mel',
                 cmap = 'coolwarm',
                 ax=ax[1])
    ax[0].set_xlabel('');

    plt.show()

## Display samples

### Blues

In [None]:
class_name = CFG.class_names[0]
print(f'# Category: {class_name}')
class_df = df.query("genre==@class_name")
print(f'# Num Samples: {len(class_df)}')
row = class_df.sample(1).squeeze()

print(row['filepath'])
# Display audio
display_audio(row)

### Classical

In [None]:
class_name = "classical"
print(f'# Category: {class_name}')
class_df = df.query("genre==@class_name")
print(f'# Num Samples: {len(class_df)}')
row = class_df.sample(1).squeeze()

# Display audio
display_audio(row)

### Hiphop

In [None]:
class_name = CFG.class_names[4]
print(f'# Category: {class_name}')
class_df = df.query("genre==@class_name")
print(f'# Num Samples: {len(class_df)}')
row = class_df.sample(1).squeeze()

# Display audio
display_audio(row)

### Rock

In [None]:
class_name = "rock"
print(f'# Category: {class_name}')
class_df = df.query("genre==@class_name")
print(f'# Num Samples: {len(class_df)}')
row = class_df.sample(1).squeeze()

# Display audio
display_audio(row)

## Display 10 waves

In [None]:
def display_wave(row, ax):

    # Read audio file
    audio, sr = load_audio(row.filepath)
    # Keep fixed length audio
    audio = audio[:CFG.audio_len]

    # Waveplot
    lid.waveshow(audio,
                 sr=CFG.sample_rate,
                 ax=ax)

    ax.set_xlabel('');
    ax.set_title(row.genre)
    

fig, axs = plt.subplots(3, 1)
rows = df.sample(3)
i = 0
for _, row in rows.iterrows():
    
    display_wave(row, axs[i])
    
    i += 1

plt.suptitle("Samples from GTZAN dataset as waveplots", fontsize=13)
plt.subplots_adjust(hspace=0.5)
fig.text(0.035, 0.5, 'Amplitude', ha='center', va='center', rotation='vertical')
fig.text(0.51, 0.03, 'time [s]', ha='center', va='center')
plt.show()




# Split Data

# Train-Val-Test split

In [None]:
from sklearn.model_selection import train_test_split

df = df.reset_index(drop=True)

df["split"] = -1

Train, Test = train_test_split(df, test_size=0.3, stratify=df['genre'], random_state=CFG.seed)
Val, Test = train_test_split(Test, test_size=0.33, stratify=Test['genre'], random_state=CFG.seed)

Train["split"] = "train"
Val["split"] = "val"
Test["split"] = "test"

df = pd.concat([Train, Val, Test])

# Augment Data

## Generate random number

In [None]:
# Generates random integer
def random_int(shape=[], minval=0, maxval=1):
    return tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=tf.int32)


# Generats random float
def random_float(shape=[], minval=0.0, maxval=1.0):
    rnd = tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=tf.float32)
    return rnd

## Augment Audio

In [None]:
# Import required packages
import tensorflow as tf

from audiomentations import Compose, AddGaussianNoise, TimeStretch, PitchShift, Shift
import numpy as np


# Define a function to crop or pad audio data to a target length
@tf.function
def CropOrPad(audio, target_len, pad_mode='constant'):

    audio_len = audio.shape[0]

    if audio_len < target_len:
        diff_len = (target_len - audio_len)
        pad1 = random_int([], minval=0, maxval=diff_len)
        pad2 = diff_len - pad1
        pad_len = [pad1, pad2]
        audio = tf.pad(audio, paddings=[pad_len], mode=pad_mode)
    elif audio_len > target_len:
        diff_len = (audio_len - target_len)
        idx = tf.random.uniform([], 0, diff_len, dtype=tf.int32)
        audio = audio[idx: (idx + target_len)]
    audio = tf.reshape(audio, [target_len])
    return audio


augmentation_pipeline = Compose([
        AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=0.5),
        TimeStretch(min_rate=0.8, max_rate=1.25, p=0.5),
        PitchShift(min_semitones=-4, max_semitones=4, p=0.5),
        Shift(p=0.5),
    ])

def AudioAug(audio_np):  
    return augmentation_pipeline(samples=audio_np, sample_rate=CFG.sample_rate)


@tf.function
def Normalize(data, min_max=True):
    MEAN = tf.math.reduce_mean(data)
    STD = tf.math.reduce_std(data)
    data = tf.math.divide_no_nan(data - MEAN, STD)
    if min_max:
        MIN = tf.math.reduce_min(data)
        MAX = tf.math.reduce_max(data)
        data = tf.math.divide_no_nan(data - MIN, MAX - MIN)
    return data

## Augment Spectrogram

In [None]:
@tf.function
def Spec2Img(spec, num_channels=3):

    if num_channels > 1:
        img = tf.tile(spec[..., tf.newaxis], [1, 1, num_channels])
    else:
        img = spec[..., tf.newaxis]
    return img

# Convert img (H,W,3) to image (H,W)
@tf.function
def Img2Spec(img):
    # Extract 1st channel
    return img[..., 0]


# Randomly mask data in time and freq axis
@tf.function
def TimeFreqMask(spec, time_mask, freq_mask, prob=0.5):
    if random_float() < prob:
        spec = tfio.audio.freq_mask(spec, param=freq_mask)
        spec = tfio.audio.time_mask(spec, param=time_mask)
    return spec


# Applies augmentation to Spectrogram
def SpecAug(spec):
    spec = tf.transpose(Img2Spec(spec), perm=[1, 0])
    spec = TimeFreqMask(spec, time_mask=CFG.time_mask, freq_mask=CFG.freq_mask, prob=CFG.mask_prob)
    spec = tf.transpose(spec, perm=[1, 0])
    spec = Spec2Img(spec)
    return spec


def mixup_image_aug(images, labels, alpha=CFG.mixup_alpha):
    
    if random_float() > CFG.mixup_prob:
        return images, labels
    
    image_shape = tf.shape(images)
    label_shape = tf.shape(labels)

    beta = tfp.distributions.Beta(alpha, alpha) 
    lam = beta.sample(1)[0]

    images = lam * images + (1 - lam) * tf.roll(images, shift=1, axis=0)
    labels = lam * labels + (1 - lam) * tf.roll(labels, shift=1, axis=0)

    images = tf.reshape(images, image_shape)
    labels = tf.reshape(labels, label_shape)
    
    return images, labels

import tensorflow as tf



def create_cutmix_mask(bbx1, bby1, bbx2, bby2, height, width, channels, batch_size):
    # Create a grid of coordinates (height x width)
    x_coords = tf.range(height)
    y_coords = tf.range(width)
    Y, X = tf.meshgrid(y_coords, x_coords)

    # Reshape the bounding box coordinates to make them broadcastable over the batch size
    bbx1 = tf.reshape(bbx1, [batch_size, 1, 1])
    bby1 = tf.reshape(bby1, [batch_size, 1, 1])
    bbx2 = tf.reshape(bbx2, [batch_size, 1, 1])
    bby2 = tf.reshape(bby2, [batch_size, 1, 1])

    # Create the mask by comparing the coordinates
    mask = (X >= bbx1) & (X < bbx2) & (Y >= bby1) & (Y < bby2)
    mask = tf.cast(mask, tf.float32)  # Convert the mask to float32
    
    # Calculate ratio 
    patch_area = tf.reduce_sum(mask, axis=[1, 2])  # Sum over height, width, and channels dimensions
    total_area = height * width
    ratio = patch_area / tf.cast(total_area, tf.float32)
    ratio = tf.expand_dims(ratio, -1)
    ratio = tf.tile(ratio, [1, 10])
    # Add the channel dimension
    mask = tf.expand_dims(mask, -1)
    # Tile the mask across the channel dimension
    mask = tf.tile(mask, [1, 1, 1, channels])
    
    
    return mask, ratio


def cutmix(images, labels, probability=0.5, alpha=1.0):
    # Only apply CutMix with the given probability
    if random_float() > probability:
        return images, labels

    # Assume images is a 4D tensor of shape [batch_size, height, width, channels]
    shape = tf.shape(images)
    
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]
    channels = shape[3]
    
    # Sample lambda and calculate patch dimensions
    beta = tfp.distributions.Beta(alpha, alpha)
    lambda_val = beta.sample(1)

    cut_rat = tf.sqrt(1. - lambda_val)
    
    cut_w = tf.cast(width, tf.float32) * cut_rat
    cut_w = tf.cast(cut_w, tf.int32)  # Now, cut_w is an int32 tensor.
    
    cut_h = tf.cast(height, tf.float32) * cut_rat
    cut_h = tf.cast(cut_h, tf.int32)  # Now, cut_h is an int32 tensor.
    
    # Uniformly sample the center of the patch
    cx = tf.random.uniform([batch_size], minval=0, maxval=width, dtype=tf.int32)
    cy = tf.random.uniform([batch_size], minval=0, maxval=height, dtype=tf.int32)
    
    # Calculate the patch coordinates
    bbx1 = tf.clip_by_value(cx - cut_w // 2, 0, width)
    bby1 = tf.clip_by_value(cy - cut_h // 2, 0, height)
    bbx2 = tf.clip_by_value(cx + cut_w // 2, 0, width)
    bby2 = tf.clip_by_value(cy + cut_h // 2, 0, height)
    
#     # Create mask
    mask, ratio = create_cutmix_mask(bbx1, bby1, bbx2, bby2, height, width, channels, batch_size)
    indices = tf.random.shuffle(tf.range(batch_size))
    mixed_images = images * mask + tf.gather(images, indices) * (1 - mask)

    # Mix the labels
    mixed_labels = labels * ratio + tf.gather(labels, indices) * (1 - ratio)
#     mixed_labels = labels
    
    return mixed_images, mixed_labels

## Convert Audio to Spectrogram

In [None]:
# Compute Spectrogram from audio 
@tf.function
def Audio2Spec(audio, spec_shape=CFG.img_size, sr=CFG.sample_rate, nfft=CFG.nfft, window=CFG.window, fmin=0, fmax=8000):
    """
    Computes a Mel-scaled spectrogram from audio using TensorFlow and TensorFlow-IO.
    """
    # Get the desired height and width of the spectrogram
    spec_height = spec_shape[0]
    spec_width = spec_shape[1]
    
    # Get the length of the audio and calculate the hop length for the STFT
    audio_len = tf.shape(audio)[0]
    hop_length = CFG.hop_length # tf.cast((audio_len // (spec_width - 1)), tf.int32) # sample rate * duration / spec width - 1 == 627
    
    # Compute the spectrogram and the Mel-scaled spectrogram using TensorFlow-IO
    spec = tfio.audio.spectrogram(audio, nfft=nfft, window=window, stride=hop_length)
    mel_spec = tfio.audio.melscale(spec, rate=sr, mels=spec_height, fmin=fmin, fmax=fmax) 
    
    # Convert the Mel-scaled spectrogram to decibels and transpose it to keep it (mel, time)
    db_mel_spec = tfio.audio.dbscale(mel_spec, top_db=80)
    db_mel_spec = tf.transpose(db_mel_spec, perm=[1, 0])
    
    # If the spectrogram is larger than the desired shape, crop it
    if tf.shape(db_mel_spec)[1] > spec_width:
        db_mel_spec = db_mel_spec[:, :spec_width]
    
    # Reshape the spectrogram to the desired shape and return it
    db_mel_spec = tf.reshape(db_mel_spec, spec_shape)
    return db_mel_spec

## Audio and Spectrogram Decoders

In [None]:
def audio_decoder(path, label=None, with_labels=True, dim=CFG.audio_len, CFG=CFG):
    def get_audio(filepath):
        file_bytes = tf.io.read_file(filepath)
        audio = tfio.audio.decode_wav(file_bytes, dtype=tf.int16) # decode .ogg file for .wave replace `decode_wav`
        audio = tf.cast(audio, tf.float32)
        audio = tf.squeeze(audio, axis=-1)
        if CFG.normalize:
            audio = Normalize(audio)
        return audio
        
    def get_target(target):          
        target = tf.reshape(target, [1])
        target = tf.cast(tf.one_hot(target, len(CFG.class_labels)), tf.float32) 
        target = tf.reshape(target, [len(CFG.class_labels)])
        return target

    def decode(path):
        audio = get_audio(path)
        audio = CropOrPad(audio, dim) # crop or pad audio to keep a fixed length
        audio = tf.reshape(audio, [dim])
        return audio
    
    def decode_with_labels(path, label):
        label = get_target(label)
        return decode(path), label
    if type(path) == str:
        return decode_with_labels(path, label) if with_labels else decode(path)
    return decode_with_labels(path.numpy(), label.numpy()) if with_labels else decode(path.numpy())


def spec_decoder(with_labels=True, dim=CFG.img_size, CFG=CFG):
    def decode(audio):
        spec = Audio2Spec(audio, spec_shape=dim, sr=CFG.sample_rate, 
                          nfft=CFG.nfft, window=CFG.window, fmin=CFG.fmin,fmax=CFG.fmax)
        
        # Spectrogram (H, W) to Image (H, W, C)
        spec = Spec2Img(spec, num_channels=3) 
        spec = tf.reshape(spec, [*dim, 3])
        return spec
    
    def decode_with_labels(path, label):
        return decode(path), label
    
    return decode_with_labels if with_labels else decode

## Augmenters

Decide wheater to use augmentation for a given sample

In [None]:
def audio_augmenter(with_labels=True, dim=CFG.audio_len, CFG=CFG):
    def augment(audio, dim=dim):
        def augment_audio(audio_tensor):
            audio_np = audio_tensor.numpy()  # Convert to NumPy array
            if random.random() <= CFG.audio_augment_prob:  # Ensure you import random
                audio_np = AudioAug(audio_np)
            return np.array(audio_np, dtype=np.float32)  # Convert back to NumPy array
        
        augmented_audio = tf.py_function(augment_audio, [audio], tf.float32)
        augmented_audio.set_shape([dim])
        return augmented_audio
    
    def augment_with_labels(audio, label):
        return augment(audio), label
    
    return augment_with_labels if with_labels else augment


def spec_augmenter(with_labels=True, dim=CFG.img_size, CFG=CFG):
    def augment(spec, dim=dim): 
        if random_float() <= CFG.spec_augment_prob:
            spec = SpecAug(spec)
        spec = tf.reshape(spec, [*dim, 3])
        return spec
    
    def augment_with_labels(spec, label):    
        return augment(spec), label
    
    return augment_with_labels if with_labels else augment

# Data Pipeline 

## Specify data pipeline

In [None]:
def build_dataset(paths, labels=None, batch_size=32, target_size=CFG.img_size, 
                  audio_decode_fn=None, audio_augment_fn=None, 
                  spec_decode_fn=None, spec_augment_fn=None,
                  cache=True, cache_dir="",drop_remainder=False,
                  augment=True, repeat=True, shuffle=100):
    """
    Creates a TensorFlow dataset from the given paths and labels.
    
    Args:
        paths (list): A list of file paths to the audio files.
        labels (list): A list of corresponding labels for the audio files.
        batch_size (int): Batch size for the created dataset.
        target_size (list): A list of target image size for the spectrograms.
        audio_decode_fn (function): A function to decode the audio file.
        audio_augment_fn (function): A function to augment the audio file.
        spec_decode_fn (function): A function to decode the spectrogram.
        spec_augment_fn (function): A function to augment the spectrogram.
        cache (bool): Whether to cache the dataset or not.
        cache_dir (str): Directory path to cache the dataset.
        drop_remainder (bool): Whether to drop the last batch if it is smaller than batch_size.
        augment (bool): Whether to augment the dataset or not.
        repeat (bool): Whether to repeat the dataset or not.
        shuffle (int): Number of elements from the dataset to buffer for shuffling.
        
    Returns:
        ds (tf.data.Dataset): A TensorFlow dataset.
    """
    
    # Create cache directory if cache is enabled
    if cache_dir != "" and cache is True:
        os.makedirs(cache_dir, exist_ok=True)

    # Set default functions if not provided
    if audio_decode_fn is None:
        audio_decode_fn = audio_decoder

    if audio_augment_fn is None:
        audio_augment_fn = audio_augmenter(
            labels is not None, dim=CFG.audio_len, CFG=CFG)

    if spec_decode_fn is None:
        spec_decode_fn = spec_decoder(
            labels is not None, dim=CFG.img_size, CFG=CFG)

    if spec_augment_fn is None:
        spec_augment_fn = spec_augmenter(
            labels is not None, dim=CFG.img_size, CFG=CFG)

    AUTO = tf.data.experimental.AUTOTUNE


    slices = paths if labels is None else (paths, labels)
    
    ds = tf.data.Dataset.from_tensor_slices(slices)
    
    def audio_decode_wrapper(x, y=None):
        if y is None:
            audio = tf.py_function(audio_decode_fn, [x], [tf.float32])
            # audio.set_shape([64, 1292, 3])
            return audio
        else:
            audio, label = tf.py_function(audio_decode_fn, [x, y], [tf.float32, tf.float32])
            # audio.set_shape([64, 1292, 3])
            # label.set_shape([32, 10])  # Adjust this according to the actual label shape
            return audio, label
    
    if labels is None:
        ds = ds.map(lambda x: audio_decode_wrapper(x), num_parallel_calls=AUTO)
    else:
        ds = ds.map(lambda x, y: audio_decode_wrapper(x, y), num_parallel_calls=AUTO)


    ds = ds.cache(cache_dir) if cache else ds

    ds = ds.repeat() if repeat else ds

    opt = tf.data.Options()

    if shuffle: 
        ds = ds.shuffle(shuffle, seed=CFG.seed)
        opt.experimental_deterministic = False

    if CFG.device=='GPU':
        opt.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

    ds = ds.with_options(opt)

    ds = ds.map(audio_augment_fn, num_parallel_calls=AUTO) if augment else ds

    ds = ds.map(spec_decode_fn, num_parallel_calls=AUTO)

    ds = ds.map(spec_augment_fn, num_parallel_calls=AUTO) if augment else ds

    ds = ds.batch(batch_size, drop_remainder=drop_remainder)
    
    ds = ds.map(mixup_image_aug, num_parallel_calls=AUTO) if augment else ds # if (augment and labels is not None) else ds
    ds = ds.map(lambda images, labels: cutmix(images, labels, probability=CFG.cutmix_prob, alpha=CFG.cutmix_alpha), num_parallel_calls=AUTO) if augment else ds


    ds = ds.prefetch(AUTO)
    return ds

## Utils for Visualization of Augmentation

In [None]:
def plot_batch(batch, row=3, col=3, filename="batch_plot.png"):
    """Plot one batch data"""
    if isinstance(batch, tuple) or isinstance(batch, list):
        imgs, tars = batch
    else:
        imgs = batch
        tars = None
        
    plt.figure(figsize=(col*5, row*3))
    for idx in range(row*col):
        ax = plt.subplot(row, col, idx+1)
        lid.specshow(imgs[idx][...,0].numpy(), 
                     sr = CFG.sample_rate, 
                     hop_length = CFG.hop_length,
                     fmin=CFG.fmin,
                     fmax=CFG.fmax,
                     x_axis = 'time', 
                     y_axis = 'mel',
                     cmap = 'coolwarm')
        if tars is not None:
            label = tars[idx].numpy().argmax()
            name = CFG.label2name[label]
            plt.title(name)
    plt.tight_layout()
    plt.savefig(f'working\\{filename}', dpi=300, bbox_inches='tight') 
    plt.show()
    
    
def plot_history(history):
    """Plot trainign history, credit: @cdeotte"""
    epochs = len(history.history['auc'])
    plt.figure(figsize=(15,5))
    plt.plot(np.arange(epochs),history.history['auc'],'-o',label='Train AUC',color='#ff7f0e')
    plt.plot(np.arange(epochs),history.history['val_auc'],'-o',label='Val AUC',color='#1f77b4')
    x = np.argmax( history.history['val_auc'] ); y = np.max( history.history['val_auc'] )
    xdist = plt.xlim()[1] - plt.xlim()[0]; ydist = plt.ylim()[1] - plt.ylim()[0]
    plt.scatter(x,y,s=200,color='#1f77b4'); plt.text(x-0.03*xdist,y-0.13*ydist,'max auc\n%.2f'%y,size=14)
    plt.ylabel('AUC (PR)',size=14); plt.xlabel('Epoch',size=14)
    plt.legend(loc=2)
    plt2 = plt.gca().twinx()
    plt2.plot(np.arange(epochs),history.history['loss'],'-o',label='Train Loss',color='#2ca02c')
    plt2.plot(np.arange(epochs),history.history['val_loss'],'-o',label='Val Loss',color='#d62728')
    x = np.argmin( history.history['val_loss'] ); y = np.min( history.history['val_loss'] )
    ydist = plt.ylim()[1] - plt.ylim()[0]
    plt.scatter(x,y,s=200,color='#d62728'); plt.text(x-0.03*xdist,y+0.05*ydist,'min loss',size=14)
    plt.ylabel('Loss',size=14)
    plt.title('Fold %i - Training Plot'%(fold+1),size=18)
    plt.legend(loc=3)
    plt.show()  

### Visualize not augmented batch

In [None]:
ds = build_dataset(df.filepath.tolist(), df.target.tolist(), augment=False, cache=False, shuffle=None)
ds = ds.take(100)
imgs, labels = next(iter(ds))
plot_batch((imgs, labels), row=3, col=4, filename="no_augmentation.png")

### Visualize augmented batch

In [None]:
ds = build_dataset(df.filepath.tolist(), df.target.tolist(), augment=True, cache=False, shuffle=None)
ds = ds.take(100)
imgs, labels = next(iter(ds))
plot_batch((imgs, labels), row=3, col=4, filename="timeshift_augmentation.png")

# Modelling utils

## Define metrics, loss, optimizers

In [None]:
import sklearn.metrics

def get_metrics():
    auc = tf.keras.metrics.AUC(curve='PR', name='auc', multi_label=False) # auc on prcision-recall curve
    acc = tf.keras.metrics.CategoricalAccuracy(name='acc')
    return [acc, auc]

def padded_cmap(y_true, y_pred, padding_factor=5):
    num_classes = y_true.shape[1]
    pad_rows = np.array([[1]*num_classes]*padding_factor)
    y_true = np.concatenate([y_true, pad_rows])
    y_pred = np.concatenate([y_pred, pad_rows])
    score = sklearn.metrics.average_precision_score(y_true, y_pred, average='macro',)
    return score

def get_loss():
    if CFG.loss=="CCE":
        loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing=CFG.label_smoothing)
    else:
        raise ValueError("Loss not found")
    return loss
    
def get_optimizer():
    if CFG.optimizer == "Adam":
        opt = tf.keras.optimizers.Adam(learning_rate=CFG.lr)
    else:
        raise ValueError("Optmizer not found")
    return opt

## Train-Val-Test Fit function

## Save model and history

In [None]:
from datetime import datetime
import pickle
MODELS_PATH = CFG.DESTINATION_PATH + "Models"

def save_model_and_history(model, history):
    # Create directory with name as current date
    date = str(datetime.now())
    path = os.path.join(MODELS_PATH, date)

    # Create directory
    os.mkdir(path)

    # Get model and history paths
    model_path = os.path.join(path, "model.h5")
    history_path = os.path.join(path, "history.pkl")

    # Save model
    model.save(model_path)

    # Save history
    with open(history_path, "wb") as f:
        pickle.dump(history.history, f)

## Load model and history

In [None]:
def load_model_and_history(newness_number=1, model_directory=None):
    models = sorted(os.listdir(MODELS_PATH))

    if model_directory: 
        model_dir = os.path.join(MODELS_PATH, model_directory) 
    else:
        model_dir = os.path.join(MODELS_PATH, models[-newness_number])

    model = tf.keras.models.load_model(os.path.join(model_dir, "model.h5"))

    with open(os.path.join(model_dir, "history.pkl"), 'rb') as json_file:
        # Load the JSON data as a dictionary
        history = pickle.load(json_file)
  
    return model, history

## Plot learning curve

In [None]:
def display_accuracy_curve(hist):
    plt.plot(hist["acc"], label="accuracy")
    plt.plot(hist["val_acc"], label="val_accuracy")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.title("Learning curve")
    plt.legend()
    plt.show()

def display_loss_curve(hist):
    plt.plot(hist["loss"], label="loss")
    plt.plot(hist["val_loss"], label="val_loss")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Learning curve")
    plt.legend()
    plt.show()

def display_auc_curve(hist):
    plt.plot(hist["auc"], label="auc")
    plt.plot(hist["val_auc"], label="val_auc")
    plt.xlabel("Epochs")
    plt.ylabel("AUC")
    plt.title("Learning curve")
    plt.legend()
    plt.show()

## Plot confussion matrix

In [None]:
import itertools
from sklearn.metrics import confusion_matrix

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig("confusion_matrix.png")

    plt.show()

## Predict genres

In [None]:
def predict_genres(model, paths):
    fake_labels = np.zeros(paths.shape, dtype=int)
    ds = build_dataset(paths, fake_labels,
                    batch_size=1, cache=False, shuffle=False,
                    augment=False, repeat=False, drop_remainder=False)

    preds = model.predict(ds)
    pred_labels = np.argmax(preds, axis=1)

    return pred_labels

## Log test dataset results

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def log_test_metrics(test_labels, preds):
    accuracy = accuracy_score(test_labels, preds)
    precision = precision_score(test_labels, preds, average='weighted') 
    recall = recall_score(test_labels, preds, average='weighted') 
    f1 = f1_score(test_labels, preds, average='weighted')  


    wandb.log({
        "Test Accuracy": accuracy,
        "Test Precision": precision,
        "Test Recall": recall,
        "Test F1 Score": f1
    })

    print(f"Accuracy: {accuracy} Precision: {precision} Recall: {recall} F1: {f1}")

# EfficientNet

## Build model

In [None]:

from tensorflow.keras.applications.efficientnet import EfficientNetB0

def build_model(CFG, compile_model=True):
    """
    Builds and returns a model based on the specified configuration.
    """

    DIM = (None, None)

    # Base - EfficientNetB0              
    base = tf.keras.applications.EfficientNetB0(
      include_top=False,
      weights="imagenet",
      input_shape=(*DIM, 3),
    )

    # Input layer
    inp = tf.keras.layers.Input(shape=(*DIM, 3))

    # Input -> base
    out = base(inp)

    # GAP layer
    out = tf.keras.layers.GlobalAveragePooling2D()(out)

    # Final dense layer for classiciation
    out = tf.keras.layers.Dense(32, activation='relu')(out)
    out = tf.keras.layers.BatchNormalization()(out)
    out = tf.keras.layers.Dropout(0.5)(out)
    
    out = tf.keras.layers.Dense(len(CFG.class_names), activation='softmax')(out)

    # Create the TensorFlow model
    model = tf.keras.Model(inputs=inp, outputs=out)
    if compile_model:
        # Optimizer
        opt = get_optimizer()
        # Loss function
        loss = get_loss()
        # Evaluation metrics
        metrics = get_metrics()
        # Compile the model 
        model.compile(optimizer=opt,
                      loss=loss,
                      metrics=metrics)
    return model

In [None]:
model = build_model(CFG)
base_model_layer_names = [layer.name for layer in model.layers]

for layer_name in base_model_layer_names:
    print(layer_name)

out = model(imgs, training=False)
print(out.shape)

print(model.summary())

# ResNet

In [None]:

from tensorflow.keras.applications.efficientnet import EfficientNetB0



import tensorflow as tf
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model

def build_flattened_model_resnet(CFG, compile_model=True):

    DIM = (None, None)
    # Load the base ResNet50 model without the top classification layers
    base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(*DIM, 3))

    # Make sure all layers are set to trainable for fine-tuning
    for layer in base_model.layers:
        layer.trainable = True

    # Add custom layers on top of ResNet50
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(len(CFG.class_names), activation='softmax', name='output_layer')(x)  # Change the number of units and activation based on your task

    # Create the complete model
    model = Model(inputs=base_model.input, outputs=x)

    # Print the model summary
    
    if compile_model:
        # Optimizer
        opt = get_optimizer()
        # Loss function
        loss = get_loss()
        # Evaluation metrics
        metrics = get_metrics()
        # Compile the model 
        model.compile(optimizer=opt,
                      loss=loss,
                      metrics=metrics)
    
    return model

In [None]:
m = build_flattened_model_resnet(CFG)
m.summary()

## Check output shape

## Fit model in KFolds

In [None]:
from wandb.keras import WandbCallback


def train_val_test_fit(model, df, CFG, model_name="Model"):
    # Split dataset with cv filter
    train_df = df.query("split == 'train'").reset_index(drop=True) 
    valid_df = df.query("split=='val'").reset_index(drop=True) 

    # Get file paths and labels
    train_paths = train_df.filepath.values; train_labels = train_df.target.values
    valid_paths = valid_df.filepath.values; valid_labels = valid_df.target.values


    # Shuffle the file paths and labels
    index = np.arange(len(train_paths))
    np.random.shuffle(index)
    train_paths  = train_paths[index]
    train_labels = train_labels[index]

    # Compute the number of training and validation samples
    num_train = len(train_paths); num_valid = len(valid_paths)
        
    # # Build the training and validation datasets
    cache=True
    train_ds = build_dataset(train_paths, train_labels, 
                              batch_size=CFG.batch_size, cache=cache, shuffle=True,
                            augment=CFG.augment, drop_remainder=CFG.drop_remainder)
    valid_ds = build_dataset(valid_paths, valid_labels,
                              batch_size=CFG.batch_size, cache=cache, shuffle=False,
                              augment=False, repeat=False, drop_remainder=CFG.drop_remainder)


    # # Print information about the training
    print('#'*25); print('#### FOLD')
    print('#### Image Size: (%i, %i) | Model: %s | Batch Size: %i | Scheduler: %s'%
          (*CFG.img_size, model_name, CFG.batch_size, CFG.scheduler))
    print('#### Num Train: {:,} | Num Valid: {:,}'.format(len(train_paths), len(valid_paths)))

    # # Callbacks
    sv = tf.keras.callbacks.ModelCheckpoint(
        'training_save.keras', monitor='val_acc', verbose=0, save_best_only=True,
        save_weights_only=False, mode='max', save_freq='epoch')
    callbacks = [sv] # OPTIONALLY: WandbCallback(generator=valid_ds)

    # # Training
    print('# Training')
    history = model.fit(
        train_ds, 
        epochs=CFG.epochs, 
        callbacks=callbacks, 
        steps_per_epoch=len(train_paths)//CFG.batch_size,
        validation_data=valid_ds, 
        # verbose=CFG.verbose,
    )
    
    model = tf.keras.models.load_model('working\\training_save.keras')
    
    return model, history

## Fit model in TrainValTest

In [None]:
 # # Clear the session, build and train the model
K.clear_session()
# model = build_model(CFG)
model = build_flattened_model_resnet(CFG)
model, history = train_val_test_fit(model, df, CFG, model_name="ResNet")

## Save last model and history

In [None]:
model = tf.keras.models.load_model('/kaggle/working/training_save.h5')

## Load last model and learning history

In [None]:
# model, hist = load_model_and_history(model_directory="2023-05-24 14:12:31.673830")

## Plot learning curves

### Accuracy

In [None]:
# display_accuracy_curve(hist)

### Loss

In [None]:
# display_loss_curve(hist)

### AUC

In [None]:
# display_auc_curve(hist)

## Predict on test set

### Load test dataset

In [None]:
test_df = df.query("split=='test'").reset_index(drop=True)
test_paths = test_df.filepath.values; test_labels = test_df.target.values

test_ds = build_dataset(test_paths, test_labels,
                          batch_size=1, cache=False, shuffle=False,
                          augment=False, repeat=False, drop_remainder=False)

### Predict

In [None]:
pred_labels = predict_genres(model, test_paths)

log_test_metrics(test_labels, pred_labels)

#### Accuracy score

#### Confussion matrix

In [None]:
cm = confusion_matrix(test_labels, pred_labels)
classes = list(CFG.labelsMapping.values())

plot_confusion_matrix(cm, classes)

wandb.log({"Top 10 Confusion Matrix": wandb.Image("confusion_matrix.png")})

In [None]:
wandb.finish()

# ResNet

## Build model

In [None]:
# import efficientnet.tfkeras as efn
from tensorflow.keras.applications.resnet50 import ResNet50

def build_model_resnet(CFG, compile_model=True):
    """
    Builds and returns a model based on the specified configuration.
    """
  
    DIM = (None, None)

    # Base model - Resnet50
    base = tf.keras.applications.ResNet50(
      include_top=False,
      weights="imagenet",
      input_shape=(*DIM, 3),
    )

    # Input layer 
    inp = tf.keras.layers.Input(shape=(*DIM, 3))

    # Input -> base 
    out = base(inp)

    # GAP Layer
    out = tf.keras.layers.GlobalAveragePooling2D()(out)

    # Final dense layer for classification
    out = tf.keras.layers.Dense(len(CFG.class_names), activation='softmax')(out)

    # Create the TensorFlow model 
    model = tf.keras.Model(inputs=inp, outputs=out)
    if compile_model:
        # Optimizer
        opt = get_optimizer()
        # Loss function
        loss = get_loss()
        # Evaluation metrics
        metrics = get_metrics()
        # Compile the model 
        model.compile(optimizer=opt,
                      loss=loss,
                      metrics=metrics)
    
    return model

In [None]:
def build_model_resnet_with_added_conv(CFG, compile_model=True):
    """
    Builds and returns a model based on the specified configuration.
    """
    
    DIM = (None, None)

    # Base model - Resnet50
    base = tf.keras.applications.ResNet50(
      include_top=False,
      weights="imagenet",
      input_shape=(*DIM, 3),
    )
    
    # Input layer
    inp = tf.keras.layers.Input(shape=(*DIM, 3))

    # Input -> base
    out = base(inp)

    # Additional Conv2D layer
    out = tf.keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu", padding="same")(out)

    # GAP Layer
    out = tf.keras.layers.GlobalAveragePooling2D()(out)

    # Final dense layer for classification
    out = tf.keras.layers.Dense(len(CFG.class_names), activation='softmax')(out)

    # Create the TensorFlow model 
    model = tf.keras.Model(inputs=inp, outputs=out)
    if compile_model:
        # Optimizer
        opt = get_optimizer()
        # Loss function
        loss = get_loss()
        # Evaluation metrics
        metrics = get_metrics()
        # Compile the model 
        model.compile(optimizer=opt,
                      loss=loss,
                      metrics=metrics)
    
    return model

In [None]:
m = build_model_resnet_with_added_conv(CFG)

In [None]:
m.summary()

## Fit model

In [None]:
K.clear_session()
model = build_model_resnet(CFG)
model, history = train_val_test_fit(model, df, CFG, model_name="ResNet")

## Save model

In [None]:
save_model_and_history(model, history)

## Load model

In [None]:
model_resnet, hist_resnet = load_model_and_history(model_directory="2023-05-24 16:00:34.599873")

In [None]:
model_resnet.summary()

In [None]:
for elem in model_resnet.get_layer('resnet50').layers:
    print(elem.name)

In [None]:
model_resnet.get_layer('resnet50').get_layer('conv5_block3_out').output

## Plot learning curves

### Accuracy

In [None]:
display_accuracy_curve(hist_resnet)

### Loss

In [None]:
display_loss_curve(hist_resnet)

### AUC

In [None]:
display_auc_curve(hist_resnet)

## Predict on test set

### Load test dataset

In [None]:
val_df = df.query("split=='val'").reset_index(drop=True)
val_paths = val_df.filepath.values; val_labels = val_df.target.values

val_ds = build_dataset(val_paths, val_labels,
                          batch_size=1, cache=False, shuffle=False,
                          augment=False, repeat=False, drop_remainder=False)

In [None]:
pred_val_labels = predict_genres(model, val_paths)

In [None]:
log_test_metrics(pred_val_labels, val_labels)

In [None]:
test_df = df.query("split=='test'").reset_index(drop=True)
test_paths = test_df.filepath.values; test_labels = test_df.target.values

test_ds = build_dataset(test_paths, test_labels,
                          batch_size=1, cache=False, shuffle=False,
                          augment=False, repeat=False, drop_remainder=False)

### Predict

In [None]:
pred_labels = predict_genres(model, test_paths)

log_test_metrics(test_labels, pred_labels)

#### Accuracy score

In [None]:
from sklearn.metrics import accuracy_score

accuracy_score(test_labels, pred_labels)

#### Confussion matrix

In [None]:
cm = confusion_matrix(test_labels, pred_labels)
classes = list(CFG.labelsMapping.values())

plot_confusion_matrix(cm, classes)

wandb.log({"Top 10 Confusion Matrix": wandb.Image("confusion_matrix.png")})

# Saliency maps



## Get last layer names

In [None]:
last_conv_layer_name_resnet = "conv2d"

## Grad-CAM

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from tensorflow.keras.utils import img_to_array, array_to_img

def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    # First, we create a model that maps the input image to the activations
    # of the last conv layer as well as the output predictions
    img_array = np.expand_dims(img_array, axis=0)

    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )

    # Then, we compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    # This is the gradient of the output neuron (top predicted or chosen)
    # with regard to the output feature map of the last conv layer
    grads = tape.gradient(class_channel, last_conv_layer_output)

    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the top predicted class
    # then sum all the channels to obtain the heatmap class activation
    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    # For visualization purpose, we will also normalize the heatmap between 0 & 1
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

def save_and_display_gradcam(img_array, heatmap, cam_path="cam.jpg", alpha=0.4):
    # Load the original image
    img = img_array
    # img = keras.preprocessing.image.load_img(img_path)
    # img = keras.preprocessing.image.img_to_array(img)

    #img = cv2.resize(img, [150, 150])
    # Rescale heatmap to a range 0-255
    heatmap = np.uint8(255 * heatmap)

    # Use jet colormap to colorize heatmap
    jet = cm.get_cmap("jet")

    # Use RGB values of the colormap
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]

    # Create an image with RGB colorized heatmap
    jet_heatmap = array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
    jet_heatmap = img_to_array(jet_heatmap)

    # Superimpose the heatmap on original image
    superimposed_img = jet_heatmap * alpha + img
    
    superimposed_img = array_to_img(superimposed_img)
    return superimposed_img, img

    # # Save the superimposed image
    # superimposed_img.save(cam_path)

    # # Display Grad CAM
    # return Image(cam_path)

### ResNet

In [None]:
model_resnet, hist_resnet = load_model_and_history(model_directory="2023-05-31 14:02:35.553126")

model_resnet.summary()

#### Util - Process audio, spectrogram

In [None]:
def preprocess_audio(path):
    file_bytes = tf.io.read_file(path)
    audio = tfio.audio.decode_wav(file_bytes, dtype=tf.int16) 
    audio = tf.cast(audio, tf.float32)
    audio = tf.squeeze(audio, axis=-1)
    if CFG.normalize:
      audio = Normalize(audio)

    audio = CropOrPad(audio, CFG.audio_len) 
    audio = tf.reshape(audio, [CFG.audio_len])

    return audio

def preprocess_spectrogram(audio):
    spec = Audio2Spec(audio, spec_shape=CFG.img_size, sr=CFG.sample_rate, 
                          nfft=CFG.nfft, window=CFG.window, fmin=CFG.fmin,fmax=CFG.fmax)

    spec = Spec2Img(spec, num_channels=3) 
    spec = tf.reshape(spec, [*CFG.img_size, 3])

    return spec

#### Util - Show spectrogram

In [None]:
def show_spectrogram(spectrogram_tensor):

    if isinstance(spectrogram_tensor,np.ndarray):
      lid.specshow(spectrogram_tensor[...,0], 
                      sr = CFG.sample_rate, 
                      hop_length = CFG.hop_length,
                      fmin=CFG.fmin,
                      fmax=CFG.fmax,
                      x_axis = 'time', 
                      y_axis = 'mel',
                      cmap = 'coolwarm')
    else:
      lid.specshow(spectrogram_tensor[...,0].numpy(), 
                      sr = CFG.sample_rate, 
                      hop_length = CFG.hop_length,
                      fmin=CFG.fmin,
                      fmax=CFG.fmax,
                      x_axis = 'time', 
                      y_axis = 'mel',
                      cmap = 'coolwarm')


#### Create function to compare spectrogram with CAM result

In [None]:
from PIL import Image

def display_cam_spectrogram(cam_image, image, main_title):
  
  # Create suptitle
  fig, axs = plt.subplots(1, 2, figsize=(15, 5))
  plt.suptitle(f"Music genre: {main_title}")

  # Stretch CAM Image
  stretched_image = cam_image.resize((500, 350))
  
  # Stretch Spectrogram, add Colormap
  image_rgb = image.convert("RGB")
  image_array = np.array(image_rgb)

  colormap = matplotlib.colormaps["coolwarm"]
  colored_image_array = colormap(image_array[:, :, 0] / 255.0)  # Normalize the red channel

  colored_image = Image.fromarray((colored_image_array[:, :, :3] * 255).astype(np.uint8))

  # Display or save the colored image
  axs[0].imshow(colored_image.resize((500, 350)).rotate(180).transpose(Image.FLIP_LEFT_RIGHT))
  axs[0].set_title("Input Spectrogram")

  axs[1].imshow(stretched_image.rotate(180).transpose(Image.FLIP_LEFT_RIGHT))
  axs[1].set_title("CAM Visualization")

  plt.show()

#### Choose random song from test set

In [None]:
random_test_data = df.query("split == 'train'").query("genre == 'rock'").sample(1).squeeze()

random_test_path = random_test_data["filepath"]
random_test_label = random_test_data["target"]

print(f"Song: {random_test_path}. Label: {random_test_label}")

#### Process audio, process and display spectrogram

In [None]:
audio = preprocess_audio(random_test_path)
spec = preprocess_spectrogram(audio)
show_spectrogram(spec)

#### Predict

In [None]:
prediction_resnet = model_resnet.predict(np.expand_dims(spec, axis=0), verbose=0)

print(f"Resnet -- {CFG.labelsMapping[np.argmax(prediction_resnet)]} -- probability={np.max(prediction_resnet)}")

#### Heatmap and Grad-CAM

In [None]:
heatmap_resnet = make_gradcam_heatmap(spec, model_resnet, last_conv_layer_name_resnet)
gradcam_resnet, image = save_and_display_gradcam(spec, heatmap_resnet)

In [None]:
OldRange = (np.max(image) - np.min(image))  
NewRange = 255 - 0
image = (((image - np.min(image)) * NewRange) / OldRange) + 0

display_cam_spectrogram(gradcam_resnet, 
                        Image.fromarray(image.numpy().astype(np.uint8)), 
                        CFG.labelsMapping[np.argmax(prediction_resnet)])

In [None]:
plt.imshow(gradcam_resnet)

#### Visualize

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(10,5))
axs[0].matshow(heatmap_resnet)
axs[0].set_title(f"ResNet: {np.max(prediction_resnet)}")

axs[1].imshow(gradcam_resnet)
axs[1].set_title(f"ResNet: {np.max(prediction_resnet)}")

plt.show()

### CAM

#### Functions used for CAM

In [None]:
import numpy as np
import cv2
from keras import backend as K
import matplotlib.cm as cm
import matplotlib
from tensorflow.keras.utils import img_to_array, array_to_img

def get_class_activation_map(model, img,  last_conv_layer_name):
    ''' 
    this function computes the class activation map
    
    Inputs:
        1) model (tensorflow model) : trained model
        2) img (numpy array of shape (224, 224, 3)) : input image
    '''
    
    # expand dimension - create batch of one image
    img = np.expand_dims(img, axis=0)

    # predict the top class and get it's label
    predictions = model.predict(img, verbose=0)
    label_index = np.argmax(predictions)
    
    # Get the input weights to the softmax of all classes and then for the winning class.
    class_weights = model.layers[-1].get_weights()[0] # shape (num_of_neurons_in_dense, num_classes)
    class_weights_winner = class_weights[:, label_index] # (num_of_neurons_in_dense, )

    # Get last Convolutional layer
    final_conv_layer = model.get_layer(last_conv_layer_name)

    # Get all filters from last Conv layer (1, filt_size, filt_size, num_of_neurons)
    get_output = K.function([model.layers[0].input],[final_conv_layer.output, model.layers[-1].output])
    [conv_outputs, predictions] = get_output([img])
    
    # Squeeze conv map to shape image to size (filt_size, filt_size, num_of_neurons)
    conv_outputs = np.squeeze(conv_outputs)

    # get class activation map for object class that is predicted to be in the image - multiply weights, maps and sum
    final_output = np.dot(cv2.resize(conv_outputs, dsize=(150, 150), interpolation=cv2.INTER_CUBIC), class_weights_winner).reshape(150,150) # dim: 224 x 224
    
    # return class activation map
    return final_output, label_index

def save_and_display_cam(img_array, heatmap, cam_path="cam.jpg", alpha=0.4):
    # Load the original image
    img = img_array
    # img = keras.preprocessing.image.load_img(img_path)
    # img = keras.preprocessing.image.img_to_array(img)

    #img = cv2.resize(img, [150, 150])
    # Rescale heatmap to a range 0-255
    OldRange = (np.max(heatmap) - np.min(heatmap))  
    NewRange = 255 - 0
    heatmap = (((heatmap - np.min(heatmap)) * NewRange) / OldRange) + 0
    heatmap = np.uint8(heatmap)

    # Use jet colormap to colorize heatmap
    jet = matplotlib.colormaps["jet"]

    # Use RGB values of the colormap
    jet_colors = jet(np.arange(256))[:, :3]
    jet_heatmap = jet_colors[heatmap]

    # Create an image with RGB colorized heatmap
    jet_heatmap = array_to_img(jet_heatmap)
    jet_heatmap = jet_heatmap.resize((img.shape[1], img.shape[0]))
    jet_heatmap = img_to_array(jet_heatmap)

    # Superimpose the heatmap on original image
    superimposed_img = jet_heatmap * alpha + img
    
    superimposed_img = array_to_img(superimposed_img)
    return superimposed_img, array_to_img(img)



#### Choose random song from test set

In [None]:
random_test_data = df.query("split == 'train'").query("genre == 'pop'").sample(1).squeeze()

random_test_path = random_test_data["filepath"]
random_test_label = random_test_data["target"]

print(f"Song: {random_test_path}. Label: {random_test_label}")

#### Process audio, process and display spectrogram

In [None]:
audio = preprocess_audio(random_test_path)
spec = preprocess_spectrogram(audio)
show_spectrogram(spec)

#### Predict

In [None]:
prediction_resnet = model_resnet.predict(np.expand_dims(spec, axis=0), verbose=0)

print(f"Resnet -- {CFG.labelsMapping[np.argmax(prediction_resnet)]} -- probability={np.max(prediction_resnet)}")

#### Calculate Heatmaps and CAM Visualizations

In [None]:
final_output, label_index = get_class_activation_map(model_resnet, spec,  last_conv_layer_name_resnet)
cam_image, image = save_and_display_cam(spec, final_output, cam_path="cam.jpg", alpha=0.4)

#### Visualize CAM and spectrogram

In [None]:
display_cam_spectrogram(cam_image,
                        image,
                        CFG.labelsMapping[np.argmax(prediction_resnet)])

In [None]:
import wandb

wandb.login(key="ed6c2fc334f7ae297c94626b3056901c86359321")

wandb.init(project='mgc-augmentation', 
           entity='kmotyka2000org', 
           id='5bhcwqps', 
           resume='must')

run = wandb.init(
    # set the wandb project where this run will be logged
    project="mgc-augmentation",
)

In [None]:
import tensorflow as tf

entity = "kmotyka2000org"
project = "mgc-augmentation"
artifact_name = "model-prime-wildflower-20"
epoch = "latest"

artifact = run.use_artifact(f'{entity}/{project}/{artifact_name}:{epoch}')

# Download the model
artifact_dir = artifact.download()

# Load the model
model = tf.keras.models.load_model(artifact_dir)

# model = tf.keras.models.load_model('/kaggle/input/no-aug-best/tensorflow2/no-aug-model/1/model-best.h5')

In [None]:
model = tf.keras.models.load_model('/kaggle/working/fold-0.h5')

## Close WANDB session

In [None]:
wandb.finish()