<a href="https://colab.research.google.com/github/ericodle/GenreDiscern/blob/main/LSTM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

0. Enable the GPU on your computer or Google Colab environment (Runtime>Change runtime type).

1. Download the GTZAN dataset. It's on Kaggle. you need to log in to download it. (If working in a Google Colab environmentasiest, the method is to save GTZAN to your Google Drive.)

2. Pre-process the GTZAN dataset. You will need to define the GTZAN path (music_path), output_path, and output_filename by yourself.

In [1]:
music_path = "/content/drive/MyDrive/gtzan" # Enter your path here.
output_path = "/content/drive/MyDrive/test_output" # Enter your path here.
output_filename = "gtzan_mfccs"

In [3]:
import sys
import json
import os
import math
import librosa

# Constants for audio processing.
SAMPLE_RATE = 22050  # Standard sample rate for GTZAN audio data.
SONG_LENGTH = 30  # Duration of each song clip in seconds.
SAMPLE_COUNT = SAMPLE_RATE * SONG_LENGTH  # Total number of samples per clip.

def mfcc_to_json(music_path, output_path, output_filename, mfcc_count=13, n_fft=2048, hop_length=512, seg_length=30):

    # Initialize the data dictionary to store extracted features and labels.
    extracted_data = {
        "mapping": [],  # List to map numeric labels to genre names.
        "labels": [],   # List to store numeric labels for each audio clip.
        "mfcc": []      # List to store extracted MFCCs.
    }

    # Calculate the number of samples per segment.
    seg_samples = seg_length * SAMPLE_RATE

    # Loop through each genre folder in the GTZAN dataset.
    for i, (folder_path, folder_name, file_name) in enumerate(os.walk(music_path)):
        if folder_path != music_path:
            # Extract genre label from folder path.
            genre_label = folder_path.split("/")[-1]
            extracted_data["mapping"].append(genre_label)
            print("\nProcessing: {}".format(genre_label))

            # Iterate over each audio file in the genre folder.
            for song_clip in file_name:
                file_path = os.path.join(folder_path, song_clip)
                try:
                    # Load the audio file.
                    audio_sig, sr = librosa.load(file_path, sr=SAMPLE_RATE)
                except Exception as e:
                    # Handle loading errors.
                    print(f"Error loading file {file_path}: {e}")
                    continue

                # Check if the song is longer than 30 seconds.
                if len(audio_sig) >= SAMPLE_RATE * seg_length:
                    # Calculate the index of the middle of the song.
                    middle_index = len(audio_sig) // 2

                    # Define start and end indices for the segment.
                    segment_start = max(0, middle_index - (seg_samples // 2))
                    segment_end = min(len(audio_sig), middle_index + (seg_samples // 2))

                    # Extract MFCCs for the segment.
                    try:
                        mfcc = librosa.feature.mfcc(y=audio_sig[segment_start:segment_end], sr=sr, n_mfcc=mfcc_count, n_fft=n_fft, hop_length=hop_length)
                        # Transpose the MFCC matrix.
                        mfcc = mfcc.T
                    except Exception as e:
                        # Handle MFCC extraction errors.
                        print(f"Error computing MFCCs for {file_path}: {e}")
                        continue

                    # Append MFCCs and label to the data dictionary.
                    extracted_data["mfcc"].append(mfcc.tolist())
                    extracted_data["labels"].append(i - 1)
                    print("{}, segment:{}".format(file_path, segment_start, segment_end))
                else:
                    print(f"{file_path} is shorter than 30 seconds. Skipping...")

    # Write the extracted data to a JSON file.
    output_filename = output_filename + ".json"
    output_file_path = os.path.join(output_path, output_filename)
    try:
        with open(output_file_path, "w") as fp:
            json.dump(extracted_data, fp, indent=4)
            print(f"Successfully wrote data to {output_file_path}")
    except Exception as e:
        print(f"Error writing data to {output_file_path}: {e}")

if __name__ == "__main__":
    # Call the function with the specified arguments
    mfcc_to_json(music_path, output_path, output_filename)




Processing: pop
/content/drive/MyDrive/gtzan/pop/pop.00002.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00001.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00005.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00006.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00004.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00000.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00003.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00017.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00013.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00015.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00018.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00011.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00012.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00008.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00007.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00014.wav, segment:2
/content/drive/MyDrive/gtzan/pop/pop.00016.wav, segment

  audio_sig, sr = librosa.load(file_path, sr=SAMPLE_RATE)
	Deprecated as of librosa version 0.10.0.
	It will be removed in librosa version 1.0.
  y, sr_native = __audioread_load(path, offset, duration, dtype)


Error loading file /content/drive/MyDrive/gtzan/jazz/jazz.00054.wav: 
/content/drive/MyDrive/gtzan/jazz/jazz.00056.wav, segment:1890
/content/drive/MyDrive/gtzan/jazz/jazz.00047.wav, segment:147
/content/drive/MyDrive/gtzan/jazz/jazz.00051.wav, segment:147
/content/drive/MyDrive/gtzan/jazz/jazz.00057.wav, segment:147
/content/drive/MyDrive/gtzan/jazz/jazz.00069.wav, segment:147
/content/drive/MyDrive/gtzan/jazz/jazz.00070.wav, segment:147
/content/drive/MyDrive/gtzan/jazz/jazz.00065.wav, segment:147
/content/drive/MyDrive/gtzan/jazz/jazz.00067.wav, segment:147
/content/drive/MyDrive/gtzan/jazz/jazz.00066.wav, segment:2990
/content/drive/MyDrive/gtzan/jazz/jazz.00059.wav, segment:5300
/content/drive/MyDrive/gtzan/jazz/jazz.00064.wav, segment:240
/content/drive/MyDrive/gtzan/jazz/jazz.00058.wav, segment:147
/content/drive/MyDrive/gtzan/jazz/jazz.00060.wav, segment:147
/content/drive/MyDrive/gtzan/jazz/jazz.00061.wav, segment:2660
/content/drive/MyDrive/gtzan/jazz/jazz.00062.wav, segment:

3. Define the model architectures.

In [4]:
import sys
import os
import json
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import _LRScheduler
from torch.nn import functional as F

################################################
#       　   Fully Connected    　 　   #
################################################

class FC_model(nn.Module):

  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(

### Fully-connected layer
      nn.Flatten(),
      nn.ReLU(),

      nn.Linear(16796, 256),
      nn.ReLU(),
      nn.Dropout(p=0.3),

      nn.Linear(256, 128),
      nn.ReLU(),
      nn.Dropout(p=0.3),

      nn.Linear(128, 10),
      nn.Softmax()
    )


  def forward(self, x):
    '''Forward pass'''
    return self.layers(x)

################################################
#        Convolutional Neural Network          #
################################################


class CNN_model(nn.Module):

  def __init__(self):
    super().__init__()
    self.layers = nn.Sequential(
### Convolutional layer

      nn.Conv2d(1,256,kernel_size=(3,3), padding=1),
      nn.ReLU(),
      nn.Conv2d(256,256,kernel_size=(3,3), padding=1),
      nn.ReLU(),
      nn.AvgPool2d(3, stride=2),
      nn.BatchNorm2d(256),
      nn.Conv2d(256,256,kernel_size=(3,3), padding=1),
      nn.ReLU(),
      nn.AvgPool2d(3, stride=2),
      nn.BatchNorm2d(256),
      nn.Conv2d(256,512,kernel_size=(4,4), padding=1),
      nn.ReLU(),
      nn.AvgPool2d(1, stride=2),
      nn.BatchNorm2d(512),

### Fully-connected layer
      nn.Flatten(),
      nn.ReLU(),

      nn.Linear(82432, 256),
      nn.ReLU(),
      nn.Dropout(p=0.2),

      nn.Linear(256, 128),
      nn.ReLU(),
      nn.Dropout(p=0.2),

      nn.Linear(128, 10),
      nn.Softmax()
    )


  def forward(self, x):
    '''Forward pass'''
    return self.layers(x)


################################################
#          Long Short-Term Memory     　       #
################################################

class LSTM_model(nn.Module):

    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, dropout_prob):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.dropout_prob = dropout_prob
        self.rnn = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True, bidirectional=False, dropout=dropout_prob)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.batch_size = None
        self.hidden = None

    def forward(self, x):
        h0, c0 = self.init_hidden(x)
        if next(self.parameters()).is_cuda:
            h0 = h0.cuda()
            c0 = c0.cuda()

        out, (hn, cn) = self.rnn(x, (h0, c0))

        out = self.fc(out[:, -1, :])
        return out

    def init_hidden(self, x):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim)
        c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim)
        return [t.cuda() for t in (h0, c0)]


################################################
#       　   Gated Recurrent Unit      　 　   #
################################################

class GRU_model(nn.Module):

    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, dropout_prob):

        super(GRU_model, self).__init__()

        # Defining the number of layers and the nodes in each layer
        self.layer_dim = layer_dim
        self.hidden_dim = hidden_dim

        # GRU layers
        self.gru = nn.GRU(
            input_dim, hidden_dim, layer_dim, batch_first=True, dropout=dropout_prob
        )

        # Fully connected layer
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):

        # Initializing hidden state for first input with zeros
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(x.device)

        # Forward propagation by passing in the input and hidden state into the model
        out, _ = self.gru(x, h0)

        # Reshaping the outputs in the shape of (batch_size, seq_length, hidden_size)
        # so that it can fit into the fully connected layer
        out = out[:, -1, :]

        # Convert the final state to our desired output shape (batch_size, output_dim)
        out = self.fc(out)

        return out

################################################
#       　        Transformer           　 　   #
################################################

class TransformerLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_heads, ff_dim, dropout):
        super(TransformerLayer, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(input_dim, num_heads)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(input_dim)
        self.ff_layer = nn.Sequential(
            nn.Linear(input_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, input_dim)
        )
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(input_dim)

    def forward(self, x):
        attn_output, _ = self.multihead_attn(x, x, x)
        attn_output = self.dropout1(attn_output)
        x = self.norm1(x + attn_output)
        ff_output = self.ff_layer(x)
        ff_output = self.dropout2(ff_output)
        x = self.norm2(x + ff_output)
        return x

################################################
#       　        Tr_FC           　 　   #
################################################

class Tr_FC(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, ff_dim, output_dim, dropout):
        super(Tr_FC, self).__init__()
        self.input_dim = input_dim
        self.transformer_layers = nn.ModuleList(
            [TransformerLayer(input_dim, hidden_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        self.output_layer = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        for layer in self.transformer_layers:
            x = layer(x)
        x = self.output_layer(x[:, -1, :])  # Taking the last token representation
        return F.log_softmax(x, dim=1)

################################################
#       　        Tr_CNN           　 　   #
################################################

class Tr_CNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, ff_dim, output_dim, dropout):
        super(Tr_CNN, self).__init__()
        self.input_dim = input_dim
        self.transformer_layers = nn.ModuleList(
            [TransformerLayer(input_dim, hidden_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        self.conv_layers = nn.Sequential(
            nn.Conv2d(1,256,kernel_size=(3,3), padding=1),
            nn.ReLU(),
            nn.Conv2d(256,256,kernel_size=(3,3), padding=1),
            nn.ReLU(),
            nn.AvgPool2d(3, stride=2),
            nn.BatchNorm2d(256),
            nn.Conv2d(256,256,kernel_size=(3,3), padding=1),
            nn.ReLU(),
            nn.AvgPool2d(3, stride=2),
            nn.BatchNorm2d(256),
            nn.Conv2d(256,512,kernel_size=(4,4), padding=1),
            nn.ReLU(),
            nn.AvgPool2d(1, stride=2),
            nn.BatchNorm2d(512)
        )
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.ReLU(),
            nn.Linear(82432, 256),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(128, output_dim),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        for layer in self.transformer_layers:
            x = layer(x)
        x = self.conv_layers(x.unsqueeze(1))  # Adding an extra dimension for the channel
        x = self.fc_layers(x)
        return x

################################################
#       　        Tr_LSTM           　 　   #
################################################

class Tr_LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, ff_dim, output_dim, dropout):
        super(Tr_LSTM, self).__init__()
        self.input_dim = input_dim
        self.transformer_layers = nn.ModuleList(
            [TransformerLayer(input_dim, hidden_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        # Initialize LSTM model
        self.lstm = LSTM_model(input_dim, hidden_dim, num_layers, output_dim, dropout)

    def forward(self, x):
        for layer in self.transformer_layers:
            x = layer(x)

        # Pass output of the last Transformer layer to LSTM
        lstm_out = self.lstm(x)

        return F.log_softmax(lstm_out, dim=1)

################################################
#       　        Tr_GRU           　 　   #
################################################

class Tr_GRU(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_heads, ff_dim, output_dim, dropout):
        super(Tr_GRU, self).__init__()
        self.input_dim = input_dim
        self.transformer_layers = nn.ModuleList(
            [TransformerLayer(input_dim, hidden_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        # Initialize LSTM model
        self.gru = GRU_model(input_dim, hidden_dim, num_layers, output_dim, dropout)

    def forward(self, x):
        for layer in self.transformer_layers:
            x = layer(x)

        # Pass output of the last Transformer layer to LSTM
        gru_out = self.gru(x)

        return F.log_softmax(gru_out, dim=1)



4. Define mfcc_path, model_type, output_directory, and initial_lr.

In [10]:
mfcc_path = "/content/drive/MyDrive/test_output/gtzan_mfccs.json"
model_type = "LSTM"
output_directory = "/content/drive/MyDrive/test_output"
initial_lr = 0.0001 #Choose any value you want to use.

5. Train and evaluate your model.

In [11]:
import sys

import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import joblib

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, roc_auc_score, roc_curve, auc

import torch
from torch import nn
import torch.optim as optim
from torch.optim.lr_scheduler import _LRScheduler
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms


# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Fetching the device that will be used throughout this notebook
device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda:0")
print("Using device", device)

def load_data(data_path):

    with open(data_path, "r") as fp:
        data = json.load(fp)

    # convert lists to numpy arrays
    X = np.array(data["mfcc"])
    y = np.array(data["labels"])

    print("Data succesfully loaded!")

    return X, y

def test_ann_model(model, test_dataloader, device='cpu'):
    model.eval()
    count = 0
    correct = 0
    true = []
    preds = []
    probs = []

    model = model.to(device)

    with torch.no_grad():
        for X_testbatch, y_testbatch in test_dataloader:
            X_testbatch = X_testbatch.unsqueeze(1).to(device)
            y_testbatch = y_testbatch.to(device)

            y_val = model(X_testbatch)
            y_probs = torch.softmax(y_val, dim=-1)
            predicted = torch.max(y_val, 1)[1]

            count += y_testbatch.size(dim=0)
            correct += (predicted == y_testbatch).sum()

            true.append(y_testbatch.cpu())
            preds.append(predicted.cpu().detach())
            probs.append(y_probs.cpu().detach())

    ground_truth = torch.cat(true)
    predicted_genres = torch.cat(preds)
    predicted_probs = torch.cat(probs)
    accuracy = correct / count

    return ground_truth, predicted_genres, predicted_probs, accuracy

def test_recurrent_model(model, test_dataloader, device='cpu'):
    model.eval()
    count = 0
    correct = 0
    true = []
    preds = []
    probs = []

    model = model.to(device)

    with torch.no_grad():
        for X_testbatch, y_testbatch in test_dataloader:
            X_testbatch = X_testbatch.to(device)
            y_testbatch = y_testbatch.to(device)

            h0 = torch.zeros(model.layer_dim, X_testbatch.size(0), model.hidden_dim).to(device)
            c0 = torch.zeros(model.layer_dim, X_testbatch.size(0), model.hidden_dim).to(device)

            y_val = model(X_testbatch)
            y_probs = torch.softmax(y_val, dim=-1)
            predicted = torch.max(y_val, 1)[1]

            count += y_testbatch.size(dim=0)
            correct += (predicted == y_testbatch).sum()

            true.append(y_testbatch.cpu())
            preds.append(predicted.cpu().detach())
            probs.append(y_probs.cpu().detach())

    ground_truth = torch.cat(true)
    predicted_genres = torch.cat(preds)
    predicted_probs = torch.cat(probs)
    accuracy = correct / count

    return ground_truth, predicted_genres, predicted_probs, accuracy

def test_transformer_model(model, test_dataloader, device='cpu'):
    model.eval()
    count = 0
    correct = 0
    true = []
    preds = []
    probs = []

    for X_testbatch, y_testbatch in test_dataloader:
        X_testbatch = X_testbatch.to(device)
        y_testbatch = y_testbatch.to(device)

        X_testbatch = X_testbatch.permute(0, 1, 2)

        model = model.to(device)

        y_val = model(X_testbatch)

        y_probs = torch.softmax(y_val, dim=-1)
        predicted = torch.max(y_val, 1)[1]

        count += y_testbatch.size(0)
        correct += (predicted == y_testbatch).sum().item()

        true.append(y_testbatch.detach().cpu())
        preds.append(predicted.detach().cpu())
        probs.append(y_probs.detach().cpu())

    ground_truth = torch.cat(true)
    predicted_genres = torch.cat(preds)
    predicted_probs = torch.cat(probs)
    accuracy = correct / count

    return ground_truth, predicted_genres, predicted_probs, accuracy

def calculate_roc_auc(y_true, y_probs):
    roc_auc_scores = []
    for class_idx in range(y_probs.shape[1]):
        roc_auc = roc_auc_score(y_true == class_idx, y_probs[:, class_idx])
        roc_auc_scores.append(roc_auc)
    return roc_auc_scores

def plot_roc_curve(y_true, y_probs, class_names, output_directory):
    auc_file = os.path.join(output_directory, 'auc.txt')
    with open(auc_file, 'w') as f:
        for class_idx in range(y_probs.shape[1]):
            fpr, tpr, _ = roc_curve(y_true == class_idx, y_probs[:, class_idx])
            roc_auc = auc(fpr, tpr)
            f.write(f'{class_names[class_idx]}: {roc_auc:.2f}\n')

    plt.figure(figsize=(8, 6))
    for class_idx in range(y_probs.shape[1]):
        fpr, tpr, _ = roc_curve(y_true == class_idx, y_probs[:, class_idx])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{class_names[class_idx]} (AUC = {roc_auc:.2f})')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves')
    plt.legend(loc='lower right')  # Adjust legend position
    output_file = os.path.join(output_directory, 'ROC.png')
    plt.savefig(output_file)
    plt.close()

def save_ann_confusion_matrix(ground_truth, predicted_genres, class_names, output_directory):
    # Compute confusion matrix
    arr = confusion_matrix(ground_truth.view(-1).detach().cpu().numpy(), predicted_genres.view(-1).detach().cpu().numpy())

    # Compute classification report
    report = classification_report(ground_truth.view(-1).detach().cpu().numpy(), predicted_genres.view(-1).detach().cpu().numpy(),
                                   target_names=class_names, output_dict=True)

    # Convert report to DataFrame
    df_report = pd.DataFrame(report).transpose()

    # Save confusion matrix to image
    df_cm = pd.DataFrame(arr, class_names, class_names)
    plt.figure(figsize=(10, 7))
    sns.heatmap(df_cm, annot=True, fmt="d", cmap='BuGn')
    plt.xlabel("Predictions")
    plt.ylabel("Ground Truths")
    plt.title('Confusion Matrix', fontsize=15)
    output_file = os.path.join(output_directory, 'confusion_matrix.png')
    plt.savefig(output_file)
    plt.close()

    # Save accuracy metrics to text file
    metrics_file = os.path.join(output_directory, 'confusion_metrics.txt')
    with open(metrics_file, 'w') as f:
        f.write("Classification Report:\n")
        f.write(df_report.to_string())

    print("Confusion matrix and accuracy metrics saved successfully.")


def train_val_split(X, y, val_ratio):
    train_ratio = 1 - val_ratio
    X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=val_ratio, shuffle=True)
    return X_train, X_val, y_train, y_val

def accuracy(out, labels):
    _, pred = torch.max(out, dim=1)
    return torch.sum(pred == labels).item()

class CyclicLR(_LRScheduler):
    def __init__(self, optimizer, schedule, last_epoch=-1):
        assert callable(schedule)
        self.schedule = schedule
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [self.schedule(self.last_epoch, lr) for lr in self.base_lrs]

def cosine(t_max, eta_min=0):
    def scheduler(epoch, base_lr):
        t = epoch % t_max
        return eta_min + (base_lr - eta_min) * (1 + np.cos(np.pi * t / t_max)) / 2
    return scheduler


def plot_learning_metrics(train_loss, val_loss, train_acc, val_acc, output_directory):
    epochs = range(1, len(train_loss) + 1)

    fig, ax1 = plt.subplots(figsize=(10, 5), dpi=600)

    color = 'tab:red'
    orange = 'tab:orange'
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss', color=color)
    ax1.plot(epochs, train_loss, label='Train Loss', color=color)
    ax1.plot(epochs, val_loss, label='Validation Loss', color=orange)
    ax1.tick_params(axis='y', labelcolor=color)

    ax2 = ax1.twinx()
    color = 'tab:blue'
    ax2.set_ylabel('Accuracy', color=color)
    ax2.plot(epochs, train_acc, label='Train Accuracy', color=color)
    ax2.plot(epochs, val_acc, label='Validation Accuracy', color='green')
    ax2.tick_params(axis='y', labelcolor=color)

    fig.tight_layout(rect=[0.05, 0.05, 0.9, 0.9])  # Adjusting layout to leave space for title and legend
    fig.legend(loc='upper left', bbox_to_anchor=(1,1))  # Moving legend outside the plot
    plt.title('Learning Metrics', pad=20)  # Adding padding to the title
    plt.savefig(os.path.join(output_directory, "learning_metrics.png"), bbox_inches='tight')  # Use bbox_inches='tight' to prevent cutting off
    plt.close()

def main(mfcc_path, model_type, output_directory, initial_lr):
    # load data
    X, y = load_data(mfcc_path)

    # Add diagnostic prints to check data dimensions
    print("Loaded data dimensions:")
    print("X shape:", X.shape)
    print("y shape:", y.shape)

    # create train/val split
    X_train, X_val, y_train, y_val = train_val_split(X, y, 0.2)

    tensor_X_train = torch.Tensor(X_train)
    tensor_X_val = torch.Tensor(X_val)
    tensor_y_train = torch.Tensor(y_train)
    tensor_y_val = torch.Tensor(y_val)

    tensor_X_test = torch.Tensor(X)
    tensor_y_test = torch.Tensor(y)

    train_dataset = TensorDataset(tensor_X_train, tensor_y_train)
    val_dataset = TensorDataset(tensor_X_val, tensor_y_val)

    test_dataset = TensorDataset(tensor_X_test, tensor_y_test)

    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True)

    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

    train_loss = []
    val_loss = []
    train_acc = []
    val_acc = []

    # Training hyperparameters
    lr = initial_lr
    n_epochs = 10000
    iterations_per_epoch = len(train_dataloader)
    best_acc = 0
    patience, trials = 20, 0

    # Initialize model based on model_type

    if model_type == 'FC':
        model = FC_model()
    elif model_type == 'CNN':
        model = CNN_model()
    elif model_type == 'LSTM':
        model = LSTM_model(input_dim=13, hidden_dim=256, layer_dim=2, output_dim=10, dropout_prob=0.2)
    elif model_type == 'GRU':
        model = GRU_model(input_dim=13, hidden_dim=256, layer_dim=2, output_dim=10, dropout_prob=0.2)
    elif model_type == "Tr_FC":
        model = Tr_FC(input_dim=13, hidden_dim=256, num_layers=4, num_heads=1, ff_dim=4, dropout=0.2, output_dim=10)
    elif model_type == "Tr_CNN":
        model = Tr_CNN(input_dim=13, hidden_dim=256, num_layers=4, num_heads=1, ff_dim=4, dropout=0.2, output_dim=10)
    elif model_type == "Tr_LSTM":
        model = Tr_LSTM(input_dim=13, hidden_dim=256, num_layers=4, num_heads=1, ff_dim=4, dropout=0.2, output_dim=10)
    elif model_type == "Tr_GRU":
        model = Tr_GRU(input_dim=13, hidden_dim=256, num_layers=4, num_heads=1, ff_dim=4, dropout=0.2, output_dim=10)
    else:
        raise ValueError("Invalid model_type")

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    opt = torch.optim.RMSprop(model.parameters(), lr=lr)
    sched = CyclicLR(opt, cosine(t_max=iterations_per_epoch * 2, eta_min=lr / 100))
    print(f'Training {model_type} model with learning rate of {initial_lr}.')

    if model_type == "FC":
        for epoch in range(1, n_epochs + 1):
            tcorrect, ttotal = 0, 0
            running_train_loss = 0
            for (x_batch, y_batch) in train_dataloader:
                model.train()
                x_batch = x_batch.unsqueeze(1)
                x_batch, y_batch = [t.cuda() for t in (x_batch, y_batch)]
                y_batch = y_batch.to(torch.int64)
                opt.zero_grad()
                out = model(x_batch)
                loss = criterion(out, y_batch)
                running_train_loss += loss.item()
                loss.backward()
                opt.step()
                sched.step()
                _,pred = torch.max(out, dim=1)
                ttotal += y_batch.size(0)
                tcorrect += torch.sum(pred==y_batch).item()
            train_acc.append(100 * tcorrect / ttotal)
            epoch_train_loss = running_train_loss / len(train_dataloader)
            train_loss.append(epoch_train_loss)
            model.eval()
            vcorrect, vtotal = 0, 0
            running_val_loss = 0
            for x_val, y_val in val_dataloader:
                x_val = x_val.unsqueeze(1)
                x_val, y_val = [t.cuda() for t in (x_val, y_val)]
                out = model(x_val)
                preds = F.log_softmax(out, dim=1).argmax(dim=1)
                vtotal += y_val.size(0)
                vcorrect += (preds == y_val).sum().item()
                running_val_loss += criterion(out, y_val.long()).item()
            vacc = vcorrect / vtotal
            val_acc.append(vacc*100)
            epoch_val_loss = running_val_loss / len(val_dataloader)
            val_loss.append(epoch_val_loss)
            if epoch % 5 == 0:
                print(f'Epoch: {epoch:3d}. Loss: {loss.item():.4f}. Val Acc.: {vacc:2.2%}')

            if vacc > best_acc:
                trials = 0
                best_acc = vacc
                torch.save(model, os.path.join(output_directory, "model"))
                print(f'Epoch {epoch} best model saved with val accuracy: {best_acc:2.2%}')
            else:
                trials += 1
                if trials >= patience:
                    print(f'Early stopping on epoch {epoch}')
                    break

    if model_type == "CNN":
        for epoch in range(1, n_epochs + 1):
            tcorrect, ttotal = 0, 0
            running_train_loss = 0

            for (x_batch, y_batch) in train_dataloader:
                model.train()
                x_batch = x_batch.unsqueeze(0)
                x_batch = x_batch.permute(1, 0, 2, 3)
                x_batch, y_batch = [t.cuda() for t in (x_batch, y_batch)]
                y_batch = y_batch.to(torch.int64)
                opt.zero_grad()
                out = model(x_batch)

                loss = criterion(out, y_batch)
                running_train_loss += loss.item()
                loss.backward()
                opt.step()
                sched.step()
                _,pred = torch.max(out, dim=1)
                ttotal += y_batch.size(0)
                tcorrect += torch.sum(pred==y_batch).item()
            train_acc.append(100 * tcorrect / ttotal)
            epoch_train_loss = running_train_loss / len(train_dataloader)
            train_loss.append(epoch_train_loss)
            model.eval()
            vcorrect, vtotal = 0, 0
            running_val_loss = 0
            for x_val, y_val in val_dataloader:
                x_val = x_val.unsqueeze(1)
                x_val, y_val = [t.cuda() for t in (x_val, y_val)]
                out = model(x_val)
                preds = F.log_softmax(out, dim=1).argmax(dim=1)
                vtotal += y_val.size(0)
                vcorrect += (preds == y_val).sum().item()
                running_val_loss += criterion(out, y_val.long()).item()
            vacc = vcorrect / vtotal
            val_acc.append(vacc*100)
            epoch_val_loss = running_val_loss / len(val_dataloader)
            val_loss.append(epoch_val_loss)
            if epoch % 5 == 0:
                print(f'Epoch: {epoch:3d}. Loss: {loss.item():.4f}. Val Acc.: {vacc:2.2%}')
            if vacc > best_acc:
                trials = 0
                best_acc = vacc
                torch.save(model, os.path.join(output_directory, "model"))
                print(f'Epoch {epoch} best model saved with val accuracy: {best_acc:2.2%}')
            else:
                trials += 1
                if trials >= patience:
                    print(f'Early stopping on epoch {epoch}')
                    break

    if model_type == "LSTM":
        for epoch in range(1, n_epochs + 1):
            tcorrect, ttotal = 0, 0
            running_train_loss = 0
            for (x_batch, y_batch) in train_dataloader:
                model.train()
                x_batch, y_batch = [t.cuda() for t in (x_batch, y_batch)]
                y_batch = y_batch.to(torch.int64)
                opt.zero_grad()
                out = model(x_batch)
                loss = criterion(out, y_batch)
                running_train_loss += loss.item()
                loss.backward()
                opt.step()
                sched.step()
                _,pred = torch.max(out, dim=1)
                ttotal += y_batch.size(0)
                tcorrect += torch.sum(pred==y_batch).item()
            train_acc.append(100 * tcorrect / ttotal)
            epoch_train_loss = running_train_loss / len(train_dataloader)
            train_loss.append(epoch_train_loss)
            model.eval()
            vcorrect, vtotal = 0, 0
            running_val_loss = 0
            for x_val, y_val in val_dataloader:
                x_val, y_val = [t.cuda() for t in (x_val, y_val)]
                out = model(x_val)
                preds = F.log_softmax(out, dim=1).argmax(dim=1)
                vtotal += y_val.size(0)
                vcorrect += (preds == y_val).sum().item()
                running_val_loss += criterion(out, y_val.long()).item()
            vacc = vcorrect / vtotal
            val_acc.append(vacc*100)
            epoch_val_loss = running_val_loss / len(val_dataloader)
            val_loss.append(epoch_val_loss)
            if epoch % 5 == 0:
                print(f'Epoch: {epoch:3d}. Loss: {loss.item():.4f}. Val Acc.: {vacc:2.2%}')
            if vacc > best_acc:
                trials = 0
                best_acc = vacc
                torch.save(model, os.path.join(output_directory, "model"))
                print(f'Epoch {epoch} best model saved with val accuracy: {best_acc:2.2%}')
            else:
                trials += 1
                if trials >= patience:
                    print(f'Early stopping on epoch {epoch}')
                    break

    if model_type == "GRU":
        for epoch in range(1, n_epochs + 1):
            tcorrect, ttotal = 0, 0
            running_train_loss = 0
            for (x_batch, y_batch) in train_dataloader:
                model.train()
                x_batch, y_batch = [t.cuda() for t in (x_batch, y_batch)]
                y_batch = y_batch.to(torch.int64)
                opt.zero_grad()
                out = model(x_batch)
                loss = criterion(out, y_batch)
                running_train_loss += loss.item()
                loss.backward()
                opt.step()
                sched.step()
                _,pred = torch.max(out, dim=1)
                ttotal += y_batch.size(0)
                tcorrect += torch.sum(pred==y_batch).item()
            train_acc.append(100 * tcorrect / ttotal)
            epoch_train_loss = running_train_loss / len(train_dataloader)
            train_loss.append(epoch_train_loss)
            model.eval()
            vcorrect, vtotal = 0, 0
            running_val_loss = 0
            for x_val, y_val in val_dataloader:
                x_val, y_val = [t.cuda() for t in (x_val, y_val)]
                out = model(x_val)
                preds = F.log_softmax(out, dim=1).argmax(dim=1)
                vtotal += y_val.size(0)
                vcorrect += (preds == y_val).sum().item()
                running_val_loss += criterion(out, y_val.long()).item()
            vacc = vcorrect / vtotal
            val_acc.append(vacc*100)
            epoch_val_loss = running_val_loss / len(val_dataloader)
            val_loss.append(epoch_val_loss)
            if epoch % 5 == 0:
                print(f'Epoch: {epoch:3d}. Loss: {loss.item():.4f}. Val Acc.: {vacc:2.2%}')
            if vacc > best_acc:
                trials = 0
                best_acc = vacc
                torch.save(model, os.path.join(output_directory, "model"))
                print(f'Epoch {epoch} best model saved with val accuracy: {best_acc:2.2%}')
            else:
                trials += 1
                if trials >= patience:
                    print(f'Early stopping on epoch {epoch}')
                    break

    if model_type == "Tr_FC":
        for epoch in range(1, n_epochs + 1):
            tcorrect, ttotal = 0, 0
            running_train_loss = 0
            for (x_batch, y_batch) in train_dataloader:
                model.train()
                x_batch, y_batch = [t.cuda() for t in (x_batch, y_batch)]
                y_batch = y_batch.to(torch.int64)
                opt.zero_grad()
                out = model(x_batch)
                loss = criterion(out, y_batch)
                running_train_loss += loss.item()
                loss.backward()
                opt.step()
                sched.step()
                _,pred = torch.max(out, dim=1)
                ttotal += y_batch.size(0)
                tcorrect += torch.sum(pred==y_batch).item()
            train_acc.append(100 * tcorrect / ttotal)
            epoch_train_loss = running_train_loss / len(train_dataloader)
            train_loss.append(epoch_train_loss)
            model.eval()
            vcorrect, vtotal = 0, 0
            running_val_loss = 0
            for x_val, y_val in val_dataloader:
                x_val, y_val = [t.cuda() for t in (x_val, y_val)]
                out = model(x_val)
                preds = F.log_softmax(out, dim=1).argmax(dim=1)
                vtotal += y_val.size(0)
                vcorrect += (preds == y_val).sum().item()
                running_val_loss += criterion(out, y_val.long()).item()
            vacc = vcorrect / vtotal
            val_acc.append(vacc*100)
            epoch_val_loss = running_val_loss / len(val_dataloader)
            val_loss.append(epoch_val_loss)
            if epoch % 5 == 0:
                print(f'Epoch: {epoch:3d}. Loss: {loss.item():.4f}. Val Acc.: {vacc:2.2%}')
            if vacc > best_acc:
                trials = 0
                best_acc = vacc
                torch.save(model, os.path.join(output_directory, "model"))
                print(f'Epoch {epoch} best model saved with val accuracy: {best_acc:2.2%}')
            else:
                trials += 1
                if trials >= patience:
                    print(f'Early stopping on epoch {epoch}')
                    break

    if model_type == "Tr_CNN":
        for epoch in range(1, n_epochs + 1):
            tcorrect, ttotal = 0, 0
            running_train_loss = 0
            for (x_batch, y_batch) in train_dataloader:
                model.train()
                x_batch, y_batch = [t.cuda() for t in (x_batch, y_batch)]
                y_batch = y_batch.to(torch.int64)
                opt.zero_grad()
                out = model(x_batch)
                loss = criterion(out, y_batch)
                running_train_loss += loss.item()
                loss.backward()
                opt.step()
                sched.step()
                _,pred = torch.max(out, dim=1)
                ttotal += y_batch.size(0)
                tcorrect += torch.sum(pred==y_batch).item()
            train_acc.append(100 * tcorrect / ttotal)
            epoch_train_loss = running_train_loss / len(train_dataloader)
            train_loss.append(epoch_train_loss)
            model.eval()
            vcorrect, vtotal = 0, 0
            running_val_loss = 0
            for x_val, y_val in val_dataloader:
                x_val, y_val = [t.cuda() for t in (x_val, y_val)]
                out = model(x_val)
                preds = F.log_softmax(out, dim=1).argmax(dim=1)
                vtotal += y_val.size(0)
                vcorrect += (preds == y_val).sum().item()
                running_val_loss += criterion(out, y_val.long()).item()
            vacc = vcorrect / vtotal
            val_acc.append(vacc*100)
            epoch_val_loss = running_val_loss / len(val_dataloader)
            val_loss.append(epoch_val_loss)
            if epoch % 5 == 0:
                print(f'Epoch: {epoch:3d}. Loss: {loss.item():.4f}. Val Acc.: {vacc:2.2%}')
            if vacc > best_acc:
                trials = 0
                best_acc = vacc
                torch.save(model, os.path.join(output_directory, "model"))
                print(f'Epoch {epoch} best model saved with val accuracy: {best_acc:2.2%}')
            else:
                trials += 1
                if trials >= patience:
                    print(f'Early stopping on epoch {epoch}')
                    break

    if model_type == "Tr_LSTM":
        for epoch in range(1, n_epochs + 1):
            tcorrect, ttotal = 0, 0
            running_train_loss = 0
            for (x_batch, y_batch) in train_dataloader:
                model.train()
                x_batch, y_batch = [t.cuda() for t in (x_batch, y_batch)]
                y_batch = y_batch.to(torch.int64)
                opt.zero_grad()
                out = model(x_batch)
                loss = criterion(out, y_batch)
                running_train_loss += loss.item()
                loss.backward()
                opt.step()
                sched.step()
                _,pred = torch.max(out, dim=1)
                ttotal += y_batch.size(0)
                tcorrect += torch.sum(pred==y_batch).item()
            train_acc.append(100 * tcorrect / ttotal)
            epoch_train_loss = running_train_loss / len(train_dataloader)
            train_loss.append(epoch_train_loss)
            model.eval()
            vcorrect, vtotal = 0, 0
            running_val_loss = 0
            for x_val, y_val in val_dataloader:
                x_val, y_val = [t.cuda() for t in (x_val, y_val)]
                out = model(x_val)
                preds = F.log_softmax(out, dim=1).argmax(dim=1)
                vtotal += y_val.size(0)
                vcorrect += (preds == y_val).sum().item()
                running_val_loss += criterion(out, y_val.long()).item()
            vacc = vcorrect / vtotal
            val_acc.append(vacc*100)
            epoch_val_loss = running_val_loss / len(val_dataloader)
            val_loss.append(epoch_val_loss)
            if epoch % 5 == 0:
                print(f'Epoch: {epoch:3d}. Loss: {loss.item():.4f}. Val Acc.: {vacc:2.2%}')
            if vacc > best_acc:
                trials = 0
                best_acc = vacc
                torch.save(model, os.path.join(output_directory, "model"))
                print(f'Epoch {epoch} best model saved with val accuracy: {best_acc:2.2%}')
            else:
                trials += 1
                if trials >= patience:
                    print(f'Early stopping on epoch {epoch}')
                    break

    if model_type == "Tr_GRU":
        for epoch in range(1, n_epochs + 1):
            tcorrect, ttotal = 0, 0
            running_train_loss = 0
            for (x_batch, y_batch) in train_dataloader:
                model.train()
                x_batch, y_batch = [t.cuda() for t in (x_batch, y_batch)]
                y_batch = y_batch.to(torch.int64)
                opt.zero_grad()
                out = model(x_batch)
                loss = criterion(out, y_batch)
                running_train_loss += loss.item()
                loss.backward()
                opt.step()
                sched.step()
                _,pred = torch.max(out, dim=1)
                ttotal += y_batch.size(0)
                tcorrect += torch.sum(pred==y_batch).item()
            train_acc.append(100 * tcorrect / ttotal)
            epoch_train_loss = running_train_loss / len(train_dataloader)
            train_loss.append(epoch_train_loss)
            model.eval()
            vcorrect, vtotal = 0, 0
            running_val_loss = 0
            for x_val, y_val in val_dataloader:
                x_val, y_val = [t.cuda() for t in (x_val, y_val)]
                out = model(x_val)
                preds = F.log_softmax(out, dim=1).argmax(dim=1)
                vtotal += y_val.size(0)
                vcorrect += (preds == y_val).sum().item()
                running_val_loss += criterion(out, y_val.long()).item()
            vacc = vcorrect / vtotal
            val_acc.append(vacc*100)
            epoch_val_loss = running_val_loss / len(val_dataloader)
            val_loss.append(epoch_val_loss)
            if epoch % 5 == 0:
                print(f'Epoch: {epoch:3d}. Loss: {loss.item():.4f}. Val Acc.: {vacc:2.2%}')
            if vacc > best_acc:
                trials = 0
                best_acc = vacc
                torch.save(model, os.path.join(output_directory, "model"))
                print(f'Epoch {epoch} best model saved with val accuracy: {best_acc:2.2%}')
            else:
                trials += 1
                if trials >= patience:
                    print(f'Early stopping on epoch {epoch}')
                    break

    print("Training finished!")

    #Evaluate trained model

    if model_type == "FC":
        plot_learning_metrics(train_loss, val_loss, train_acc, val_acc, output_directory)
        print("Learning metrics plotted!")

        # Test the model
        ground_truth, predicted_genres, predicted_probs, accuracy = test_ann_model(model, test_dataloader)

        # Print test accuracy
        print(f'Test accuracy: {accuracy * 100:.2f}%')

        # Plot confusion matrix
        class_names = ['pop', 'classical', 'jazz', 'hiphop', 'reggae', 'disco', 'metal', 'country', 'blues', 'rock']
        save_ann_confusion_matrix(ground_truth, predicted_genres, class_names, output_directory)

        # Calculate ROC AUC scores
        roc_auc_scores = calculate_roc_auc(ground_truth, predicted_probs)

        # Print ROC AUC scores
        for class_idx, score in enumerate(roc_auc_scores):
            print(f'Class {class_idx} ROC AUC: {score:.4f}')

        # Plot ROC curves
        plot_roc_curve(ground_truth, predicted_probs, class_names, output_directory)

    if model_type == "CNN":
        plot_learning_metrics(train_loss, val_loss, train_acc, val_acc, output_directory)
        print("Learning metrics plotted!")

        # Test the model
        ground_truth, predicted_genres, predicted_probs, accuracy = test_ann_model(model, test_dataloader)

        # Print test accuracy
        print(f'Test accuracy: {accuracy * 100:.2f}%')

        # Plot confusion matrix
        class_names = ['pop', 'classical', 'jazz', 'hiphop', 'reggae', 'disco', 'metal', 'country', 'blues', 'rock']
        save_ann_confusion_matrix(ground_truth, predicted_genres, class_names, output_directory)

        # Calculate ROC AUC scores
        roc_auc_scores = calculate_roc_auc(ground_truth, predicted_probs)

        # Print ROC AUC scores
        for class_idx, score in enumerate(roc_auc_scores):
            print(f'Class {class_idx} ROC AUC: {score:.4f}')

        # Plot ROC curves
        plot_roc_curve(ground_truth, predicted_probs, class_names, output_directory)


    if model_type == "LSTM":
        plot_learning_metrics(train_loss, val_loss, train_acc, val_acc, output_directory)
        print("Learning metrics plotted!")

        # Test the model
        ground_truth, predicted_genres, predicted_probs, accuracy = test_recurrent_model(model, test_dataloader, device=device)

        # Print test accuracy
        print(f'Test accuracy: {accuracy * 100:.2f}%')

        # Plot confusion matrix
        class_names = ['pop', 'classical', 'jazz', 'hiphop', 'reggae', 'disco', 'metal', 'country', 'blues', 'rock']
        save_ann_confusion_matrix(ground_truth, predicted_genres, class_names, output_directory)

        # Calculate ROC AUC scores
        roc_auc_scores = calculate_roc_auc(ground_truth, predicted_probs)

        # Print ROC AUC scores
        for class_idx, score in enumerate(roc_auc_scores):
            print(f'Class {class_idx} ROC AUC: {score:.4f}')

        # Plot ROC curves
        plot_roc_curve(ground_truth, predicted_probs, class_names, output_directory)

    if model_type == "GRU":
        plot_learning_metrics(train_loss, val_loss, train_acc, val_acc, output_directory)
        print("Learning metrics plotted!")

        # Test the model
        ground_truth, predicted_genres, predicted_probs, accuracy = test_recurrent_model(model, test_dataloader, device=device)

        # Print test accuracy
        print(f'Test accuracy: {accuracy * 100:.2f}%')

        # Plot confusion matrix
        class_names = ['pop', 'classical', 'jazz', 'hiphop', 'reggae', 'disco', 'metal', 'country', 'blues', 'rock']
        save_ann_confusion_matrix(ground_truth, predicted_genres, class_names, output_directory)

        # Calculate ROC AUC scores
        roc_auc_scores = calculate_roc_auc(ground_truth, predicted_probs)

        # Print ROC AUC scores
        for class_idx, score in enumerate(roc_auc_scores):
            print(f'Class {class_idx} ROC AUC: {score:.4f}')

        # Plot ROC curves
        plot_roc_curve(ground_truth, predicted_probs, class_names, output_directory)

    if model_type == "Tr_FC":
        plot_learning_metrics(train_loss, val_loss, train_acc, val_acc, output_directory)
        print("Learning metrics plotted!")

        # Test the model
        ground_truth, predicted_genres, predicted_probs, accuracy = test_transformer_model(model, test_dataloader)

        # Print test accuracy
        print(f'Test accuracy: {accuracy * 100:.2f}%')

        # Plot confusion matrix
        class_names = ['pop', 'classical', 'jazz', 'hiphop', 'reggae', 'disco', 'metal', 'country', 'blues', 'rock']
        save_ann_confusion_matrix(ground_truth, predicted_genres, class_names, output_directory)

        # Calculate ROC AUC scores
        roc_auc_scores = calculate_roc_auc(ground_truth, predicted_probs)

        # Print ROC AUC scores
        for class_idx, score in enumerate(roc_auc_scores):
            print(f'Class {class_idx} ROC AUC: {score:.4f}')

        # Plot ROC curves
        plot_roc_curve(ground_truth, predicted_probs, class_names, output_directory)

    if model_type == "Tr_CNN":
        plot_learning_metrics(train_loss, val_loss, train_acc, val_acc, output_directory)
        print("Learning metrics plotted!")

        # Test the model
        ground_truth, predicted_genres, predicted_probs, accuracy = test_transformer_model(model, test_dataloader)

        # Print test accuracy
        print(f'Test accuracy: {accuracy * 100:.2f}%')

        # Plot confusion matrix
        class_names = ['pop', 'classical', 'jazz', 'hiphop', 'reggae', 'disco', 'metal', 'country', 'blues', 'rock']
        save_ann_confusion_matrix(ground_truth, predicted_genres, class_names, output_directory)

        # Calculate ROC AUC scores
        roc_auc_scores = calculate_roc_auc(ground_truth, predicted_probs)

        # Print ROC AUC scores
        for class_idx, score in enumerate(roc_auc_scores):
            print(f'Class {class_idx} ROC AUC: {score:.4f}')

        # Plot ROC curves
        plot_roc_curve(ground_truth, predicted_probs, class_names, output_directory)

    if model_type == "Tr_LSTM":
        plot_learning_metrics(train_loss, val_loss, train_acc, val_acc, output_directory)
        print("Learning metrics plotted!")

        # Test the model
        ground_truth, predicted_genres, predicted_probs, accuracy = test_transformer_model(model, test_dataloader, device=device)

        # Print test accuracy
        print(f'Test accuracy: {accuracy * 100:.2f}%')

        # Plot confusion matrix
        class_names = ['pop', 'classical', 'jazz', 'hiphop', 'reggae', 'disco', 'metal', 'country', 'blues', 'rock']
        save_ann_confusion_matrix(ground_truth, predicted_genres, class_names, output_directory)

        # Calculate ROC AUC scores
        roc_auc_scores = calculate_roc_auc(ground_truth, predicted_probs)

        # Print ROC AUC scores
        for class_idx, score in enumerate(roc_auc_scores):
            print(f'Class {class_idx} ROC AUC: {score:.4f}')

        # Plot ROC curves
        plot_roc_curve(ground_truth, predicted_probs, class_names, output_directory)

    if model_type == "Tr_GRU":
        plot_learning_metrics(train_loss, val_loss, train_acc, val_acc, output_directory)
        print("Learning metrics plotted!")

        # Test the model
        ground_truth, predicted_genres, predicted_probs, accuracy = test_transformer_model(model, test_dataloader, device=device)

        # Print test accuracy
        print(f'Test accuracy: {accuracy * 100:.2f}%')

        # Plot confusion matrix
        class_names = ['pop', 'classical', 'jazz', 'hiphop', 'reggae', 'disco', 'metal', 'country', 'blues', 'rock']
        save_ann_confusion_matrix(ground_truth, predicted_genres, class_names, output_directory)

        # Calculate ROC AUC scores
        roc_auc_scores = calculate_roc_auc(ground_truth, predicted_probs)

        # Print ROC AUC scores
        for class_idx, score in enumerate(roc_auc_scores):
            print(f'Class {class_idx} ROC AUC: {score:.4f}')

        # Plot ROC curves
        plot_roc_curve(ground_truth, predicted_probs, class_names, output_directory)

if __name__ == '__main__':
    # Call main function with provided arguments
    main(mfcc_path, model_type, output_directory, initial_lr)


Using device cuda:0
Data succesfully loaded!
Loaded data dimensions:
X shape: (990, 1292, 13)
y shape: (990,)
Training LSTM model with learning rate of 0.0001.
Epoch 1 best model saved with val accuracy: 31.82%
Epoch 2 best model saved with val accuracy: 33.84%
Epoch 3 best model saved with val accuracy: 37.37%
Epoch 4 best model saved with val accuracy: 37.88%
Epoch:   5. Loss: 1.7359. Val Acc.: 39.90%
Epoch 5 best model saved with val accuracy: 39.90%
Epoch 7 best model saved with val accuracy: 42.93%
Epoch 8 best model saved with val accuracy: 43.94%
Epoch:  10. Loss: 1.1109. Val Acc.: 45.45%
Epoch 10 best model saved with val accuracy: 45.45%
Epoch 14 best model saved with val accuracy: 46.97%
Epoch:  15. Loss: 0.8730. Val Acc.: 44.44%
Epoch:  20. Loss: 1.2171. Val Acc.: 46.97%
Epoch:  25. Loss: 0.9014. Val Acc.: 49.49%
Epoch 25 best model saved with val accuracy: 49.49%
Epoch:  30. Loss: 1.1782. Val Acc.: 49.49%
Epoch 31 best model saved with val accuracy: 50.00%
Epoch:  35. Loss:

Finished.