In [89]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

In [90]:
from sklearn.metrics import (
    accuracy_score, 
    f1_score, 
    classification_report, 
    confusion_matrix, 
    roc_auc_score, 
    precision_score, 
    recall_score,
)
from scikitplot.metrics import plot_roc
from scikitplot.metrics import plot_precision_recall

In [91]:
df = pd.read_csv("../../our_analyses/dataset_prepared.csv")
df_test= pd.read_csv("../../our_analyses/dataset_test_prepared.csv")

In [92]:

from sklearn.preprocessing import LabelEncoder, StandardScaler

df = df.drop(['name', 'artists', 'album_name'], axis=1)
df_test = df_test.drop(['name', 'artists', 'album_name'], axis=1)

genre_groups = {
    'idm': 0, 'iranian': 0, 'study': 0,  # Electronic/Dance
    'black-metal': 1, 'breakbeat': 1, 'techno': 1,  # Metal/Rock
    'brazil': 2, 'forro': 2, 'happy': 2, 'spanish': 2, 'j-idol': 2,  # Country/Folk/Pop
    'afrobeat': 3, 'chicago-house': 3, 'industrial': 3, 'j-dance': 3,  # World/Commercial Pop
    'bluegrass': 4, 'disney': 4, 'indian': 4, 'mandopop': 4, 'sleep': 4  # Other
}

df['genre_group'] = df['genre'].map(genre_groups).astype(int)
df_test['genre_group'] = df_test['genre'].map(genre_groups).astype(int)

le = LabelEncoder()
df['explicit'] = le.fit_transform(df['explicit'])
df_test['explicit'] = le.transform(df_test['explicit'])

In [93]:
# Separazione delle features e dei target
X_train = df.drop(['genre_group', 'genre'], axis=1)
y_train = df['genre_group'].values
X_test = df_test.drop(['genre_group', 'genre'], axis=1)
y_test = df_test['genre_group'].values

y = np.array(df['genre_group'])

# Rimozione delle colonne 'genre_group' e 'genre' per ottenere le feature
X = df.drop(['genre_group', 'genre'], axis=1)

y_test = np.array(df['genre_group'])

X_test = df.drop(['genre_group', 'genre'], axis=1)


In [94]:
scaler = StandardScaler()
X_train_norm = scaler.fit_transform(X_train)
X_test_norm = scaler.transform(X_test)

In [95]:
from sklearn.neighbors import KNeighborsClassifier

# Addestramento del classificatore KNN
clf = KNeighborsClassifier(n_neighbors=8, metric="cityblock", weights="uniform")
clf.fit(X_train_norm, y_train)

# Valutazione del classificatore sul set di addestramento
y_train_pred = clf.predict(X_train_norm)
print("Accuracy sul set di addestramento:", accuracy_score(y_train, y_train_pred))

# Valutazione del classificatore sul set di test
y_test_pred = clf.predict(X_test_norm)
print("Accuracy sul set di test:", accuracy_score(y_test, y_test_pred))

print(classification_report(y_test, y_test_pred))

Accuracy sul set di addestramento: 0.7744
Accuracy sul set di test: 0.7146
              precision    recall  f1-score   support

           0       0.70      0.72      0.71      1500
           1       0.75      0.61      0.68      1000
           2       0.73      0.82      0.77      1750
           3       0.66      0.60      0.63       750

    accuracy                           0.71      5000
   macro avg       0.71      0.69      0.70      5000
weighted avg       0.71      0.71      0.71      5000



In [96]:
import plotly.figure_factory as ff
from sklearn.metrics import confusion_matrix

# Assumiamo che 'y_test' e 'y_test_pred' siano già definiti e contengano le etichette vere e predette dal tuo modello

# Genera la matrice di confusione
cf = confusion_matrix(y_test, y_test_pred)

# Definisci le etichette per i generi in base ai raggruppamenti definiti
genre_labels = ['Dance/Electronic', 'Ambient/Relaxing', 'Global/Traditional', 'Metal/Industrial', 'Pop/World']

# Crea la heatmap usando plotly
fig = ff.create_annotated_heatmap(z=cf, x=genre_labels, y=genre_labels,
                                  annotation_text=cf.astype(str), colorscale='Greens')

# Aggiorna il layout
fig.update_layout(title_text='Confusion Matrix', title_x=0.5,
                  xaxis=dict(title='Predicted Labels', tickangle=45),
                  yaxis=dict(title='True Labels', tickmode='array', tickvals=list(range(len(genre_labels))), ticktext=genre_labels),
                  yaxis_autorange='reversed')  # Reverse the y-axis to have the first class at the top

# Mostra il grafico
fig.show()


PlotlyError: oops, the x list that you provided does not match the width of your z matrix 

In [None]:
import plotly.graph_objs as go
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

# Calcola i punteggi di probabilità
y_prob = clf.predict_proba(X_test_norm)

# Binarizza le etichette in un formato one-vs-all
y_test_binarized = label_binarize(y_test, classes=[0, 1, 2, 3, 4])

# Calcola la ROC curve e l'area sotto la curva (AUC) per ogni classe
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(5):  # Numero di classi
    fpr[i], tpr[i], _ = roc_curve(y_test_binarized[:, i], y_prob[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Crea una figura per Plotly
fig = go.Figure()

# Aggiungi una linea ROC per ogni classe
for i in range(6):
    fig.add_trace(go.Scatter(x=fpr[i], y=tpr[i], mode='lines', name=f'ROC curve (area = {roc_auc[i]:.2f}) for class {i}'))

# Aggiungi la linea di riferimento per una classificazione casuale
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1], mode='lines', name='Random', line=dict(dash='dash')))

# Aggiorna il layout
fig.update_layout(title='Receiver Operating Characteristic (ROC) - Multi-Class',
                  xaxis_title='False Positive Rate',
                  yaxis_title='True Positive Rate',
                  xaxis=dict(scaleanchor="y", scaleratio=1),  # Questo assicura che l'asse x abbia lo stesso rapporto di scala dell'asse y
                  yaxis=dict(constrain='domain'),  # Questo fa in modo che l'asse y sia proporzionato correttamente
                  width=800,  # Larghezza della figura
                  height=800)  # Altezza della figura (rendendola uguale alla larghezza, il grafico sarà quadrato)

# Mostra il grafico
fig.show()


KeyError: 5

ValueError: 
    Invalid value of type 'builtins.str' received for the 'color' property of scatter.line
        Received value: 'transparent'

    The 'color' property is a color and may be specified as:
      - A hex string (e.g. '#ff0000')
      - An rgb/rgba string (e.g. 'rgb(255,0,0)')
      - An hsl/hsla string (e.g. 'hsl(0,100%,50%)')
      - An hsv/hsva string (e.g. 'hsv(0,100%,100%)')
      - A named CSS color:
            aliceblue, antiquewhite, aqua, aquamarine, azure,
            beige, bisque, black, blanchedalmond, blue,
            blueviolet, brown, burlywood, cadetblue,
            chartreuse, chocolate, coral, cornflowerblue,
            cornsilk, crimson, cyan, darkblue, darkcyan,
            darkgoldenrod, darkgray, darkgrey, darkgreen,
            darkkhaki, darkmagenta, darkolivegreen, darkorange,
            darkorchid, darkred, darksalmon, darkseagreen,
            darkslateblue, darkslategray, darkslategrey,
            darkturquoise, darkviolet, deeppink, deepskyblue,
            dimgray, dimgrey, dodgerblue, firebrick,
            floralwhite, forestgreen, fuchsia, gainsboro,
            ghostwhite, gold, goldenrod, gray, grey, green,
            greenyellow, honeydew, hotpink, indianred, indigo,
            ivory, khaki, lavender, lavenderblush, lawngreen,
            lemonchiffon, lightblue, lightcoral, lightcyan,
            lightgoldenrodyellow, lightgray, lightgrey,
            lightgreen, lightpink, lightsalmon, lightseagreen,
            lightskyblue, lightslategray, lightslategrey,
            lightsteelblue, lightyellow, lime, limegreen,
            linen, magenta, maroon, mediumaquamarine,
            mediumblue, mediumorchid, mediumpurple,
            mediumseagreen, mediumslateblue, mediumspringgreen,
            mediumturquoise, mediumvioletred, midnightblue,
            mintcream, mistyrose, moccasin, navajowhite, navy,
            oldlace, olive, olivedrab, orange, orangered,
            orchid, palegoldenrod, palegreen, paleturquoise,
            palevioletred, papayawhip, peachpuff, peru, pink,
            plum, powderblue, purple, red, rosybrown,
            royalblue, rebeccapurple, saddlebrown, salmon,
            sandybrown, seagreen, seashell, sienna, silver,
            skyblue, slateblue, slategray, slategrey, snow,
            springgreen, steelblue, tan, teal, thistle, tomato,
            turquoise, violet, wheat, white, whitesmoke,
            yellow, yellowgreen