In [1]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import wandb
from datasets import load_from_disk
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import roc_auc_score, confusion_matrix, multilabel_confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import classification_report, plot_roc_curve
from argparse import Namespace

In [2]:
config = {
    "train_subset": 1500000,
    "valid_subset": 400000,
    "test_subset" : 200000,
    "seed": 42,
    "wandb_project_name": "rf_param_opt",
    "count": 20   # Number of runs for the Sweep
}

args = Namespace(**config)

### Define the Sweep config

In [3]:
sweep_config = {
    'method' : 'random',
    'metric': {
        'name': 'roc_auc',
        'goal': 'maximize'
    },
    'parameters': {
        'criterion': {'values': ['gini', 'entropy']},
        'n_estimators': {'values': [50, 500, 1000]},
        'min_samples_split': {'values': [10, 100, 1000]},
        'min_samples_leaf': {'values': [10, 50, 100]},
    }
}

### Log in to Weight and Biases

In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mm2im[0m ([33mnpsdaor[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

### Read the dataset

In [5]:
violence_hidden = load_from_disk("../../Violence_data/geo_corpus.0.0.1_datasets_hidden_xlmt")

In [6]:
violence_hidden

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask', 'hidden_state'],
        num_rows: 16769932
    })
    validation: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask', 'hidden_state'],
        num_rows: 4192483
    })
    test: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask', 'hidden_state'],
        num_rows: 2329158
    })
})

In [7]:
# Remove unncesary columns
keep_cols = ['hidden_state', 'labels']
remove_columns = [col for col in violence_hidden['train'].column_names if col not in keep_cols]

In [8]:
violence_hidden = violence_hidden.remove_columns(remove_columns)

In [9]:
violence_hidden

DatasetDict({
    train: Dataset({
        features: ['labels', 'hidden_state'],
        num_rows: 16769932
    })
    validation: Dataset({
        features: ['labels', 'hidden_state'],
        num_rows: 4192483
    })
    test: Dataset({
        features: ['labels', 'hidden_state'],
        num_rows: 2329158
    })
})

In [10]:
# Extract a subset of the dataset
train_clf_ds = violence_hidden["train"].shuffle(args.seed).select(range(args.train_subset))
validation_clf_ds = violence_hidden["validation"].shuffle(args.seed).select(range(args.valid_subset))
test_clf_ds = violence_hidden["test"].shuffle(args.seed).select(range(args.test_subset))

Loading cached shuffled indices for dataset at ../../Violence_data/geo_corpus.0.0.1_datasets_hidden_xlmt/train/cache-333532fe06490884.arrow
Loading cached shuffled indices for dataset at ../../Violence_data/geo_corpus.0.0.1_datasets_hidden_xlmt/validation/cache-879cab607103254e.arrow
Loading cached shuffled indices for dataset at ../../Violence_data/geo_corpus.0.0.1_datasets_hidden_xlmt/test/cache-9d513c2ddd67e482.arrow


In [None]:
train_clf_ds[0]

### Create a feature matrix

In [11]:
%time X_train = np.array(train_clf_ds["hidden_state"])
y_train = np.array(train_clf_ds["labels"])
X_validation = np.array(validation_clf_ds["hidden_state"])
y_validation = np.array(validation_clf_ds["labels"])
X_test = np.array(test_clf_ds["hidden_state"])
y_test = np.array(test_clf_ds["labels"])
X_train.shape

CPU times: user 24 s, sys: 20.6 s, total: 44.6 s
Wall time: 29.9 s


(1500000, 768)

## Helper Functions

In [12]:
# Create a function to report the various metrics for each classifier
def metricsReport(test_labels, predictions):
    
    roc_auc = roc_auc_score(test_labels, predictions, average = "micro")

    weighted_precision = precision_score(test_labels, predictions, average='weighted')
    weighted_recall = recall_score(test_labels, predictions, average='weighted')
    weighted_f1 = f1_score(test_labels, predictions, average='weighted')

    micro_precision = precision_score(test_labels, predictions, average='micro')
    micro_recall = recall_score(test_labels, predictions, average='micro')
    micro_f1 = f1_score(test_labels, predictions, average='micro')
    
    
    return {"roc_auc": format(roc_auc, '.4f'), "weighted_precision": format(weighted_precision, '.4f'),
                                    "weighted_recall": format(weighted_recall, '.4f'), "weighted_f1": format(weighted_f1, '.4f'), "micro_precision": format(micro_precision, '.4f'),
                                    "micro_recall": format(micro_recall, '.4f'), "micro_f1": format(micro_f1, '.4f')}

In [13]:
def plot_confusion_matrix(y_test, y_pred, clf:str):
    f, axes = plt.subplots(2, 3, figsize=(25, 15))
    f.suptitle(clf, fontsize=36)
    axes = axes.ravel()
    for i in range(6):
        labels=['post7geo10', 'post7geo30', 'post7geo50','pre7geo10','pre7geo30', 'pre7geo50']
        disp = ConfusionMatrixDisplay(confusion_matrix(y_test[:, i],
                                                       y_pred[:, i]),
                                      display_labels=[0, i])
        disp.plot(ax=axes[i], values_format='.4g')
        disp.ax_.set_title(labels[i])
        if i<10:
            disp.ax_.set_xlabel('')
        if i%5!=0:
            disp.ax_.set_ylabel('')
        disp.im_.colorbar.remove()

    plt.subplots_adjust(wspace=0.10, hspace=0.1)
    f.colorbar(disp.im_, ax=axes)
    plt.show()

# Random Forest Classifier

In [14]:
def train():
    with wandb.init():
        config = wandb.config
        rfClassifier = RandomForestClassifier(
            n_jobs=-1, 
            random_state=args.seed,
            n_estimators = config.n_estimators,
            min_samples_split = config.min_samples_split,
            min_samples_leaf = config.min_samples_leaf)
        rfClassifier.fit(X_train, y_train)
        rfPreds = rfClassifier.predict(X_test)
        scores = metricsReport(y_test, rfPreds)
        roc_auc = float(scores["roc_auc"])
        prec_w = float(scores["weighted_precision"])
        recall_w = float(scores["weighted_recall"])
        f1_w = float(scores["weighted_f1"])
        prec_micro = float(scores["micro_precision"])
        recall_micro = float(scores["micro_recall"])
        f1_micro = float(scores["micro_f1"])
    
        wandb.log({'roc_auc': roc_auc,
              'weighted_precision': prec_w,
              'weighted_recall': recall_w,
              'weighted_f1': f1_w,
              'micro_precision': prec_micro,
              'micro_recall': recall_micro,
              'micro_f1': f1_micro})

### Initialize the Sweep

In [15]:
sweep_id = wandb.sweep(sweep=sweep_config, project=args.wandb_project_name)

Create sweep with ID: e9ww50tw
Sweep URL: https://wandb.ai/npsdaor/rf_param_opt/sweeps/e9ww50tw


### Start the Sweep agent

In [16]:
wandb.agent(sweep_id, function=train, count=args.count)

[34m[1mwandb[0m: Agent Starting Run: 8yo4teer with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 50
[34m[1mwandb[0m: 	min_samples_split: 100
[34m[1mwandb[0m: 	n_estimators: 50


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5505
micro_precision,0.579
micro_recall,0.5246
roc_auc,0.6003
weighted_f1,0.4647
weighted_precision,0.5966
weighted_recall,0.5246


[34m[1mwandb[0m: Agent Starting Run: gde3904h with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 10
[34m[1mwandb[0m: 	min_samples_split: 10
[34m[1mwandb[0m: 	n_estimators: 50


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5537
micro_precision,0.5805
micro_recall,0.5293
roc_auc,0.6022
weighted_f1,0.488
weighted_precision,0.5897
weighted_recall,0.5293


[34m[1mwandb[0m: Agent Starting Run: dmmawwfd with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 50
[34m[1mwandb[0m: 	min_samples_split: 1000
[34m[1mwandb[0m: 	n_estimators: 1000


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.140779…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5366
micro_precision,0.5769
micro_recall,0.5015
roc_auc,0.5946
weighted_f1,0.4338
weighted_precision,0.599
weighted_recall,0.5015


[34m[1mwandb[0m: Agent Starting Run: 79fu8zqe with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 10
[34m[1mwandb[0m: 	min_samples_split: 1000
[34m[1mwandb[0m: 	n_estimators: 1000


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.140700…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.537
micro_precision,0.5773
micro_recall,0.5021
roc_auc,0.5949
weighted_f1,0.4346
weighted_precision,0.5997
weighted_recall,0.5021


[34m[1mwandb[0m: Agent Starting Run: kgfa8rks with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 50
[34m[1mwandb[0m: 	min_samples_split: 10
[34m[1mwandb[0m: 	n_estimators: 500


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.148330…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.547
micro_precision,0.5815
micro_recall,0.5164
roc_auc,0.6004
weighted_f1,0.4559
weighted_precision,0.6029
weighted_recall,0.5164


[34m[1mwandb[0m: Agent Starting Run: 2ibe6wvp with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 10
[34m[1mwandb[0m: 	min_samples_split: 100
[34m[1mwandb[0m: 	n_estimators: 1000


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5494
micro_precision,0.5834
micro_recall,0.5191
roc_auc,0.6022
weighted_f1,0.461
weighted_precision,0.6054
weighted_recall,0.5191


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: m4072jvu with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 100
[34m[1mwandb[0m: 	min_samples_split: 1000
[34m[1mwandb[0m: 	n_estimators: 1000


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.140595…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5359
micro_precision,0.5767
micro_recall,0.5005
roc_auc,0.5943
weighted_f1,0.4326
weighted_precision,0.5987
weighted_recall,0.5005


[34m[1mwandb[0m: Agent Starting Run: 3a4y1qpp with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 10
[34m[1mwandb[0m: 	min_samples_split: 1000
[34m[1mwandb[0m: 	n_estimators: 50


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.140805…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5388
micro_precision,0.5761
micro_recall,0.506
roc_auc,0.5949
weighted_f1,0.4386
weighted_precision,0.5973
weighted_recall,0.506


[34m[1mwandb[0m: Agent Starting Run: wq6fu585 with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 10
[34m[1mwandb[0m: 	min_samples_split: 1000
[34m[1mwandb[0m: 	n_estimators: 1000


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.537
micro_precision,0.5773
micro_recall,0.5021
roc_auc,0.5949
weighted_f1,0.4346
weighted_precision,0.5997
weighted_recall,0.5021


[34m[1mwandb[0m: Agent Starting Run: 7yzplijb with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 10
[34m[1mwandb[0m: 	min_samples_split: 10
[34m[1mwandb[0m: 	n_estimators: 50


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.148475…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5537
micro_precision,0.5805
micro_recall,0.5293
roc_auc,0.6022
weighted_f1,0.488
weighted_precision,0.5897
weighted_recall,0.5293


[34m[1mwandb[0m: Agent Starting Run: s2hmzqkg with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 50
[34m[1mwandb[0m: 	min_samples_split: 10
[34m[1mwandb[0m: 	n_estimators: 500


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.148359…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.547
micro_precision,0.5815
micro_recall,0.5164
roc_auc,0.6004
weighted_f1,0.4559
weighted_precision,0.6029
weighted_recall,0.5164


[34m[1mwandb[0m: Agent Starting Run: 9pagtrgb with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 50
[34m[1mwandb[0m: 	min_samples_split: 100
[34m[1mwandb[0m: 	n_estimators: 1000


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5468
micro_precision,0.5817
micro_recall,0.5159
roc_auc,0.6004
weighted_f1,0.4554
weighted_precision,0.6034
weighted_recall,0.5159


[34m[1mwandb[0m: Agent Starting Run: zwsu664q with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 100
[34m[1mwandb[0m: 	min_samples_split: 100
[34m[1mwandb[0m: 	n_estimators: 500


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5426
micro_precision,0.5794
micro_recall,0.5102
roc_auc,0.5979
weighted_f1,0.4467
weighted_precision,0.6012
weighted_recall,0.5102


[34m[1mwandb[0m: Agent Starting Run: 5by07wwc with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 100
[34m[1mwandb[0m: 	min_samples_split: 10
[34m[1mwandb[0m: 	n_estimators: 50


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.148330…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5455
micro_precision,0.5778
micro_recall,0.5166
roc_auc,0.598
weighted_f1,0.4527
weighted_precision,0.5981
weighted_recall,0.5166


[34m[1mwandb[0m: Agent Starting Run: t2d9kplp with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 100
[34m[1mwandb[0m: 	min_samples_split: 1000
[34m[1mwandb[0m: 	n_estimators: 1000


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5359
micro_precision,0.5767
micro_recall,0.5005
roc_auc,0.5943
weighted_f1,0.4326
weighted_precision,0.5987
weighted_recall,0.5005


[34m[1mwandb[0m: Agent Starting Run: lssc0jur with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 10
[34m[1mwandb[0m: 	min_samples_split: 1000
[34m[1mwandb[0m: 	n_estimators: 500


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5372
micro_precision,0.5773
micro_recall,0.5024
roc_auc,0.595
weighted_f1,0.435
weighted_precision,0.5997
weighted_recall,0.5024


[34m[1mwandb[0m: Agent Starting Run: qu0y4zg0 with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 50
[34m[1mwandb[0m: 	min_samples_split: 1000
[34m[1mwandb[0m: 	n_estimators: 50


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.538
micro_precision,0.5761
micro_recall,0.5047
roc_auc,0.5947
weighted_f1,0.437
weighted_precision,0.5978
weighted_recall,0.5047


[34m[1mwandb[0m: Agent Starting Run: gxvnahlf with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 10
[34m[1mwandb[0m: 	min_samples_split: 1000
[34m[1mwandb[0m: 	n_estimators: 500


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5372
micro_precision,0.5773
micro_recall,0.5024
roc_auc,0.595
weighted_f1,0.435
weighted_precision,0.5997
weighted_recall,0.5024


[34m[1mwandb[0m: Agent Starting Run: mkc3kpzu with config:
[34m[1mwandb[0m: 	criterion: entropy
[34m[1mwandb[0m: 	min_samples_leaf: 10
[34m[1mwandb[0m: 	min_samples_split: 10
[34m[1mwandb[0m: 	n_estimators: 50


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.148359…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5537
micro_precision,0.5805
micro_recall,0.5293
roc_auc,0.6022
weighted_f1,0.488
weighted_precision,0.5897
weighted_recall,0.5293


[34m[1mwandb[0m: Agent Starting Run: x3agtxbf with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 50
[34m[1mwandb[0m: 	min_samples_split: 100
[34m[1mwandb[0m: 	n_estimators: 50


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.148475…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.5505
micro_precision,0.579
micro_recall,0.5246
roc_auc,0.6003
weighted_f1,0.4647
weighted_precision,0.5966
weighted_recall,0.5246
