In [None]:
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize

In [None]:
num_classes = 2

model.load_state_dict(torch.load('DM_gtzan_best.pth'))
model = model.to(device)
model.eval()

correct_test = 0
total_test = 0
predicted_labels = []
true_labels = []
wrong_predictions = []

all_predicted_probs = []
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.float().to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total_test += labels.size(0)
        correct_test += (predicted == labels).sum().item()
        
        predicted_labels.extend(predicted.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())
        
        all_predicted_probs.append(outputs.cpu().numpy())

        for i in range(len(predicted)):
            if predicted[i] != labels[i]:
                wrong_predictions.append({
                    'Audio': inputs[i],
                    'True Label': labels[i].item(), 
                    'Predicted Label': predicted[i].item()
                })

In [None]:
test_accuracy = correct_test / total_test
print(f"Test Accuracy: {test_accuracy:.2%}")

conf_matrix = confusion_matrix(true_labels, predicted_labels)

precision = precision_score(true_labels, predicted_labels, average='weighted')
recall = recall_score(true_labels, predicted_labels, average='weighted')
f1 = f1_score(true_labels, predicted_labels, average='weighted')

all_predicted_probs = np.concatenate(all_predicted_probs, axis=0)
true_labels_bin = label_binarize(true_labels, classes=range(num_classes))

In [None]:
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()

In [None]:
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")