In [1]:
import sys
sys.path.append("../../")

from shared.sae_actions import load_pretrained_sae, sae_featurize_data 
from shared.models import MiniPileDataset

%load_ext autoreload
%autoreload 2

In [2]:
# Load SAE
sae = load_pretrained_sae("../training_sae/saes/spam_messages_4_20241106_180053")

featurize train

In [14]:
import pandas as pd
import numpy as np
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize

def prepare_dataset(sentences_file, embeddings_file, sae, num_features=200, feature_registry_file="feature_registry.npy", label_column='label', text_key="text", split_sentences=True):
    # Load dataset
    df = pd.read_csv(sentences_file)
    
    # Create MiniPileDataset
    mini_pile_dataset = MiniPileDataset(sentences_file, embeddings_file, key=text_key)
    
    # Featurize data
    try:
        X = np.memmap(
            feature_registry_file,
            dtype="float32",
            mode="r",
            shape=(sae.encoder.weight.shape[0], len(mini_pile_dataset.sentences)),
        )
    except FileNotFoundError:
        X = sae_featurize_data(mini_pile_dataset, sae, output_file=feature_registry_file)
    
    X = X.T
    
    # Apply num_features if specified
    if num_features is not None:
        X = X[:, :num_features]
    
    y = np.where(df[label_column] == 'ham', 0, 1)
    
    if split_sentences:
        X_split = []
        y_split = []
        for i, text in enumerate(df[text_key]):
            sentences = sent_tokenize(text)
            X_split.extend([X[i]] * len(sentences))
            y_split.extend([y[i]] * len(sentences))
        X = np.array(X_split)
        y = np.array(y_split)
    
    return X, y

[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [16]:
sentences_file = "../data_preparation/data/spam_messages_train.csv"
embeddings_file = "../data_preparation/embedding_chunks/embedded_chunks/spam_messages_train_20241106_234743/embeddings.npy"
X_train, y_train = prepare_dataset(sentences_file, embeddings_file, sae, feature_registry_file="feature_registry_train_all.npy")

In [18]:
sentences_file = "../data_preparation/data/spam_messages_val.csv"
embeddings_file = "../data_preparation/embedding_chunks/embedded_chunks/spam_messages_val_20241107_005540/embeddings.npy"
X_val, y_val = prepare_dataset(sentences_file, embeddings_file, sae, feature_registry_file="feature_registry_val_all.npy")

In [19]:
sentences_file = "../data_preparation/data/spam_messages_test.csv"
embeddings_file = "../data_preparation/embedding_chunks/embedded_chunks/spam_messages_test_20241107_005747/embeddings.npy"
X_test, y_test = prepare_dataset(sentences_file, embeddings_file, sae, feature_registry_file="feature_registry_test_all.npy")

goodfire intervention

feature selection

In [20]:
from sklearn.ensemble import RandomForestClassifier
from boruta import BorutaPy

boruta_selector = BorutaPy(
  RandomForestClassifier(n_jobs=-1, class_weight="balanced", max_depth=5),
  n_estimators="auto",
  verbose=0,
  random_state=1,
)
boruta_selector.fit(X_train[:,:200], y_train)

In [21]:
boruta_selector.n_features_

95

In [9]:
boruta_selector.ranking_

array([ 50,  41,   1,   2,   1,   2,  44,  19,   3,  77, 113,   2,  47,
        66,  94,  48, 113,   1, 113,   1,   2,  69,   1,   1,   1,   1,
         1,  15,   1,  44,  88,   1,   1,   1,   1,  99, 113,   1,   1,
        30,  81,   1,  85,  65,  71,   1,  91,  54,   1,  46,   1,   1,
        32,  28,   1,   1,   1, 105,   1,   1,   1, 103,  38,   1,  60,
        20,   1,  82,   1, 104,  63,   1,  21,   1,   1,  53,  56,   1,
         1,  40,   1,   1,   3, 113,  34,   1,  33,  77,   1,   1,   1,
        15,  70,  34,   5,  84,  95,  44,   1,   1,  38, 113,   7,  83,
         1,  17,  28,  26,  85,  60,  93, 113,  59,   1,   1,  15,  42,
         7,  10,   1,   1,  52,  10,   3,  79,  75,   1,  36,  95,  71,
       113,   1,   1,   1,  22,  91, 106,   1,  89,  67, 101,   1,  63,
       102,  13, 113,  18, 113, 113,   1,   1,  98,   1,  90,   1,   1,
         5, 113, 113,  74,  49, 113,   1,   1,  50,  77,  68,   1,   2,
        23,  79,  25,   1,  37,  99,  56,  12,   1,   1,   1,   

In [22]:
features_folder = "../feature_extraction/features/20241106_222955"

def get_feature_labels_from_mask(mask):
    import json
    import os
    from tabulate import tabulate
    
    feature_info = []
    for i, is_selected in enumerate(mask):
        if is_selected:
            feature_file = os.path.join(features_folder, f"feature_{i}.json")
            with open(feature_file) as f:
                feature_data = json.load(f)
                feature_info.append({
                    'index': i,
                    'label': feature_data['label'],
                    'confidence': feature_data.get('confidence', 'N/A')
                })
    
    # Print nicely formatted table
    headers = ['Index', 'Label', 'Confidence']
    table = [[info['index'], info['label'], info['confidence']] for info in feature_info]
    print(tabulate(table, headers=headers, tablefmt='grid'))
    
    return [info['label'] for info in feature_info]


In [23]:
print("\nSelected Features:")
get_feature_labels_from_mask(boruta_selector.support_weak_)


Selected Features:
+---------+----------------------------------------------+--------------+
|   Index | Label                                        |   Confidence |
|       4 | Contains contact information                 |           80 |
+---------+----------------------------------------------+--------------+
|      20 | Spam or promotional indication               |           26 |
+---------+----------------------------------------------+--------------+
|      21 | Fragmented, urgent, obfuscated text patterns |            6 |
+---------+----------------------------------------------+--------------+
|      39 | Text with special patterns and symbols       |           75 |
+---------+----------------------------------------------+--------------+
|      97 | Contains direct call to action               |           30 |
+---------+----------------------------------------------+--------------+
|     121 | Spam-like and promotional structure          |           30 |
+---------+-------

['Contains contact information',
 'Spam or promotional indication',
 'Fragmented, urgent, obfuscated text patterns',
 'Text with special patterns and symbols',
 'Contains direct call to action',
 'Spam-like and promotional structure',
 'Financial gain discussion',
 'Repetitive text sequences',
 'Promotional or sensitive content presence',
 'Structured financial and web content']

ok let's actually train it and evaluate it

In [24]:
from classifier_model import BinaryClassifierModel

# Initialize model
model = BinaryClassifierModel()

# Train model
model.train_model(
    X_train=X_train, 
    y_train=y_train,
    X_val=X_val,
    y_val=y_val,
    use_feature_selection=False
)

# Evaluate model
model.evaluate_model(X_test, y_test)

Scaling features...
Will try 5 different max_depths and 3 different n_estimators
Total combinations to try: 15


Training Progress:   0%|          | 0/15 [00:00<?, ?it/s]


Iteration Results (max_depth=10, n_estimators=10):
Accuracy: 0.6544
Precision: 0.9940
Recall: 0.2747
F1: 0.4304
AUC-ROC: 0.7318

Iteration Results (max_depth=10, n_estimators=50):
Accuracy: 0.6546
Precision: 0.9909
Recall: 0.2760
F1: 0.4317
AUC-ROC: 0.7341

Iteration Results (max_depth=10, n_estimators=100):
Accuracy: 0.6560
Precision: 0.9923
Recall: 0.2785
F1: 0.4349
AUC-ROC: 0.7347

Iteration Results (max_depth=20, n_estimators=10):
Accuracy: 0.6584
Precision: 0.9925
Recall: 0.2835
F1: 0.4410
AUC-ROC: 0.7327

Iteration Results (max_depth=20, n_estimators=50):
Accuracy: 0.6569
Precision: 0.9926
Recall: 0.2802
F1: 0.4370
AUC-ROC: 0.7350

Iteration Results (max_depth=20, n_estimators=100):
Accuracy: 0.6570
Precision: 0.9924
Recall: 0.2805
F1: 0.4374
AUC-ROC: 0.7374

Iteration Results (max_depth=30, n_estimators=10):
Accuracy: 0.6563
Precision: 0.9793
Recall: 0.2828
F1: 0.4389
AUC-ROC: 0.7318

Iteration Results (max_depth=30, n_estimators=50):
Accuracy: 0.6552
Precision: 0.9721
Recall: 

let's try it with a logistic regression model

In [19]:
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import numpy as np

# Initialize scaler and model
scaler = MinMaxScaler(feature_range=(0, 1))
model = LogisticRegression(max_iter=1000)

# Scale the features
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# Train model
model.fit(X_train_scaled, y_train)

# Get predictions on validation set
y_pred_val = model.predict(X_val_scaled)
y_pred_proba_val = model.predict_proba(X_val_scaled)[:, 1]

# Calculate validation metrics
print("\nValidation Metrics:")
print(f"Accuracy: {accuracy_score(y_val, y_pred_val):.4f}")
print(f"Precision: {precision_score(y_val, y_pred_val):.4f}")
print(f"Recall: {recall_score(y_val, y_pred_val):.4f}")
print(f"F1: {f1_score(y_val, y_pred_val):.4f}")
print(f"AUC-ROC: {roc_auc_score(y_val, y_pred_proba_val):.4f}")

# Get predictions on test set
y_pred_test = model.predict(X_test_scaled)
y_pred_proba_test = model.predict_proba(X_test_scaled)[:, 1]

# Calculate test metrics
print("\nTest Metrics:")
print(f"Accuracy: {accuracy_score(y_test, y_pred_test):.4f}")
print(f"Precision: {precision_score(y_test, y_pred_test):.4f}")
print(f"Recall: {recall_score(y_test, y_pred_test):.4f}")
print(f"F1: {f1_score(y_test, y_pred_test):.4f}")
print(f"AUC-ROC: {roc_auc_score(y_test, y_pred_proba_test):.4f}")

# Print feature indexes with the highest weights and their labels
feature_weights = model.coef_[0]
top_feature_indices = np.argsort(np.abs(feature_weights))[::-1][:10]  # Get top 10 feature indices
print("\nTop 10 Features with Highest Weights:")

# Create boolean mask for get_feature_labels_from_mask
mask = np.zeros(len(feature_weights), dtype=bool)
mask[top_feature_indices] = True

# Get labels and print table with weights
feature_labels = get_feature_labels_from_mask(mask)
for idx, label in zip(top_feature_indices, feature_labels):
    print(f"Feature {idx} ({label}): weight = {feature_weights[idx]:.4f}")



Validation Metrics:
Accuracy: 0.6817
Precision: 0.9637
Recall: 0.2387
F1: 0.3826
AUC-ROC: 0.6811

Test Metrics:
Accuracy: 0.6944
Precision: 0.9498
Recall: 0.2355
F1: 0.3774
AUC-ROC: 0.6744

Top 10 Features with Highest Weights:
+---------+----------------------------------------------+--------------+
|   Index | Label                                        |   Confidence |
|      60 | Business and finance communication in Polish |           79 |
+---------+----------------------------------------------+--------------+
|      68 | Business and operational references          |           16 |
+---------+----------------------------------------------+--------------+
|      71 | Incomplete URL presence                      |           79 |
+---------+----------------------------------------------+--------------+
|      81 | Mentions Vikings and sports content          |           80 |
+---------+----------------------------------------------+--------------+
|      89 | Finance and economi