# Model training 

Goal: Use a FastAI pretrained model to classify albums on black or death metal based on their name only.

In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import re
from sklearn.model_selection import train_test_split
import pandas as pd
from sklearn.preprocessing import LabelEncoder

Let's use a ASGD Weight-Dropped LSTM model, which is a type of Recurrent Neural Network (RNN). Instead of training our deep learning model from scratch, I am reusing a pretrained RNN model trained on a large amount of data, as we don't have much data, in a process known as Transfer Learning.

In [None]:
from fastai.text.all import *
import pandas as pd

# ---------------------------
# 1. Load and filter data
# ---------------------------
df = pd.read_csv("../data/angrymetalguy_reviews_with_scores_clean.csv")

# Keep only Black Metal and Death Metal albums
df = df[df['Genre'].isin(['Black Metal', 'Death Metal'])].dropna(subset=['Album', 'Genre'])

# ---------------------------
# 2. Create FastAI DataLoaders
# ---------------------------
dls = TextDataLoaders.from_df(
    df,
    text_col='Album',
    label_col='Genre',
    valid_pct=0.2,
    text_vocab=None,
    is_lm=False,
    bs=32
)

# ---------------------------
# 3. Create the learner
# ---------------------------
learn = text_classifier_learner(
    dls,
    AWD_LSTM,
    drop_mult=0.5,
    metrics=accuracy
)

# ---------------------------
# 4. Fine-tune the model
# ---------------------------
n_epochs = 3
learn.fine_tune(n_epochs)

# ---------------------------
# 5. Extract training history
# ---------------------------
rec = learn.recorder

# Metrics per epoch
train_losses = [v[0] for v in rec.values]
valid_losses = [v[1] for v in rec.values]
accuracies   = [v[2] for v in rec.values]

# Build DataFrame
history = pd.DataFrame({
    "epoch": range(1, n_epochs+1),
    "train_loss": train_losses,
    "valid_loss": valid_losses,
    "accuracy": accuracies
})

# Save history
# history.to_csv("../data/training_history.csv", index=False)

# ---------------------------
# 6. Save the model
# ---------------------------
# learn.export("../model/album_genre_classifier.pkl")

Due to IPython and Windows limitation, python multiprocessing isn't available now.
So `n_workers` has to be changed to 0 to avoid getting stuck


epoch,train_loss,valid_loss,accuracy,time
0,0.759605,0.700806,0.520315,00:31


epoch,train_loss,valid_loss,accuracy,time
0,0.698753,0.69348,0.538663,01:02
1,0.687852,0.687643,0.567497,00:54
2,0.679271,0.690677,0.563565,00:52


Now let's get the bands where the model has predicted to be "most black" or "most death"

In [4]:
all_preds = []

for album, true_genre in zip(df['Album'], df['Genre']):
    pred_class, pred_idx, probs = learn.predict(album)
    blackness = probs[0].item()
    deathness = probs[1].item()
    all_preds.append({
        "Album": album,
        "True_Genre": true_genre,
        "Predicted_Genre": pred_class,
        "Black_Prob": blackness,
        "Death_Prob": deathness
    })

# Convert to DataFrame for easier analysis
import pandas as pd
df_preds = pd.DataFrame(all_preds)

# Top predicted Black Metal albums
top_black = df_preds.sort_values("Black_Prob", ascending=False).head(10)

# Top predicted Death Metal albums
top_death = df_preds.sort_values("Death_Prob", ascending=False).head(10)

In [5]:
top_black

Unnamed: 0,Album,True_Genre,Predicted_Genre,Black_Prob,Death_Prob
1058,Vrees De Toorn Van De Wezens Verscholen Achter Majestueuze Vleugels,Black Metal,Black Metal,0.979812,0.020188
3671,De Doden Hebben Het Goed II,Black Metal,Black Metal,0.927115,0.072885
81,Met De Drietand Op Mijn Huid,Black Metal,Black Metal,0.925078,0.074922
2904,"Chante, Ô Flamme de la Liberté",Black Metal,Black Metal,0.923776,0.076224
1409,De Sève et de Sang,Black Metal,Black Metal,0.921404,0.078596
2035,La Caceria De Brujas,Black Metal,Black Metal,0.904933,0.095067
1748,La Caída De Tonatiuh,Death Metal,Black Metal,0.904933,0.095067
1141,La Era de la Bestia,Black Metal,Black Metal,0.902395,0.097605
1405,Le Triomphe du Charnier,Black Metal,Black Metal,0.902124,0.097876
1548,De verminkte stilte van het zijn,Black Metal,Black Metal,0.900578,0.099422


In [6]:
top_death

Unnamed: 0,Album,True_Genre,Predicted_Genre,Black_Prob,Death_Prob
3557,Those Who Have Fallen Beyond the Grace of God,Death Metal,Death Metal,0.068042,0.931958
2137,Those Who Reign Below,Death Metal,Death Metal,0.071965,0.928035
678,"And as We Have Seen the Storm, We Have Embraced the Eye",Death Metal,Death Metal,0.096148,0.903852
3205,The World That Was,Death Metal,Death Metal,0.099933,0.900067
3480,Fear Those Who Fear Him,Death Metal,Death Metal,0.100954,0.899046
1343,Tales of Grotesque Demise,Death Metal,Death Metal,0.104794,0.895206
376,Diorama of Human Suffering,Death Metal,Death Metal,0.108426,0.891574
3674,Echoes Review and Album Premiere,Death Metal,Death Metal,0.108576,0.891424
1955,"Violence, Our Power",Black Metal,Death Metal,0.116238,0.883762
615,Where the Corpses Sink Forever,Black Metal,Death Metal,0.118146,0.881854


Now, the most wrong predictions

In [7]:
wrong_preds = df_preds[((df_preds["Predicted_Genre"] == "Black Metal") & (df_preds["True_Genre"] == "Death Metal") | 
          ((df_preds["Predicted_Genre"] == "Death Metal") & (df_preds["True_Genre"] == "Black Metal")))]

most_wrong_death = wrong_preds.sort_values(by="Black_Prob", ascending=False).iloc[:10]

most_wrong_black = wrong_preds.sort_values(by="Death_Prob", ascending=False).iloc[:10]

In [8]:
most_wrong_death

Unnamed: 0,Album,True_Genre,Predicted_Genre,Black_Prob,Death_Prob
1748,La Caída De Tonatiuh,Death Metal,Black Metal,0.904933,0.095067
1163,III-Hear Me O’ Death (Sing Thou Wretched Choirs),Death Metal,Black Metal,0.846486,0.153514
738,Au Bord du Précipice,Death Metal,Black Metal,0.846177,0.153823
304,L’Être et la Nausée,Death Metal,Black Metal,0.840861,0.159139
903,Beyond the Red Light District: A Canal Experiment,Death Metal,Black Metal,0.840619,0.159381
3538,L’abime dévore les âmes,Death Metal,Black Metal,0.78852,0.21148
689,Le Dernier Crépuscule,Death Metal,Black Metal,0.782667,0.217333
303,Le Déclin,Death Metal,Black Metal,0.765356,0.234644
2785,Kopár hant…az alvilág felé,Death Metal,Black Metal,0.764419,0.235581
2157,Nach uns die Grindflut,Death Metal,Black Metal,0.762469,0.237531


In [9]:
most_wrong_black

Unnamed: 0,Album,True_Genre,Predicted_Genre,Black_Prob,Death_Prob
1955,"Violence, Our Power",Black Metal,Death Metal,0.116238,0.883762
615,Where the Corpses Sink Forever,Black Metal,Death Metal,0.118146,0.881854
2379,Flesh Torn – Spirit Pierced,Black Metal,Death Metal,0.141143,0.858857
1266,To Those Who Fell,Black Metal,Death Metal,0.147798,0.852202
2241,As We Were When We Were Not,Black Metal,Death Metal,0.151942,0.848058
844,Where Shadows Forever Reign,Black Metal,Death Metal,0.161329,0.838671
3632,Acts of Repentance,Black Metal,Death Metal,0.173525,0.826475
1844,Wounds of Desolation,Black Metal,Death Metal,0.176697,0.823303
3679,As Life Drifts Away,Black Metal,Death Metal,0.177094,0.822906
2011,Visions of Collapse,Black Metal,Death Metal,0.197172,0.802827


Our model seems to classify non english album names as black metal