In [1]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import wandb
from datasets import load_from_disk
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import roc_auc_score, confusion_matrix, multilabel_confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import classification_report, plot_roc_curve
from sklearn.model_selection import ParameterGrid
from sklearn.multioutput import MultiOutputClassifier
from argparse import Namespace

In [2]:
config = {
    "train_subset": 1500,
    "valid_subset": 400,
    "test_subset" : 200,
    "seed": 42,
    "wandb_project_name": "rf_grid_search",
    "n_jobs": -1,  
}

args = Namespace(**config)

### Define the Sweep config

In [3]:
sweep_config = {
    'method' : 'grid',
    'metric': {
        'name': 'roc_auc_micro',
        'goal': 'maximize'
    },
    'parameters': {
        'criterion': {'values': ['gini', 'entropy']},
        'n_estimators': {'values': [100, 250, 500]},
        'min_samples_split': {'values': [10, 50, 100]},
        'min_samples_leaf': {'values': [5, 25, 50]},
    }
}

### Log in to Weight and Biases

In [4]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mm2im[0m ([33mnpsdaor[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

### Read the dataset

In [5]:
violence_hidden = load_from_disk("/data4/mmendieta/data/geo_corpus.0.0.1_datasets_hidden_e5_all_labels")

In [6]:
violence_hidden

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask', 'hidden_state'],
        num_rows: 1500000
    })
    validation: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask', 'hidden_state'],
        num_rows: 400000
    })
    test: Dataset({
        features: ['text', 'labels', 'input_ids', 'attention_mask', 'hidden_state'],
        num_rows: 200000
    })
})

In [7]:
# Remove unncesary columns
keep_cols = ['hidden_state', 'labels']
remove_columns = [col for col in violence_hidden['train'].column_names if col not in keep_cols]

In [8]:
violence_hidden = violence_hidden.remove_columns(remove_columns)

In [9]:
violence_hidden

DatasetDict({
    train: Dataset({
        features: ['labels', 'hidden_state'],
        num_rows: 1500000
    })
    validation: Dataset({
        features: ['labels', 'hidden_state'],
        num_rows: 400000
    })
    test: Dataset({
        features: ['labels', 'hidden_state'],
        num_rows: 200000
    })
})

In [None]:
# Extract a subset of the dataset
train_clf_ds = violence_hidden["train"].shuffle(args.seed).select(range(args.train_subset))
validation_clf_ds = violence_hidden["validation"].shuffle(args.seed).select(range(args.valid_subset))
test_clf_ds = violence_hidden["test"].shuffle(args.seed).select(range(args.test_subset))

In [10]:
# Some hidden datasets were stored with the required samples. There is no need to sample as in the above cell
train_clf_ds = violence_hidden["train"]
validation_clf_ds = violence_hidden["validation"]
test_clf_ds = violence_hidden["test"]

In [11]:
train_clf_ds[0]

{'labels': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.,
         0., 0., 1., 1.]),
 'hidden_state': tensor([ 0.6473,  0.3876, -1.1127,  ...,  0.6801, -0.8819,  0.6340])}

### Create a feature matrix

In [12]:
%time X_train = np.array(train_clf_ds["hidden_state"])
y_train = np.array(train_clf_ds["labels"])
X_validation = np.array(validation_clf_ds["hidden_state"])
y_validation = np.array(validation_clf_ds["labels"])
X_test = np.array(test_clf_ds["hidden_state"])
y_test = np.array(test_clf_ds["labels"])
X_train.shape

CPU times: user 6.96 s, sys: 26.4 s, total: 33.3 s
Wall time: 8.83 s


(1500000, 1024)

## Helper Functions

In [None]:
# Old function
# Create a function to report the various metrics for each classifier
def metricsReport(y_true, y_pred, y_probs):
    return {
        "roc_auc_micro": roc_auc_score(y_true, y_probs, average="micro"),
        "roc_auc_weighted": roc_auc_score(y_true, y_probs, average="weighted"),
        "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=0),
        "weighted_recall": recall_score(y_true, y_pred, average='weighted', zero_division=0),
        "weighted_f1": f1_score(y_true, y_pred, average='weighted', zero_division=0),
        "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=0),
        "micro_recall": recall_score(y_true, y_pred, average='micro', zero_division=0),
        "micro_f1": f1_score(y_true, y_pred, average='micro', zero_division=0),
    }

In [13]:
def metricsReport(y_true, y_pred, y_probs):
    """
    Computes and returns a dictionary of classification metrics.

    Args:
        y_true (array-like): True labels.
        y_pred (array-like): Binary predictions.
        y_probs (array-like): Prediction probabilities.
    
    Returns:
        dict: A dictionary of computed metrics.
    """
    # The MultiOutputClassifier.predict_proba returns a list of arrays.
    # We must stack them to get a single (n_samples, n_labels) array for roc_auc_score.
    if isinstance(y_probs, list):
        # We check for the number of dimensions to ensure we're stacking correctly.
        # Each array in the list is (n_samples, 2), so we take the second column (class 1).
        y_probs = np.stack([p[:, 1] for p in y_probs], axis=1)
    
    # Initialize ROC AUC scores as NaN in case of errors
    roc_auc_micro = np.nan
    roc_auc_weighted = np.nan
    
    try:
        # Compute ROC AUC scores
        roc_auc_micro = roc_auc_score(y_true, y_probs, average="micro")
        roc_auc_weighted = roc_auc_score(y_true, y_probs, average="weighted")
    except ValueError as e:
        print(f"Warning: Could not calculate ROC AUC. Error: {e}")

    # Compute other metrics
    return {
        "roc_auc_micro": roc_auc_micro,
        "roc_auc_weighted": roc_auc_weighted,
        "weighted_precision": precision_score(y_true, y_pred, average='weighted', zero_division=0),
        "weighted_recall": recall_score(y_true, y_pred, average='weighted', zero_division=0),
        "weighted_f1": f1_score(y_true, y_pred, average='weighted', zero_division=0),
        "micro_precision": precision_score(y_true, y_pred, average='micro', zero_division=0),
        "micro_recall": recall_score(y_true, y_pred, average='micro', zero_division=0),
        "micro_f1": f1_score(y_true, y_pred, average='micro', zero_division=0)
    }


In [None]:
def plot_confusion_matrix(y_test, y_pred, clf:str):
    f, axes = plt.subplots(2, 3, figsize=(25, 15))
    f.suptitle(clf, fontsize=36)
    axes = axes.ravel()
    for i in range(6):
        labels=['post7geo10', 'post7geo30', 'post7geo50','pre7geo10','pre7geo30', 'pre7geo50']
        disp = ConfusionMatrixDisplay(confusion_matrix(y_test[:, i],
                                                       y_pred[:, i]),
                                      display_labels=[0, i])
        disp.plot(ax=axes[i], values_format='.4g')
        disp.ax_.set_title(labels[i])
        if i<10:
            disp.ax_.set_xlabel('')
        if i%5!=0:
            disp.ax_.set_ylabel('')
        disp.im_.colorbar.remove()

    plt.subplots_adjust(wspace=0.10, hspace=0.1)
    f.colorbar(disp.im_, ax=axes)
    plt.show()

# Random Forest Classifier

In [None]:
# Old function
def train():
    with wandb.init():
        config = wandb.config
        
        rfClassifier = RandomForestClassifier(
            n_jobs=args.n_jobs, 
            random_state=args.seed,
            n_estimators = config.n_estimators,
            criterion = config.criterion,
            min_samples_split = config.min_samples_split,
            min_samples_leaf = config.min_samples_leaf)
        
        rfClassifier.fit(X_train, y_train)
        
        rfPreds = rfClassifier.predict(X_validation)
        rfProbs = rfClassifier.predict_proba(X_validation)
        
        # Convert list of (n_samples, 2) arrays -> single (n_samples, n_labels) array
        if isinstance(rfProbs, list):
            rfProbs = np.stack([p[:, 1] for p in rfProbs], axis=1)
        
        scores = metricsReport(y_validation, rfPreds, rfProbs)
    
        wandb.log(scores)

In [14]:
def train():
    """
    The main training function for a single W&B sweep trial.
    Initializes a MultiOutputClassifier with a RandomForestClassifier base.
    """
    with wandb.init():
        config = wandb.config
        
        # Correctly instantiate MultiOutputClassifier with RandomForestClassifier as the base.
        # This is essential for your multi-label problem.
        rfBaseClassifier = RandomForestClassifier(
            n_estimators=config.n_estimators,
            criterion=config.criterion,
            min_samples_split=config.min_samples_split,
            min_samples_leaf=config.min_samples_leaf,
            n_jobs=1, 
            random_state=args.seed
        )

        rfClassifier = MultiOutputClassifier(
            estimator=rfBaseClassifier,
            n_jobs=args.n_jobs # Parallelizes across the 40 labels
        )
        
        # Train the model
        rfClassifier.fit(X_train, y_train)
        
        # Make predictions and get probabilities on the validation set
        rfPreds = rfClassifier.predict(X_validation)
        rfProbs = rfClassifier.predict_proba(X_validation)
        
        # Compute scores using the corrected metricsReport function
        scores = metricsReport(y_validation, rfPreds, rfProbs)
        
        # Log the scores to W&B
        wandb.log(scores)

### Initialize the Sweep

In [15]:
sweep_id = wandb.sweep(sweep=sweep_config, project=args.wandb_project_name)

Create sweep with ID: l7kstfg7
Sweep URL: https://wandb.ai/npsdaor/rf_grid_search/sweeps/l7kstfg7


### Start the Sweep agent

In [None]:
wandb.agent(sweep_id, function=train)

[34m[1mwandb[0m: Agent Starting Run: 6rpgkegf with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 5
[34m[1mwandb[0m: 	min_samples_split: 10
[34m[1mwandb[0m: 	n_estimators: 100


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc_micro,▁
roc_auc_weighted,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.37402
micro_precision,0.61549
micro_recall,0.26864
roc_auc_micro,0.70996
roc_auc_weighted,0.58388
weighted_f1,0.26237
weighted_precision,0.58697
weighted_recall,0.26864


[34m[1mwandb[0m: Agent Starting Run: uas2wb30 with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 5
[34m[1mwandb[0m: 	min_samples_split: 10
[34m[1mwandb[0m: 	n_estimators: 250


VBox(children=(Label(value='0.001 MB of 0.006 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.122341…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc_micro,▁
roc_auc_weighted,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.37246
micro_precision,0.62005
micro_recall,0.26617
roc_auc_micro,0.71295
roc_auc_weighted,0.59175
weighted_f1,0.25634
weighted_precision,0.59647
weighted_recall,0.26617


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: yaqk2tfa with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 5
[34m[1mwandb[0m: 	min_samples_split: 10
[34m[1mwandb[0m: 	n_estimators: 500


wandb: Network error (ReadTimeout), entering retry loop.
wandb: Network error (ReadTimeout), entering retry loop.
wandb: Network error (ReadTimeout), entering retry loop.
wandb: Network error (ReadTimeout), entering retry loop.
wandb: Network error (ReadTimeout), entering retry loop.
wandb: Network error (ReadTimeout), entering retry loop.
wandb: Network error (ReadTimeout), entering retry loop.


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc_micro,▁
roc_auc_weighted,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.3709
micro_precision,0.62175
micro_recall,0.26428
roc_auc_micro,0.71403
roc_auc_weighted,0.59549
weighted_f1,0.25346
weighted_precision,0.59929
weighted_recall,0.26428


[34m[1mwandb[0m: Agent Starting Run: g4uzw4ms with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 5
[34m[1mwandb[0m: 	min_samples_split: 50
[34m[1mwandb[0m: 	n_estimators: 100


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc_micro,▁
roc_auc_weighted,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.36542
micro_precision,0.6208
micro_recall,0.25891
roc_auc_micro,0.71142
roc_auc_weighted,0.58704
weighted_f1,0.24288
weighted_precision,0.62658
weighted_recall,0.25891


[34m[1mwandb[0m: Agent Starting Run: iyzqllfu with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 5
[34m[1mwandb[0m: 	min_samples_split: 50
[34m[1mwandb[0m: 	n_estimators: 250


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.131211…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc_micro,▁
roc_auc_weighted,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.36346
micro_precision,0.62402
micro_recall,0.2564
roc_auc_micro,0.71331
roc_auc_weighted,0.59315
weighted_f1,0.23813
weighted_precision,0.63196
weighted_recall,0.2564


[34m[1mwandb[0m: Agent Starting Run: a1nhdf9y with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 5
[34m[1mwandb[0m: 	min_samples_split: 50
[34m[1mwandb[0m: 	n_estimators: 500


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc_micro,▁
roc_auc_weighted,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.3621
micro_precision,0.62516
micro_recall,0.25486
roc_auc_micro,0.71399
roc_auc_weighted,0.59576
weighted_f1,0.23605
weighted_precision,0.63441
weighted_recall,0.25486


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: eo28bitf with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 5
[34m[1mwandb[0m: 	min_samples_split: 100
[34m[1mwandb[0m: 	n_estimators: 100


[34m[1mwandb[0m: [32m[41mERROR[0m Error while calling W&B API: dial tcp 35.226.229.132:3307: connect: connection refused (<Response [500]>)
[34m[1mwandb[0m: [32m[41mERROR[0m Error while calling W&B API: dial tcp 35.226.229.132:3307: connect: connection refused (<Response [500]>)
[34m[1mwandb[0m: [32m[41mERROR[0m Error while calling W&B API: dial tcp 35.226.229.132:3307: connect: connection refused (<Response [500]>)
[34m[1mwandb[0m: Network error (HTTPError), entering retry loop.
wandb: ERROR Error while calling W&B API: dial tcp 35.226.229.132:3307: connect: connection refused (<Response [500]>)
wandb: ERROR Error while calling W&B API: dial tcp 35.226.229.132:3307: connect: connection refused (<Response [500]>)
wandb: ERROR Error while calling W&B API: dial tcp 35.226.229.132:3307: connect: connection refused (<Response [500]>)
[34m[1mwandb[0m: [32m[41mERROR[0m Error while calling W&B API: context deadline exceeded (<Response [500]>)


0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc_micro,▁
roc_auc_weighted,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.35915
micro_precision,0.62204
micro_recall,0.25245
roc_auc_micro,0.7114
roc_auc_weighted,0.5883
weighted_f1,0.23014
weighted_precision,0.64487
weighted_recall,0.25245


[34m[1mwandb[0m: Agent Starting Run: 1ml59m1a with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 5
[34m[1mwandb[0m: 	min_samples_split: 100
[34m[1mwandb[0m: 	n_estimators: 250


VBox(children=(Label(value='0.001 MB of 0.005 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.131188…

0,1
micro_f1,▁
micro_precision,▁
micro_recall,▁
roc_auc_micro,▁
roc_auc_weighted,▁
weighted_f1,▁
weighted_precision,▁
weighted_recall,▁

0,1
micro_f1,0.35693
micro_precision,0.62435
micro_recall,0.2499
roc_auc_micro,0.71287
roc_auc_weighted,0.5937
weighted_f1,0.22622
weighted_precision,0.64963
weighted_recall,0.2499


[34m[1mwandb[0m: Agent Starting Run: i9krn3d5 with config:
[34m[1mwandb[0m: 	criterion: gini
[34m[1mwandb[0m: 	min_samples_leaf: 5
[34m[1mwandb[0m: 	min_samples_split: 100
[34m[1mwandb[0m: 	n_estimators: 500
