# Testing KAN (Kolmogorov-Arnold Networks) Model

This notebook demonstrates the use of Kolmogorov-Arnold Networks (KAN) for regression on wire drawing data.

KAN is a neural network architecture that leverages the Kolmogorov-Arnold representation theorem to approximate functions. This architecture can be particularly effective for regression tasks.

In this notebook, we will:
1. Load the wire drawing data
2. Create and configure a KAN model
3. Train and evaluate the model
4. Visualize the results
5. Analyze errors and save the model

In [None]:
import sys
import os
import pickle
sys.path.append('../../')

In [6]:
from analysis_functions import test_after_opt, split_transform_one_comp_cv, opener

In [7]:
# Load data
X_stress_components_new = opener(
    "X_stress_components_new_components", path_import="../../new_components_resources/"
)
y_stress_components_new = opener(
    "y_stress_components_new_components", path_import="../../new_components_resources/"
)

print(X_stress_components_new.shape)

component_num = 0
X_current = X_stress_components_new[component_num]
y_current = y_stress_components_new[component_num]

../../new_components_resources//X_stress_components_new_components.pkl
../../new_components_resources//y_stress_components_new_components.pkl
(3, 1597, 5)


In [9]:
from analysis_functions import KANModelTrainTest

In [10]:
kan_model = KANModelTrainTest(False)

2025-05-11 15:05:14,012 - INFO - Using device: cpu


In [11]:
kan_model.best_params = {
    "n_layers": 3,
    "opt": "LBFGS",
    "steps": 679,
    "grid": 3,
    "k": 1,
    "n_units_0": 59,
    "n_units_1": 31,
    "n_units_2": 60,
    "lr": 0.013473746056928101,
}

In [12]:
kan_model.create_train_val_test(
    X_current,
    y_current,
)
kan_model.calc_test_metric()

2025-05-11 15:07:10,132 - INFO - Testing with optimizer: LBFGS, Learning rate: 0.013473746056928101


checkpoint directory created: ./model
saving model version 0.0


| train_loss: 3.55e+01 | test_loss: 3.49e+01 | reg: 0.00e+00 | :   1%| | 5/679 [01:58<4:25:21, 23.62



KeyboardInterrupt: 

In [None]:
# Create directory for saving model if it doesn't exist
os.makedirs('saved_models', exist_ok=True)

# Save the entire KAN model object using pickle
model_file_path = f'saved_models/kan_model_component_{component_num}.pkl'

with open(model_file_path, 'wb') as f:
    pickle.dump(kan_model, f)

print(f"KAN model saved to {model_file_path}")

# To demonstrate how to load the model
print("\nExample of loading the model:")
print("with open('saved_models/kan_model_component_0.pkl', 'rb') as f:")
print("    loaded_model = pickle.load(f)")

In [None]:
# Load the pickled KAN model (demonstration - commented out to avoid duplication)
"""
with open(model_file_path, 'rb') as f:
    loaded_kan_model = pickle.load(f)

# Test the loaded model (make predictions on test data)
import torch
import numpy as np

# Get test data
X_test = loaded_kan_model.cur_X_test
y_test = loaded_kan_model.cur_y_test

# Convert test data to tensor and move to device
test_tensor = torch.tensor(X_test, dtype=torch.float32).to(loaded_kan_model.device)

# Generate predictions with loaded model
with torch.no_grad():
    loaded_predictions = loaded_kan_model.final_model(test_tensor).cpu().numpy().flatten()

# Calculate RMSE
rmse = np.sqrt(np.mean(np.square(loaded_predictions - y_test.flatten())))
print(f"RMSE from loaded model: {rmse:.6f}")
"""

# Note: In production, you can use the saved model as follows:
"""
from analysis_functions import opener

# Load the saved model
with open('saved_models/kan_model_component_0.pkl', 'rb') as f:
    model = pickle.load(f)

# Make predictions
def predict(input_features):
    # Convert input to tensor
    tensor_input = torch.tensor(input_features, dtype=torch.float32).to(model.device)
    
    # Generate predictions
    with torch.no_grad():
        predictions = model.final_model(tensor_input).cpu().numpy()
    
    return predictions
"""

## Model Saving Options

In this notebook, we've demonstrated saving the KAN model using pickle. There are several ways to save machine learning models:

1. **Pickle (used above)**: 
n   - Pros: Saves the entire object including all methods and attributes
   - Cons: Can be fragile across Python versions and requires the same dependencies

2. **PyTorch's save/load**:
   - Pros: More portable across PyTorch versions
   - Cons: Only saves model parameters, not the entire object structure

3. **ONNX format**:
   - Pros: Framework-agnostic, allows deployment across platforms
   - Cons: More complex to set up, may not preserve all KAN-specific features

Choose the method that best fits your deployment needs. For research and development within the same environment, pickle is often the simplest option.