In [None]:
from tensorflow.keras.models import load_model
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score
import numpy as np

In [None]:
# Load the model
model = load_model('community_classification_model.h5')

# Load the untagged community posts
new_data = pd.read_csv('./../data/compiled-posts/validation_data_unlabelled_binary.csv')

X_new = new_data[['Toxicity','Rationality','Mutual Respect','Emotion','Moderator','Diversity']].values

# Make predictions with the model
predictions = model.predict(X_new)
new_data[['Combative', 'Deliberative']] = predictions

# Validation metrics for the model

In [None]:
# Load the community posts that were used to predict but with the actual labels
actual_labels =  pd.read_csv('./../data/compiled-posts/validation_posts_binary.csv')
actual_labels = actual_labels[['Combative', 'Deliberative']].values
binary_predictions = (predictions > 0.5).astype(int)
for i in range(binary_predictions.shape[0]):
    max_index = np.argmax(predictions[i])
    binary_predictions[i] = 0
    binary_predictions[i][max_index] = 1



In [None]:
# print the actual_labels that are [0 0] or [1 1]
for i in binary_predictions:
    if i[0] == 1 and i[1] == 1:
        print(i)
    if i[0] == 0 and i[1] == 0:
        print(i)

In [None]:
binary_predictions.shape

In [None]:
# Calculate the accuracy of the model 
accuracy = accuracy_score(actual_labels, binary_predictions)
precision = precision_score(actual_labels, binary_predictions, average='weighted')
recall = recall_score(actual_labels, binary_predictions, average='weighted')
f1 = f1_score(actual_labels, binary_predictions, average='weighted')
roc_auc = roc_auc_score(actual_labels, predictions)
conf_matrix = confusion_matrix(actual_labels.argmax(axis=1), binary_predictions.argmax(axis=1))


In [None]:


# Print the metrics
print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("F1 Score:", f1)
print("Confusion Matrix:\n", conf_matrix)
print("ROC AUC:", roc_auc)