In [1]:
import numpy as np
import tensorflow as tf

from string import punctuation
from nltk.corpus import stopwords
from tensorflow.python import keras
from tensorflow.python.keras import layers
from tensorflow.python.keras.preprocessing.text import Tokenizer
from tensorflow.python.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import confusion_matrix, accuracy_score
from imblearn.over_sampling import SMOTE

In [2]:
def clean_lyric_tokens(tokens):
    #remove punctuation from each token
    table = str.maketrans('', '', punctuation)
    tokens = [w.translate(table) for w in tokens]
    #remove tokens that are not alphabetic
    tokens = [word for word in tokens if word.isalpha()]
    #filter out stop words
    stop_words = set(stopwords.words('english'))
    tokens = [word for word in tokens if not word in stop_words]
    #filter out short tokens
    tokens = [word for word in tokens if len(word) > 2]

    return tokens

def tokens_to_line(tokens, vocab):
    # clean doc
    tokens = clean_lyric_tokens(tokens)
    # filter by vocab
    tokens = [w for w in tokens if w in vocab]
    
    return ' '.join(tokens)

def baseline_model():
    
    model = keras.Sequential([
        layers.Dense(64, input_shape=(4563,), activation=tf.nn.relu),
        layers.Dropout(0.5),
        layers.Dense(64, activation=tf.nn.relu),
        layers.Dropout(0.5),
        layers.Dense(64, activation=tf.nn.relu),
        layers.Dense(5, activation=tf.nn.softmax)
    ])

    model.compile(optimizer='adam', 
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model

In [3]:
vocab = ''
with open("vocab.txt") as f:
    vocab = f.read()

vocab = vocab.split()
vocab = set(vocab)

lyrics_dataset = {}
with open("lyrics_dataset.txt") as f:
    lyrics_dataset = dict(x.rstrip().split(None, 1) for x in f)

lines = []
labels = []
for id in lyrics_dataset.keys():
    lyrics_dataset[id] = eval(lyrics_dataset[id])
    line = tokens_to_line(lyrics_dataset[id][0], vocab)
    lines.append(line)
    labels.append(lyrics_dataset[id][1])

labels = np.array(labels)
categories, target_label = np.unique(labels, return_inverse=True)

print(categories)
target_label

['ambiguous' 'angry' 'happy' 'relaxed' 'sad']


array([2, 2, 2, ..., 1, 1, 1])

In [4]:
tokenizer = Tokenizer()
tokenizer.fit_on_texts(lines)

dataset = tokenizer.texts_to_matrix(lines, mode='tfidf')
print(dataset.shape)
dataset

(7197, 4563)


array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 2.53212838, 1.39257974, ..., 0.        , 0.        ,
        0.        ],
       [0.        , 2.53212838, 2.35784246, ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 3.14847665, 0.        , ..., 0.        , 0.        ,
        0.        ]])

In [5]:
sm = SMOTE(sampling_strategy={0:3115, 1:3115, 2:3115, 3:3115, 4:3115}, random_state=7)
dataset_bal, target_label_bal = sm.fit_sample(dataset, target_label)

np.bincount(target_label_bal)

array([3115, 3115, 3115, 3115, 3115])

In [6]:
estimator = KerasClassifier(build_fn=baseline_model, epochs=100, batch_size=128, verbose=0)
kfold = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=7)

for train_index, test_index in kfold.split(dataset_bal, target_label_bal):
    X_tr, X_tes = dataset_bal[train_index], dataset_bal[test_index]
    y_tr, y_tes = target_label_bal[train_index], target_label_bal[test_index]
    estimator.fit(X_tr, y_tr) 

    y_pred=estimator.predict(X_tes)
    acc = accuracy_score(y_tes, y_pred)
    cnf_matrix = confusion_matrix(y_tes, y_pred)
    print(acc)
    print(cnf_matrix)

Instructions for updating:
keep_dims is deprecated, use keepdims instead
0.8401284109149277
[[615   0   1   1   6]
 [  0 588  18   0  17]
 [ 13  23 410   1 176]
 [  1   0   0 622   0]
 [ 28  45 160   8 382]]
[0.         1.20657274 0.         ... 0.         0.         0.        ] 2 4
[0. 0. 0. ... 0. 0. 0.] 2 4
[0.         1.20657274 2.35784246 ... 0.         0.         0.        ] 4 2
[0. 0. 0. ... 0. 0. 0.] 2 4
[0.         1.20657274 0.         ... 0.         0.         0.        ] 2 4
[0.         2.04290523 4.45239016 ... 0.         0.         0.        ] 4 2
[0.         2.53212838 0.         ... 0.         0.         0.        ] 2 4
[0. 0. 0. ... 0. 0. 0.] 2 4
[0. 0. 0. ... 0. 0. 0.] 4 2
[0.         3.36846087 0.         ... 0.         0.         0.        ] 4 2
[0.         0.         4.10241478 ... 0.         0.         0.        ] 2 4
[0.         1.20657274 4.96447624 ... 0.         0.         0.        ] 4 2
[0.         0.         5.92973896 ... 0.         0.         0.        ] 

[0.         1.20657274 2.35784246 ... 0.         0.         0.        ] 2 4
[0.         0.         3.88774767 ... 0.         0.         0.        ] 4 2
[0.         3.36846087 4.10241478 ... 0.         0.         0.        ] 4 3
[0. 0. 0. ... 0. 0. 0.] 0 4
[0. 0. 0. ... 0. 0. 0.] 4 1
[0.         1.20657274 1.39257974 ... 0.         0.         0.        ] 2 4
[0.         1.20657274 1.39257974 ... 0.         0.         0.        ] 4 2
[0. 0. 0. ... 0. 0. 0.] 4 1
[0.         2.04290523 0.         ... 0.         0.         0.        ] 2 1
[0. 0. 0. ... 0. 0. 0.] 1 4
[0. 0. 0. ... 0. 0. 0.] 4 1
[0. 0. 0. ... 0. 0. 0.] 4 2
[0.         3.85768402 0.         ... 0.         0.         0.        ] 2 4
[0. 0. 0. ... 0. 0. 0.] 2 1
[0.         2.04290523 0.         ... 0.         0.         0.        ] 0 4
[0.         2.04290523 0.         ... 0.         0.         0.        ] 2 4
[0.         1.20657274 0.         ... 0.         0.         0.        ] 4 0
[0.         0.         1.39257974 ... 0.    

[0. 0. 0. ... 0. 0. 0.] 2 4
[0.         0.         1.39257974 ... 0.         0.         0.        ] 4 1
[0.         0.         2.92248495 ... 0.         0.         0.        ] 4 2
[0.         1.20657274 1.39257974 ... 0.         0.         0.        ] 1 2
[0.         2.04290523 0.         ... 0.         0.         0.        ] 2 4
[0. 0. 0. ... 0. 0. 0.] 2 0
[0.         1.20657274 0.         ... 0.         0.         0.        ] 1 4
[0. 0. 0. ... 0. 0. 0.] 4 2
[0.         2.87923773 0.         ... 0.         0.         0.        ] 4 2
[0.         1.20657274 0.         ... 0.         0.         0.        ] 4 1
[0.         2.53212838 0.         ... 0.         0.         0.        ] 2 1
[0.         1.20657274 0.         ... 0.         0.         0.        ] 2 4
[0.         0.         1.39257974 ... 0.         0.         0.        ] 4 2
[0. 0. 0. ... 0. 0. 0.] 2 4
[0. 0. 0. ... 0. 0. 0.] 2 4
[0.         1.20657274 0.         ... 0.         0.         0.        ] 2 4
[0. 0. 0. ... 0. 0. 0.] 