In [1]:
import os
import numpy as np
import pandas as pd
import absl.logging
from nlp_embedding import GloVe, SmallBert, Bert, Word2vec, LargeBert
from nlp_classifier import NaiveBayes, SVM, XGBoost, CNN, BinaryCNN
from sklearn import preprocessing, metrics
from ast import literal_eval
absl.logging.set_verbosity(absl.logging.ERROR)

In [2]:
max_words = 400
dataset_name = 'small'
cnn_epochs = 5
indiv_genre = 'Rock'
optimizer = 'adam'

In [3]:
import tensorflow as tf
from tensorflow.keras import layers, models
tf.get_logger().setLevel('ERROR')

class CNN2Step():
    def __init__(self, vec_len, class_count, optimizer, indiv_class):
        self.name = '2_step_cnn'
        self.model1 = BinaryCNN(vec_len, optimizer)
        self.model2 = CNN(vec_len, class_count - 1, optimizer)
        self.indiv_class = indiv_class

    def partial_fit(self, X, Y):
        Y_binary = np.array(Y == self.indiv_class).astype(int)
        self.model1.partial_fit(X.reshape(*X.shape, 1), Y_binary.reshape(-1, 1))
        
        X_other = X[Y != self.indiv_class]
        Y_other = Y[Y != self.indiv_class]
        self.model2.partial_fit(X_other.reshape(*X_other.shape, 1), Y_other.reshape(-1, 1))

    def predict(self, X):
        pred1 = self.model1.predict(X.reshape(*X.shape, 1))
        X2 = X[pred1 == 0]
        pred2 = self.model2.predict(X2.reshape(*X2.shape, 1))
        
        pred = np.zeros(len(X), dtype=int)
        pred[pred1 == 1] = self.indiv_class
        pred[pred1 == 0] = pred2

        print(np.unique(pred))
        print(np.unique(pred1))
        print(np.unique(pred2))
        
        return pred.flatten()

    def predict_proba(self, X):
        # we are not using it anyway for now
        pass
    
    def save(self, filename):
        self.model1.save(f'{filename}1')
        self.model2.save(f'{filename}2')

    def load(self, filename):
        self.model1 = models.load_model(f'{filename}1')
        self.model2 = models.load_model(f'{filename}2')

In [4]:
def train(data_x, data_y, nlp_embedding, nlp_classifier, label_encoder, batch_size, dataset_name, epochs=1, model_dir='models', start_idx=0, fname_end=''):
    print('Training...')
    fname = os.path.join(model_dir, dataset_name, f'model_{nlp_embedding.name}_{nlp_classifier.name}{fname_end}')
    data_y_enc = label_encoder.transform(data_y)
    
    for epoch in range(epochs):
        print(f'Epoch: {str(epoch + 1)}/{str(epochs)}')
        for i in range(start_idx, data_x.shape[0], batch_size):
            
            if i + batch_size > data_x.shape[0]:
                j = data_x.shape[0]
            else:
                j = i + batch_size
            
            print(f'Processing rows: {i} - {j - 1}')

            embeddings = nlp_embedding.embed_lyrics(data_x[i:j])
            nlp_classifier.partial_fit(embeddings, data_y_enc[i:j])
            nlp_classifier.save(fname)
        start_idx = 0
    
    print('Success!')

In [5]:
def test(data_x, nlp_embedding, nlp_classifier, label_encoder, batch_size, dataset_name, pred_dir='predictions', start_idx=0, fname_end=''):
    print('Testing...')
    fname = os.path.join(pred_dir, dataset_name, f'model_{nlp_embedding.name}_{nlp_classifier.name}{fname_end}.csv')
    predictions_all = []

    if start_idx == 0 and os.path.exists(fname):
        os.remove(fname)
    
    for i in range(start_idx, data_x.shape[0], batch_size):

        if i + batch_size > data_x.shape[0]:
            j = data_x.shape[0]
        else:
            j = i + batch_size
        
        print(f'Processing rows: {i} - {j - 1}')

        embeddings = nlp_embedding.embed_lyrics(data_x[i:j])
        predictions_enc = nlp_classifier.predict(embeddings)
        predictions = label_encoder.inverse_transform(predictions_enc)
        
        predictions_all.extend(predictions)

        pd.DataFrame(predictions.reshape(-1, 1)).to_csv(fname, mode='a', index=False, header=False)
    
    print('Success!')    
    
    return predictions_all

In [6]:
def get_results(y_true, y_pred):
    print('RESULTS:')
    print(f'accuracy = {metrics.accuracy_score(y_true=y_true, y_pred=y_pred)}')
    print(f'balanced accuracy = {metrics.balanced_accuracy_score(y_true=y_true, y_pred=y_pred)}')
    print(f'f1 score = {metrics.f1_score(y_true=y_true, y_pred=y_pred, average="weighted")}')

In [7]:
def train_and_save_results(emb, clf, x_train, y_train, x_test, y_test, dataset_name, le, batch_size=5000, epochs=1, fname_end=''):
    train(x_train, y_train, emb, clf, le, batch_size, dataset_name, epochs=epochs, fname_end=fname_end)
    y_pred = test(x_test, emb, clf, le, batch_size, dataset_name, fname_end=fname_end)
    get_results(y_test, y_pred)
    return

In [8]:
def add_normalized_lyrics(data):
    tokens = data.tokens.apply(literal_eval)
    data['normalized_lyrics'] = [' '.join(t) for t in tokens]

In [9]:
model_dir = os.path.join('models', dataset_name)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

pred_dir = os.path.join('predictions', dataset_name)
if not os.path.exists(pred_dir):
    os.makedirs(pred_dir)

In [10]:
train_data = pd.read_csv(f'data/train/{dataset_name}.csv')
test_data = pd.read_csv(f'data/test/{dataset_name}.csv')

In [11]:
train_data.genre.value_counts()

Rock       87661
Pop        33527
Metal      24294
Hip-Hop    21746
Country    16257
Name: genre, dtype: int64

In [12]:
test_data.genre.value_counts()

Rock       37569
Pop        14369
Metal      10412
Hip-Hop     9320
Country     6967
Name: genre, dtype: int64

In [13]:
add_normalized_lyrics(train_data)
add_normalized_lyrics(test_data)

In [14]:
genres = np.unique(train_data.genre)
label_encoder = preprocessing.LabelEncoder()
label_encoder.fit(genres)
label_encoder.classes_

array(['Country', 'Hip-Hop', 'Metal', 'Pop', 'Rock'], dtype=object)

In [15]:
indiv_genre_label = label_encoder.transform([indiv_genre])[0]

## Smaller BERT

In [None]:
emb_sm_bert = SmallBert(max_words)

In [None]:
clf_cnn = CNN(max_words * emb_sm_bert.embedding_size, len(genres), optimizer)
train_and_save_results(emb_sm_bert, clf_cnn_norm,
                       train_data.lyrics, train_data.genre, test_data.normalized_lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs, fname_end='_norm')

In [None]:
clf_cnn_norm = CNN(max_words * emb_sm_bert.embedding_size, len(genres), optimizer)
train_and_save_results(emb_sm_bert, clf_cnn_norm,
                       train_data.normalized_lyrics, train_data.genre, test_data.normalized_lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs, fname_end='_norm')

## BERT

In [None]:
emb_bert = Bert(max_words)

In [None]:
clf_cnn_b = CNN(max_words * emb_bert.embedding_size, len(genres), optimizer)
train_and_save_results(emb_bert, clf_cnn_b,
                       train_data.lyrics, train_data.genre, test_data.lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs, batch_size=1000)

In [None]:
clf_cnn_b_norm = CNN(max_words * emb_bert.embedding_size, len(genres), optimizer)
train_and_save_results(emb_bert, clf_cnn_b_norm,
                       train_data.normalized_lyrics, train_data.genre, test_data.normalized_lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs, batch_size=1000, fname_end='_norm')

## Large BERT

In [None]:
emb_lr_bert = LargeBert(max_words)

In [None]:
clf_cnn_bl = CNN(max_words * emb_lr_bert.embedding_size, len(genres), optimizer)
train_and_save_results(emb_lr_bert, clf_cnn_bl,
                       train_data.lyrics, train_data.genre, test_data.lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs, batch_size=1000)

In [None]:
clf_cnn_bl_norm = CNN(max_words * emb_lr_bert.embedding_size, len(genres), optimizer)
train_and_save_results(emb_lr_bert, clf_cnn_bl_norm,
                       train_data.normalized_lyrics, train_data.genre, test_data.normalized_lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs, batch_size=1000, fname_end='_norm')

## GloVe

In [16]:
emb_glove = GloVe(max_words)

glove_100d download started this may take some time.
Approximate size to download 145.3 MB
[OK!]


In [None]:
clf_cnn_g = CNN(max_words * emb_glove.embedding_size, len(genres), optimizer)
train_and_save_results(emb_glove, clf_cnn_g, 
                       train_data.lyrics, train_data.genre, test_data.lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs)

In [None]:
clf_cnn_g_norm = CNN(max_words * emb_glove.embedding_size, len(genres), optimizer)
train_and_save_results(emb_glove, clf_cnn_g_norm, 
                       train_data.normalized_lyrics, train_data.genre, test_data.normalized_lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs, fname_end='_norm')

### 2-step CNN

In [17]:
clf_cnn2_g_norm = CNN2Step(max_words * emb_glove.embedding_size, len(genres), optimizer, indiv_genre_label)
train_and_save_results(emb_glove, clf_cnn2_g_norm, 
                       train_data.normalized_lyrics, train_data.genre, test_data.normalized_lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs, fname_end='_norm')

Training...
Epoch: 1/5
Processing rows: 0 - 4999
Processing rows: 5000 - 9999
Processing rows: 10000 - 14999
Processing rows: 15000 - 19999
Processing rows: 20000 - 24999
Processing rows: 25000 - 29999
Processing rows: 30000 - 34999
Processing rows: 35000 - 39999
Processing rows: 40000 - 44999
Processing rows: 45000 - 49999
Processing rows: 50000 - 54999
Processing rows: 55000 - 59999
Processing rows: 60000 - 64999
Processing rows: 65000 - 69999
Processing rows: 70000 - 74999
Processing rows: 75000 - 79999
Processing rows: 80000 - 84999
Processing rows: 85000 - 89999
Processing rows: 90000 - 94999
Processing rows: 95000 - 99999
Processing rows: 100000 - 104999
Processing rows: 105000 - 109999
Processing rows: 110000 - 114999
Processing rows: 115000 - 119999
Processing rows: 120000 - 124999
Processing rows: 125000 - 129999
Processing rows: 130000 - 134999
Processing rows: 135000 - 139999
Processing rows: 140000 - 144999
Processing rows: 145000 - 149999
Processing rows: 150000 - 154999
P

Processing rows: 10000 - 14999
Processing rows: 15000 - 19999
Processing rows: 20000 - 24999
Processing rows: 25000 - 29999
Processing rows: 30000 - 34999
Processing rows: 35000 - 39999
Processing rows: 40000 - 44999
Processing rows: 45000 - 49999
Processing rows: 50000 - 54999
Processing rows: 55000 - 59999
Processing rows: 60000 - 64999
Processing rows: 65000 - 69999
Processing rows: 70000 - 74999
Processing rows: 75000 - 79999
Processing rows: 80000 - 84999
Processing rows: 85000 - 89999
Processing rows: 90000 - 94999
Processing rows: 95000 - 99999
Processing rows: 100000 - 104999
Processing rows: 105000 - 109999
Processing rows: 110000 - 114999
Processing rows: 115000 - 119999
Processing rows: 120000 - 124999
Processing rows: 125000 - 129999
Processing rows: 130000 - 134999
Processing rows: 135000 - 139999
Processing rows: 140000 - 144999
Processing rows: 145000 - 149999
Processing rows: 150000 - 154999
Processing rows: 155000 - 159999
Processing rows: 160000 - 164999
Processing ro

Processing rows: 15000 - 19999
Processing rows: 20000 - 24999
Processing rows: 25000 - 29999
Processing rows: 30000 - 34999
Processing rows: 35000 - 39999
Processing rows: 40000 - 44999
Processing rows: 45000 - 49999
Processing rows: 50000 - 54999
Processing rows: 55000 - 59999
Processing rows: 60000 - 64999
Processing rows: 65000 - 69999
Processing rows: 70000 - 74999
Processing rows: 75000 - 79999
Processing rows: 80000 - 84999
Processing rows: 85000 - 89999
Processing rows: 90000 - 94999
Processing rows: 95000 - 99999
Processing rows: 100000 - 104999
Processing rows: 105000 - 109999
Processing rows: 110000 - 114999
Processing rows: 115000 - 119999
Processing rows: 120000 - 124999
Processing rows: 125000 - 129999
Processing rows: 130000 - 134999
Processing rows: 135000 - 139999
Processing rows: 140000 - 144999
Processing rows: 145000 - 149999
Processing rows: 150000 - 154999
Processing rows: 155000 - 159999
Processing rows: 160000 - 164999
Processing rows: 165000 - 169999
Processing 

Processing rows: 20000 - 24999
Processing rows: 25000 - 29999
Processing rows: 30000 - 34999
Processing rows: 35000 - 39999
Processing rows: 40000 - 44999
Processing rows: 45000 - 49999
Processing rows: 50000 - 54999
Processing rows: 55000 - 59999
Processing rows: 60000 - 64999
Processing rows: 65000 - 69999
Processing rows: 70000 - 74999
Processing rows: 75000 - 79999
Processing rows: 80000 - 84999
Processing rows: 85000 - 89999
Processing rows: 90000 - 94999
Processing rows: 95000 - 99999
Processing rows: 100000 - 104999
Processing rows: 105000 - 109999
Processing rows: 110000 - 114999
Processing rows: 115000 - 119999
Processing rows: 120000 - 124999
Processing rows: 125000 - 129999
Processing rows: 130000 - 134999
Processing rows: 135000 - 139999
Processing rows: 140000 - 144999
Processing rows: 145000 - 149999
Processing rows: 150000 - 154999
Processing rows: 155000 - 159999
Processing rows: 160000 - 164999
Processing rows: 165000 - 169999
Processing rows: 170000 - 174999
Processin

Processing rows: 25000 - 29999
Processing rows: 30000 - 34999
Processing rows: 35000 - 39999
Processing rows: 40000 - 44999
Processing rows: 45000 - 49999
Processing rows: 50000 - 54999
Processing rows: 55000 - 59999
Processing rows: 60000 - 64999
Processing rows: 65000 - 69999
Processing rows: 70000 - 74999
Processing rows: 75000 - 79999
Processing rows: 80000 - 84999
Processing rows: 85000 - 89999
Processing rows: 90000 - 94999
Processing rows: 95000 - 99999
Processing rows: 100000 - 104999
Processing rows: 105000 - 109999
Processing rows: 110000 - 114999
Processing rows: 115000 - 119999
Processing rows: 120000 - 124999
Processing rows: 125000 - 129999
Processing rows: 130000 - 134999
Processing rows: 135000 - 139999
Processing rows: 140000 - 144999
Processing rows: 145000 - 149999
Processing rows: 150000 - 154999
Processing rows: 155000 - 159999
Processing rows: 160000 - 164999
Processing rows: 165000 - 169999
Processing rows: 170000 - 174999
Processing rows: 175000 - 179999
Process

[0 1 2 3 4]
[0 1]
[0 1 2 3]
Processing rows: 35000 - 39999
[0 1 2 3 4]
[0 1]
[0 1 2 3]
Processing rows: 40000 - 44999
[0 1 2 3 4]
[0 1]
[0 1 2 3]
Processing rows: 45000 - 49999
[0 1 2 3 4]
[0 1]
[0 1 2 3]
Processing rows: 50000 - 54999
[0 1 2 3 4]
[0 1]
[0 1 2 3]
Processing rows: 55000 - 59999
[0 1 2 3 4]
[0 1]
[0 1 2 3]
Processing rows: 60000 - 64999
[0 1 2 3 4]
[0 1]
[0 1 2 3]
Processing rows: 65000 - 69999
[0 1 2 3 4]
[0 1]
[0 1 2 3]
Processing rows: 70000 - 74999
[0 1 2 3 4]
[0 1]
[0 1 2 3]
Processing rows: 75000 - 78636
[0 1 2 3 4]
[0 1]
[0 1 2 3]
Success!
RESULTS:
accuracy = 0.5329552246398006
balanced accuracy = 0.5089186429107216
f1 score = 0.5391173053706446


## Word2vec

In [None]:
emb_wv = Word2vec(max_words)

In [None]:
clf_cnn_w_nn = CNN(max_words * emb_wv.embedding_size, len(genres), 'adam')
train_and_save_results(emb_wv, clf_cnn_w_nn, 
                       train_data.lyrics, train_data.genre, test_data.lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs)

In [None]:
clf_cnn_w = CNN(max_words * emb_wv.embedding_size, len(genres), 'adam')
train_and_save_results(emb_wv, clf_cnn_w, 
                       train_data.normalized_lyrics, train_data.genre, test_data.normalized_lyrics, test_data.genre, 
                       dataset_name, label_encoder, epochs=cnn_epochs, fname_end='_norm')