In [None]:
import os
import shutil
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from sklearn import preprocessing, metrics
tf.get_logger().setLevel('ERROR')

In [None]:
dataset_name = 'small_balanced'

In [None]:
# bert_en_uncased_L-12_H-768_A-12
bert_enc = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3'
bert_prepr = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'
# small_bert/bert_en_uncased_L-2_H-256_A-4
bert_enc_sm = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1'
bert_prepr_sm = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'

In [None]:
def create_bert_model(bert_enc, bert_prepr, classes_count):
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
    preprocessing_layer = hub.KerasLayer(bert_prepr, name='preprocessing')
    encoder_inputs = preprocessing_layer(text_input)
    encoder = hub.KerasLayer(bert_enc, trainable=True, name='BERT_encoder')
    outputs = encoder(encoder_inputs)
    net = outputs['pooled_output']
    net = tf.keras.layers.Dropout(0.1)(net)
    net = tf.keras.layers.Dense(classes_count, activation='softmax', name='classifier')(net)
    model = tf.keras.Model(text_input, net)
    
    model.compile(optimizer="adam", 
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

In [None]:
def tranform_X(data_x):
    X = np.array(data_x)
    X = np.asarray(X).astype('str')
    X = X.reshape(-1,1)
    return X

In [None]:
def train_bert(data_x, data_y, test_x, test_y, model, label_encoder, epochs=1, model_dir='models', filename="bert_tuned"):
    print('Training...')
    fname = os.path.join(model_dir, dataset_name, f'model_{filename}')
    Y = label_encoder.transform(data_y).reshape(-1,1)
    X = tranform_X(data_x)
    Yt = label_encoder.transform(test_y).reshape(-1,1)
    Xt = tranform_X(test_x)
    history = model.fit(x=X, y=Y, epochs=epochs, validation_data=(Xt, Yt))
    model.save(fname)
    print('Success!')
    return history

In [None]:
def test_bert(data_x, model, label_encoder, pred_dir='predictions', filename="bert_tuned"):
    print('Testing...')
    fname = os.path.join(pred_dir, dataset_name, f'model_{filename}.csv')
    if os.path.exists(fname):
        os.remove(fname)
    X = tranform_X(data_x)
    predictions_enc = np.argmax(model.predict(X), axis=1).flatten()
    predictions = label_encoder.inverse_transform(predictions_enc)
    pd.DataFrame(predictions.reshape(-1, 1)).to_csv(fname, mode='a', index=False, header=False)
    print('Success!')
    return predictions

In [None]:
def get_results(y_true, y_pred):
    print('RESULTS:')
    print(f'accuracy = {metrics.accuracy_score(y_true=y_true, y_pred=y_pred)}')
    print(f'balanced accuracy = {metrics.balanced_accuracy_score(y_true=y_true, y_pred=y_pred)}')
    print(f'f1 score = {metrics.f1_score(y_true=y_true, y_pred=y_pred, average="weighted")}')

In [None]:
model_dir = os.path.join('models', dataset_name)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

pred_dir = os.path.join('predictions', dataset_name)
if not os.path.exists(pred_dir):
    os.makedirs(pred_dir)

In [None]:
train_data = pd.read_csv(f'data/train/{dataset_name}.csv')
test_data = pd.read_csv(f'data/test/{dataset_name}.csv')

In [None]:
genres = np.unique(train_data.genre)
label_encoder = preprocessing.LabelEncoder()
label_encoder.fit(genres)
label_encoder.classes_

In [None]:
bert_model_sm = create_bert_model(bert_enc_sm, bert_prepr_sm, len(label_encoder.classes_))

In [None]:
bert_model = create_bert_model(bert_enc, bert_prepr, len(label_encoder.classes_))

In [None]:
hist_sm = train_bert(train_data.lyrics, train_data.genre, 
                     test_data.lyrics, test_data.genre,
                     bert_model_sm, label_encoder, epochs=1, filename="small_bert_tuned")

In [None]:
preds = test_bert(test_data.lyrics, bert_model_sm, label_encoder, filename="small_bert_tuned")
get_results(test_data.genre, preds)

In [None]:
hist = train_bert(train_data.lyrics, train_data.genre, 
                  test_data.lyrics, test_data.genre,
                  bert_model, label_encoder, epochs=1)

In [None]:
preds = test_bert(test_data.lyrics, bert_model, label_encoder)
get_results(test_data.genre, preds)