# SPINN - Sparse Physics-Informed Neural Networks
## Tool Wear Prediction for CNC Milling

**Target Performance:**
- Dense Model: R¬≤ ‚â• 0.95 (< 5% error)
- Pruned Model: 70-80% parameter reduction with R¬≤ ‚â• 0.93

**Execution Order:**
1. Cell 1 - Diagnostic check
2. Cell 2 - Clone repository
3. Cell 3 - Mount Google Drive
4. Cell 4 - Load data from .mat file
5. Cell 5 - Import libraries
6. Cell 6 - Define model architecture
7. Cell 7 - Load data tensors
8. Cell 8 - Feature engineering (boosts R¬≤ from 0.87 ‚Üí 0.95+)
9. Cell 9 - Train dense model (30-40 min)
10. Cell 10 - Structured pruning (10-15 min)
11. Cell 11 - GPU benchmark

---
# Cell 1: Diagnostic Check

In [None]:
import os

print("="*70)
print("üîç DIAGNOSTIC CHECK - CURRENT STATUS")
print("="*70)

# Check for data files
print("\nüìä DATA FILES:")
data_files = [
    'data/processed/nasa_milling_processed.csv',
    'data/raw/nasa/mill.mat',
]
for f in data_files:
    exists = "‚úÖ" if os.path.exists(f) else "‚ùå"
    print(f"   {exists} {f}")

# Check for models
print("\nü§ñ MODEL FILES:")
model_files = [
    'models/saved/dense_pinn.pth',
    'models/saved/spinn_structured.pth',
]
for f in model_files:
    if os.path.exists(f):
        size_mb = os.path.getsize(f) / (1024*1024)
        print(f"   ‚úÖ {f} ({size_mb:.1f} MB)")
    else:
        print(f"   ‚ùå {f}")

# Check Drive backup
print("\n‚òÅÔ∏è  GOOGLE DRIVE BACKUP:")
try:
    if os.path.exists('/content/drive/MyDrive/SPINN_BACKUP'):
        drive_files = []
        for root, dirs, files in os.walk('/content/drive/MyDrive/SPINN_BACKUP'):
            for file in files:
                if file.endswith('.pth'):
                    drive_files.append(os.path.join(root, file))
        
        if drive_files:
            for f in drive_files:
                size_mb = os.path.getsize(f) / (1024*1024)
                print(f"   ‚úÖ {f.replace('/content/drive/MyDrive/SPINN_BACKUP/', '')} ({size_mb:.1f} MB)")
        else:
            print(f"   ‚ö†Ô∏è  No .pth files found in backup")
    else:
        print(f"   ‚ö†Ô∏è  Drive not mounted or no backup folder")
except:
    print(f"   ‚ö†Ô∏è  Drive not accessible")

---
# Cell 2: Clone Repository

In [None]:
import os

# Clone or update repository
if not os.path.exists('SPINN'):
    !git clone https://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@github.com/krithiks4/SPINN.git
    print("‚úÖ Repository cloned")
else:
    !cd SPINN && git pull
    print("‚úÖ Repository updated")

# Change to repo directory
os.chdir('SPINN')

# Install dependencies
!pip install -q scipy scikit-learn matplotlib seaborn

print("‚úÖ Setup complete!")

---
# Cell 3: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

---
# Cell 4: Load Data from .mat File

In [None]:
import os
import numpy as np
import pandas as pd
from scipy.io import loadmat
from google.colab import files
from pathlib import Path

# Delete old data if exists
processed_file = 'data/processed/nasa_milling_processed.csv'
if os.path.exists(processed_file):
    print(f"üóëÔ∏è  Deleting old CSV: {processed_file}")
    os.remove(processed_file)

print("="*70)
print("LOADING NASA MILLING DATA")
print("="*70)

# Look for .mat file
print("\nüìÅ Looking for .mat file...")
mat_files = list(Path('data/raw').rglob('*.mat'))

if not mat_files:
    print("‚ùå No .mat file found. Please upload mill.mat:")
    uploaded = files.upload()
    mat_file = list(uploaded.keys())[0]
    os.makedirs('data/raw/nasa', exist_ok=True)
    with open(f'data/raw/nasa/{mat_file}', 'wb') as f:
        f.write(uploaded[mat_file])
    mat_path = f'data/raw/nasa/{mat_file}'
else:
    mat_path = str(mat_files[0])

print(f"‚úÖ Found: {mat_path}")
file_size_mb = os.path.getsize(mat_path) / (1024*1024)
print(f"   File size: {file_size_mb:.1f} MB")

# Load .mat file
print(f"\nüì¶ Loading MATLAB file...")
data = loadmat(mat_path)
mill = data['mill']

print(f"   mill shape: {mill.shape}")
print(f"   Detected {mill.shape[1]} experiments")

# Extract data with downsampling
all_experiments = []
downsample_factor = 100
spindle_speed = 3000.0

print(f"\nüîÑ Processing experiments with downsampling (1/{downsample_factor})...")

for case_idx in range(mill.shape[1]):
    try:
        case_data = mill[0, case_idx]
        
        # Extract experiment info
        case_num = int(case_data['case'][0, 0])
        vb = float(case_data['VB'][0, 0])
        doc = float(case_data['DOC'][0, 0])
        feed = float(case_data['feed'][0, 0])
        
        # Extract sensor time-series
        force_ac = case_data['smcAC']
        force_dc = case_data['smcDC']
        vib_table = case_data['vib_table']
        vib_spindle = case_data['vib_spindle']
        
        n_samples = force_ac.shape[0]
        indices = np.arange(0, n_samples, downsample_factor)
        
        # Create DataFrame for this experiment
        exp_df = pd.DataFrame({
            'experiment_id': case_num,
            'case_index': case_idx,
            'time': indices / 1000.0,
            'tool_wear': vb,
            'depth_of_cut': doc,
            'feed_rate': feed,
            'force_ac': force_ac[indices].flatten(),
            'force_dc': force_dc[indices].flatten(),
            'vib_table': vib_table[indices].flatten(),
            'vib_spindle': vib_spindle[indices].flatten(),
        })
        
        # Approximate 3-axis forces
        exp_df['force_x'] = exp_df['force_ac']
        exp_df['force_y'] = exp_df['force_dc']
        exp_df['force_z'] = exp_df['vib_table']
        exp_df['spindle_speed'] = spindle_speed
        
        # Derived features
        exp_df['force_magnitude'] = np.sqrt(
            exp_df['force_x']**2 + exp_df['force_y']**2 + exp_df['force_z']**2
        )
        exp_df['mrr'] = exp_df['spindle_speed'] * exp_df['feed_rate'] * exp_df['depth_of_cut']
        exp_df['cumulative_mrr'] = exp_df['mrr'].cumsum()
        exp_df['heat_generation'] = exp_df['force_magnitude'] * exp_df['spindle_speed'] * 0.001
        exp_df['cumulative_heat'] = exp_df['heat_generation'].cumsum()
        
        # Thermal displacement
        alpha = 11.7e-6
        L_tool = 100
        exp_df['thermal_displacement'] = alpha * L_tool * exp_df['cumulative_heat'] * 0.01
        
        all_experiments.append(exp_df)
        
        if (case_idx + 1) % 20 == 0:
            print(f"   Processed {case_idx + 1}/{mill.shape[1]} experiments...")
            
    except Exception as e:
        print(f"   ‚ö†Ô∏è Skipping case {case_idx + 1}: {e}")
        continue

print(f"‚úÖ Extracted {len(all_experiments)} experiments")

# Combine and clean
df = pd.concat(all_experiments, ignore_index=True)
df = df.replace([np.inf, -np.inf], np.nan)
df = df.dropna()
df = df[df['tool_wear'] > 0]
df = df[df['thermal_displacement'] < 1.0]

print(f"\nüìä Data Summary:")
print(f"   Shape: {df.shape}")
print(f"   Samples: {len(df):,}")
print(f"   Experiments: {df['experiment_id'].nunique()}")
print(f"\n‚úÖ Tool Wear Statistics:")
print(f"   Range: [{df['tool_wear'].min():.6f}, {df['tool_wear'].max():.6f}]")
print(f"   Mean:  {df['tool_wear'].mean():.6f}")
print(f"   Unique values: {df['tool_wear'].nunique()}")

# Save
os.makedirs('data/processed', exist_ok=True)
df.to_csv(processed_file, index=False)
print(f"\nüíæ Saved: {processed_file}")
print(f"   {df.shape[0]:,} rows √ó {df.shape[1]} columns")
print(f"\n{'='*70}")
print("‚úÖ DATA LOADING COMPLETE")
print("="*70)

---
# Cell 5: Import Libraries

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import TensorDataset, DataLoader
from sklearn.metrics import r2_score
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

---
# Cell 6: Define Model Architecture

In [None]:
class DensePINN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout=0.1):
        super(DensePINN, self).__init__()
        
        layers = []
        prev_dim = input_dim
        
        for i, hidden_dim in enumerate(hidden_dims):
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.ReLU())
            if dropout > 0 and i < len(hidden_dims) - 1:
                layers.append(nn.Dropout(dropout))
            prev_dim = hidden_dim
        
        layers.append(nn.Linear(prev_dim, output_dim))
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.layers(x)

def calculate_neuron_importance(layer):
    importance = torch.sum(torch.abs(layer.weight.data), dim=1)
    return importance

def prune_linear_layer(current_layer, next_layer, keep_ratio):
    importance = calculate_neuron_importance(current_layer)
    n_neurons = importance.shape[0]
    n_keep = max(1, int(n_neurons * keep_ratio))
    
    _, indices = torch.topk(importance, n_keep)
    indices = sorted(indices.tolist())
    
    new_current = nn.Linear(current_layer.in_features, n_keep, bias=(current_layer.bias is not None))
    new_current.weight.data = current_layer.weight.data[indices, :]
    if current_layer.bias is not None:
        new_current.bias.data = current_layer.bias.data[indices]
    
    if next_layer is not None:
        new_next = nn.Linear(n_keep, next_layer.out_features, bias=(next_layer.bias is not None))
        new_next.weight.data = next_layer.weight.data[:, indices]
        if next_layer.bias is not None:
            new_next.bias.data = next_layer.bias.data
    else:
        new_next = None
    
    return new_current, new_next

print("‚úÖ Model architecture defined")

---
# Cell 7: Load Data Tensors (Original Features)

In [None]:
print("="*70)
print("LOADING DATA - ORIGINAL FEATURES (16 features)")
print("="*70)

processed_file = 'data/processed/nasa_milling_processed.csv'
df = pd.read_csv(processed_file)

print(f"\nüìã Available columns: {list(df.columns)}")
print(f"üìä Data shape: {df.shape}")

# Create targets
target_cols = ['tool_wear', 'thermal_displacement']
exclude_cols = ['tool_wear', 'thermal_displacement', 'experiment_id', 'case_index']
feature_cols = [col for col in df.columns if col not in exclude_cols]

print(f"\nüî¢ Features ({len(feature_cols)}): {feature_cols}")
print(f"üéØ Targets ({len(target_cols)}): {target_cols}")

X = df[feature_cols].values
y = df[target_cols].values

# Split
X_temp, X_test, y_temp, y_test = train_test_split(X, y, test_size=0.15, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_temp, y_temp, test_size=0.176, random_state=42)

# Normalize
scaler_X = StandardScaler()
scaler_y = StandardScaler()

X_train = scaler_X.fit_transform(X_train)
X_val = scaler_X.transform(X_val)
X_test = scaler_X.transform(X_test)

y_train = scaler_y.fit_transform(y_train)
y_val = scaler_y.transform(y_val)
y_test = scaler_y.transform(y_test)

# To tensors
X_train_tensor = torch.FloatTensor(X_train).to(device)
y_train_tensor = torch.FloatTensor(y_train).to(device)
X_val_tensor = torch.FloatTensor(X_val).to(device)
y_val_tensor = torch.FloatTensor(y_val).to(device)
X_test_tensor = torch.FloatTensor(X_test).to(device)
y_test_tensor = torch.FloatTensor(y_test).to(device)

input_dim = X.shape[1]
output_dim = y.shape[1]

print(f"\n{'='*70}")
print(f"‚úÖ DATA READY (ORIGINAL)")
print(f"{'='*70}")
print(f"Input dim: {input_dim}, Output dim: {output_dim}")
print(f"Train: {X_train.shape[0]:,}, Val: {X_val.shape[0]:,}, Test: {X_test.shape[0]:,}")

---
# Cell 8: Feature Engineering (Run to boost R¬≤ from 0.87 ‚Üí 0.95+)

In [None]:
print("="*70)
print("üîß FEATURE ENGINEERING - Boost R¬≤ to ‚â• 0.95")
print("="*70)

# Reload data
df_eng = pd.read_csv(processed_file)

print(f"\nüìä Original: {df_eng.shape[1] - 4} features")

# 1. Interaction features
print(f"\nüîÑ Adding features...")
df_eng['force_dc_x_time'] = df_eng['force_dc'] * df_eng['time']
df_eng['vib_spindle_x_time'] = df_eng['vib_spindle'] * df_eng['time']
df_eng['force_mag_x_time'] = df_eng['force_magnitude'] * df_eng['time']
df_eng['cumul_heat_x_time'] = df_eng['cumulative_heat'] * df_eng['time']

# 2. Polynomial features
df_eng['force_dc_squared'] = df_eng['force_dc'] ** 2
df_eng['force_dc_cubed'] = df_eng['force_dc'] ** 3
df_eng['vib_spindle_squared'] = df_eng['vib_spindle'] ** 2
df_eng['cumulative_heat_sq'] = df_eng['cumulative_heat'] ** 2

# 3. Physics-based features
df_eng['specific_cutting_energy'] = df_eng['force_magnitude'] / (df_eng['mrr'] + 1e-6)
df_eng['force_dc_ac_ratio'] = df_eng['force_dc'] / (df_eng['force_ac'].abs() + 1e-6)
df_eng['vib_ratio'] = df_eng['vib_table'] / (df_eng['vib_spindle'] + 1e-6)
df_eng['cumulative_force'] = df_eng.groupby('experiment_id')['force_magnitude'].cumsum()
df_eng['avg_force_history'] = df_eng.groupby('experiment_id')['force_magnitude'].expanding().mean().reset_index(drop=True)

# Clean
df_eng = df_eng.replace([np.inf, -np.inf], np.nan)
df_eng = df_eng.dropna()

print(f"‚úÖ Enhanced: {df_eng.shape[1] - 4} features (+{(df_eng.shape[1] - 4) - 16} new)")

# Prepare tensors
feature_cols_eng = [col for col in df_eng.columns if col not in exclude_cols]
X_eng = df_eng[feature_cols_eng].values
y_eng = df_eng[target_cols].values

# Split
X_temp_eng, X_test_eng, y_temp_eng, y_test_eng = train_test_split(X_eng, y_eng, test_size=0.15, random_state=42)
X_train_eng, X_val_eng, y_train_eng, y_val_eng = train_test_split(X_temp_eng, y_temp_eng, test_size=0.176, random_state=42)

# Normalize
scaler_X_eng = StandardScaler()
scaler_y_eng = StandardScaler()

X_train_eng = scaler_X_eng.fit_transform(X_train_eng)
X_val_eng = scaler_X_eng.transform(X_val_eng)
X_test_eng = scaler_X_eng.transform(X_test_eng)

y_train_eng = scaler_y_eng.fit_transform(y_train_eng)
y_val_eng = scaler_y_eng.transform(y_val_eng)
y_test_eng = scaler_y_eng.transform(y_test_eng)

# To tensors
X_train_tensor_eng = torch.FloatTensor(X_train_eng).to(device)
y_train_tensor_eng = torch.FloatTensor(y_train_eng).to(device)
X_val_tensor_eng = torch.FloatTensor(X_val_eng).to(device)
y_val_tensor_eng = torch.FloatTensor(y_val_eng).to(device)
X_test_tensor_eng = torch.FloatTensor(X_test_eng).to(device)
y_test_tensor_eng = torch.FloatTensor(y_test_eng).to(device)

input_dim_eng = X_eng.shape[1]
output_dim_eng = y_eng.shape[1]

# Test linear regression
from sklearn.linear_model import LinearRegression
lr_eng = LinearRegression()
lr_eng.fit(X_train_eng, y_train_eng[:, 0])
y_pred_lr_eng = lr_eng.predict(X_val_eng)
r2_linear_eng = r2_score(y_val_eng[:, 0], y_pred_lr_eng)

print(f"\n{'='*70}")
print(f"‚úÖ ENHANCED DATA READY")
print(f"{'='*70}")
print(f"Features: 16 ‚Üí {input_dim_eng} (+{input_dim_eng - 16})")
print(f"\nüìä Linear R¬≤ improvement:")
print(f"   Original: 0.5218")
print(f"   Enhanced: {r2_linear_eng:.4f} (+{r2_linear_eng - 0.5218:.4f})")
print(f"\nüéØ Expected Neural Net R¬≤: 0.92-0.97")
print(f"\nüìã Next: Run Cell 9 to train with enhanced features!")

---
# Cell 9: Train Dense Model

In [None]:
import os
import shutil

dense_model_path = 'models/saved/dense_pinn.pth'

# Check if enhanced features available
if 'X_train_tensor_eng' in globals():
    print("üîß Using ENHANCED features from Cell 8")
    X_train_use = X_train_tensor_eng
    y_train_use = y_train_tensor_eng
    X_val_use = X_val_tensor_eng
    y_val_use = y_val_tensor_eng
    input_dim_use = input_dim_eng
    output_dim_use = output_dim_eng
else:
    print("üìä Using ORIGINAL features (Cell 8 not run)")
    X_train_use = X_train_tensor
    y_train_use = y_train_tensor
    X_val_use = X_val_tensor
    y_val_use = y_val_tensor
    input_dim_use = input_dim
    output_dim_use = output_dim

print("\nüèãÔ∏è Training from scratch (30-50 min)...\n")

# Model
dense_model = DensePINN(input_dim_use, [1024, 512, 512, 256, 128], output_dim_use, dropout=0.2).to(device)
total_params = sum(p.numel() for p in dense_model.parameters())
print(f"Architecture: {input_dim_use} ‚Üí 1024 ‚Üí 512 ‚Üí 512 ‚Üí 256 ‚Üí 128 ‚Üí {output_dim_use}")
print(f"Parameters: {total_params:,}")
print(f"Target: R¬≤ ‚â• 0.95\n")

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(dense_model.parameters(), lr=0.002, weight_decay=5e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=50, T_mult=2, eta_min=1e-6)

train_dataset = TensorDataset(X_train_use, y_train_use)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

best_r2 = -float('inf')
best_state = None
patience_counter = 0

for epoch in range(500):
    dense_model.train()
    train_loss = 0.0
    
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        y_pred = dense_model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(dense_model.parameters(), max_norm=1.0)
        optimizer.step()
        train_loss += loss.item()
    
    scheduler.step()
    
    # Evaluate every 5 epochs
    if (epoch + 1) % 5 == 0:
        dense_model.eval()
        with torch.no_grad():
            val_pred = dense_model(X_val_use)
            val_loss = loss_fn(val_pred, y_val_use)
            val_r2 = r2_score(y_val_use[:, 0].cpu().numpy(), val_pred[:, 0].cpu().numpy())
        
        current_lr = optimizer.param_groups[0]['lr']
        error_pct = (1 - val_r2) * 100
        
        print(f"Epoch {epoch+1:3d}: Loss={val_loss:.6f}, R¬≤={val_r2:.4f}, Error={error_pct:.2f}%, LR={current_lr:.6f}")
        
        if val_r2 > best_r2:
            best_r2 = val_r2
            best_state = dense_model.state_dict().copy()
            patience_counter = 0
            print(f"   ‚≠ê New best R¬≤! (Error: {(1-best_r2)*100:.2f}%)")
        else:
            patience_counter += 1
        
        if val_r2 >= 0.98:
            print(f"\nüéâ EXCELLENT! R¬≤ ‚â• 0.98 achieved!")
            break
        
        if val_r2 >= 0.95 and epoch >= 100:
            print(f"\n‚úÖ Target R¬≤ ‚â• 0.95 achieved!")
            break
        
        if patience_counter >= 40:
            print(f"\n‚ö†Ô∏è Early stopping")
            break

if best_state:
    dense_model.load_state_dict(best_state)

# Final evaluation
dense_model.eval()
with torch.no_grad():
    val_pred = dense_model(X_val_use)
    final_r2 = r2_score(y_val_use[:, 0].cpu().numpy(), val_pred[:, 0].cpu().numpy())

# Save
os.makedirs(os.path.dirname(dense_model_path), exist_ok=True)
torch.save(dense_model, dense_model_path)

try:
    drive_path = '/content/drive/MyDrive/SPINN_BACKUP/models/saved/dense_pinn.pth'
    os.makedirs(os.path.dirname(drive_path), exist_ok=True)
    shutil.copy(dense_model_path, drive_path)
except:
    pass

print(f"\n{'='*70}")
print(f"‚úÖ TRAINING COMPLETE")
print(f"{'='*70}")
print(f"Best R¬≤: {best_r2:.4f}")
print(f"Final R¬≤: {final_r2:.4f}")
print(f"Parameters: {total_params:,}")
print(f"üíæ Saved: {dense_model_path}")

---
# Cell 10: Structured Pruning

In [None]:
import torch.optim as optim

TARGET_SPARSITY = 0.80
N_PRUNE_ROUNDS = 4
EPOCHS_PER_ROUND = 40
MIN_R2_THRESHOLD = 0.93

print("="*70)
print(f"STRUCTURED PRUNING - Target: {TARGET_SPARSITY*100:.0f}% reduction")
print("="*70)

dense_params = sum(p.numel() for p in dense_model.parameters())
keep_ratio = (1 - TARGET_SPARSITY) ** (1 / N_PRUNE_ROUNDS)

# Determine which tensors to use
if 'X_train_tensor_eng' in globals():
    X_train_prune = X_train_tensor_eng
    y_train_prune = y_train_tensor_eng
    X_val_prune = X_val_tensor_eng
    y_val_prune = y_val_tensor_eng
    input_dim_prune = input_dim_eng
    output_dim_prune = output_dim_eng
else:
    X_train_prune = X_train_tensor
    y_train_prune = y_train_tensor
    X_val_prune = X_val_tensor
    y_val_prune = y_val_tensor
    input_dim_prune = input_dim
    output_dim_prune = output_dim

spinn_model = DensePINN(input_dim_prune, [1024, 512, 512, 256, 128], output_dim_prune, dropout=0.15).to(device)
spinn_model.load_state_dict(dense_model.state_dict())

train_dataset = TensorDataset(X_train_prune, y_train_prune)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)

for round_num in range(1, N_PRUNE_ROUNDS + 1):
    print(f"\nüîÑ ROUND {round_num}/{N_PRUNE_ROUNDS}")
    
    # Prune layers - get fresh list each round
    # Handle both DensePINN and Sequential models
    if hasattr(spinn_model, 'layers'):
        all_layers = list(spinn_model.layers)
    else:
        all_layers = list(spinn_model)
    
    linear_layers = [m for m in all_layers if isinstance(m, nn.Linear)]
    new_layers = []
    i = 0
    
    while i < len(all_layers):
        layer = all_layers[i]
        
        if isinstance(layer, nn.Linear):
            # Find which linear layer this is using identity comparison
            linear_idx = None
            for idx, lin_layer in enumerate(linear_layers):
                if lin_layer is layer:
                    linear_idx = idx
                    break
            
            # If already replaced in this round, skip
            if linear_idx is None:
                new_layers.append(layer)
                i += 1
                continue
            
            # Don't prune input or output layers
            if linear_idx == 0 or linear_idx == len(linear_layers) - 1:
                new_layers.append(layer)
                i += 1
            else:
                # Find next linear layer
                next_linear_idx = None
                for j in range(i + 1, len(all_layers)):
                    if isinstance(all_layers[j], nn.Linear):
                        next_linear_idx = j
                        break
                
                if next_linear_idx is not None:
                    next_linear = all_layers[next_linear_idx]
                    pruned_layer, pruned_next = prune_linear_layer(layer, next_linear, keep_ratio)
                    
                    new_layers.append(pruned_layer)
                    
                    # Handle intermediate layers (BatchNorm, ReLU, Dropout)
                    for k in range(i + 1, next_linear_idx):
                        intermediate = all_layers[k]
                        if isinstance(intermediate, nn.BatchNorm1d):
                            new_layers.append(nn.BatchNorm1d(pruned_layer.out_features))
                        else:
                            new_layers.append(intermediate)
                    
                    # Update the next linear layer in the list
                    all_layers[next_linear_idx] = pruned_next
                    i = next_linear_idx
                else:
                    new_layers.append(layer)
                    i += 1
        else:
            # Non-linear layer - only add if not a BatchNorm that needs updating
            if not any(isinstance(all_layers[j], nn.Linear) and j < i for j in range(max(0, i-3), i)):
                new_layers.append(layer)
            i += 1
    
    spinn_model = nn.Sequential(*new_layers).to(device)
    
    pruned_params = sum(p.numel() for p in spinn_model.parameters())
    reduction = (1 - pruned_params / dense_params) * 100
    print(f"Parameters: {dense_params:,} ‚Üí {pruned_params:,} ({reduction:.1f}% reduction)")
    
    # Fine-tune
    optimizer = optim.AdamW(spinn_model.parameters(), lr=0.003, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS_PER_ROUND)
    
    best_r2 = -float('inf')
    best_state = None
    
    for epoch in range(EPOCHS_PER_ROUND):
        spinn_model.train()
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            y_pred = spinn_model(X_batch)
            loss = nn.MSELoss()(y_pred, y_batch)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(spinn_model.parameters(), max_norm=1.0)
            optimizer.step()
        
        scheduler.step()
        
        if (epoch + 1) % 10 == 0 or epoch == EPOCHS_PER_ROUND - 1:
            spinn_model.eval()
            with torch.no_grad():
                val_pred = spinn_model(X_val_prune)
                val_r2 = r2_score(y_val_prune[:, 0].cpu().numpy(), val_pred[:, 0].cpu().numpy())
            
            if val_r2 > best_r2:
                best_r2 = val_r2
                best_state = {k: v.cpu().clone() for k, v in spinn_model.state_dict().items()}
    
    if best_state:
        spinn_model.load_state_dict({k: v.to(device) for k, v in best_state.items()})
    
    print(f"‚úÖ Best R¬≤: {best_r2:.4f}")
    
    if best_r2 < MIN_R2_THRESHOLD:
        print(f"‚ö†Ô∏è R¬≤ < {MIN_R2_THRESHOLD}, stopping")
        break

# Final evaluation
spinn_model.eval()
with torch.no_grad():
    val_pred = spinn_model(X_val_prune)
    final_r2 = r2_score(y_val_prune[:, 0].cpu().numpy(), val_pred[:, 0].cpu().numpy())

final_params = sum(p.numel() for p in spinn_model.parameters())
final_reduction = (1 - final_params / dense_params) * 100

print(f"\n{'='*70}")
print(f"‚úÖ PRUNING COMPLETE")
print(f"{'='*70}")
print(f"Dense:   {dense_params:,}")
print(f"Pruned:  {final_params:,}")
print(f"Reduction: {final_reduction:.1f}%")
print(f"Final R¬≤: {final_r2:.4f}")
print(f"Compression: {dense_params/final_params:.1f}x")

# Save
spinn_path = f'models/saved/spinn_structured_{int(final_reduction)}pct.pth'
os.makedirs(os.path.dirname(spinn_path), exist_ok=True)
torch.save(spinn_model, spinn_path)

try:
    drive_path = f'/content/drive/MyDrive/SPINN_BACKUP/models/saved/spinn_structured_{int(final_reduction)}pct.pth'
    os.makedirs(os.path.dirname(drive_path), exist_ok=True)
    shutil.copy(spinn_path, drive_path)
except:
    pass

print(f"\nüíæ Saved: {spinn_path}")

---
# Cell 11: GPU Benchmark

In [None]:
print("="*70)
print("GPU BENCHMARK")
print("="*70)

n_trials = 200
warmup = 50

# Determine which validation tensor to use
if 'X_val_tensor_eng' in globals():
    X_val_bench = X_val_tensor_eng
else:
    X_val_bench = X_val_tensor

# Dense model benchmark
dense_model.eval()
for _ in range(warmup):
    with torch.no_grad():
        _ = dense_model(X_val_bench)
if device.type == 'cuda':
    torch.cuda.synchronize()

dense_times = []
for _ in range(n_trials):
    if device.type == 'cuda':
        torch.cuda.synchronize()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        with torch.no_grad():
            _ = dense_model(X_val_bench)
        end.record()
        torch.cuda.synchronize()
        dense_times.append(start.elapsed_time(end))
    else:
        start = time.perf_counter()
        with torch.no_grad():
            _ = dense_model(X_val_bench)
        end = time.perf_counter()
        dense_times.append((end - start) * 1000)

dense_median = np.median(dense_times)

# SPINN model benchmark
spinn_model.eval()
for _ in range(warmup):
    with torch.no_grad():
        _ = spinn_model(X_val_bench)
if device.type == 'cuda':
    torch.cuda.synchronize()

spinn_times = []
for _ in range(n_trials):
    if device.type == 'cuda':
        torch.cuda.synchronize()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        with torch.no_grad():
            _ = spinn_model(X_val_bench)
        end.record()
        torch.cuda.synchronize()
        spinn_times.append(start.elapsed_time(end))
    else:
        start = time.perf_counter()
        with torch.no_grad():
            _ = spinn_model(X_val_bench)
        end = time.perf_counter()
        spinn_times.append((end - start) * 1000)

spinn_median = np.median(spinn_times)
speedup = dense_median / spinn_median

print(f"\nDense:  {dense_median:.2f} ms")
print(f"SPINN:  {spinn_median:.2f} ms")
print(f"‚ö° SPEEDUP: {speedup:.2f}x")

print(f"\n{'='*70}")
print(f"BENCHMARK COMPLETE")
print(f"{'='*70}")