In [None]:
#importing the models
import torch.nn as nn
import torch
import torch.nn.functional as F
from tqdm import tqdm
import models.MaqamCNN1D
import models.MaqamCNN2D
import models.CNN_LSTM
import models.MFCC_LSTM
import models.MFCC_LSTM1D
import models.ANNModel
import models.ANNModel1
import models.ANNModel2
import MaqamDataset
from MaqamDataset import*
import torch
from torch.utils.data import DataLoader
import librosa
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim

In [None]:
# Before running the code you should change the dataset path accordingly in the MaqamDataset.py code.

In [None]:
#Setting parameters, batching size, and the option to chose the model you want to use.
batch_size = 64
option = 1 #       1-CNN                    2-LSTM                      3-ANN               4-combined_3models (just for test)
if option==3:
    feature = 'mfcc'
elif option==1 or option==2:
    feature = 'mfcc'
# feature = 'chroma'

In [None]:
# If the option is to use CNN model 
if option == 1:
    # Load train and validation datasets
    train_dataset = MaqamDataset(mode='train', cache_file='running_cache/train.pkl', feature=feature)
    val_dataset = MaqamDataset(mode='val', cache_file='running_cache/validation.pkl', feature=feature)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
    
    l = 0.001  # Learning rate
    
    # Lists to store accuracy and loss values
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    
    print("__________________________________________________________________________________________________________________")
    print("Learning rate =", l)
    
    # Initialize model, loss function, and optimizer
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = MaqamCNN2D().to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=l)
    
    # Train the model for a specified number of epochs
    num_epochs = 40
    patience = num_epochs
    best_val_loss = float('inf')
    best_model_state_dict = None
    no_improvement_epochs = 0
    
    print("Starting training")
    for epoch in range(num_epochs):
        # Set the model to training mode for the current epoch
        model.train()
        
        # Training loop
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        for i, data in enumerate(tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)):
            inputs, targets = data
            targets = targets.to(device)
            inputs = inputs.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Update loss and accuracy metrics
            running_loss += loss.item()
            _, predicted_labels = torch.max(outputs, 1)
            correct_predictions += (predicted_labels == targets).sum().item()
            total_samples += len(targets)
        
        # Calculate and print average loss and accuracy
        avg_loss = running_loss / len(train_loader)
        avg_accuracy = 100 * correct_predictions / total_samples
        print(f'Epoch {epoch + 1}/{num_epochs}: Train Loss={avg_loss:.5f}, Train Accuracy={avg_accuracy:.5f}%')
        
        # Validation loop
        model.eval()
        val_loss = 0.0
        total_correct = 0
        total_samples = 0
        with torch.no_grad():
            for data in tqdm(val_loader, desc='Validation', leave=False):
                inputs, targets = data
                targets = targets.to(device)
                inputs = inputs.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, targets).item() * len(targets)
                _, predicted_labels = torch.max(outputs, 1)
                total_correct += (predicted_labels == targets).sum().item()
                total_samples += len(targets)
        val_loss /= len(val_dataset)
        val_acc = 100 * total_correct / total_samples
        train_losses.append(avg_loss)
        train_accuracies.append(avg_accuracy)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        print(f'Epoch {epoch + 1}/{num_epochs} validation: val_loss={val_loss:.5f}, val_acc={val_acc:.5f}%')
        
        # Check for early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state_dict = model.state_dict()
            no_improvement_epochs = 0
        else:
            no_improvement_epochs += 1
        if no_improvement_epochs >= patience:
            print("Early stopping. No improvement in validation loss for {} epochs.".format(patience))
            break
    
    # Load the best model state dict
    if best_model_state_dict is not None:
        model.load_state_dict(best_model_state_dict)
    
    # Test the model on the test dataset
    test_dataset = MaqamDataset(mode='test', cache_file='running_cache/test.pkl', feature=feature)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing', leave=False):
            inputs, targets = data
            targets = targets.to(device)
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted_labels = torch.max(outputs, 1)
            total_correct += (predicted_labels == targets).sum().item()
            total_samples += len(targets)
    test_acc = 100 * total_correct / total_samples
    print(f'Test Accuracy: {test_acc:.5f}%')
    
    # Save the trained model
    torch.save(model.state_dict(), 'results/CNN1.pth')
    plt.figure()
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Learning Rate: {l}')
    plt.savefig(f'results/CNN loss_plot.png')
    plt.close()
    plt.figure()
    plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
    plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(f'Learning Rate: {l}')
    plt.savefig(f'results/CNN accuracy_plot.png')
    plt.close()

In [None]:
# If the option is to use LSTM model
if option == 2:
    l = 0.001  # Learning rate
    
    # Lists to store accuracy and loss values
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    
    # Load train and validation datasets
    train_dataset = MaqamDataset(mode='train', cache_file='running_cache/train.pkl', feature=feature)
    val_dataset = MaqamDataset(mode='val', cache_file='running_cache/val.pkl', feature=feature)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
    
    print("_________________________________________________________")
    print("Learning rate =", l)
    
    # Initialize model and define loss function and optimizer
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = MFCC_LSTM1D().to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=l)
    
    # Train the model for a specified number of epochs
    num_epochs = 800
    patience = num_epochs  # Number of epochs to wait for improvement before early stopping
    best_val_loss = float('inf')
    best_model_state_dict = None
    no_improvement_epochs = 0
    
    print("Starting training")
    for epoch in range(num_epochs):
        # Set the model to training mode for the current epoch
        model.train()
        
        # Training loop
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        
        for i, data in enumerate(tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)):
            inputs, targets = data
            targets = targets.to(device)
            inputs = inputs.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # Update loss and accuracy metrics
            running_loss += loss.item()
            _, predicted_labels = torch.max(outputs, 1)
            correct_predictions += (predicted_labels == targets).sum().item()
            total_samples += len(targets)
        
        # Calculate and print average loss and accuracy for the current epoch
        avg_loss = running_loss / len(train_loader)
        avg_accuracy = 100 * correct_predictions / total_samples
        if epoch % 50 == 0:
            print(f'Epoch {epoch + 1}/{num_epochs}: Train Loss={avg_loss:.5f}, Train Accuracy={avg_accuracy:.5f}%')
        
        # Validation loop
        model.eval()
        val_loss = 0.0
        total_correct = 0
        total_samples = 0
        
        with torch.no_grad():
            for data in tqdm(val_loader, desc='Validation', leave=False):
                inputs, targets = data
                targets = targets.to(device)
                inputs = inputs.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, targets).item() * len(targets)
                
                _, predicted_labels = torch.max(outputs, 1)
                total_correct += (predicted_labels == targets).sum().item()
                total_samples += len(targets)
        
        val_loss /= len(val_dataset)
        val_acc = 100 * total_correct / total_samples
        
        train_losses.append(avg_loss)
        train_accuracies.append(avg_accuracy)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        
        if epoch % 50 == 0:
            print(f'Epoch {epoch + 1}/{num_epochs} validation: val_loss={val_loss:.5f}, val_acc={val_acc:.5f}%')
        
        # Check for early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state_dict = model.state_dict()
            no_improvement_epochs = 0
        else:
            no_improvement_epochs += 1
        
        if no_improvement_epochs >= patience:
            print("Early stopping. No improvement in validation loss for {} epochs.".format(patience))
            break
    
    # Load the best model state dict
    if best_model_state_dict is not None:
        model.load_state_dict(best_model_state_dict)
    
    # Test the model on the test dataset
    test_dataset = MaqamDataset(mode='test', cache_file='running_cache/test.pkl', feature=feature)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
    
    model.eval()
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing', leave=False):
            inputs, targets = data
            targets = targets.to(device)
            inputs = inputs.to(device)
            outputs = model(inputs)
            
            _, predicted_labels = torch.max(outputs, 1)
            total_correct += (predicted_labels == targets).sum().item()
            total_samples += len(targets)
    
    test_acc = 100 * total_correct / total_samples
    print(f'Test Accuracy: {test_acc:.5f}%')
    
    # Save the trained model
    torch.save(model.state_dict(), 'results/lstm.pth')
    
    plt.figure()
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Learning Rate: {l}')
    plt.savefig(f'results/LSTM loss_plot_lr_{l}.png')
    plt.close()
    
    plt.figure()
    plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
    plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(f'Learning Rate: {l}')
    plt.savefig(f'results/LSTM accuracy_plot_lr_{l}.png')
    plt.close()


In [None]:
# Create a SummaryWriter for logging
writer = SummaryWriter('logs/')

# If the option is to use ANN model
if option == 3:
    l = 0.0001  # Learning rate
    
    # Lists to store accuracy and loss values
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    
    # Load train and validation datasets
    train_dataset = MaqamDataset(mode='train', cache_file='running_cache/train.pkl', feature=feature)
    val_dataset = MaqamDataset(mode='val', cache_file='running_cache/val.pkl', feature=feature)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
    
    print("_________________________________________________________")
    print("Learning rate =", l)
    
    # Initialize model and define loss function and optimizer
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = ANNModel2().to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=l)
    
    num_epochs = 125
    patience = num_epochs
    best_val_loss = float('inf')
    best_model_state_dict = None
    no_improvement_epochs = 0
    
    print("Starting training")
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0
        
        # Log learning rate for the current epoch
        writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)
        
        for i, data in enumerate(tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)):
            inputs, targets = data
            targets = targets.to(device)
            inputs = inputs.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted_labels = torch.max(outputs, 1)
            correct_predictions += (predicted_labels == targets).sum().item()
            total_samples += len(targets)
        
        avg_loss = running_loss / len(train_loader)
        avg_accuracy = 100 * correct_predictions / total_samples
        train_losses.append(avg_loss)
        train_accuracies.append(avg_accuracy)
        
        # Log training loss and accuracy
        writer.add_scalar('Train Loss', avg_loss, epoch)
        writer.add_scalar('Train Accuracy', avg_accuracy, epoch)
        
        print(f'Epoch {epoch + 1}/{num_epochs}: Train Loss={avg_loss:.5f}, Train Accuracy={avg_accuracy:.5f}%')
        
        model.eval()
        val_loss = 0.0
        total_correct = 0
        total_samples = 0
        
        with torch.no_grad():
            for data in tqdm(val_loader, desc='Validation', leave=False):
                inputs, targets = data
                targets = targets.to(device)
                inputs = inputs.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, targets).item() * len(targets)
                
                _, predicted_labels = torch.max(outputs, 1)
                total_correct += (predicted_labels == targets).sum().item()
                total_samples += len(targets)
        
        val_loss /= len(val_dataset)
        val_acc = 100 * total_correct / total_samples
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        print(f'Epoch {epoch + 1}/{num_epochs} validation: val_loss={val_loss:.5f}, val_acc={val_acc:.5f}%')
        
        # Log validation loss and accuracy
        writer.add_scalar('Validation Loss', val_loss, epoch)
        writer.add_scalar('Validation Accuracy', val_acc, epoch)
        
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state_dict = model.state_dict()
            no_improvement_epochs = 0
        else:
            no_improvement_epochs += 1
        
        if no_improvement_epochs >= patience:
            print("Early stopping. No improvement in validation loss for {} epochs.".format(patience))
            break
        
        # Log histograms of model parameters
        for name, param in model.named_parameters():
            writer.add_histogram(name, param, epoch)
            
    if best_model_state_dict is not None:
        model.load_state_dict(best_model_state_dict)
    
    # Test the model on the test dataset
    test_dataset = MaqamDataset(mode='test', cache_file='running_cache/test.pkl', feature=feature)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)
    
    model.eval()
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing', leave=False):
            inputs, targets = data
            targets = targets.to(device)
            inputs = inputs.to(device)
            outputs = model(inputs)
            
            _, predicted_labels = torch.max(outputs, 1)
            total_correct += (predicted_labels == targets).sum().item()
            total_samples += len(targets)
    
    test_acc = 100 * total_correct / total_samples
    print(f'Test Accuracy: {test_acc:.5f}%')
    writer.close()
    
    # Save the trained model
    torch.save(model.state_dict(), 'results/ANN.pth')
    
    # Plot and save loss and accuracy curves
    plt.figure()
    plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
    plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(f'Learning Rate: {l}')
    plt.savefig(f'results/ANN loss_plot_lr_{l}.png')
    plt.close()
    
    plt.figure()
    plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label='Train Accuracy')
    plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(f'Learning Rate: {l}')
    plt.savefig(f'results/ANN accuracy_plot_lr_{l}.png')
    plt.close()


In [None]:
# Load test dataset
test_dataset = MaqamDataset(mode='test', cache_file='running_cache/test2.pkl', feature=feature)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

model.eval()
total_correct = 0
total_samples = 0

# Lists to store accuracy by maqam statistics
acc_by_maqam_idx = [0, 1, 2, 3, 4, 5, 6, 7]
acc_by_maqam_ss = [0, 0, 0, 0, 0, 0, 0, 0]
acc_by_maqam_cc = [0, 0, 0, 0, 0, 0, 0, 0]
p = []  # List to store predicted labels
t = []  # List to store true labels

# Evaluate the model on the test dataset
with torch.no_grad():
    for data in tqdm(test_loader, desc='Testing', leave=False):
        inputs, targets = data
        targets = targets.to(device)
        inputs = inputs.to(device)
        outputs = model(inputs)

        _, predicted_labels = torch.max(outputs, 1)
        t += targets
        p += predicted_labels
        total_correct += (predicted_labels == targets).sum().item()
        total_samples += len(targets)
        
        # Calculate accuracy by maqam statistics
        for i in acc_by_maqam_idx:
            ss = 0
            cc = 0
            a = [-1 if x != i else i for x in targets]
            for counter in range(len(a)):
                if a[counter] != -1:
                    ss += 1
                    if a[counter] == predicted_labels[counter]:
                        cc += 1
            acc_by_maqam_ss[i] += ss
            acc_by_maqam_cc[i] += cc

# Calculate overall test accuracy
test_acc = 100 * total_correct / total_samples
print(f'Test Accuracy: {test_acc:.5f}%')

# Define maqam classes
classes = ['Ajam', 'Bayat', 'Hijaz', 'Kurd', 'Nahawand', 'Rast', 'Saba', 'Seka']

# Print accuracy by maqam
print(f'Accuracy by maqam:')
for i in range(8):
    print(classes[i] + " accuracy = ", acc_by_maqam_cc[i] / acc_by_maqam_ss[i])


In [None]:
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

combined_tensor = torch.stack(t)
combined_tensor1 = torch.stack(p)
true = combined_tensor.cpu()
predicted = combined_tensor1.cpu()

# Calculate the confusion matrix
cm = confusion_matrix(true, predicted)

# Create a pandas DataFrame to display the confusion matrix with class names
class_names = ['Ajam', 'Bayat', 'Hijaz',
                'Kurd', 'Nahawand', 'Rast', 'Saba', 'Seka']
cm_df = pd.DataFrame(cm, index=class_names, columns=class_names)

# Display the confusion matrix using seaborn heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(cm_df, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted Class')
plt.ylabel('True Class')
plt.title('Confusion Matrix')

# Save the confusion matrix plot as a PNG image
plt.savefig('results/confusion_matrix.jpeg', format='jpeg')

# Show the plot (optional)
plt.show()

# Save the confusion matrix to a CSV file
cm_df.to_csv('results/confusion_matrix.csv')

# Calculate the confusion matrix
cm = confusion_matrix(true, predicted)

# Calculate the probability representation of the confusion matrix
cm_probability = cm / cm.sum()

# Create a pandas DataFrame to display the confusion matrix with class names
class_names = ['Ajam', 'Bayat', 'Hijaz',
                'Kurd', 'Nahawand', 'Rast', 'Saba', 'Seka']
cm_df_probability = pd.DataFrame(
    cm_probability, index=class_names, columns=class_names)

# Display the confusion matrix with probabilities using seaborn heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(cm_df_probability, annot=True, fmt='.2f', cmap='Blues')
plt.xlabel('Predicted Class')
plt.ylabel('True Class')
plt.title('Confusion Matrix (Probability)')

# Save the confusion matrix plot as a PNG image
plt.savefig('results/confusion_matrix_probability.png', format='png')

# Show the plot (optional)
plt.show()
print("Option2")

# Calculate the confusion matrix
cm = confusion_matrix(true, predicted)

# Calculate the probability representation of the confusion matrix
cm_probability = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Create a pandas DataFrame to display the confusion matrix with class names
class_names = ['Ajam', 'Bayat', 'Hijaz', 'Kurd', 'Nahawand', 'Rast', 'Saba', 'Seka']
cm_df_probability = pd.DataFrame(cm_probability, index=class_names, columns=class_names)

# Display the confusion matrix with probabilities using seaborn heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(cm_df_probability, annot=True, fmt='.2f', cmap='Blues')
plt.xlabel('Predicted Class')
plt.ylabel('True Class')
plt.title('Confusion Matrix (Probability)')

# Save the confusion matrix plot as a PNG image
plt.savefig('results/confusion_matrix_probability.png', format='png')

# Show the plot
plt.show()

# Convert the true labels and predicted labels to one-hot encoded format
n_classes = len(np.unique(true))
true_labels_onehot = label_binarize(true, classes=np.arange(n_classes))
predicted_labels_onehot = label_binarize(predicted, classes=np.arange(n_classes))

# Compute the ROC curve for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(true_labels_onehot[:, i], predicted_labels_onehot[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Compute micro-average ROC curve and AUC
fpr["micro"], tpr["micro"], _ = roc_curve(true_labels_onehot.ravel(), predicted_labels_onehot.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

# Plot the ROC curve for each class
classes = ['Ajam', 'Bayat', 'Hijaz', 'Kurd', 'Nahawand', 'Rast', 'Saba', 'Seka']
plt.figure(figsize=(8, 6))
for i in range(n_classes):
    plt.plot(fpr[i], tpr[i], label=f'ROC curve ({classes[i]}) (area = {roc_auc[i]:.2f})')

# Plot the micro-average ROC curve
plt.plot(fpr["micro"], tpr["micro"], label=f'Micro-average ROC curve (area = {roc_auc["micro"]:.2f})', linestyle=':', linewidth=4)

# Customize the plot
plt.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Random Guess')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Multiclass Classification')
plt.legend()
plt.grid(True)
plt.savefig('results/ROC_curve.png', format='png')
plt.show()


In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

if option == 4: #combine 3 models results
    batch_size = 32

    # Load the test dataset
    test_dataset = MaqamDataset(mode='test', cache_file='test.pkl')
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate)

    # Load models and their weights 
    model1 = ANNModel().to(device)
    model2 = ANNModel1().to(device)
    model3 = ANNModel2().to(device)

    model1_path = "results/85FullData/ANN.pth"
    model2_path = "results/94mid_data/ANN.pth"
    model3_path = "results/all_features_85/ANN.pth"

    model1.load_state_dict(torch.load(model1_path))
    model2.load_state_dict(torch.load(model2_path))
    model3.load_state_dict(torch.load(model3_path))

    # Set the models to evaluation mode
    model1.eval()
    model2.eval()
    model3.eval()

    # Define the weights for combining the models
    weight1 = 0.25
    weight2 = 0.5
    weight3 = 0.25

    # Prepare lists to store the final predictions and corresponding labels
    all_predictions = []
    all_labels = []

    # Prepare lists to store the predictions and labels for each model independently
    model1_predictions = []
    model2_predictions = []
    model3_predictions = []

    # Loop through the test data
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)  # Move data to the device (GPU)
            labels = labels.to(device)

            # Get the predictions from each model
            predictions1 = model1(inputs)
            predictions2 = model2(inputs)
            predictions3 = model3(inputs)

            # Save the predictions of each model for individual accuracy calculation
            model1_predictions.extend(predictions1.argmax(dim=1).cpu().numpy())
            model2_predictions.extend(predictions2.argmax(dim=1).cpu().numpy())
            model3_predictions.extend(predictions3.argmax(dim=1).cpu().numpy())

            # Combine the predictions using the specified weights
            combined_predictions = weight1 * predictions1 + weight2 * predictions2 + weight3 * predictions3

            # Apply softmax to get the probabilities
            probabilities = F.softmax(combined_predictions, dim=1)

            # Get the class with the highest probability as the predicted class
            _, predicted_labels = torch.max(probabilities, 1)

            # Append the predictions and labels to the lists
            all_predictions.extend(predicted_labels.cpu().numpy())  # Convert back to CPU and extract the numpy array
            all_labels.extend(labels.cpu().numpy())  # Convert back to CPU and extract the numpy array

    # Calculate the accuracy of each model independently
    correct_model1 = sum([1 for pred, true in zip(model1_predictions, all_labels) if pred == true])
    correct_model2 = sum([1 for pred, true in zip(model2_predictions, all_labels) if pred == true])
    correct_model3 = sum([1 for pred, true in zip(model3_predictions, all_labels) if pred == true])

    total_samples = len(all_labels)
    acc_model1 = correct_model1 / total_samples * 100
    acc_model2 = correct_model2 / total_samples * 100
    acc_model3 = correct_model3 / total_samples * 100

    print(f'Model 1 Accuracy: {acc_model1:.5f}%')
    print(f'Model 2 Accuracy: {acc_model2:.5f}%')
    print(f'Model 3 Accuracy: {acc_model3:.5f}%')

    # Calculate the combined accuracy
    correct_combined = sum([1 for pred, true in zip(all_predictions, all_labels) if pred == true])
    acc_combined = correct_combined / total_samples * 100
    print(f'Combined Accuracy: {acc_combined:.5f}%')