<div class='heading'>
    <div style='float:left;'><h1>CPSC 8810: Machine Learning for Graphs</h1></div>
     <img style="float: right; padding-right: 10px" width="100" src="https://raw.githubusercontent.com/bsethwalker/clemson-cs4300/main/images/clemson_paw.png"> </div>
     </div>

**Clemson University**<br>
**Instructor(s):** Aaron Masino <br>

## Lab 2: Network Feature Engineering and Traditional Machine Learning

This notebook demonstrates traditional machine learning approaches for graph related tasks using node metrics (such as centrality and groups) and graph metrics (such as density). This is analagous to the traditional approaches to image classifcation prior to the development of deep-learning models such as convolutional neural networks (CNNs). Before CNNs, researchers developed a large body of image metrics, so-called feature engineering, to support image classification using a variety of algorithmic approaches including standard machine learning methods like logistic regression and tree ensembles. With CNNs, nearly all current models work directly with the image and learn a useful, though latent, feature representation as part of the learning process. Similarly, researchers have developed numerous graph metrics that can be viewed as engineered features and used them for prediction tasks as input to standard machine learning methods.

### Learning Objectives
1. Apply NetworkX to compute and organize graph metrics in preparation for machine learning model development
2. Understand the differences in graph metric behavior using numerical analysis and visualization methods
3. Create machine learning models to predict graph node attributes using node metrics
4. Evaluate the performance of machine learning models for graph node attribute prediction
5. Create machine learning models to predict graph attributes using input graph metrics
6. Evaluate the performance of machine learning models for graph attribute prediction

## Library Imports and Setup

In [None]:
# Data manipulation and analysis
import numpy as np
import pandas as pd
from collections import Counter
import os
from pathlib import Path
import pickle
from itertools import cycle
import random

# Graph analysis
import networkx as nx
try:
    import pygraphviz as pgv
    PYGRAPHVIZ_AVAILABLE = True
except ImportError:
    PYGRAPHVIZ_AVAILABLE = False
    print("Warning: pygraphviz not available. Network visualizations will be limited.")

# PyTorch and PyTorch Geometric
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import to_networkx

# Machine Learning
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold, GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.metrics import (classification_report, confusion_matrix, 
                           roc_auc_score, roc_curve, auc, accuracy_score,
                           balanced_accuracy_score)
from sklearn.multiclass import OneVsRestClassifier
from sklearn.datasets import load_breast_cancer

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
import warnings
warnings.filterwarnings('ignore')

# Setup
plt.style.use('default')
sns.set_palette("husl")
np.random.seed(42)
torch.manual_seed(42)

# Create output directory
# Set output directory for images and files
output_dir = Path('./output/lab_02')
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Output directory created: {output_dir}")

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)

# Reusable random state
RANDOM_STATE = 654321

# Pandas set max columns to display all columns
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)

## 1. Scikit-learn Classification Refresher

Let's first review key machine learning concepts using scikit-learn before diving into graph-specific analysis.

### Key Concepts Covered
1. **Train/Test Split**: Proper data partitioning for unbiased evaluation
2. **Hyperparameter Tuning**: Grid search with cross-validation on training data
3. **Model Evaluation**: Classification report, confusion matrix, ROC analysis

As a review of fundamental machine learning practices using scikit-learn, let's build and evaluate a binary logistic regression classifier using the scikit-learn [breast cancer dataset](https://scikit-learn.org/stable/datasets/toy_dataset.html#breast-cancer-dataset). The model will predict between _benign_ and _malignant_.

We begin by loading the dataset, splitting it into train and test sets, and standardize the features.

In [None]:
# Load example dataset (binary classification)
data = load_breast_cancer()
X, y = data.data, data.target
feature_names = data.feature_names
class_names = data.target_names

print(f"Dataset: {X.shape[0]} samples, {X.shape[1]} features")
print(f"Classes: {class_names} (distribution: {np.bincount(y)})")

# TRAIN/TEST SPLIT
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=RANDOM_STATE
)
print(f"Training: {X_train.shape[0]} samples | Test: {X_test.shape[0]} samples")

# Feature scaling (important for logistic regression)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

### 1.1 Develop a logistic regression model
Now, let's create an instance of the [LogisticRegression](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html#sklearn.linear_model.LogisticRegression) model. We also create a `param_grid` dictionary to hold the values of the model tuning parameters we wish to test.

Next, we create an instance of [GridSearchCV](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html) to perform a nested grid search with cross-validation to identify the best hyperparameters. 

Finally, using the `grid_search` pipeline, we fit the model to the training data and identify the best hyperparameter combination. Because we set `refit=True` in the `GridSearchCV` inputs, the `grid_search` object will retrain the model with all training data for the optimal hyperparameter combination.

In [None]:
# HYPERPARAMETER TUNING WITH CROSS-VALIDATION
# Create base model
lr = LogisticRegression(solver='liblinear', random_state=RANDOM_STATE, max_iter=1000)
param_grid = {
    'C': [0.01, 0.1, 1.0, 10.0, 100.0],  # Regularization strength
    'penalty': ['l1', 'l2']               # Regularization type
}


# Grid search with cross-validation (on training data only)
cv_strategy = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)
grid_search = GridSearchCV(
    lr, param_grid, cv=cv_strategy, scoring='roc_auc', n_jobs=-1, refit=True
)

# Fit on training data
grid_search.fit(X_train_scaled, y_train)

print(f"Best parameters: {grid_search.best_params_}")
print(f"Best CV ROC-AUC: {grid_search.best_score_:.3f}")

# Get best model (already retrained on full training set)
best_model = grid_search.best_estimator_

### 1.2 Evaluating model performance
Now that we've selected our optimal tuning parameters and trained the model, let's evaluate performance on the test set. We will create a confusion matrix and an ROC plot.

In [None]:
# Predictions
y_pred = best_model.predict(X_test_scaled)
y_proba = best_model.predict_proba(X_test_scaled)[:, 1]  # Probability of positive class

# 3a. Classification Report
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=class_names))

# 3b. Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(12, 4))

plt.subplot(1, 3, 1)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names)
plt.title('Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')

# 3c. ROC Curve
fpr, tpr, _ = roc_curve(y_test, y_proba)
roc_auc = auc(fpr, tpr)

plt.subplot(1, 3, 2)
plt.plot(fpr, tpr, color='darkorange', lw=2, 
         label=f'ROC Curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.5)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)

# 2. Enron Dataset Node Classification
Here we will consider the [Enron Email Dataset](https://snap.stanford.edu/data/email-Enron.html). The dataset represents email exchanges among Enron employees, primarily senior staff. Nodes of the network are email addresses and if an address i sent at least one email to address j, the graph contains an undirected edge from i to j. The complete network contains about 37,000 nodes, though here we will work with a subset of about 4,000 nodes.

What's Enron? The [Enron scandal](https://en.wikipedia.org/wiki/Enron_scandal) was an accounting scandal sparked by American energy company Enron Corporation filing for bankruptcy after news of widespread internal fraud became public in October 2001. Shareholders filed a $40 billion lawsuit, for which they were eventually partially compensated $7.2 billion, after the company's stock price plummeted from a high of US$90.75 per share in mid-1990s to less than $1 by the end of November 2001. This email data was originally made public, and posted to the web, by the Federal Energy Regulatory Commission during its investigation.

Let's start by loading the data.

In [None]:
with open('../data/enron/enron_connected_subgraph.pkl', 'rb') as f:
    G = pickle.load(f)

labels = [G.nodes[node]['label'] for node in G.nodes()]
class_names = []

with open('../data/enron/label_mapping.csv', 'r') as f:
    # skip the header
    f.readline()
    class_names = [line.strip().split(',')[1] for line in f.readlines()]

print("Distinct class names:\n",class_names)
enron_class_names = class_names

print("Number of nodes in the graph:", len(G.nodes()))

### 2.1 Data Visualization and Exploration

Understanding the structure of our data is crucial before applying machine learning models. We'll examine:

1. **Class Distribution**: How balanced are the 5 employee catagories?
2. **Network Structure**: Visual representation of the citation network with node colors representing classes

**Visualization Tools**:
- **Matplotlib/Seaborn**: For statistical plots and class distributions
- **NetworkX + Pygraphviz**: For network layout and visualization ([NetworkX docs](https://networkx.org/documentation/stable/), [Pygraphviz docs](https://pygraphviz.github.io/documentation/stable/))
- **Interactive features**: Zoom capabilities if supported by the visualization backend

In [None]:
PLOT_NETWORK_VISUALIZATION = False

if PLOT_NETWORK_VISUALIZATION:
    print("Creating network visualization, this may take several minutes...")

    # Prepare node colors based on classes
    node_colors = [labels[node] for node in G.nodes()]
    color_palette = sns.color_palette("husl", len(class_names))
    node_color_map = [color_palette[color] for color in node_colors]

    # Create network visualization
    plt.figure(figsize=(12,8))

# Choose layout based on pygraphviz availability
if PYGRAPHVIZ_AVAILABLE and PLOT_NETWORK_VISUALIZATION:
    print("Using Pygraphviz for high-quality layout...")
    try:
        # Use graphviz layout for better visualization
        pos = nx.nx_agraph.graphviz_layout(G, prog='sfdp', args='-Goverlap=false -Gsplines=true')
        layout_used = "Graphviz SFDP"
    except Exception as e:
        print(f"Graphviz layout failed: {e}. Falling back to spring layout.")
        pos = nx.spring_layout(G, k=1, iterations=50, seed=42)
        layout_used = "Spring Layout (fallback)"
elif PLOT_NETWORK_VISUALIZATION:
    print("Using NetworkX spring layout...")
    pos = nx.spring_layout(G, k=1, iterations=50, seed=42)
    layout_used = "Spring Layout"

# Draw the network
if PLOT_NETWORK_VISUALIZATION:
    nx.draw_networkx_nodes(G, pos, 
                          node_color=node_color_map,
                          node_size=50,
                          alpha=0.8,
                          linewidths=0.5,
                          edgecolors='black')

    nx.draw_networkx_edges(G, pos, alpha=0.3, width=0.3, edge_color='gray')

    # Add node labels (showing node IDs)
    # For readability, only show labels for a subset of nodes
    if len(G.nodes()) <= 500:  # Show all labels for smaller graphs
        nx.draw_networkx_labels(G, pos, font_size=6, font_color='black', alpha=0.7)
    else:  # Show labels for every 10th node to avoid overcrowding
        label_dict = {node: str(node) for i, node in enumerate(G.nodes()) if i % 10 == 0}
        nx.draw_networkx_labels(G, pos, labels=label_dict, font_size=6, font_color='black', alpha=0.7)

    plt.title(f'Enron Employee Network\n{G.number_of_nodes()} nodes, {G.number_of_edges()} edges\nLayout: {layout_used}', 
        fontsize=16, fontweight='bold', pad=20)

    # Create legend
    legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                             markerfacecolor=color_palette[i], markersize=10,
                             label=class_names[i].replace('_', ' '))
                        for i in range(len(class_names))]

    plt.legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.02, 1), 
            title='Role', title_fontsize=12, fontsize=10)

    plt.axis('off')
    plt.tight_layout()

    # Save high-resolution version
    plt.savefig(output_dir / 'enron_network.png', dpi=300, bbox_inches='tight')
    plt.show()


Now, let's examine the class distribution of the nodes.

In [None]:
# Convert labels to numpy for easier manipulation
denom = len(labels)
class_counts = Counter(labels)

# Create class distribution visualization
fig, ax1 = plt.subplots(1, 1, figsize=(8, 6))

# Bar plot
bars = ax1.bar(range(len(enron_class_names)), [class_counts[i] for i in range(len(enron_class_names))], 
               color=sns.color_palette("husl", len(enron_class_names)), alpha=0.8)
ax1.set_xlabel('Role', fontsize=12)
ax1.set_ylabel('Number of Employees', fontsize=12)
ax1.set_title('Enron: Class Distribution', fontsize=14, fontweight='bold')
ax1.set_xticks(range(len(enron_class_names)))
ax1.set_xticklabels([name.replace('_', '\n') for name in enron_class_names], rotation=45, ha='right')

# Add value labels on bars
for i, bar in enumerate(bars):
    height = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2., height + 5,
             f'{int(height)} ({(height/denom)*100:.1f}%)', ha='center', va='bottom', fontsize=10)

plt.tight_layout()
plt.savefig(output_dir / 'enron_class_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

### 2.2 Graph-Based Feature Engineering: Node Centrality & Group Measures

Node centrality measures capture the structural importance of nodes within the network. These features encode different aspects of a node's position and influence in the network. Node group measures provide an indication of the size and structure of communities that form within the graph. We anticipate that employees with different roles will have different influence within the network and belong to different groups, hence centrality and group metrics should be predictive of employee role.

### Centrality Measures Overview

1. **Degree Centrality**: Number of connections 
   - *Interpretation*: The number of employees to whom this person sends or recieves email 
   - *NetworkX*: [`nx.degree_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.degree_centrality.html)

2. **Eigenvector Centrality**: Importance based on connections to other important nodes
   - *Interpretation*: The importance of the employee relative to the importance of to whom they send/recieve email
   - *NetworkX*: [`nx.eigenvector_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.eigenvector_centrality.html)

3. **PageRank**: Google's algorithm measuring node importance via random walks
   - *Interpretation*: Employes likely to be reached by random email chain browsing
   - *NetworkX*: [`nx.pagerank()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.link_analysis.pagerank_alg.pagerank.html)

4. **Clustering Coefficient**: Measure of local network density around a node
   - *Interpretation*: Employees in tightly connected communities
   - *NetworkX*: [`nx.clustering()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.cluster.clustering.html)

5. **Betweenness Centrality**: Frequency of node appearing on shortest paths
   - *Interpretation*: Employees that bridge different roles or departments
   - *NetworkX*: [`nx.betweenness_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.betweenness_centrality.html)

6. **Largest Clique**: The number of nodes in the largest maximal clique containing a given node
   - *Interpretation*: The largest group of employees to which this node belongs and all group members exchange emails
   - *NetworkX*: [`nx.node_clique_number()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.clique.node_clique_number.html#networkx.algorithms.clique.node_clique_number)

7. **Clique Count**: The number of maximal cliques containing the node
   - *Interpretation*: The number of distinct groups this employee is a member of
   - *NetworkX*: [`nx.node_clique_number()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.clique.node_clique_number.html#networkx.algorithms.clique.node_clique_number)

**Note**: All centrality measures are normalized and any NaN values are replaced with 0.

In [None]:
# Initialize progress tracking
centrality_measures = {}
computation_times = {}

import time

# 1. Degree Centrality
print("\n[1/7] Computing degree centrality...")
start_time = time.time()
degree_cent = nx.degree_centrality(G)
computation_times['degree'] = time.time() - start_time
print(f"   Completed in {computation_times['degree']:.2f} seconds")

# 2. Eigenvector Centrality (with error handling)
print("\n[2/7] Computing eigenvector centrality...")
start_time = time.time()
try:
    eigen_cent = nx.eigenvector_centrality(G, max_iter=1000, tol=1e-06)
    computation_times['eigenvector'] = time.time() - start_time
    print(f"   Completed in {computation_times['eigenvector']:.2f} seconds")
except nx.NetworkXError as e:
    print(f"   Warning: Eigenvector centrality failed ({e}). Using zeros.")
    eigen_cent = {node: 0.0 for node in G.nodes()}
    computation_times['eigenvector'] = time.time() - start_time

# 3. PageRank
print("\n[3/7] Computing PageRank...")
start_time = time.time()
pagerank_cent = nx.pagerank(G, alpha=0.85, max_iter=1000, tol=1e-06)
computation_times['pagerank'] = time.time() - start_time
print(f"   Completed in {computation_times['pagerank']:.2f} seconds")

# 4. Clustering Coefficient
print("\n[4/7] Computing clustering coefficient...")
start_time = time.time()
clustering_coeff = nx.clustering(G)
computation_times['clustering'] = time.time() - start_time
print(f"   Completed in {computation_times['clustering']:.2f} seconds")

# 5. Betweenness Centrality (most computationally expensive)
print("\n[5/7] Computing betweenness centrality...")
print("   This is the most computationally intensive measure...")
start_time = time.time()
betweenness_cent = nx.betweenness_centrality(G, normalized=True, seed=RANDOM_STATE)
computation_times['betweenness'] = time.time() - start_time
print(f"   Completed in {computation_times['betweenness']:.2f} seconds")

# 6. Largest Clique
print("\n[6/7] Computing largest clique...")
start_time = time.time()
largest_clique = nx.node_clique_number(G)
computation_times['largest_clique'] = time.time() - start_time
print(f"   Completed in {computation_times['largest_clique']:.2f} seconds")

# 7. Clique Count
print("\n[7/7] Computing clique count...")
start_time = time.time()
clique_count = nx.number_of_cliques(G)
computation_times['clique_count'] = time.time() - start_time
print(f"   Completed in {computation_times['clique_count']:.2f} seconds")

print(f"\nTotal computation time: {sum(computation_times.values()):.2f} seconds")



Before we use these metrics to develop a machine learning model, let's examine their distributions. This will help us get a sense of the variance and range. Recall, that to inform a statistical learning model, the features must have some variance. Though, the presence of variance is not a guarantee that a machine learning model will perform well. The features must also carry information that is correlated to the outcome label.

First, we'll organize the features into a numpy array and then a Pandas DataFrame for convenience with our visualization and model development.

In [None]:
# Convert to ordered arrays (matching node order in data.x)
node_order = sorted(G.nodes())

# Create centrality feature matrix
centrality_features = {
    'degree': np.array([degree_cent[node] for node in node_order]),
    'eigenvector': np.array([eigen_cent[node] for node in node_order]),
    'pagerank': np.array([pagerank_cent[node] for node in node_order]),
    'clustering': np.array([clustering_coeff[node] for node in node_order]),
    'betweenness': np.array([betweenness_cent[node] for node in node_order]),
    'largest_clique': np.array([largest_clique[node] for node in node_order]),
    'clique_count': np.array([clique_count[node] for node in node_order])
}

# Check for and replace NaN values
print("\nChecking for NaN values...")
for feature_name, feature_values in centrality_features.items():
    nan_count = np.isnan(feature_values).sum()
    if nan_count > 0:
        print(f"   {feature_name}: {nan_count} NaN values found - replacing with 0")
        centrality_features[feature_name] = np.nan_to_num(feature_values, nan=0.0)
    else:
        print(f"   {feature_name}: No NaN values found")

# Create combined centrality feature matrix
centrality_matrix = np.column_stack([
    centrality_features['degree'],
    centrality_features['eigenvector'],
    centrality_features['pagerank'],
    centrality_features['clustering'],
    centrality_features['betweenness'],
    centrality_features['largest_clique'],
    centrality_features['clique_count']
])

print(f"\nCentrality feature matrix shape: {centrality_matrix.shape}")

# Compute statistics
feature_names = ['Degree', 'Eigenvector', 'PageRank', 'Clustering', 'Betweenness', 'Largest Clique', 'Clique Count']
feature_keys = [k for k in centrality_features.keys()]
stats_df = pd.DataFrame({
    'Feature': feature_names,
    'Mean': [np.mean(centrality_features[name]) for name in feature_keys],
    'Std': [np.std(centrality_features[name]) for name in feature_keys],
    'Min': [np.min(centrality_features[name]) for name in feature_keys],
    'Max': [np.max(centrality_features[name]) for name in feature_keys],
    'Zeros': [np.sum(centrality_features[name] == 0) for name in feature_keys]
})

print("\n=== Centrality Measures Statistics ===")
print(stats_df.round(4))

Now, let's create histograms and kernel density estimates for each of the node metrics.

In [None]:
# Set up the plot array (2x4 grid)
fig, axes = plt.subplots(2, 4, figsize=(18, 16))
axes = axes.ravel()  # Flatten for easier indexing

colors = sns.color_palette("husl", len(feature_names))

for i, (feature_name, color) in enumerate(zip(feature_names, colors)):
    feature_key = feature_keys[i]
    values = centrality_features[feature_key]
    
    # Create histogram with density curve
    axes[i].hist(values, bins=50, density=True, alpha=0.7, color=color, edgecolor='black', linewidth=0.5)
    
    # Add density curve
    if np.std(values) > 0:  # Only if there's variation
        kde_x = np.linspace(values.min(), values.max(), 100)
        try:
            from scipy import stats
            kde = stats.gaussian_kde(values)
            axes[i].plot(kde_x, kde(kde_x), color='darkred', linewidth=2, alpha=0.8)
        except ImportError:
            # Fallback if scipy not available
            pass
    
    # Customize subplot
    axes[i].set_title(f'{feature_name}\nMean: {np.mean(values):.4f}, Std: {np.std(values):.4f}', 
                     fontsize=12, fontweight='bold')
    axes[i].set_xlabel(f'{feature_name} Value', fontsize=10)
    axes[i].set_ylabel('Density', fontsize=10)
    axes[i].grid(True, alpha=0.3)
    
    # Add vertical line for mean
    axes[i].axvline(np.mean(values), color='red', linestyle='--', linewidth=2, alpha=0.8, label=f'Mean: {np.mean(values):.4f}')
    axes[i].legend(fontsize=9)

# Remove the empty subplot
fig.delaxes(axes[7])

# Add overall title
fig.suptitle('Node Centrality & Group Measures Distribution\nEnron Network', 
             fontsize=16, fontweight='bold', y=0.98)

plt.tight_layout()
plt.savefig(output_dir / 'enron_centrality_group_distributions.png', dpi=300, bbox_inches='tight')
plt.show()

# Create correlation heatmap of centrality measures
plt.figure(figsize=(10, 8))

# Create correlation matrix
centrality_df = pd.DataFrame(centrality_matrix, columns=feature_names)
correlation_matrix = centrality_df.corr()

# Create heatmap
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))  # Mask upper triangle
sns.heatmap(correlation_matrix, mask=mask, annot=True, cmap='coolwarm', center=0,
            square=True, fmt='.3f', cbar_kws={"shrink": .8})

plt.title('Centrality & GroupMeasures Correlation Matrix\nEnron Network', 
          fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig(output_dir / 'enron_centrality_group_correlations.png', dpi=300, bbox_inches='tight')
plt.show()

### 2.3 Logistical Regression with Node Metrics

We'll train and evaluate **logistic regression and random forest models** using the centrality measures to predict the employee role for a node.

### Methodology

1. **Data Splitting**: Split the data into train/test splits of 80%, 20%. We'll use cross validation within the training set.
2. **Model 1**: [Logistic Regression with regularization](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html)
3. **Hyperparameter Tuning**: [Cross-validation](https://scikit-learn.org/stable/modules/cross_validation.html) on training set to optimize regularization parameter (C) and regularization penalty. 
4. **Cross-validation Evaluation Metric**: [Balanced accuracy](https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html) to handle class imbalance
5. **Performance Analysis**: Classification reports, confusion matrices, and ROC-AUC curves

First, let's split the data.

In [None]:
# Data splitting
X = centrality_matrix.copy()
y = np.array(labels)

print(f"Feature matrix shape: {X.shape} | Labels shape: {y.shape}")

# split the data into train/validation/test sets of 80% 20%
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    test_size=0.2,
                                                    stratify=y,
                                                    random_state=RANDOM_STATE)

# Standardize the features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

print(f"Training: {X_train.shape[0]} samples | Test: {X_test.shape[0]} samples")

Now let's develop the model. We'll use a grid search with k-fold cross validation to select the best regularization parameter. 

In [None]:
# Define regularization parameter search space
param_grid = {'C': [0.01, 0.1, 1.0, 10.0, 100.0],
             'penalty': ['l1', 'l2', 'elasticnet', 'none']}  # Regularization strengths

# Create base model
base_model = LogisticRegression(
    solver='liblinear',  # Good for small datasets and L2 penalty
    random_state=RANDOM_STATE,
    max_iter=1000,
    multi_class='ovr'  # One-vs-Rest for multiclass
)

# Grid search with balanced accuracy scoring
grid_search = GridSearchCV(
    base_model,
    param_grid,
    cv=StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE),
    scoring='balanced_accuracy',
    n_jobs=-1,
    verbose=0,
    refit=True  # Refit the best model on the entire training set
)

# Fit grid search
start_time = time.time()
grid_search.fit(X_train, y_train)
end_time = time.time()

# Store the best model
best_model = grid_search.best_estimator_

print(f"Best C: {grid_search.best_params_['C']}, Best penalty: {grid_search.best_params_['penalty']}")
print(f"Best CV balanced accuracy: {grid_search.best_score_:.4f}")
print(f"Tuning completed in {end_time - start_time:.2f} seconds")

In reality, we would likely consider many alternative models during the training and evaluation process. For purposes of this lab, let's assume we've convinced ourselves that the logistic regression model is "good enough" and we'll use the optimal tuning parameters. 

Let's examine the test set results. We start with a classification report.

In [None]:
# Make predictions
y_test_pred = best_model.predict(X_test)

# Pretty print classification report
print(classification_report(y_test, y_test_pred, target_names=class_names, zero_division=0))

Let's also visualize the test results with a confustion matrix and ROC-AUC plot

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(8, 6))
cm = confusion_matrix(y_test, y_test_pred, normalize='pred')
sns.heatmap(cm, annot=True, fmt='.2f', cmap='Blues', cbar=False, ax=axes)
axes.set_xlabel('Predicted Label')
axes.set_ylabel('True Label')
axes.set_title('Enron Confusion Matrix', fontsize=14, fontweight='bold')
plt.xticks([_+0.5 for _ in range(len(enron_class_names))], enron_class_names, rotation=45, ha='center');
plt.yticks([_+0.5 for _ in range(len(enron_class_names))], enron_class_names, rotation=45, ha='right');

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'red', 'green', 'purple', 'brown'])

# Binarize the output for multi-class ROC
y_test_bin = label_binarize(y_test, classes=list(range(len(enron_class_names))))
n_classes = y_test_bin.shape[1]

# Compute ROC curve and AUC for each class
fpr = dict()
tpr = dict()
roc_auc = dict()

# get the test set probabilities
y_test_proba = best_model.predict_proba(X_test)

for j in range(n_classes):
    fpr[j], tpr[j], _ = roc_curve(y_test_bin[:, j], y_test_proba[:, j])
    roc_auc[j] = auc(fpr[j], tpr[j])

for j, color in zip(range(n_classes), colors):
        ax.plot(fpr[j], tpr[j], color=color, lw=2,
               label=f'{class_names[j].replace("_", " ")} (AUC = {roc_auc[j]:.2f})')
            
# Plot random classifier line
ax.plot([0, 1], [0, 1], 'k--', lw=2, alpha=0.5, label='Random (AUC = 0.50)')

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=11)
ax.set_ylabel('True Positive Rate', fontsize=11)
ax.set_title(f'Enron Network\nPer-Class ROC Curves',fontsize=12, fontweight='bold')
ax.legend(loc="lower right", fontsize=9)
ax.grid(True, alpha=0.3)


Finally, let's look at the importance of the features by examining the coefficients of the features from the logistic regression model. Remember, we standardized the input features, so the absolute values of the logistic regression coefficients give us a measure of global feature importance for the model. Additionally, the sign of the coefficients indicate whether a positive increase in the feature results in an increase in the probability of class membership (positive coefficient) or a decrease in the probability of class membership (negative coeeficient.). Also, note, that in our model, we are consider multi-class classification using a one-vs-all approach, so we actually have one model per class and hence multiple sets of coefficients.

In [None]:
coeffs = best_model.coef_
# This is multi-class classification, so coeffs will have shape (n_classes, n_features)
print("Logistic Regression Coefficients Shape:")
print(coeffs.shape)

# create a DataFrame for better visualization
coeffs_df = pd.DataFrame(coeffs, columns=feature_names, index=class_names)
print("\nLogistic Regression Coefficients:")
print(coeffs_df.round(4))

### 2.4 Random Forest with Graph Metrics
Now let's create a random forest model to predict the role of Enron employees. We will use the same graph metrics as we did in our logistic regression model.

#### **Excercise 2.1** 
Complete the code below to construct the `param_grid` dictionary so that the model training can evaluate the following hyperparameter values:
- n_estimators: 50, 100, 200
- max_depth: None, 5, 10
- criteron: 'gini', 'entropy'

In [None]:
# Exercise 2.1
param_grid = {
# Your code here
}

Now let's develop the Random Forest model.

#### **Exercise 2.2**
Complete the code below to fit the model to the data using the `grid_search` and to get the best estimator (i.e., the random forest model with the optimal hyperparameters).

In [None]:
base_model = RandomForestClassifier(random_state=RANDOM_STATE)

# Grid search with balanced accuracy scoring
grid_search = GridSearchCV(
    base_model,
    param_grid,
    cv=StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=RANDOM_STATE),
    scoring='balanced_accuracy',
    n_jobs=-1,
    verbose=0,
    refit=True  # Refit the best model on the entire training set
)

# Fit the model
grid_search.___ # Your code here

# Store the best model
best_model = ___ # Your code here

print(f"Best C: {grid_search.best_params_['n_estimators']}, Best max_depth: {grid_search.best_params_['max_depth']}, Best criterion: {grid_search.best_params_['criterion']}")
print(f"Best CV balanced accuracy: {grid_search.best_score_:.4f}")

Let's examine the performance on the test set. 

In [None]:
# Make predictions
y_test_pred = best_model.predict(X_test)

# Print classification report
print(classification_report(y_test, y_test_pred, target_names=class_names, zero_division=0))

Now let's examine the confustion matrix.

#### **Exercise 2.3**
Complete the code below to create the confusion matrix

In [None]:
fig, axes = plt.subplots(1, 1, figsize=(8, 6))

cm = ____ # Your code here

sns.heatmap(cm, annot=True, fmt='.2f', cmap='Blues', cbar=False, ax=axes)
axes.set_xlabel('Predicted Label')
axes.set_ylabel('True Label')
axes.set_title('Enron Confusion Matrix', fontsize=14, fontweight='bold')
plt.xticks([_+0.5 for _ in range(len(enron_class_names))], enron_class_names, rotation=45, ha='center');
plt.yticks([_+0.5 for _ in range(len(enron_class_names))], enron_class_names, rotation=45, ha='right');

Finally, let's also examine the ROC curves.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
olors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'red', 'green', 'purple', 'brown'])

# Binarize the output for multi-class ROC
y_test_bin = label_binarize(y_test, classes=list(range(len(enron_class_names))))
n_classes = y_test_bin.shape[1]

# Compute ROC curve and AUC for each class
fpr = dict()
tpr = dict()
roc_auc = dict()

y_test_proba = best_model.predict_proba(X_test)

for j in range(n_classes):
    fpr[j], tpr[j], _ = roc_curve(y_test_bin[:, j], y_test_proba[:, j])
    roc_auc[j] = auc(fpr[j], tpr[j])

for j, color in zip(range(n_classes), colors):
        ax.plot(fpr[j], tpr[j], color=color, lw=2,
               label=f'{class_names[j].replace("_", " ")} (AUC = {roc_auc[j]:.2f})')
            
# Plot random classifier line
ax.plot([0, 1], [0, 1], 'k--', lw=2, alpha=0.5, label='Random (AUC = 0.50)')

ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=11)
ax.set_ylabel('True Positive Rate', fontsize=11)
ax.set_title(f'Enron Network\nPer-Class ROC Curves',fontsize=12, fontweight='bold')
ax.legend(loc="lower right", fontsize=9)
ax.grid(True, alpha=0.3)

# 3. MUTAG Dataset Graph Classification

In this section, we'll explore graph classification using the MUTAG dataset, which contains molecular structures labeled as mutagenic or non-mutagenic. Unlike the previous section where we predicted properties of individual nodes (employees), here we'll predict properties of entire graphs (molecules).

The MUTAG dataset consists of 188 chemical compounds (graphs) where each compound is labeled as mutagenic (1) or non-mutagenic (0). Each node in the graph is one of 14 atoms:

| Index | Atom |
|-------|------|
| 0     | C    |
| 1     | O    |
| 2     | Cl   |
| 3     | H    |
| 4     | N    |
| 5     | F    |
| 6     | Br   |
| 7     | S    |
| 8     | P    |
| 9     | I    |
| 10    | Na    |
| 11    | K    |
| 12    | Li    |
| 13    | Ca    |

### Key Differences from Node Classification:
1. **Unit of Analysis**: Entire graphs instead of individual nodes
2. **Features**: Graph-level metrics instead of node level (though we'll use average node leve, this type of aggregation will be common) 
3. **Task**: Binary classification of molecular mutagenicity
4. **Metrics**: Graph structural properties like diameter, etc.

Let's start by loading the data.

In [None]:
# Load MUTAG dataset
from torch_geometric.datasets import TUDataset

# Load the dataset
dataset = TUDataset(root='../data/TUDataset', name='MUTAG')

print(f"Dataset: {dataset}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")

### 3.1 Data Visualization and Exploration

It's generally a good idea to get a sense of the structure of the graphs in our dataset before applying machine learning models. We'll examine:

1. **Sample Visualization**: Visual representation of molecular structures with atoms colored by type
2. **Class Distribution**: Balance between mutagenic and non-mutagenic compounds
3. **Graph Statistics**: Size characteristics (nodes/edges) across the dataset

**Key Insights Expected**:
- Molecular graphs are typically small (10-30 nodes) compared to other networks such as social networks
- Graph connectivity patterns may differ between mutagenic and non-mutagenic compounds

First, let's visualize sample molecular graphs from both classes, with nodes colored by atom type.

In [None]:
# Visualize sample graphs from both classes
def visualize_mutag_samples(dataset, num_samples=5, random_seed=42):
    """Visualize molecular graphs with nodes colored by atom type"""
    
    # Find samples from both classes (at least 2 from each)
    mutagenic_indices = [i for i, data in enumerate(dataset) if data.y.item() == 1]
    non_mutagenic_indices = [i for i, data in enumerate(dataset) if data.y.item() == 0]
    
    # select half of the samples from each class there are num_samples in total
    n = int(num_samples / 2)
    m = num_samples - n  # Remaining samples from the other class
    # Ensure we have enough samples
    if len(mutagenic_indices) < n or len(non_mutagenic_indices) < m:
        raise ValueError("Not enough samples in one of the classes to visualize.")
    selected_indices = mutagenic_indices[:n] + non_mutagenic_indices[:m]

    # Define atom colors for mutagenic dataset
    atom_names = ['C', 'O', 'Cl', 'H', 'N', 'F', 'Br', 'S', 'P', 'I', 'Na', 'K', 'Li', 'Ca']  # Common atom types in MUTAG
    atom_colors = []
    random.seed(random_seed)  # For reproducibility
    for atom_name in atom_names:
        r = random.random()
        g = random.random()
        b = random.random()
        atom_colors.append((r, g, b))
    # let's define some specific colors for common atoms in MUTAG based on traditional coloring from organic chemistry
    atom_colors[0] = (0,0,0)
    atom_colors[1] = (1,0,0)  # O is red
    atom_colors[2] = (0,1,0)  # Cl is green
    atom_colors[3] = (.9,.9,.9)  # H is off-white to distinguish from white background
    atom_colors[4] = (0,0,1)  # N is blue
    atom_colors[5] = (0,1,1)  # F is cyan
    atom_colors[6] = (150/255, 75/255, 0)  # Br is brown
    atom_colors[7] = (1,1,0)  # S is yellow
    atom_colors[8] = (1,165/255,0)  # P is orange

    # only put 5 samples per row in the subplot
    # Create subplots
    if num_samples > 5:
        num_rows = (num_samples + 4) // 5
    else:
        num_rows = 1
    fig, axes = plt.subplots(num_rows, 5, figsize=(20, 4 * num_rows))
    axes = axes.flatten() if num_rows > 1 else [axes]  # Flatten if multiple rows

    for idx, graph_idx in enumerate(selected_indices):
        data = dataset[graph_idx]
        
        # Convert to NetworkX
        G = to_networkx(data, to_undirected=True, remove_self_loops=True)
        
        # Get atom types (argmax of one-hot encoding)
        atom_types = data.x.argmax(dim=1).numpy()
        node_colors = [atom_colors[atom_type] for atom_type in atom_types]
        
        # Create layout
        pos = nx.spring_layout(G, seed=42, k=1, iterations=50)
        
        # Draw the graph
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=300, 
                              alpha=0.8, ax=axes[idx])
        nx.draw_networkx_edges(G, pos, alpha=0.6, width=2, edge_color='gray', ax=axes[idx])
        nx.draw_networkx_labels(G, pos, font_size=8, font_color='black', ax=axes[idx])
        
        # Set title
        label = 'Mutagenic' if data.y.item() == 1 else 'Non-mutagenic'
        axes[idx].set_title(f'{label}\n{data.x.size(0)} nodes, {data.edge_index.size(1)//2} edges', 
                           fontsize=12, fontweight='bold')
        axes[idx].axis('off')
    
    # add a legend for atom types
    legend_elements = [plt.Line2D([0], [0], marker='o', color='w', 
                                   markerfacecolor=atom_colors[i], markersize=10,
                                   label=f'{atom_names[i]}')
                       for i in range(len(atom_colors))]
    axes[-1].legend(handles=legend_elements, loc='upper left', bbox_to_anchor=(1.05, 1), 
                    title='Atom Types', title_fontsize=12, fontsize=10)
    plt.suptitle('MUTAG Dataset: Sample Molecular Graphs', fontsize=16, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(output_dir / 'mutag_sample_graphs.png', dpi=300, bbox_inches='tight')
    plt.show()

# Visualize samples
visualize_mutag_samples(dataset, 10, 42)

Next, let's examine the balance between mutagenic and non-mutagenic compounds in the dataset.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
mutag_class_names = ['Non-mutagenic', 'Mutagenic']
total_samples = len(dataset)

# Count samples in each class
class_counts = [0, 0]
for data in dataset:
    class_counts[data.y.item()] += 1

# Bar plot
colors = ['#FF6B6B', '#4ECDC4']  # Red for non-mutagenic, teal for mutagenic
bars = ax.bar(range(len(mutag_class_names)), [class_counts[i] for i in range(len(mutag_class_names))], 
              color=colors, alpha=0.8, edgecolor='black', linewidth=1)

ax.set_xlabel('Class', fontsize=12)
ax.set_ylabel('Number of Compounds', fontsize=12)
ax.set_title('MUTAG Dataset: Class Distribution', fontsize=14, fontweight='bold')
ax.set_xticks(range(len(mutag_class_names)))
ax.set_xticklabels(mutag_class_names)

# Add value labels on bars
for i, bar in enumerate(bars):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 1,
            f'{int(height)} ({(height/total_samples)*100:.1f}%)', 
            ha='center', va='bottom', fontsize=11, fontweight='bold')

ax.grid(True, alpha=0.3, axis='y')
plt.tight_layout()
plt.savefig(output_dir / 'mutag_class_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

Finally, let's analyze the structural properties of the molecular graphs, including size statistics.

In [None]:
# Analyze graph properties
def analyze_graph_properties(dataset):
    """Compute statistics about graph structure"""
    num_nodes = []
    num_edges = []
    labels_list = []
    
    for data in dataset:
        num_nodes.append(data.x.size(0))
        num_edges.append(data.edge_index.size(1) // 2)  # Divide by 2 for undirected edges
        labels_list.append(data.y.item())

    return np.array(num_nodes), np.array(num_edges), np.array(labels_list)

# Analyze properties
num_nodes, num_edges, labels_array = analyze_graph_properties(dataset)

# Create box plots of graph sizes
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Prepare data for box plots
nodes_by_class = [num_nodes[labels_array == i] for i in range(len(mutag_class_names))]
edges_by_class = [num_edges[labels_array == i] for i in range(len(mutag_class_names))]

# Nodes box plot
bp1 = axes[0].boxplot(nodes_by_class, labels=mutag_class_names, patch_artist=True)
colors = ['#FF6B6B', '#4ECDC4']
for patch, color in zip(bp1['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

axes[0].set_xlabel('Class', fontsize=12)
axes[0].set_ylabel('Number of Nodes', fontsize=12)
axes[0].set_title('Graph Size Distribution (Nodes)', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Edges box plot
bp2 = axes[1].boxplot(edges_by_class, labels=mutag_class_names, patch_artist=True)
for patch, color in zip(bp2['boxes'], colors):
    patch.set_facecolor(color)
    patch.set_alpha(0.7)

axes[1].set_xlabel('Class', fontsize=12)
axes[1].set_ylabel('Number of Edges', fontsize=12)
axes[1].set_title('Graph Size Distribution (Edges)', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / 'mutag_graph_size_boxplots.png', dpi=300, bbox_inches='tight')
plt.show()

### 3.2 Graph-Based Feature Engineering: Graph Measures

Unlike node classification where we computed centrality measures for individual nodes, graph classification requires features that characterize entire graphs. These graph-level metrics capture different structural properties of molecular compounds that may be predictive of mutagenicity.

### Graph-Level Measures Overview

1. **Mean Node Degree Centrality**: Average connectivity across all nodes
   - *Interpretation*: Overall connectivity density of the molecule
   - *NetworkX*: [`nx.degree_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.degree_centrality.html)

2. **Degree Assortativity**: Tendency for nodes to connect to similar-degree nodes
   - *Interpretation*: Whether highly connected atoms tend to connect to other highly connected atoms
   - *NetworkX*: [`nx.degree_assortativity_coefficient()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.assortativity.degree_assortativity_coefficient.html)

3. **Graph Diameter**: Longest shortest path between any two nodes
   - *Interpretation*: Maximum "distance" across the molecule
   - *NetworkX*: [`nx.diameter()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.distance_measures.diameter.html)

4. **Average Betweenness Centrality**: Mean betweenness centrality across all nodes
   - *Interpretation*: Average tendency for atoms to lie on shortest paths between other atoms
   - *NetworkX*: [`nx.betweenness_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.betweenness_centrality.html)

5. **Average Eigenvector Centrality**: Mean eigenvector centrality across all nodes
   - *Interpretation*: Average importance of atoms based on connections to other important atoms
   - *NetworkX*: [`nx.eigenvector_centrality()`](https://networkx.org/documentation/stable/reference/algorithms/generated/networkx.algorithms.centrality.eigenvector_centrality.html)

Now let's compute these graph-level features for each molecular graph in the MUTAG dataset.

In [None]:
# Compute graph-level features for all MUTAG graphs
def compute_graph_features(dataset):
    """Compute graph-level features for each molecular graph"""
    
    # Initialize feature storage
    features = {
        'mean_degree_centrality': [],
        'degree_assortativity': [],
        'diameter': [],
        'mean_betweenness_centrality': [],
        'mean_eigenvector_centrality': [],
        'labels': []
    }
    
    print("Computing graph features...")
    print(f"Processing {len(dataset)} graphs...")
    
    for i, data in enumerate(dataset):
        if (i + 1) % 50 == 0 or i == 0:
            print(f"  Progress: {i+1}/{len(dataset)} graphs processed")
        
        # Convert PyTorch Geometric graph to NetworkX
        G = to_networkx(data, to_undirected=True, remove_self_loops=True)
        
        # Store label
        features['labels'].append(data.y.item())
        
        try:
            # 1. Mean node degree centrality
            degree_cent = nx.degree_centrality(G)
            features['mean_degree_centrality'].append(np.mean(list(degree_cent.values())))
            
            # 2. Degree assortativity coefficient
            try:
                assortativity = nx.degree_assortativity_coefficient(G)
                features['degree_assortativity'].append(assortativity if not np.isnan(assortativity) else 0.0)
            except:
                features['degree_assortativity'].append(0.0)
            
            # 3. Graph diameter (handle disconnected graphs)
            try:
                if nx.is_connected(G):
                    diameter = nx.diameter(G)
                else:
                    # For disconnected graphs, use the maximum diameter of connected components
                    diameters = []
                    for component in nx.connected_components(G):
                        subgraph = G.subgraph(component)
                        if len(component) > 1:  # Avoid single-node components
                            diameters.append(nx.diameter(subgraph))
                    diameter = max(diameters) if diameters else 0
                features['diameter'].append(diameter)
            except:
                features['diameter'].append(0)
            
            # 4. Average betweenness centrality
            try:
                betweenness_cent = nx.betweenness_centrality(G, normalized=True)
                features['mean_betweenness_centrality'].append(np.mean(list(betweenness_cent.values())))
            except:
                features['mean_betweenness_centrality'].append(0.0)
            
            # 5. Average eigenvector centrality
            try:
                eigenvector_cent = nx.eigenvector_centrality(G, max_iter=1000, tol=1e-06)
                features['mean_eigenvector_centrality'].append(np.mean(list(eigenvector_cent.values())))
            except:
                # Eigenvector centrality can fail for disconnected graphs or graphs with certain structures
                features['mean_eigenvector_centrality'].append(0.0)
            
        except Exception as e:
            print(f"  Error processing graph {i}: {e}")
            # Add default values for failed computations
            features['mean_degree_centrality'].append(0.0)
            features['degree_assortativity'].append(0.0)
            features['diameter'].append(0)
            features['mean_betweenness_centrality'].append(0.0)
            features['mean_eigenvector_centrality'].append(0.0)
    
    print(f"Feature computation completed!")
    return features

# Compute features
graph_features = compute_graph_features(dataset)

# Convert to numpy arrays for easier manipulation
feature_names = ['mean_degree_centrality', 'degree_assortativity', 'diameter', 
                'mean_betweenness_centrality', 'mean_eigenvector_centrality']

feature_matrix = np.column_stack([
    graph_features['mean_degree_centrality'],
    graph_features['degree_assortativity'],
    graph_features['diameter'],
    graph_features['mean_betweenness_centrality'],
    graph_features['mean_eigenvector_centrality']
])

labels_array = np.array(graph_features['labels'])

print(f"\nFeature matrix shape: {feature_matrix.shape}")
print(f"Labels shape: {labels_array.shape}")

# Display feature statistics
stats_df = pd.DataFrame({
    'Feature': feature_names,
    'Mean': [np.mean(graph_features[name]) for name in feature_names],
    'Std': [np.std(graph_features[name]) for name in feature_names],
    'Min': [np.min(graph_features[name]) for name in feature_names],
    'Max': [np.max(graph_features[name]) for name in feature_names],
    'Zeros': [np.sum(np.array(graph_features[name]) == 0) for name in feature_names]
})

print("\n=== Graph-Level Feature Statistics ===")
print(stats_df.round(4))

Let's visualize the distributions of these graph-level features to understand their variance and potential discriminative power between classes.

In [None]:
# Create distribution plots for graph features
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
axes = axes.ravel()

# Define nice feature names for plotting
feature_display_names = [
    'Mean Degree Centrality',
    'Degree Assortativity',
    'Graph Diameter',
    'Mean Betweenness Centrality',
    'Mean Eigenvector Centrality'
]

colors = ['#FF6B6B', '#4ECDC4']  # Red for non-mutagenic, teal for mutagenic

# Plot distributions for all 5 features
for i in range(5):
    feature_name = feature_names[i]
    display_name = feature_display_names[i]
    
    # Create histogram for each class
    for class_idx, class_name in enumerate(mutag_class_names):
        mask = labels_array == class_idx
        feature_values = np.array(graph_features[feature_name])[mask]
        
        axes[i].hist(feature_values, bins=15, alpha=0.7, 
                    label=f'{class_name} ({np.sum(mask)} graphs)',
                    color=colors[class_idx], edgecolor='black', linewidth=0.5)
    
    axes[i].set_xlabel(display_name, fontsize=11)
    axes[i].set_ylabel('Frequency', fontsize=11)
    axes[i].set_title(f'{display_name}\\nMean: {np.mean(graph_features[feature_name]):.3f}, Std: {np.std(graph_features[feature_name]):.3f}', 
                     fontsize=12, fontweight='bold')
    axes[i].legend(fontsize=9)
    axes[i].grid(True, alpha=0.3)

# Remove the empty subplot (since we only have 5 features now)
fig.delaxes(axes[5])

plt.suptitle('MUTAG Dataset: Graph-Level Feature Distributions', 
             fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig(output_dir / 'mutag_graph_feature_distributions.png', dpi=300, bbox_inches='tight')
plt.show()

# Create correlation heatmap of graph features
plt.figure(figsize=(10, 8))

# Create correlation matrix
graph_df = pd.DataFrame(feature_matrix, columns=feature_display_names)
correlation_matrix = graph_df.corr()

# Create heatmap
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool))  # Mask upper triangle
sns.heatmap(correlation_matrix, mask=mask, annot=True, cmap='coolwarm', center=0,
            square=True, fmt='.3f', cbar_kws={"shrink": .8})

plt.title('Graph-Level Feature Correlation Matrix\\nMUTAG Dataset', 
          fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig(output_dir / 'mutag_graph_feature_correlations.png', dpi=300, bbox_inches='tight')
plt.show()

### 3.3 Logistic Regression with Graph Metrics

We'll train and evaluate a **logistic regression model** using the graph-level features to predict molecular mutagenicity. This follows the same methodology as Section 2 but adapted for binary graph classification.

### Methodology

1. **Data Splitting**: Split the data into train/test sets (80%/20%) with feature standardization
2. **Model Training**: Logistic regression with regularization using grid search and cross-validation
3. **Hyperparameter Tuning**: 5-fold cross-validation to optimize regularization parameters 
4. **Performance Analysis**: Classification reports, confusion matrices, and ROC-AUC curves

**Key Differences from Enron Node Classification:**
- **Binary classification** (mutagenic vs non-mutagenic) instead of multi-class
- **Graph-level features** instead of individual node measures (though with some aggregated node feautres)
- **Smaller dataset** (188 graphs vs 4,203 nodes) 

First, let's split the data into training and test sets, and standardize the features.

In [None]:
# Data splitting for MUTAG graph classification
X = feature_matrix.copy()
y = labels_array.copy()

print(f"Feature matrix shape: {X.shape} | Labels shape: {y.shape}")
print(f"Class distribution: {np.bincount(y)} (0: Non-mutagenic, 1: Mutagenic)")

# Split the data into train/test sets (80%/20%)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.2, 
    stratify=y, 
    random_state=RANDOM_STATE
)

# Standardize the features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

print(f"Training: {X_train.shape[0]} samples | Test: {X_test.shape[0]} samples")
print(f"Training class distribution: {np.bincount(y_train)}")
print(f"Test class distribution: {np.bincount(y_test)}")

Now let's create a logistic regression model with hyperparameter tuning using grid search and cross-validation.

#### **Exercise 3.1**
Complete the code below to create a LogisticRegression base model with the following parameters:
- solver='liblinear' (good for small datasets)
- random_state=RANDOM_STATE 
- max_iter=1000

In [None]:
# Define hyperparameter search space
param_grid = {
    'C': [0.01, 0.1, 1.0, 10.0, 100.0],  # Regularization strength
    'penalty': ['l1', 'l2']               # Regularization type
}

# Exercise 3.1: Create the base model
base_model = ___  # Your code here

# Grid search with cross-validation
cv_strategy = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)
grid_search = GridSearchCV(
    base_model,
    param_grid,
    cv=cv_strategy,
    scoring='roc_auc',  # Use ROC-AUC for binary classification
    n_jobs=-1,
    refit=True
)

# Fit on training data
start_time = time.time()
grid_search.fit(X_train_scaled, y_train)
end_time = time.time()

print(f"Best parameters: {grid_search.best_params_}")
print(f"Best CV ROC-AUC: {grid_search.best_score_:.4f}")
print(f"Training completed in {end_time - start_time:.2f} seconds")

# Get best model
best_model = grid_search.best_estimator_

Let's evaluate the model performance on the test set with a classification report.

In [None]:
# Make predictions on test set
y_test_pred = best_model.predict(X_test_scaled)

# Print classification report
print("=== MUTAG Test Set Classification Report ===")
print(classification_report(y_test, y_test_pred, target_names=mutag_class_names, zero_division=0))

Next, let's visualize the confusion matrix to understand the classification errors.

In [None]:
# Create confusion matrix
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

# Create confusion matrix with normalize='pred'
cm = confusion_matrix(y_test, y_test_pred, normalize='pred')

# Create heatmap
sns.heatmap(cm, annot=True, fmt='.2f', cmap='Blues', 
            xticklabels=mutag_class_names, yticklabels=mutag_class_names,
            cbar_kws={'label': 'Proportion'}, ax=ax)

ax.set_xlabel('Predicted Label', fontsize=12)
ax.set_ylabel('True Label', fontsize=12)
ax.set_title('MUTAG: Confusion Matrix', 
             fontsize=14, fontweight='bold')

# Add count annotations
cm_counts = confusion_matrix(y_test, y_test_pred)
for i in range(len(mutag_class_names)):
    for j in range(len(mutag_class_names)):
        ax.text(j + 0.5, i + 0.7, f'({cm_counts[i,j]})', 
               ha='center', va='center', fontsize=10, color='red')

plt.tight_layout()
plt.savefig(output_dir / 'mutag_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

Now, let's create an ROC curve to visualize the model's performance across different classification thresholds.

#### **Exercise 3.2**
Complete the code below to compute the false positive rate (fpr) and true positive rate (tpr) for the ROC curve.

**Hint**: Use the `roc_curve` function from sklearn.metrics with `y_test` and `y_test_proba` as inputs.

In [None]:
y_test_proba = best_model.predict_proba(X_test_scaled)[:, 1]  # Probability of positive class (mutagenic)

# Exercise 3.2: Compute ROC curve
fpr, tpr, thresholds = ___  # Your code here

# Calculate AUC
roc_auc = auc(fpr, tpr)

# Create ROC plot
fig, ax = plt.subplots(1, 1, figsize=(8, 6))

# Plot ROC curve
ax.plot(fpr, tpr, color='darkorange', lw=2, 
        label=f'ROC Curve (AUC = {roc_auc:.3f})')

# Plot random classifier line
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.5,
        label='Random Classifier (AUC = 0.50)')

# Customize plot
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xlabel('False Positive Rate', fontsize=12)
ax.set_ylabel('True Positive Rate', fontsize=12)
ax.set_title('MUTAG: ROC Curve for Graph Classification', fontsize=14, fontweight='bold')
ax.legend(loc="lower right", fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / 'mutag_roc_curve.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"ROC-AUC Score: {roc_auc:.4f}")

Finally, let's examine the feature importance by reviewing the logistic regression coefficients.

In [None]:
coeffs = best_model.coef_
# This is binary classification, so coeffs will have shape (1, n_features)
print("Logistic Regression Coefficients Shape:")
print(coeffs.shape)

for fn, coef in zip(feature_display_names, coeffs[0]):
    print(f"{fn}: {coef:.4f}")