In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import pandas as pd
import numpy as np
import os
import copy
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error # For evaluation

# torch-pruning
import torch_pruning as tp

# Type Hinting (Optional but good practice)
from typing import Tuple, List, Dict, Union, Optional

 ### Data Loading and Preprocessing Functions (Using your provided code)

In [2]:
# --- Functions from your provided code ---

column_names = ['unit_number', 'time_in_cycles'] + [f'op_setting_{i}' for i in range(1, 4)] + [f'sensor_{i}' for i in range(1, 24)]

def load_dataframe(file_path: str) -> pd.DataFrame | None:
    """Loads a single CMaps data file."""
    try:
        df = pd.read_csv(file_path, sep=' ', header=None, names=column_names)
        # Drop the last two columns if they are all NaNs (often artifacts of space delimiter)
        df.dropna(axis=1, how='all', inplace=True)
        return df
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return None

def clean_data(df: pd.DataFrame) -> list:
    """Identifies columns to remove based on low std dev."""
    if df is None:
        return []
    # Columns with std dev < 0.02 (potential candidates for removal)
    # Avoid removing unit_number or time_in_cycles here.
    cols_to_check = [col for col in df.columns if 'sensor' in col or 'op_setting' in col]
    low_std_cols = [col for col in cols_to_check if df[col].std() < 0.02]
    print(f"Columns with std < 0.02 (potential removal): {low_std_cols}")
    # You might decide which ones to actually remove based on domain knowledge or experiment
    # For this example, let's remove them as identified.
    return low_std_cols

def add_rul(df: pd.DataFrame) -> pd.DataFrame | None:
    """Calculates and adds the Remaining Useful Life (RUL) column."""
    if df is None:
        return None
    max_cycles = df.groupby('unit_number')['time_in_cycles'].max().reset_index()
    max_cycles.columns = ['unit_number', 'max_cycle']
    df = df.merge(max_cycles, on='unit_number', how='left')
    df['RUL'] = df['max_cycle'] - df['time_in_cycles']
    df.drop(columns=['max_cycle'], inplace=True)
    # Clip RUL (optional, common practice to limit max RUL)
    df['RUL'] = df['RUL'].clip(upper=125)
    return df

def normalize_data(df: pd.DataFrame,
                   columns_to_normalize: List[str], scaler: MinMaxScaler = None) -> Tuple[pd.DataFrame, MinMaxScaler] | Tuple[None, None]:
    """Normalizes specified columns using MinMaxScaler."""
    if df is None:
        return None, None
    data_to_scale = df[columns_to_normalize]
    if scaler is None:
        scaler = MinMaxScaler()
        df[columns_to_normalize] = scaler.fit_transform(data_to_scale)
    else:
        # Ensure only columns present in the scaler are transformed
        valid_cols = [col for col in columns_to_normalize if col in scaler.feature_names_in_]
        if len(valid_cols) < len(columns_to_normalize):
            print("Warning: Some columns not found in the provided scaler. Skipping them.")
        if valid_cols: # Check if there's anything to transform
             df[valid_cols] = scaler.transform(df[valid_cols])

    return df, scaler

# --- Data Preparation Main Function ---
def prepare_cmapss_data(data_dir: str, train_file: str, test_file: str, test_rul_file: str) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, MinMaxScaler, List[str]]:
    """Loads, cleans, preprocesses train/test data and RUL."""
    print("--- Preparing Training Data ---")
    train_df = load_dataframe(os.path.join(data_dir, train_file))
    train_df = add_rul(train_df)

    print("\n--- Preparing Test Data ---")
    test_df = load_dataframe(os.path.join(data_dir, test_file))
    test_rul_df = pd.read_csv(os.path.join(data_dir, test_rul_file), header=None, names=['RUL'])
    # Adjust RUL based on test_rul_df and clipping if needed
    # Test RUL is usually the RUL at the *end* of the test sequence
    # We'll use this test_rul_df directly later for evaluation targets

    # Clean Data - identify columns based on TRAINING data variance
    cols_to_remove = clean_data(train_df)
    feature_cols = [col for col in train_df.columns if col not in ['unit_number', 'time_in_cycles', 'RUL'] + cols_to_remove]
    print(f"\nUsing Features: {feature_cols}")

    # Drop removed columns from both train and test
    train_df.drop(columns=cols_to_remove, inplace=True, errors='ignore')
    test_df.drop(columns=cols_to_remove, inplace=True, errors='ignore')


    # Normalize features based on TRAINING data
    print("\n--- Normalizing Data ---")
    train_df_norm, scaler = normalize_data(train_df.copy(), feature_cols, scaler=None)
    # Use the same scaler for test data
    test_df_norm, _ = normalize_data(test_df.copy(), feature_cols, scaler=scaler)

    return train_df_norm, test_df_norm, test_rul_df, scaler, feature_cols

### Define MLP Model and Dataset Class


In [3]:
# --- MLP Model Definition (Using your provided class) ---
class MLPmodel(nn.Module):
    def __init__(self,
                 layer_units: List[int],    # List of neuron counts for hidden layers, e.g., [512, 256, 128]
                 input_size: int,           # Number of input features
                 output_size: int = 1,      # Number of output units (1 for RUL regression)
                 dropout_rate: float = 0.2,
                 use_batchnorm: bool = False, # Flag to enable/disable BatchNorm
                 final_activation: Optional[nn.Module] = None # Optional: for tasks needing a final activation
                ):
        super(MLPmodel, self).__init__()
        self.model_type = 'MLP'
        self.layer_units = layer_units # Store for potential reference
        self.input_size = input_size
        self.output_size = output_size
        self.use_batchnorm = use_batchnorm
        self.final_activation = final_activation

        layers = []
        current_dim = input_size

        # Create hidden layers
        for hidden_units in layer_units:
            layers.append(nn.Linear(current_dim, hidden_units))
            if use_batchnorm:
                layers.append(nn.BatchNorm1d(hidden_units)) # BatchNorm applied to the output of Linear
            layers.append(nn.ReLU()) # Activation function
            layers.append(nn.Dropout(dropout_rate)) # Dropout for regularization
            current_dim = hidden_units # Update current_dim for the next layer

        # Create the output layer
        layers.append(nn.Linear(current_dim, output_size))

        # Add final activation if specified (e.g., Sigmoid for binary classification)
        if self.final_activation is not None:
            layers.append(self.final_activation)

        # Combine all layers into a sequential module
        self.network = nn.Sequential(*layers)

        # Initialize weights
        self.init_weights()

    def init_weights(self) -> None:
        """Initializes weights for linear layers."""
        for module in self.network.modules(): # Iterate through all modules in self.network
            if isinstance(module, nn.Linear):
                # Kaiming uniform is often good for layers followed by ReLU
                nn.init.kaiming_uniform_(module.weight, nonlinearity='relu')
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
            elif isinstance(module, nn.BatchNorm1d):
                # Default initialization for BatchNorm1d is usually fine (mean 0, var 1 for weights, 0 for bias)
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Defines the forward pass of the MLP.
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_size).
        Returns:
            torch.Tensor: Output tensor, shape depends on output_size and final_activation.
                          For RUL regression, usually (batch_size, 1).
        """
        # Input x shape should be [batch_size, input_features]
        return self.network(x)

    def get_prunable_layers(self) -> List[nn.Linear]:
        """Helper to get a list of linear layers, excluding the final output layer for pruning."""
        prunable = []
        # Iterate through direct children of self.network, which are the layers we added
        for i, layer in enumerate(self.network):
            if isinstance(layer, nn.Linear):
                # Check if it's the last Linear layer in the sequence
                is_last_linear = True
                for subsequent_layer in list(self.network)[i+1:]: # Look ahead
                    if isinstance(subsequent_layer, nn.Linear):
                        is_last_linear = False
                        break
                if not is_last_linear:
                    prunable.append(layer)
        return prunable

    def get_output_layer(self) -> Optional[nn.Linear]:
        """Helper to get the final output linear layer, which should typically be ignored during pruning."""
        # Iterate backwards to find the last nn.Linear layer in self.network
        for layer in reversed(list(self.network.children())):
            if isinstance(layer, nn.Linear):
                return layer
        return None

# --- PyTorch Dataset for CMaps MLP ---
class CMAPSS_MLP_Dataset(Dataset):
    def __init__(self, features: np.ndarray, targets: np.ndarray):
        if features.ndim == 3: # Handle potential sequence input by flattening or taking last step
             print("Warning: Input features seem sequential. Taking last step for MLP.")
             # This assumes LSTM-prepared data; better to prepare MLP data correctly upstream
             features = features[:, -1, :]
        if targets.ndim > 1 and targets.shape[1] > 1:
             print("Warning: Targets have more than one dimension. Squeezing.")
             targets = targets.squeeze()


        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32).unsqueeze(1) # Ensure target shape is [N, 1]

    def __len__(self) -> int:
        return len(self.features)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.features[idx], self.targets[idx]

### Training and Evaluation Functions (Adapted for Regression)

In [4]:
# --- Utility Functions ---
def save_model_as_onnx(model: nn.Module, example_input: torch.Tensor, output_path: str, device: torch.device):
    """Exports a PyTorch model to ONNX format."""
    model.eval() # Ensure model is in evaluation mode
    model.to(device) # Ensure model is on the correct device
    example_input_on_device = example_input.to(device) # Ensure example input is on the device

    try:
        torch.onnx.export(
            model,
            example_input_on_device,
            output_path,
            export_params=True,
            opset_version=11, # Common opset, can be 12, 13, etc.
            input_names=['input'],
            output_names=['output'],
            dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
        )
        print(f"✅ Model successfully saved as ONNX to {output_path}")
    except Exception as e:
        print(f" Loglevel: Error - Failed to export model to ONNX at {output_path}: {e}")
        print(" Loglevel: Error - Please check model compatibility with ONNX opset version or input/output names.")


def save_model_state(model: nn.Module, path_prefix: str, example_input_for_onnx: Optional[torch.Tensor] = None, device_for_onnx: Optional[torch.device] = None):
    """Saves model state dictionary (.pth) and optionally its ONNX version."""
    # Ensure path_prefix ends with .pth for the PyTorch state_dict
    if not path_prefix.endswith(".pth"):
        pth_path = path_prefix + ".pth"
    else:
        pth_path = path_prefix # Assume full .pth path given
        path_prefix = path_prefix[:-4] # Remove .pth for ONNX naming

    os.makedirs(os.path.dirname(pth_path), exist_ok=True)
    torch.save(model.state_dict(), pth_path)
    print(f"Model state (.pth) saved to {pth_path}")

    if example_input_for_onnx is not None and device_for_onnx is not None:
        onnx_path = path_prefix + ".onnx" # Use the prefix before .pth was added
        print(f"Attempting to save ONNX model to: {onnx_path}")
        save_model_as_onnx(model, example_input_for_onnx, onnx_path, device_for_onnx)

def load_model_state(model, path, device):
    """Loads model state dictionary."""
    model.load_state_dict(torch.load(path, map_location=device))
    print(f"Model state loaded from {path}")
    return model

def calculate_flops_params(model, example_input): # Assuming this is defined elsewhere or here
    # Ensure model and example_input are on the same device for count_ops_and_params
    device = next(model.parameters()).device
    example_input_on_device = example_input.to(device)
    return tp.utils.count_ops_and_params(model, example_input_on_device)

# --- Evaluation Function (RMSE) ---
def evaluate_model_rmse(model: nn.Module, data_loader: DataLoader, device: torch.device, example_input: torch.Tensor) -> Dict[str, float]:
    """Evaluates the regression model using RMSE and calculates FLOPs/Params."""
    model.eval()
    all_predictions = []
    all_targets = []
    with torch.no_grad():
        for features, targets in data_loader:
            features = features.to(device)
            predictions = model(features)
            all_predictions.append(predictions.cpu().numpy())
            all_targets.append(targets.numpy()) # Targets are already [N, 1]

    # Concatenate results from all batches
    all_predictions = np.concatenate(all_predictions).squeeze()
    all_targets = np.concatenate(all_targets).squeeze()

    # Calculate RMSE
    rmse = np.sqrt(mean_squared_error(all_targets, all_predictions))
    print(f"Evaluation RMSE: {rmse:.4f}")

    # Calculate FLOPs and Params
    flops, params = calculate_flops_params(model, example_input.to(device))
    size_mb = params * 4 / 1e6 # Approximation

    return {
        'rmse': rmse,
        'flops': flops,
        'params': params,
        'size_mb': size_mb
    }



def train_mlp_model(
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
        device: torch.device,
        num_epochs: int,
        patience: int = 10,
        model_save_path: str = "temp_best_model.pth"
                    ) -> Tuple[nn.Module, List[float], List[float], List[float]]: # Return history
    """Trains an MLP model with validation, early stopping.
       Returns: best model, train_loss_history, val_loss_history, val_rmse_history (if calculable)
    """

    best_val_loss = float('inf')
    epochs_no_improve = 0
    best_model_state = None

    train_loss_history = []
    val_loss_history = []
    # val_rmse_history = [] # We'll calculate RMSE outside based on final model on test set for simplicity in main loop

    print(f"Starting training on {device} with patience={patience}")

    for epoch in range(num_epochs):
        model.train()
        running_loss_train = 0.0
        for features, targets in train_loader:
            features, targets = features.to(device), targets.to(device)

            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss_train += loss.item() * features.size(0)

        epoch_train_loss = running_loss_train / len(train_loader.dataset)
        train_loss_history.append(epoch_train_loss)

        # Validation phase
        model.eval()
        running_loss_val = 0.0
        all_val_preds = []
        all_val_targets = []
        with torch.no_grad():
            for features, targets in val_loader:
                features, targets = features.to(device), targets.to(device)
                outputs = model(features)
                loss = criterion(outputs, targets)
                running_loss_val += loss.item() * features.size(0)
                all_val_preds.append(outputs.cpu().numpy())
                all_val_targets.append(targets.cpu().numpy())


        epoch_val_loss = running_loss_val / len(val_loader.dataset)
        val_loss_history.append(epoch_val_loss)

        # Optional: Calculate validation RMSE per epoch if needed for detailed plotting
        # current_val_preds = np.concatenate(all_val_preds).squeeze()
        # current_val_targets = np.concatenate(all_val_targets).squeeze()
        # epoch_val_rmse = np.sqrt(mean_squared_error(current_val_targets, current_val_preds))
        # val_rmse_history.append(epoch_val_rmse)


        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch+1}/{num_epochs}: Train Loss={epoch_train_loss:.4f}, Val Loss={epoch_val_loss:.4f}, LR={current_lr:.1e}")
        # print(f"Epoch {epoch+1}/{num_epochs}: Train Loss={epoch_train_loss:.4f}, Val Loss={epoch_val_loss:.4f}, Val RMSE={epoch_val_rmse:.4f} LR={current_lr:.1e}")


        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            epochs_no_improve = 0
            best_model_state = copy.deepcopy(model.state_dict())
            print(f"*** New best validation loss: {best_val_loss:.4f} (Epoch {epoch+1}) ***")
        else:
            epochs_no_improve += 1
            # print(f"Val loss did not improve for {epochs_no_improve} epoch(s).") # Can be verbose

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {patience} epochs without improvement on Val Loss.")
            break

        if scheduler:
            if isinstance(scheduler, ReduceLROnPlateau):
                scheduler.step(epoch_val_loss)
            else:
                scheduler.step()

    print(f"Training finished. Best validation loss: {best_val_loss:.4f}")
    if best_model_state:
        model.load_state_dict(best_model_state)
        print("Loaded best model state based on validation loss.")
    else:
         print("Warning: No improvement in validation loss. Using model from last epoch.")

    return model, train_loss_history, val_loss_history #, val_rmse_history

### Pruning Function (Adapted for MLP)

In [5]:
def prune_mlp_model(model_to_prune: nn.Module, example_input: torch.Tensor, strategy_name_for_debug:str, strategy: Dict, # Added strategy_name_for_debug
                    target_sparsity: float = 0.5, iterative_steps: int = 1,
                    importance: Optional[tp.importance.Importance] = None) -> nn.Module:
    """Prunes an MLP model using the specified strategy (Revised with more prints)."""
    device = example_input.device
    model_to_prune.eval().to(device)

    print(f"\n--- Model Structure BEFORE Pruning ({strategy_name_for_debug}) ---")
    print(model_to_prune)
    flops_before, params_before = calculate_flops_params(model_to_prune, example_input)
    print(f"State Before Pruning ({strategy_name_for_debug}): FLOPs={flops_before/1e6:.4f}M, Params={params_before/1e6:.4f}M")


    ignored_layers = []
    if hasattr(model_to_prune, 'get_output_layer') and callable(model_to_prune.get_output_layer):
        output_layer_instance = model_to_prune.get_output_layer()
        if output_layer_instance:
            ignored_layers.append(output_layer_instance)
            print(f"Ignoring output layer ({output_layer_instance}) during pruning for {strategy_name_for_debug}.")
    elif hasattr(model_to_prune, 'layers') and isinstance(model_to_prune.layers, nn.ModuleList) and len(model_to_prune.layers) > 0: # Fallback for older MLP structure
        output_layer_candidate = model_to_prune.layers[-1]
        if isinstance(output_layer_candidate, nn.Linear):
             ignored_layers.append(output_layer_candidate)
             print(f"Ignoring output layer ({output_layer_candidate}) using direct access for {strategy_name_for_debug}.")
    else:
        print(f"Warning ({strategy_name_for_debug}): Could not automatically determine output layer. Check model structure.")


    if importance is None:
        importance_metric = strategy['importance']
    else:
        importance_metric = importance
    print(f"Using Importance Metric ({strategy_name_for_debug}): {type(importance_metric).__name__}")

    pruner_class = strategy['pruner']
    print(f"Using Pruner Class ({strategy_name_for_debug}): {pruner_class.__name__}")

    try:
        pruner = pruner_class(
            model=model_to_prune,
            example_inputs=example_input.to(device),
            importance=importance_metric,
            iterative_steps=iterative_steps,
            ch_sparsity=target_sparsity,
            root_module_types=[nn.Linear],
            ignored_layers=ignored_layers,
        )
    except TypeError as e:
        print(f"Warning ({strategy_name_for_debug}): Error initializing pruner {pruner_class.__name__} with standard args: {e}")
        print("Attempting initialization with fewer args...")
        try:
            pruner = pruner_class(
                model=model_to_prune,
                example_inputs=example_input.to(device),
                importance=importance_metric,
                ch_sparsity=target_sparsity,
                ignored_layers=ignored_layers,
             )
        except Exception as E:
             print(f"ERROR ({strategy_name_for_debug}): Could not initialize pruner {pruner_class.__name__}")
             raise E


    print(f"Starting pruning with {strategy_name_for_debug}, Target Sparsity: {target_sparsity:.2f}")

    if isinstance(importance_metric, tp.importance.TaylorImportance):
        model_to_prune.train()
        input_on_device = example_input.to(device)
        output = model_to_prune(input_on_device)
        loss = torch.sum(output**2)
        model_to_prune.zero_grad()
        try:
            loss.backward()
            print(f"Calculated gradients for TaylorImportance ({strategy_name_for_debug}).")
        except Exception as e:
            print(f"ERROR ({strategy_name_for_debug}): Could not perform backward pass for TaylorImportance: {e}")
            raise e
        finally:
             model_to_prune.eval()


    # --- BEGIN DEBUGGING PRUNING GROUPS (Optional) ---
    # if "magnitude" in strategy_name_for_debug: # Or any strategy you want to debug
    #     print(f"--- Inspecting Pruning Groups for {strategy_name_for_debug} (Interactive Mode) ---")
    #     # Create a temporary pruner instance for inspection only to avoid altering the main one
    #     temp_pruner_for_inspection = pruner_class(
    #         model=model_to_prune, example_inputs=example_input.to(device),
    #         importance=importance_metric, ch_sparsity=target_sparsity,
    #         root_module_types=[nn.Linear], ignored_layers=ignored_layers,
    #     )
    #     # This call is just to get the groups, not to prune with this temp_pruner
    #     num_groups_found = 0
    #     # DG should be created upon pruner instantiation if using torch_pruning>=1.2
    #     if hasattr(temp_pruner_for_inspection, 'DG'):
    #         example_group = temp_pruner_for_inspection.DG.get_pruning_group(model_to_prune.layers[0], tp.pruning.pruning_dim.pruning_out_channels, 0) # Example
    #         print(f"Example Group from DG: {example_group}")

    #     # Iterate through what the pruner would do if it were interactive
    #     # Note: pruner.step(interactive=True) might need to be called to generate groups for some pruners
    #     # This part of debugging might be tricky without a deeper dive into a specific pruner
    #     print(f"--- End of Group Inspection for {strategy_name_for_debug} ---")
    # --- END DEBUGGING PRUNING GROUPS ---

    try:
        pruner.step()
    except AttributeError as e:
         print(f"ERROR ({strategy_name_for_debug}): During pruner.step() for {pruner_class.__name__}: {e}")
         raise e
    except Exception as e:
        print(f"ERROR ({strategy_name_for_debug}): An unexpected error during pruner.step() for {pruner_class.__name__}: {e}")
        raise e

    print(f"\n--- Model Structure AFTER Pruning ({strategy_name_for_debug}) ---")
    print(model_to_prune)
    flops_after, params_after = calculate_flops_params(model_to_prune, example_input.to(device))
    print(f"Pruning finished for {strategy_name_for_debug}. Final FLOPs: {flops_after/1e6:.4f}M, Params: {params_after/1e6:.4f}M")
    if flops_before > 0: print(f"FLOPs Reduction ({strategy_name_for_debug}): {(flops_before-flops_after)/flops_before*100:.2f}%")
    if params_before > 0: print(f"Params Reduction ({strategy_name_for_debug}): {(params_before-params_after)/params_before*100:.2f}%")

    print(f"\n--- Named Parameters AFTER Pruning ({strategy_name_for_debug}) ---")
    for name, param in model_to_prune.named_parameters():
        print(f"{name}: shape={param.shape}, num_elements={param.numel()}, requires_grad={param.requires_grad}")


    num_params_fixed = 0
    for name, param in model_to_prune.named_parameters():
        if not param.requires_grad:
            param.requires_grad = True
            num_params_fixed += 1
    if num_params_fixed > 0:
        print(f"Set requires_grad=True for {num_params_fixed} parameters ({strategy_name_for_debug}).")

    return model_to_prune

###  Comparison and Plotting (Adapted for Regression)

In [6]:
import os # Ensure imported
from typing import Dict # Ensure imported

def compare_results_and_plot_rmse(results: Dict[str, Dict[str, float]], output_dir: str):
    """Prints comparison table and plots results for regression (RMSE)."""

    print("\n=== Pruning Strategy Comparison (RMSE) ===")
    print(f"{'Strategy':<12} | {'FLOPs':<12} | {'Params':<10} | {'Size (MB)':<10} | {'RMSE':<10}")
    print("-" * 65)
    # Sort strategies by RMSE (lower is better) for better comparison
    sorted_strategies = sorted(results.keys(), key=lambda s: results[s].get('rmse', float('inf')))

    for strategy in sorted_strategies:
        metrics = results[strategy]
        # Use .get() with defaults for robustness if a metric is missing
        flops_m = metrics.get('flops', 0) / 1e6
        params_m = metrics.get('params', 0) / 1e6
        size_mb_val = metrics.get('size_mb', 0)
        rmse_val = metrics.get('rmse', float('nan'))
        print(f"{strategy:<12} | {flops_m:<11.2f}M | {params_m:<9.2f}M | {size_mb_val:>9.2f} | {rmse_val:<10.4f}")

    os.makedirs(output_dir, exist_ok=True)
    # Ensure 'initial' is first if it exists for plotting reference
    plot_strategies = ['initial'] + [s for s in sorted_strategies if s != 'initial']
    metrics_to_plot = ['flops', 'params', 'size_mb', 'rmse']
    titles = {'flops': 'FLOPs Comparison', 'params': 'Parameters Comparison',
              'size_mb': 'Model Size (MB) Comparison', 'rmse': 'RMSE Comparison (Lower is Better)'}
    y_labels = {'flops': 'FLOPs (Millions)', 'params': 'Parameters (Millions)',
                'size_mb': 'Size (MB)', 'rmse': 'RMSE'}

    colors = plt.cm.viridis(np.linspace(0, 1, len(plot_strategies)))

    for metric_name in metrics_to_plot:
        if not any(metric_name in results.get(s, {}) for s in plot_strategies):
             print(f"Skipping plot for '{metric_name}', data not found in results.")
             continue

        values = []
        for strategy in plot_strategies:
             metric_val = results.get(strategy, {}).get(metric_name, np.nan)
             if metric_name in ['flops', 'params']:
                 if not np.isnan(metric_val): # Avoid division by zero or on NaN
                     metric_val /= 1e6
             values.append(metric_val)

        plt.figure(figsize=(12, 6))
        bars = plt.bar(plot_strategies, values, color=colors)
        plt.xlabel('Strategy')
        plt.ylabel(y_labels[metric_name])
        plt.title(titles[metric_name])
        plt.xticks(rotation=45, ha='right')

        # Determine format string for labels OUTSIDE the f-string placeholder
        label_format = '.4f' if metric_name == 'rmse' else '.2f'
        for bar in bars:
            yval = bar.get_height()
            if not np.isnan(yval):
                # Use the determined label_format inside the f-string placeholder
                plt.text(bar.get_x() + bar.get_width()/2., yval, f'{yval:{label_format}}',
                         ha='center', va='bottom', fontsize=9)

        # Add initial model reference line
        if 'initial' in results and metric_name in results['initial'] and not np.isnan(results['initial'][metric_name]):
            initial_value = results['initial'][metric_name]
            if metric_name in ['flops', 'params']:
                initial_value /= 1e6

            # --- CORRECTED PART for the label ---
            # 1. Determine the format string based on the metric
            initial_label_format = '.4f' if metric_name == 'rmse' else '.2f'
            # 2. Apply the format string to the value
            formatted_initial_value = f"{initial_value:{initial_label_format}}"
            # 3. Construct the label string
            initial_line_label = f"Initial ({formatted_initial_value})"

            plt.axhline(y=initial_value, color='r', linestyle='--', label=initial_line_label) # Use the constructed label
            plt.legend()

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'mlp_{metric_name}_comparison.png'))
        plt.close()

    print(f"Comparison plots saved to {output_dir}")

### Main Workflow Configuration

In [7]:
# --- Configuration ---
DATA_DIR = './data/CMaps/' # <<< IMPORTANT: Set path to your NASA CMaps data directory
OUTPUT_DIR = './output_mlp_pruning/fd001/'
TRAIN_FILE = 'train_FD001.txt'
TEST_FILE = 'test_FD001.txt'
TEST_RUL_FILE = 'RUL_FD001.txt'

# Create output directory if it doesn't exist
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Model Config
MLP_HIDDEN_UNITS = [2048, 1024, 1024, 512, 512, 256, 128]
OUTPUT_SIZE = 1 # For RUL regression
DROPOUT_RATE = 0.3 # Maybe increase for bigger model
USE_BATCHNORM = True # Set to True to include BatchNorm layers

# Training Config
INITIAL_TRAIN_EPOCHS = 100 # Train longer initially
FINETUNE_EPOCHS = 100    # Fine-tune potentially as long
BATCH_SIZE = 128
INITIAL_LR = 0.001
FINETUNE_LR = 0.0005
PATIENCE = 15 # Patience for early stopping
VAL_SPLIT_RATIO = 0.2 # Use 20% of training engines for validation

# Pruning Config
PRUNING_TARGET_SPARSITY = 0.2 # Target 50% sparsity
PRUNING_ITERATIVE_STEPS = 1 # For structured pruning, 1 step is common
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Pruning strategies (can reuse from ResNet example)
# Note: BNScalePruner/GroupNormPruner less applicable to MLP without BatchNorm layers
pruning_strategies = {
    'magnitude': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.MagnitudeImportance(p=2)},
    #'bn_scale': {'pruner': tp.pruner.BNScalePruner, 'importance': tp.importance.BNScaleImportance()}, # If you add BN layers
    'random': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.RandomImportance()},
    'Taylor': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.TaylorImportance()},
    #'Hessian': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.GroupHessianImportance()}, # Slow, requires grads
    'lamp': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.LAMPImportance(p=2)},
    #'geometry': {'pruner': tp.pruner.MagnitudePruner, 'importance': tp.importance.FPGMImportance()} # More geometric
}

### Data Loading and Preparation Execution

In [8]:
# --- Load and Prepare Data ---
train_df_norm, test_df_norm, test_rul_df, scaler, feature_cols = prepare_cmapss_data(
    DATA_DIR, TRAIN_FILE, TEST_FILE, TEST_RUL_FILE
)

INPUT_SIZE = len(feature_cols) # Determine input size dynamically
print(f"MLP Input Size determined as: {INPUT_SIZE}")

# --- Split Training Data into Train/Validation (by engine unit) ---
train_units = train_df_norm['unit_number'].unique()
np.random.seed(42) # For reproducible split
np.random.shuffle(train_units)
split_idx = int(len(train_units) * (1 - VAL_SPLIT_RATIO))
train_unit_ids = train_units[:split_idx]
val_unit_ids = train_units[split_idx:]

df_train_split = train_df_norm[train_df_norm['unit_number'].isin(train_unit_ids)]
df_val_split = train_df_norm[train_df_norm['unit_number'].isin(val_unit_ids)]

print(f"Training data split: {len(df_train_split)} samples ({len(train_unit_ids)} engines)")
print(f"Validation data split: {len(df_val_split)} samples ({len(val_unit_ids)} engines)")

# --- Prepare MLP Inputs/Outputs ---
# Training data: use all time steps
X_train = df_train_split[feature_cols].values
y_train = df_train_split['RUL'].values

# Validation data: use all time steps
X_val = df_val_split[feature_cols].values
y_val = df_val_split['RUL'].values

# Test data: use ONLY the LAST time step for each engine
X_test = []
test_engine_ids = test_df_norm['unit_number'].unique()
for eng_id in test_engine_ids:
    eng_data = test_df_norm[test_df_norm['unit_number'] == eng_id]
    last_step_features = eng_data[feature_cols].iloc[-1].values # Get last row features
    X_test.append(last_step_features)
X_test = np.array(X_test)
# Target RULs for test set are provided directly in RUL_FD001.txt
y_test = test_rul_df['RUL'].values[:len(X_test)] # Ensure lengths match if RUL file has extra lines

print(f"Prepared MLP data shapes:")
print(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
print(f"X_val: {X_val.shape}, y_val: {y_val.shape}")
print(f"X_test: {X_test.shape}, y_test: {y_test.shape}")


# --- Create DataLoaders ---
train_dataset = CMAPSS_MLP_Dataset(X_train, y_train)
val_dataset = CMAPSS_MLP_Dataset(X_val, y_val)
test_dataset = CMAPSS_MLP_Dataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # Use same batch size for consistency

# Create an example input tensor for pruning/flops calculation
example_input_tensor = torch.randn(1, INPUT_SIZE).to(DEVICE)

--- Preparing Training Data ---

--- Preparing Test Data ---
Columns with std < 0.02 (potential removal): ['op_setting_1', 'op_setting_2', 'op_setting_3', 'sensor_1', 'sensor_5', 'sensor_6', 'sensor_10', 'sensor_16', 'sensor_18', 'sensor_19']

Using Features: ['sensor_2', 'sensor_3', 'sensor_4', 'sensor_7', 'sensor_8', 'sensor_9', 'sensor_11', 'sensor_12', 'sensor_13', 'sensor_14', 'sensor_15', 'sensor_17', 'sensor_20', 'sensor_21']

--- Normalizing Data ---
MLP Input Size determined as: 14
Training data split: 16340 samples (80 engines)
Validation data split: 4291 samples (20 engines)
Prepared MLP data shapes:
X_train: (16340, 14), y_train: (16340,)
X_val: (4291, 14), y_val: (4291,)
X_test: (100, 14), y_test: (100,)


In [9]:
def plot_structural_metrics_comparison():
    pass
def plot_finetuning_curves():
    pass

### Main Pruning Workflow Execution

In [10]:
# Cell 9: Main Pruning Workflow Execution
import pickle
# --- Main Workflow ---
all_results = {}
structural_metrics_after_pruning = {}
fine_tuning_history = {}
initial_train_history = {} # Initialize as empty dict

# Define base paths for saving (without .pth extension)
initial_model_base_path = os.path.join(OUTPUT_DIR, "mlp_initial")
initial_history_path = os.path.join(OUTPUT_DIR, "mlp_initial_train_history.pkl") # Path for history
current_mlp_model = None

# --- 1. Initial Training ---
if not os.path.exists(initial_model_base_path + ".pth"):
    print("\n--- Training Initial MLP Model ---")
    current_mlp_model = MLPmodel(layer_units=MLP_HIDDEN_UNITS, input_size=INPUT_SIZE, dropout_rate=DROPOUT_RATE).to(DEVICE)
    # ... (criterion, optimizer, scheduler setup for initial training) ...
    criterion_initial = nn.MSELoss()
    optimizer_initial = torch.optim.Adam(current_mlp_model.parameters(), lr=INITIAL_LR, weight_decay=1e-5)
    scheduler_initial = ReduceLROnPlateau(optimizer_initial, mode='min', factor=0.5, patience=int(PATIENCE/2), verbose=True)


    current_mlp_model, train_hist, val_hist = train_mlp_model(
        model=current_mlp_model, train_loader=train_loader, val_loader=val_loader,
        criterion=criterion_initial, optimizer=optimizer_initial, scheduler=scheduler_initial,
        device=DEVICE, num_epochs=INITIAL_TRAIN_EPOCHS, patience=PATIENCE,
        model_save_path=initial_model_base_path + "_best_val_checkpoint.pth"
    )
    initial_train_history['train_loss'] = train_hist
    initial_train_history['val_loss'] = val_hist
    save_model_state(current_mlp_model, initial_model_base_path, example_input_for_onnx=example_input_tensor, device_for_onnx=DEVICE)
    # Save the training history
    with open(initial_history_path, 'wb') as f:
        pickle.dump(initial_train_history, f)
    print(f"Initial model training history saved to {initial_history_path}")

else:
    print(f"\n--- Loading Initial MLP Model from {initial_model_base_path}.pth ---")
    current_mlp_model = MLPmodel(layer_units=MLP_HIDDEN_UNITS, input_size=INPUT_SIZE, dropout_rate=DROPOUT_RATE).to(DEVICE)
    current_mlp_model = load_model_state(current_mlp_model, initial_model_base_path + ".pth", DEVICE)
    # Load the training history if it exists
    if os.path.exists(initial_history_path):
        with open(initial_history_path, 'rb') as f:
            initial_train_history = pickle.load(f)
        print(f"Initial model training history loaded from {initial_history_path}")
    else:
        print(f"Warning: Initial model training history file not found at {initial_history_path}. Will not be plotted.")
        initial_train_history = {} # Ensure it's an empty dict if not loaded

# --- 2. Evaluate Initial Model on Test Set ---
print("\n--- Evaluating Initial MLP Model on Test Set ---")
initial_metrics_on_test = evaluate_model_rmse(current_mlp_model, test_loader, DEVICE, example_input_tensor)
all_results['initial'] = initial_metrics_on_test # This will be used in the final comparison table
# Store initial structural metrics for the structural comparison plot
structural_metrics_after_pruning['initial'] = {
    'flops': initial_metrics_on_test['flops'],
    'params': initial_metrics_on_test['params'],
    'size_mb': initial_metrics_on_test['size_mb']
}


# --- 3. Pruning and Fine-tuning Loop ---
for strategy_name, strategy_details in pruning_strategies.items():
    print(f"\n\n{'='*20} STRATEGY: {strategy_name.upper()} {'='*20}")
    print(f"--- Preparing Model for Pruning with Strategy: {strategy_name} ---")

    # Load a fresh copy of the *initial trained model* for each pruning strategy
    model_for_this_strategy = MLPmodel(layer_units=MLP_HIDDEN_UNITS, input_size=INPUT_SIZE, dropout_rate=DROPOUT_RATE).to(DEVICE)
    model_for_this_strategy = load_model_state(model_for_this_strategy, initial_model_base_path + ".pth", DEVICE) # Load initial weights

    current_importance_metric = None # For stateful importance like Taylor
    if strategy_name == 'Taylor':
         current_importance_metric = tp.importance.TaylorImportance()
         # TaylorImportance requires gradients to be calculated before pruner.step() uses it.
         # This is handled inside the `prune_mlp_model` function.

    pruned_successfully_flag = False
    # Create a deep copy to prune, so model_for_this_strategy (the initial loaded one) remains untouched if pruning fails
    model_to_actually_prune = copy.deepcopy(model_for_this_strategy)

    print(f"--- Attempting Pruning with Strategy: {strategy_name} ---")
    try:
        pruned_model_output = prune_mlp_model( # This function modifies the model in-place
            model_to_prune=model_to_actually_prune, # Pass the copy
            example_input=example_input_tensor,
            strategy_name_for_debug=strategy_name,
            strategy=strategy_details,
            target_sparsity=PRUNING_TARGET_SPARSITY,
            iterative_steps=PRUNING_ITERATIVE_STEPS,
            importance = current_importance_metric # Pass stateful importance if created
        )
        # After successful pruning, model_to_actually_prune is the pruned model
        pruned_model_for_finetuning = model_to_actually_prune
        pruned_successfully_flag = True
        print(f"Pruning successful for {strategy_name}.")
    except Exception as e:
        print(f"!!!!!! CRITICAL PRUNING FAILURE for strategy {strategy_name}: {e} !!!!!!")
        # If pruning fails, we will fine-tune the original unpruned model for this strategy slot
        # to keep the workflow running and have some comparison point.
        print(f"Using UNPRUNED (initial) model for fine-tuning slot of strategy {strategy_name} due to pruning error.")
        pruned_model_for_finetuning = model_for_this_strategy # Use the original loaded model
        pruned_successfully_flag = False


    # Store structural metrics IMMEDIATELY AFTER PRUNING (or use initial if pruning failed)
    if pruned_successfully_flag:
        flops_post_prune, params_post_prune = calculate_flops_params(pruned_model_for_finetuning, example_input_tensor)
        size_mb_post_prune = params_post_prune * 4 / 1e6
        structural_metrics_after_pruning[strategy_name] = {
            'flops': flops_post_prune,
            'params': params_post_prune,
            'size_mb': size_mb_post_prune
        }
    else: # Pruning failed, so structural metrics are same as initial
        structural_metrics_after_pruning[strategy_name] = structural_metrics_after_pruning['initial']

    # Save the state of the model after pruning (even if it's the unpruned one due to failure)
    pruned_model_base_path = os.path.join(OUTPUT_DIR, f"mlp_{strategy_name}_pruned")
    save_model_state(pruned_model_for_finetuning, pruned_model_base_path) # Only .pth for intermediate

    # --- Fine-tune the (potentially) pruned model ---
    print(f"\n--- Fine-tuning MLP for Strategy: {strategy_name} ---")
    # Ensure all parameters that should be trained require gradients
    for param in pruned_model_for_finetuning.parameters():
        param.requires_grad = True

    optimizer_ft = torch.optim.Adam(pruned_model_for_finetuning.parameters(), lr=FINETUNE_LR, weight_decay=1e-5)
    scheduler_ft = ReduceLROnPlateau(optimizer_ft, mode='min', factor=0.5, patience=int(PATIENCE/2)-1 if PATIENCE>2 else 1, verbose=True) # Slightly less patience for scheduler than early stopping
    criterion_ft = nn.MSELoss()

    # The model passed (pruned_model_for_finetuning) will be modified in-place by train_mlp_model
    fine_tuned_model_instance, ft_train_hist, ft_val_hist = train_mlp_model(
        model=pruned_model_for_finetuning, # This model gets updated with best val weights
        train_loader=train_loader, val_loader=val_loader,
        criterion=criterion_ft, optimizer=optimizer_ft, scheduler=scheduler_ft,
        device=DEVICE, num_epochs=FINETUNE_EPOCHS, patience=PATIENCE,
        model_save_path=os.path.join(OUTPUT_DIR, f"mlp_{strategy_name}_finetune_best_val_checkpoint.pth")
    )
    fine_tuning_history[strategy_name] = {
        'train_loss': ft_train_hist,
        'val_loss': ft_val_hist,
    }

    # --- Evaluate the fine-tuned model on Test Set ---
    print(f"\n--- Evaluating Fine-tuned MLP ({strategy_name}) on Test Set ---")
    final_metrics_on_test = evaluate_model_rmse(fine_tuned_model_instance, test_loader, DEVICE, example_input_tensor)

    # Combine final test performance (RMSE) with its structural metrics (which were set after pruning)
    # The FLOPs/Params don't change during fine-tuning, so use from structural_metrics_after_pruning
    combined_final_metrics = {
        'rmse': final_metrics_on_test['rmse'], # RMSE from test set after fine-tuning
        **structural_metrics_after_pruning[strategy_name] # FLOPs, Params, size_mb from after pruning
    }
    all_results[strategy_name] = combined_final_metrics


    # Save the final fine-tuned model state (.pth) and its ONNX version
    final_model_base_path = os.path.join(OUTPUT_DIR, f"mlp_{strategy_name}_final")
    save_model_state(fine_tuned_model_instance, final_model_base_path, example_input_for_onnx=example_input_tensor, device_for_onnx=DEVICE)


# --- 4. Compare Final Test Results (Table for RMSE and post-pruning structural metrics) ---
print("\n\n--- Final Test RMSE and Structural Metrics (Post-Pruning) Comparison ---")
compare_results_and_plot_rmse(all_results, OUTPUT_DIR) # This uses all_results which contains RMSE and structural data

# --- 5. Plot Training/Fine-tuning History (Loss curves over epochs) ---
print("\n\n--- Plotting Fine-tuning Loss Curves ---")
if 'fine_tuning_history' in locals() and fine_tuning_history:
    plot_finetuning_curves(fine_tuning_history,
                           initial_history=(initial_train_history if initial_train_history else None), # Pass loaded history
                           output_dir=OUTPUT_DIR)
else:
    print("Fine-tuning history dictionary not found or is empty.")

# --- 6. Plot Structural Metrics (Bar Chart: FLOPs/Params Initial vs. Post-Pruning) ---
print("\n\n--- Plotting Structural Metrics (Initial vs. Post-Pruning) ---")
if 'structural_metrics_after_pruning' in locals() and structural_metrics_after_pruning:
    plot_structural_metrics_comparison(structural_metrics_after_pruning, OUTPUT_DIR)
else:
    print("Structural metrics (post-pruning) dictionary not found or is empty.")

print("\nMLP Pruning Workflow Completed!")


--- Training Initial MLP Model ---




Starting training on cuda with patience=15
Epoch 1/100: Train Loss=991.6038, Val Loss=571.1660, LR=1.0e-03
*** New best validation loss: 571.1660 (Epoch 1) ***
Epoch 2/100: Train Loss=665.4012, Val Loss=612.5370, LR=1.0e-03
Epoch 3/100: Train Loss=639.5186, Val Loss=1176.5614, LR=1.0e-03
Epoch 4/100: Train Loss=622.9357, Val Loss=707.1494, LR=1.0e-03
Epoch 5/100: Train Loss=622.5278, Val Loss=1188.2493, LR=1.0e-03
Epoch 6/100: Train Loss=580.2433, Val Loss=1490.3106, LR=1.0e-03
Epoch 7/100: Train Loss=601.4487, Val Loss=1185.8904, LR=1.0e-03
Epoch 8/100: Train Loss=565.4043, Val Loss=1016.8169, LR=1.0e-03
Epoch 9/100: Train Loss=551.8645, Val Loss=830.8966, LR=1.0e-03
Epoch 10/100: Train Loss=516.5370, Val Loss=1258.1280, LR=5.0e-04
Epoch 11/100: Train Loss=505.1975, Val Loss=1158.3758, LR=5.0e-04
Epoch 12/100: Train Loss=493.8528, Val Loss=861.8777, LR=5.0e-04
Epoch 13/100: Train Loss=486.1058, Val Loss=913.9163, LR=5.0e-04
Epoch 14/100: Train Loss=473.9388, Val Loss=1100.9368, LR=5.0




--- Model Structure AFTER Pruning (magnitude) ---
MLPmodel(
  (network): Sequential(
    (0): Linear(in_features=14, out_features=1638, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.3, inplace=False)
    (3): Linear(in_features=1638, out_features=819, bias=True)
    (4): ReLU()
    (5): Dropout(p=0.3, inplace=False)
    (6): Linear(in_features=819, out_features=819, bias=True)
    (7): ReLU()
    (8): Dropout(p=0.3, inplace=False)
    (9): Linear(in_features=819, out_features=409, bias=True)
    (10): ReLU()
    (11): Dropout(p=0.3, inplace=False)
    (12): Linear(in_features=409, out_features=409, bias=True)
    (13): ReLU()
    (14): Dropout(p=0.3, inplace=False)
    (15): Linear(in_features=409, out_features=204, bias=True)
    (16): ReLU()
    (17): Dropout(p=0.3, inplace=False)
    (18): Linear(in_features=204, out_features=102, bias=True)
    (19): ReLU()
    (20): Dropout(p=0.3, inplace=False)
    (21): Linear(in_features=102, out_features=1, bias=True)
  )
)
Pruning finished



Epoch 1/100: Train Loss=812.5706, Val Loss=1147.5005, LR=5.0e-04
*** New best validation loss: 1147.5005 (Epoch 1) ***
Epoch 2/100: Train Loss=628.3127, Val Loss=832.0639, LR=5.0e-04
*** New best validation loss: 832.0639 (Epoch 2) ***
Epoch 3/100: Train Loss=612.1355, Val Loss=775.6994, LR=5.0e-04
*** New best validation loss: 775.6994 (Epoch 3) ***
Epoch 4/100: Train Loss=598.3383, Val Loss=1029.1780, LR=5.0e-04
Epoch 5/100: Train Loss=574.3806, Val Loss=565.8034, LR=5.0e-04
*** New best validation loss: 565.8034 (Epoch 5) ***
Epoch 6/100: Train Loss=578.2769, Val Loss=816.6314, LR=5.0e-04
Epoch 7/100: Train Loss=554.2090, Val Loss=1215.0672, LR=5.0e-04
Epoch 8/100: Train Loss=566.1565, Val Loss=735.9275, LR=5.0e-04
Epoch 9/100: Train Loss=545.2213, Val Loss=799.6775, LR=5.0e-04
Epoch 10/100: Train Loss=543.8791, Val Loss=986.0833, LR=5.0e-04
Epoch 11/100: Train Loss=526.8757, Val Loss=1001.3114, LR=5.0e-04
Epoch 12/100: Train Loss=521.4647, Val Loss=846.0322, LR=5.0e-04
Epoch 13/100



Epoch 1/100: Train Loss=674.4255, Val Loss=1133.1097, LR=5.0e-04
*** New best validation loss: 1133.1097 (Epoch 1) ***
Epoch 2/100: Train Loss=612.9259, Val Loss=1038.8062, LR=5.0e-04
*** New best validation loss: 1038.8062 (Epoch 2) ***
Epoch 3/100: Train Loss=589.0846, Val Loss=1449.0496, LR=5.0e-04
Epoch 4/100: Train Loss=579.1713, Val Loss=895.7759, LR=5.0e-04
*** New best validation loss: 895.7759 (Epoch 4) ***
Epoch 5/100: Train Loss=565.6650, Val Loss=672.7230, LR=5.0e-04
*** New best validation loss: 672.7230 (Epoch 5) ***
Epoch 6/100: Train Loss=566.9346, Val Loss=920.5810, LR=5.0e-04
Epoch 7/100: Train Loss=546.0113, Val Loss=1489.1959, LR=5.0e-04
Epoch 8/100: Train Loss=524.8692, Val Loss=748.6324, LR=5.0e-04
Epoch 9/100: Train Loss=551.7423, Val Loss=633.0120, LR=5.0e-04
*** New best validation loss: 633.0120 (Epoch 9) ***
Epoch 10/100: Train Loss=527.6934, Val Loss=688.5154, LR=5.0e-04
Epoch 11/100: Train Loss=521.9492, Val Loss=1221.2100, LR=5.0e-04
Epoch 12/100: Train Lo



Epoch 1/100: Train Loss=726.1977, Val Loss=577.9471, LR=5.0e-04
*** New best validation loss: 577.9471 (Epoch 1) ***
Epoch 2/100: Train Loss=650.8271, Val Loss=833.5916, LR=5.0e-04
Epoch 3/100: Train Loss=613.5140, Val Loss=839.4975, LR=5.0e-04
Epoch 4/100: Train Loss=623.4720, Val Loss=1146.2618, LR=5.0e-04
Epoch 5/100: Train Loss=579.8646, Val Loss=817.7384, LR=5.0e-04
Epoch 6/100: Train Loss=580.7284, Val Loss=1099.1759, LR=5.0e-04
Epoch 7/100: Train Loss=577.9139, Val Loss=814.0399, LR=5.0e-04
Epoch 8/100: Train Loss=575.6259, Val Loss=1110.4190, LR=5.0e-04
Epoch 9/100: Train Loss=532.0131, Val Loss=1281.5012, LR=2.5e-04
Epoch 10/100: Train Loss=536.2497, Val Loss=1457.3895, LR=2.5e-04
Epoch 11/100: Train Loss=528.8344, Val Loss=1123.1528, LR=2.5e-04
Epoch 12/100: Train Loss=516.6119, Val Loss=1208.5332, LR=2.5e-04
Epoch 13/100: Train Loss=518.9981, Val Loss=1551.0431, LR=2.5e-04
Epoch 14/100: Train Loss=517.7538, Val Loss=1427.6987, LR=2.5e-04
Epoch 15/100: Train Loss=513.0484, Va

TypeError: plot_finetuning_curves() got an unexpected keyword argument 'initial_history'

### Plotting Epoch-wise fine tune history

In [11]:
import matplotlib.pyplot as plt # Ensure plt is imported here
import numpy as np            # Ensure np is imported here
import os                     # Ensure os is imported here
from typing import Dict, List, Optional # Ensure typing is imported here


def plot_finetuning_curves(history_dict: Dict[str, Dict[str, List[float]]],
                           initial_history: Optional[Dict[str, List[float]]] = None,
                           output_dir: str = "./"):
    """Plots training and validation loss curves from fine-tuning."""
    num_strategies_ft = len(history_dict) # Use a different variable name

    # Create a figure for each strategy's fine-tuning curves
    for strategy_name, history in history_dict.items():
        if not history or 'train_loss' not in history or 'val_loss' not in history:
            print(f"Skipping plot for {strategy_name}, missing history data.")
            continue

        plt.figure(figsize=(10, 6))
        epochs_ft = range(1, len(history['train_loss']) + 1)
        plt.plot(epochs_ft, history['train_loss'], label=f'{strategy_name} - Train Loss')
        plt.plot(epochs_ft, history['val_loss'], label=f'{strategy_name} - Val Loss')

        # Plot initial model's training curves if available and valid
        if initial_history and \
           initial_history.get('train_loss') and \
           initial_history.get('val_loss'): # Check keys and ensure not empty lists
             epochs_initial = range(1, len(initial_history['train_loss']) + 1)
             plt.plot(epochs_initial, initial_history['train_loss'], linestyle='--', color='gray', alpha=0.8, label='Initial Model - Train Loss')
             plt.plot(epochs_initial, initial_history['val_loss'], linestyle=':', color='darkgray', alpha=0.8, label='Initial Model - Val Loss')

        plt.xlabel("Epochs")
        plt.ylabel("Loss (MSE)")
        plt.title(f"Fine-tuning Loss: {strategy_name} Strategy")
        plt.legend()
        plt.grid(True)
        plt.yscale('log')
        plot_path = os.path.join(output_dir, f"mlp_finetuning_loss_{strategy_name}.png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved fine-tuning curve for {strategy_name} to {plot_path}")

    # Combined plot for all strategies' validation losses
    if num_strategies_ft > 0: # Only create combined plot if there's history
        plt.figure(figsize=(12, 7))
        plotted_something_combined = False
        for strategy_name, history in history_dict.items():
            if history and history.get('val_loss'): # Check key and ensure not empty
                epochs_ft = range(1, len(history['val_loss']) + 1)
                plt.plot(epochs_ft, history['val_loss'], label=f'{strategy_name} - Val Loss')
                plotted_something_combined = True

        if initial_history and \
           initial_history.get('val_loss'): # Check key and ensure not empty
            epochs_initial = range(1, len(initial_history['val_loss']) + 1)
            plt.plot(epochs_initial, initial_history['val_loss'], linestyle=':', color='black', linewidth=2, label='Initial Model - Val Loss')
            plotted_something_combined = True

        if plotted_something_combined:
            plt.xlabel("Epochs")
            plt.ylabel("Validation Loss (MSE)")
            plt.title("Comparison of Validation Loss During Training/Fine-tuning")
            plt.legend(loc='upper left', bbox_to_anchor=(1.05, 1.0)) # Adjusted for potential many lines
            plt.grid(True)
            plt.yscale('log')
            combined_plot_path = os.path.join(output_dir, "mlp_val_loss_comparison.png")
            plt.savefig(combined_plot_path, bbox_inches='tight')
            plt.close()
            print(f"Saved combined validation loss curves to {combined_plot_path}")
        else:
            print("No valid validation loss data found to create combined plot.")
    else:
        print("No fine-tuning history available to plot combined validation losses.")


# --- Call the plotting function for fine-tuning history ---
if 'fine_tuning_history' in locals() and fine_tuning_history: # Check if dict exists and is not empty
    plot_finetuning_curves(fine_tuning_history,
                           initial_history=(initial_train_history if 'initial_train_history' in locals() else None),
                           output_dir=OUTPUT_DIR)
else:
    print("Fine-tuning history dictionary not found or is empty.")




Saved fine-tuning curve for magnitude to ./output_mlp_pruning/fd001/mlp_finetuning_loss_magnitude.png
Saved fine-tuning curve for random to ./output_mlp_pruning/fd001/mlp_finetuning_loss_random.png
Saved fine-tuning curve for Taylor to ./output_mlp_pruning/fd001/mlp_finetuning_loss_Taylor.png
Saved fine-tuning curve for lamp to ./output_mlp_pruning/fd001/mlp_finetuning_loss_lamp.png
Saved combined validation loss curves to ./output_mlp_pruning/fd001/mlp_val_loss_comparison.png


### --- Plotting Structural Metrics (Bar Chart for FLOPs/Params after pruning) ---

In [12]:

def plot_structural_metrics_comparison(structural_metrics: Dict[str, Dict[str, float]], output_dir: str):
    if not structural_metrics or 'initial' not in structural_metrics: # Ensure initial exists
        print("No structural metrics to plot, or 'initial' metrics are missing.")
        return

    # Ensure 'initial' is first for consistent plotting order
    labels = ['initial'] + [s for s in structural_metrics.keys() if s != 'initial']

    # Filter out strategies not present in structural_metrics for robustness
    labels = [s for s in labels if s in structural_metrics]
    if not labels:
        print("No valid strategies found in structural metrics for plotting.")
        return

    flops_values = [structural_metrics[s].get('flops', 0) / 1e6 for s in labels]  # MFLOPs
    params_values = [structural_metrics[s].get('params', 0) / 1e6 for s in labels] # MParams

    x = np.arange(len(labels))
    width = 0.35

    fig, ax1 = plt.subplots(figsize=(14, 8)) # Increased figure size slightly

    color_flops = 'tab:blue'
    ax1.set_xlabel('Model State (Initial vs. Post-Pruning, Pre-Finetuning)')
    ax1.set_ylabel('FLOPs (Millions)', color=color_flops)
    bars1 = ax1.bar(x - width/2, flops_values, width, label='FLOPs', color=color_flops)
    ax1.tick_params(axis='y', labelcolor=color_flops)
    ax1.set_xticks(x)
    ax1.set_xticklabels(labels, rotation=45, ha="right")

    ax2 = ax1.twinx()
    color_params = 'tab:red'
    ax2.set_ylabel('Parameters (Millions)', color=color_params)
    bars2 = ax2.bar(x + width/2, params_values, width, label='Parameters', color=color_params)
    ax2.tick_params(axis='y', labelcolor=color_params)

    def autolabel(bars, ax):
        for bar in bars:
            height = bar.get_height()
            if not np.isnan(height): # Check for NaN before annotating
                ax.annotate(f'{height:.2f}',
                            xy=(bar.get_x() + bar.get_width() / 2, height),
                            xytext=(0, 3),
                            textcoords="offset points",
                            ha='center', va='bottom', fontsize=9) # Increased font size slightly
    autolabel(bars1, ax1)
    autolabel(bars2, ax2)

    fig.suptitle('Structural Metrics: Initial vs. Post-Pruning (Before Fine-tuning)', fontsize=14) # Changed title position
    fig.tight_layout(rect=[0, 0.1, 1, 0.96]) # Adjust layout to make space for suptitle and legend

    # Combined legend
    handles1, labels1_fig = ax1.get_legend_handles_labels()
    handles2, labels2_fig = ax2.get_legend_handles_labels()
    fig.legend(handles1 + handles2, labels1_fig + labels2_fig, loc='upper center', bbox_to_anchor=(0.5, 0.07), ncol=2)


    structural_plot_path = os.path.join(output_dir, "mlp_structural_metrics_after_pruning.png")
    plt.savefig(structural_plot_path, bbox_inches='tight')
    plt.close()
    print(f"Saved structural metrics plot to {structural_plot_path}")

#Ensure this is called at the end of your main workflow cell:
if 'structural_metrics_after_pruning' in locals() and structural_metrics_after_pruning:
    plot_structural_metrics_comparison(structural_metrics_after_pruning, OUTPUT_DIR)
else:
    print("Structural metrics (post-pruning) dictionary not found or is empty.")

Saved structural metrics plot to ./output_mlp_pruning/fd001/mlp_structural_metrics_after_pruning.png
