## Random forest

This notebook contains:
- Code for training a random forest ensamble model to classify texts as being either LLM-generated or human-written.
- Code for calculating and visualizing SHAP values (feature importance)

In [3]:
import pandas as pd
from scipy.sparse import load_npz
from scipy.sparse import csr_matrix
from scipy.sparse import hstack
import pickle
from sklearn.metrics import f1_score
import numpy as np
import shap
import pickle
from sklearn.ensemble import RandomForestClassifier



In [2]:
train_data_x_sparse = load_npz('../data/train_data_x_sparse.npz')
train_data_x_dense = pd.read_csv('../data/train_data_x_dense.csv')
pretest_data_x_sparse = load_npz('../data/pretest_data_x_sparse.npz')
pretest_data_x_dense = pd.read_csv('../data/pretest_data_x_dense.csv')
test_data_x_sparse = load_npz('../data/test_data_x_sparse.npz')
test_data_x_dense = pd.read_csv('../data/test_data_x_dense.csv')

train_data_x = pd.DataFrame(hstack([train_data_x_sparse, csr_matrix(train_data_x_dense.values)]).toarray())
pretest_data_x = pd.DataFrame(hstack([pretest_data_x_sparse, csr_matrix(pretest_data_x_dense.values)]).toarray())
test_data_x = pd.DataFrame(hstack([test_data_x_sparse, csr_matrix(test_data_x_dense.values)]).toarray())
train_data_y = np.ravel(pd.read_csv('../data/train_data_y.csv'))
pretest_data_y = np.ravel(pd.read_csv('../data/pretest_data_y.csv'))
test_data_y = np.ravel(pd.read_csv('../data/test_data_y.csv'))

with open('../data/sparse_matrices_feature_names.pkl', 'rb') as f:
    sparse_matrices_feature_names = pickle.load(f)

all_feature_names = list(sparse_matrices_feature_names) + list(train_data_x_dense.columns)
dense_feature_indices = [all_feature_names.index(feature) for feature in train_data_x_dense.columns]
dense_feature_names = list(train_data_x_dense.columns) # Engineered features
sparse_feature_indices = [all_feature_names.index(feature) for feature in sparse_matrices_feature_names]
sparse_feature_names = list(sparse_matrices_feature_names) # TF-IDF features

In [4]:

rf_classifier = RandomForestClassifier(n_estimators=100, random_state=28)
rf_classifier.fit(train_data_x, train_data_y)


In [6]:
train_data_f1 = f1_score(train_data_y, rf_classifier.predict(train_data_x))
pretest_data_f1 = f1_score(pretest_data_y, rf_classifier.predict(pretest_data_x))
test_data_f1 = f1_score(test_data_y, rf_classifier.predict(test_data_x))

print(f"{'':<20s}{'|':<12s}{'Random forest':<23s}{'|'}")
print(f"{'-' * 20}{'|'}{'-' * 17}{'+'}{'-' * 16}{'|'}")
print(f"{'F1 score train':<20s}{'|':<15s}{train_data_f1:<20.4f}{'|'}")
print(f"{'F1 score pretest':<20s}{'|':<15s}{pretest_data_f1:<20.4f}{'|'}")
print(f"{'F1 score test':<20s}{'|':<15s}{test_data_f1:<20.4f}{'|'}")

                    |           Random forest          |
--------------------|-----------------+----------------|
F1 score train      |              1.0000              |
F1 score pretest    |              0.9112              |
F1 score test       |              0.9543              |


In [7]:
# Samples used for calculating SHAP values

shap_samples = shap.sample(test_data_x, 100, random_state=42)
dense_shap_samples = shap_samples.iloc[:, dense_feature_indices]
sparse_shap_samples = shap_samples.iloc[:, sparse_feature_indices]

In [None]:
# Use Shap to explain the model

explainer = shap.KernelExplainer(rf_classifier, shap_samples)

shap_values = explainer.shap_values(shap_samples)
dense_shap_values = shap_values[:, dense_feature_indices]
sparse_shap_values = shap_values[:, sparse_feature_indices]

In [None]:
# Visualizing feature importance through beeswarm plots

shap.beeswarm(
    shap.Explanation(values=shap_values, data=shap_samples, feature_names=all_feature_names),
    max_display=20,
    plot_size=(10, 6)
)

shap.plots.beeswarm(
    shap.Explanation(values=dense_shap_values, data=dense_shap_samples, feature_names=dense_feature_names),
    max_display=20,
    plot_size=(10, 6)
)

shap.plots.beeswarm(
    shap.Explanation(values=sparse_shap_values, data=sparse_shap_samples, feature_names=sparse_feature_names),
    max_display=20,
    plot_size=(10, 6)
)