# Model Training and Evaluation
In this notebook, we will train machine learning models (Logistic Regression, SVM, and KNN) using the features extracted from the vectorisation process. We will evaluate their performance using K-fold cross validation.

## Set Up Dependencies

In [29]:
import pandas as pd
import numpy as np
import pickle
from scipy import sparse
from sklearn.model_selection import cross_validate, KFold
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import make_scorer, accuracy_score, precision_score, recall_score, f1_score

## Load Text Representations and Labels

Retrieve the DataFrame

In [47]:
with open('data/df_balanced.pkl', 'rb') as f:
    df_balanced = pickle.load(f)

df_balanced.head()

Unnamed: 0,class,tweet,cleaned_text,Word2Vec,TF-IDF,SentenceTrans
0,0,I LOVE my 10 &amp; 5 but most days they remind...,I LOVE my 10 &amp; 5 but most days they remind...,"[love, my, amp, but, most, day, they, remind, ...",love my amp but most day they remind me why bi...,i love my amp but most days they remind me w...
1,1,She be thinking she throwing that pussy back s...,She be thinking she throwing that pussy back s...,"[she, be, think, she, throw, that, pussy, back...",she be think she throw that pussy back so good...,she be thinking she throwing that pussy back s...
2,1,RT @lamessican: I love when bitches throw shad...,I love when bitches throw shade. Just confirms...,"[love, when, bitch, throw, shade, just, confir...",love when bitch throw shade just confirm do so...,i love when bitches throw shade just confirms ...
3,1,"If you ain't a hoe, get up out my trap house @...","If you ain't a hoe, get up out my trap house .","[if, you, ain, hoe, get, up, out, my, trap, ho...",if you ain hoe get up out my trap house,if you aint a hoe get up out my trap house
4,0,Just hit 40 in flappy bird.&#128527;,Just hit 40 in flappy bird.&#128527;,"[just, hit, in, flappy, bird]",just hit in flappy bird,just hit in flappy bird


Retrieve the Text Representations

In [73]:
x_tfidf = sparse.load_npz('x_tfidf.npz')
x_w2v = np.load('representations/x_w2v.npy', allow_pickle=True)
x_st = np.load('representations/x_st.npy', allow_pickle=True)

Retrieve the labels

In [25]:
y = df_balanced['class']

## Initialise K-Fold Cross Validation

In [75]:
cv = KFold(n_splits=5, shuffle=True, random_state=42)
scoring = {
    'accuracy': 'accuracy',
    'precision': 'precision',
    'recall': 'recall',
    'f1': 'f1'
}

In [74]:
print("TF-IDF shape:", x_tfidf.shape)
print("Word2Vec shape:", x_w2v.shape)
print("Sentence Transformers shape:", x_st.shape)

TF-IDF shape: (8326, 7924)
Word2Vec shape: (8326, 100)
Sentence Transformers shape: (8326, 384)


## Classification

### Logistic Regression

In [77]:
log_reg = LogisticRegression()

# TF-IDF
TFIDF_logistic_scores = cross_validate(log_reg, x_tfidf, y, cv=cv, scoring=scoring)
print(f'TF-IDF:{TFIDF_logistic_scores}')

# Word2Vec
W2V_logistic_scores = cross_validate(log_reg, x_w2v, y, cv=cv, scoring=scoring)
print(f'W2V:{W2V_logistic_scores}')

# Sentence Transformers
ST_logistic_scores = cross_validate(log_reg, x_st, y, cv=cv, scoring=scoring)
print(f'ST:{ST_logistic_scores}')

TF-IDF:{'fit_time': array([0.02603626, 0.02750111, 0.01592422, 0.00698638, 0.01618481]), 'score_time': array([0.00999951, 0.00598073, 0.00598359, 0.01771045, 0.        ]), 'test_accuracy': array([0.94597839, 0.93093093, 0.92912913, 0.94294294, 0.94654655]), 'test_precision': array([0.97875   , 0.96942675, 0.97880795, 0.96875   , 0.96833773]), 'test_recall': array([0.91471963, 0.89319249, 0.87871581, 0.91288344, 0.91864831]), 'test_f1': array([0.94565217, 0.92974954, 0.92606516, 0.93998737, 0.94283879])}
W2V:{'fit_time': array([0.01637459, 0.03666639, 0.03224373, 0.02982116, 0.01637244]), 'score_time': array([0.        , 0.        , 0.00191164, 0.        , 0.01647949]), 'test_accuracy': array([0.8547419 , 0.83423423, 0.84624625, 0.84324324, 0.83723724]), 'test_precision': array([0.86899038, 0.84698795, 0.86335404, 0.83698297, 0.82432432]), 'test_recall': array([0.84462617, 0.82511737, 0.82639715, 0.84417178, 0.83979975]), 'test_f1': array([0.85663507, 0.83590963, 0.84447145, 0.840562  ,