In [None]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import wandb
from datasets import load_from_disk
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import roc_auc_score, confusion_matrix, multilabel_confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import classification_report, plot_roc_curve
from argparse import Namespace

In [None]:
config = {
    "train_subset": 1500000,
    "valid_subset": 400000,
    "test_subset" : 200000,
    "seed": 42,
    "wandb_project_name": "rf_param_opt",
    "count": 20   # Number of runs for the Sweep
}

args = Namespace(**config)

### Define the Sweep config

In [None]:
sweep_config = {
    'method' : 'random',
    'metric': {
        'name': 'roc_auc',
        'goal': 'maximize'
    },
    'parameters': {
        'criterion': {'values': ['gini', 'entropy']},
        'n_estimators': {'values': [50, 500, 1000]},
        'min_samples_split': {'values': [10, 100, 1000]},
        'min_samples_leaf': {'values': [10, 50, 100]},
    }
}

### Log in to Weight and Biases

In [None]:
wandb.login()

### Read the dataset

In [None]:
violence_hidden = load_from_disk("../../Violence_data/geo_corpus.0.0.1_datasets_hidden_xlmt")

In [None]:
violence_hidden

In [None]:
# Remove unncesary columns
keep_cols = ['hidden_state', 'labels']
remove_columns = [col for col in violence_hidden['train'].column_names if col not in keep_cols]

In [None]:
violence_hidden = violence_hidden.remove_columns(remove_columns)

In [None]:
violence_hidden

In [None]:
# Extract a subset of the dataset
train_clf_ds = violence_hidden["train"].shuffle(args.seed).select(range(args.train_subset))
validation_clf_ds = violence_hidden["validation"].shuffle(args.seed).select(range(args.valid_subset))
test_clf_ds = violence_hidden["test"].shuffle(args.seed).select(range(args.test_subset))

In [None]:
train_clf_ds[0]

### Create a feature matrix

In [None]:
%time X_train = np.array(train_clf_ds["hidden_state"])
y_train = np.array(train_clf_ds["labels"])
X_validation = np.array(validation_clf_ds["hidden_state"])
y_validation = np.array(validation_clf_ds["labels"])
X_test = np.array(test_clf_ds["hidden_state"])
y_test = np.array(test_clf_ds["labels"])
X_train.shape

## Helper Functions

In [None]:
# Create a function to report the various metrics for each classifier
def metricsReport(test_labels, predictions):
    
    roc_auc = roc_auc_score(test_labels, predictions, average = "micro")

    weighted_precision = precision_score(test_labels, predictions, average='weighted')
    weighted_recall = recall_score(test_labels, predictions, average='weighted')
    weighted_f1 = f1_score(test_labels, predictions, average='weighted')

    micro_precision = precision_score(test_labels, predictions, average='micro')
    micro_recall = recall_score(test_labels, predictions, average='micro')
    micro_f1 = f1_score(test_labels, predictions, average='micro')
    
    
    return {"roc_auc": format(roc_auc, '.4f'), "weighted_precision": format(weighted_precision, '.4f'),
                                    "weighted_recall": format(weighted_recall, '.4f'), "weighted_f1": format(weighted_f1, '.4f'), "micro_precision": format(micro_precision, '.4f'),
                                    "micro_recall": format(micro_recall, '.4f'), "micro_f1": format(micro_f1, '.4f')}

In [None]:
def plot_confusion_matrix(y_test, y_pred, clf:str):
    f, axes = plt.subplots(2, 3, figsize=(25, 15))
    f.suptitle(clf, fontsize=36)
    axes = axes.ravel()
    for i in range(6):
        labels=['post7geo10', 'post7geo30', 'post7geo50','pre7geo10','pre7geo30', 'pre7geo50']
        disp = ConfusionMatrixDisplay(confusion_matrix(y_test[:, i],
                                                       y_pred[:, i]),
                                      display_labels=[0, i])
        disp.plot(ax=axes[i], values_format='.4g')
        disp.ax_.set_title(labels[i])
        if i<10:
            disp.ax_.set_xlabel('')
        if i%5!=0:
            disp.ax_.set_ylabel('')
        disp.im_.colorbar.remove()

    plt.subplots_adjust(wspace=0.10, hspace=0.1)
    f.colorbar(disp.im_, ax=axes)
    plt.show()

# Random Forest Classifier

In [None]:
def train():
    with wandb.init():
        config = wandb.config
        rfClassifier = RandomForestClassifier(
            n_jobs=-1, 
            random_state=args.seed,
            n_estimators = config.n_estimators,
            min_samples_split = config.min_samples_split,
            min_samples_leaf = config.min_samples_leaf)
        rfClassifier.fit(X_train, y_train)
        rfPreds = rfClassifier.predict(X_test)
        scores = metricsReport(y_test, rfPreds)
        roc_auc = float(scores["roc_auc"])
        prec_w = float(scores["weighted_precision"])
        recall_w = float(scores["weighted_recall"])
        f1_w = float(scores["weighted_f1"])
        prec_micro = float(scores["micro_precision"])
        recall_micro = float(scores["micro_recall"])
        f1_micro = float(scores["micro_f1"])
    
        wandb.log({'roc_auc': roc_auc,
              'weighted_precision': prec_w,
              'weighted_recall': recall_w,
              'weighted_f1': f1_w,
              'micro_precision': prec_micro,
              'micro_recall': recall_micro,
              'micro_f1': f1_micro})

### Initialize the Sweep

In [None]:
sweep_id = wandb.sweep(sweep=sweep_config, project=args.wandb_project_name)

### Start the Sweep agent

In [None]:
wandb.agent(sweep_id, function=train, count=args.count)