In [9]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from models.KAN import Model
from data_provider.data_factory import data_provider
from exp.exp_main import Exp_Main
import argparse

# Setup args and load model (same as before)
args = argparse.Namespace(
    is_training=0,
    root_path='./dataset/',
    data_path='MRO.csv',
    model='KAN',
    data='custom',
    features='S',
    target='Adj Close',
    seq_len=720,
    pred_len=120,
    label_len=48,
    enc_in=1,
    des='Exp',
    batch_size=8,
    learning_rate=0.0005,
    checkpoints='./checkpoints/',
    freq='d',
    embed='timeF',
    use_gpu=False,
    gpu=0,
    use_multi_gpu=False,
    num_workers=0,
    itr=1,
    train_epochs=10,
    patience=3,
    dropout=0.05,
    loss='mse',
    lradj='type1',
    use_amp=False,
    individual=False,
    c_out=1,
    seed=421
)

In [10]:
# Load model and data
exp = Exp_Main(args)
setting = f'MRO_720_120_S_channels_6_seed_{args.seed}'
model_name = 'KAN'
exp.test(setting, model_name, test=1)
model = exp.model
model.eval()

test_data, test_loader = data_provider(args, flag='test')
batch_x, batch_y, _, _ = next(iter(test_loader))
batch_x = batch_x.float()
batch_y = batch_y.float()

Use CPU

Model(
  (kans): ModuleList(
    (0): KAN(
      (layers): ModuleList(
        (0-2): 3 x KANLinear(
          (base_activation): SiLU()
        )
      )
    )
  )
)


Initial test data size: 1453
loading model

Prediction Metrics:
--------------------------------------------------------------------------------------------------------------------------------------------
KAN                  MSE: 0.66052         MAE: 0.62938         SE: 1.16774          RRMSE: 37.81%        RMAE: 29.28%     (numpy)
Repeat               MSE: 0.39761         MAE: 0.47471         SE: 0.70651          RRMSE: 29.33%        RMAE: 22.08%     (numpy)

Shapes: pred (1453, 120, 1), true (1453, 120, 1), naive_pred (1453, 120, 1), gt (1453, 840, 1), metrics {'KAN': {'mse': 0.66051555, 'mae': 0.6293815, 'se': 1.1677376, 'relative_rmse': 0.378053, 'relative_mae': 0.29276904, 'mse_torch': 0.6605156064033508, 'mae_torch': 0.629381537437439, 'se_torch': 1.1677377223968506, 'relative_rmse_torch': 0.378053025724

In [11]:
# Analyze KAN layers
kan = model.kans[0]  # Get first (and only) KAN

# For each KANLinear layer
for i, layer in enumerate(kan.layers):
    print(f"\nAnalyzing KANLinear Layer {i+1}:")
    
    # Get layer dimensions
    in_features = layer.base_weight.shape[1]
    out_features = layer.base_weight.shape[0]
    print(f"Input features: {in_features}, Output features: {out_features}")
    
    # Analyze base weights
    base_weights = layer.base_weight.detach().numpy()
    print(f"Base weight stats:")
    print(f"Mean: {base_weights.mean():.4f}")
    print(f"Std: {base_weights.std():.4f}")
    print(f"Min: {base_weights.min():.4f}")
    print(f"Max: {base_weights.max():.4f}")
    
    # Plot base weight distribution
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.hist(base_weights.flatten(), bins=50)
    plt.title(f'Layer {i+1} Base Weight Distribution')
    plt.xlabel('Weight Value')
    plt.ylabel('Count')
    
    # Plot base weight matrix heatmap
    plt.subplot(1, 2, 2)
    plt.imshow(base_weights, aspect='auto', cmap='RdBu')
    plt.colorbar()
    plt.title(f'Layer {i+1} Base Weight Matrix')
    plt.xlabel('Input Features')
    plt.ylabel('Output Features')
    plt.tight_layout()
    plt.show()


Analyzing KANLinear Layer 1:
Input features: 720, Output features: 256
Base weight stats:
Mean: -0.0023
Std: 0.0230
Min: -0.0866
Max: 0.0737

Analyzing KANLinear Layer 2:
Input features: 256, Output features: 128
Base weight stats:
Mean: -0.0018
Std: 0.0373
Min: -0.1039
Max: 0.1064

Analyzing KANLinear Layer 3:
Input features: 128, Output features: 120
Base weight stats:
Mean: -0.0099
Std: 0.0534
Min: -0.1441
Max: 0.1249
