In [1]:
import sys
import os

# Get the root directory
root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))

# Add the root directory to the Python path
sys.path.append(root_dir)

In [2]:
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from torch.optim.lr_scheduler import OneCycleLR
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from TINTOlib.tinto import TINTO
from kan import *
from tqdm import tqdm


import traceback
import time
import gc
import copy
import traceback
import torch.nn as nn
import cv2
import math
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import csv
#from torch.optim import LBFGS


if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")

print(device)



cuda


In [3]:
SEED = 381
# SET RANDOM SEED FOR REPRODUCIBILITY
torch.manual_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)

# BEST: ACC = 0.998

In [4]:
folder="data/wall-robot-navigation"
x_col=["V1", "V2", "V3", "V4", "V5", "V6", "V7", "V8","V9","V10","V11", "V12", "V13", 
       "V14", "V15", "V16", "V17", "V18", "V19","V20","V21", "V22", "V23", "V24"]
target_col=["class"]

# Functions

### Load Dataset and Images

In [5]:
def load_and_clean(npy_filename, y_filename, x_col, target_col):
    """
    Load the feature array (npy_filename) and label array (y_filename),
    drop rows in the feature array that contain any NaNs, and apply
    the same mask to the label array.
    """
    # Load numpy arrays
    X = np.load(os.path.join(folder, npy_filename))
    y = np.load(os.path.join(folder, y_filename))
    
    # Ensure the number of rows matches between X and y
    if X.shape[0] != y.shape[0]:
        raise ValueError("The number of rows in {} and {} do not match.".format(npy_filename, y_filename))
    
    # Create a boolean mask for rows that do NOT have any NaN values in X
    valid_rows = ~np.isnan(X).any(axis=1)

    # Filter both arrays using the valid_rows mask
    X_clean = X[valid_rows]
    y_clean = y[valid_rows]
    
    # Convert arrays to DataFrames
    df_X = pd.DataFrame(X_clean)
    df_y = pd.DataFrame(y_clean)
    df_X.columns = x_col
    df_y.columns = target_col
    return df_X, df_y

In [6]:
def load_and_preprocess_data(X_train, y_train, X_test, y_test, X_val, y_val, image_model, problem_type, batch_size=32):
    # Add target column to input for IGTD
    X_train_full = X_train.copy()
    X_train_full["target"] = y_train.values

    X_val_full = X_val.copy()
    X_val_full["target"] = y_val.values

    X_test_full = X_test.copy()
    X_test_full["target"] = y_test.values

    # Generate the images if the folder does not exist
    if not os.path.exists(f'{images_folder}/train'):
        #print(X_train_full)
        image_model.fit_transform(X_train_full, f'{images_folder}/train')
        image_model.saveHyperparameters(f'{images_folder}/model.pkl')
    else:
        print("The images are already generated")

    # Load image paths
    imgs_train = pd.read_csv(os.path.join(f'{images_folder}/train', f'{problem_type}.csv'))
    imgs_train["images"] = images_folder + "/train/" + imgs_train["images"]

    if not os.path.exists(f'{images_folder}/val'):
        image_model.transform(X_val_full, f'{images_folder}/val')
    else:
        print("The images are already generated")

    imgs_val = pd.read_csv(os.path.join(f'{images_folder}/val', f'{problem_type}.csv'))
    imgs_val["images"] = images_folder + "/val/" + imgs_val["images"]

    if not os.path.exists(f'{images_folder}/test'):
        image_model.transform(X_test_full, f'{images_folder}/test')
    else:
        print("The images are already generated")

    imgs_test = pd.read_csv(os.path.join(f'{images_folder}/test', f'{problem_type}.csv'))
    imgs_test["images"] = images_folder + "/test/" + imgs_test["images"]

    # Image data
    X_train_img = np.array([cv2.imread(img) for img in imgs_train["images"]])
    X_val_img = np.array([cv2.imread(img) for img in imgs_val["images"]])
    X_test_img = np.array([cv2.imread(img) for img in imgs_test["images"]])

    # Create a MinMaxScaler object
    scaler = MinMaxScaler()

    # Scale numerical data
    X_train_num = pd.DataFrame(scaler.fit_transform(X_train), columns=X_train.columns)
    X_val_num = pd.DataFrame(scaler.transform(X_val), columns=X_val.columns)
    X_test_num = pd.DataFrame(scaler.transform(X_test), columns=X_test.columns)

    attributes = len(X_train_num.columns)
    height, width, channels = X_train_img[0].shape
    imgs_shape = (channels, height, width)

    print("Images shape: ", imgs_shape)
    print("Attributes: ", attributes)

    # Convert data to PyTorch tensors
    X_train_num_tensor = torch.as_tensor(X_train_num.values, dtype=torch.float32)
    X_val_num_tensor = torch.as_tensor(X_val_num.values, dtype=torch.float32)
    X_test_num_tensor = torch.as_tensor(X_test_num.values, dtype=torch.float32)
    X_train_img_tensor = torch.as_tensor(X_train_img, dtype=torch.float32).permute(0, 3, 1, 2) / 255.0
    X_val_img_tensor = torch.as_tensor(X_val_img, dtype=torch.float32).permute(0, 3, 1, 2) / 255.0
    X_test_img_tensor = torch.as_tensor(X_test_img, dtype=torch.float32).permute(0, 3, 1, 2) / 255.0
    y_train_tensor = torch.as_tensor(y_train.values, dtype=torch.float32).reshape(-1, 1)
    y_val_tensor = torch.as_tensor(y_val.values, dtype=torch.float32).reshape(-1, 1)
    y_test_tensor = torch.as_tensor(y_test.values, dtype=torch.float32).reshape(-1, 1)

    # Create DataLoaders
    train_dataset = TensorDataset(X_train_num_tensor, X_train_img_tensor, y_train_tensor)
    val_dataset = TensorDataset(X_val_num_tensor, X_val_img_tensor, y_val_tensor)
    test_dataset = TensorDataset(X_test_num_tensor, X_test_img_tensor, y_test_tensor)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, pin_memory=True)

    return train_loader, val_loader, test_loader, attributes, imgs_shape

In [7]:
def complete_coordinate_and_xcol(coordinate, x_col):
    """
    Given a coordinate (tuple of row, col arrays) and x_col feature list,
    return completed coordinate and x_col including empty positions.

    Empty positions are filled with labels: 'Ex1', 'Ex2', ...
    """

    row_coords, col_coords = coordinate
    max_row = row_coords.max()
    max_col = col_coords.max()
    max_c = max(max_row, max_col)
    # All possible coordinate slots
    full_coords = set((r, c) for r in range(max_col + 1) for c in range(max_col + 1))
    current_coords = set(zip(row_coords, col_coords))
    missing_coords = sorted(full_coords - current_coords)

    # Create updated coordinate arrays
    new_row_coords = list(row_coords)
    new_col_coords = list(col_coords)
    new_x_col = list(x_col)

    for idx, (r, c) in enumerate(missing_coords):
        new_row_coords.append(r)
        new_col_coords.append(c)
        new_x_col.append(f"Ex{idx+1}")

    completed_coordinate = (np.array(new_row_coords), np.array(new_col_coords))
    return completed_coordinate, new_x_col

In [8]:
def plot_feature_mapping(x_col, coordinate, scale=(4,4)):
    grid = np.full(scale, "", dtype=object)
    rows, cols = coordinate
    for i, (r, c) in enumerate(zip(rows, cols)):
        if i < len(x_col):
            grid[r, c] = x_col[i]
        else:
            grid[r, c] = "?"
    
    plt.figure(figsize=(scale[1] * 2, scale[0] * 2))
    for i in range(scale[0]):
        for j in range(scale[1]):
            plt.text(j, i, grid[i, j], ha='center', va='center', fontsize=10,
                     bbox=dict(facecolor='white', edgecolor='gray'))
    
    plt.xticks(np.arange(scale[1]))
    plt.yticks(np.arange(scale[0]))
    plt.grid(True)
    plt.title("Feature → Pixel Mapping")
    plt.gca().invert_yaxis()  # So row 0 is at the top
    plt.show()

In [9]:
def combine_loader(loader):
    """
    Combines all batches from a DataLoader into three tensors.
    Assumes each batch is a tuple: (mlp_tensor, img_tensor, target_tensor)
    """
    mlp_list, img_list, target_list = [], [], []
    for mlp, img, target in loader:
        mlp_list.append(mlp)
        img_list.append(img)
        target_list.append(target)
    return torch.cat(mlp_list, dim=0), torch.cat(img_list, dim=0), torch.cat(target_list, dim=0)

### Functions for KAN

In [10]:
dtype = torch.get_default_dtype()
min_expected = 0.
max_expected = 1.

In [11]:
def plot_sorted_feature_importance(columns, importances):
    # Move to CPU and numpy if it's a GPU tensor
    if isinstance(importances, torch.Tensor):
        importances = importances.detach().cpu().numpy()

    # Pair columns and importances and sort by importance descending
    sorted_pairs = sorted(zip(columns, importances), key=lambda x: x[1], reverse=True)
    print(sorted_pairs)
    sorted_columns, sorted_importances = zip(*sorted_pairs)
    
    # Create the bar plot
    plt.figure(figsize=(4, 3))
    plt.barh(sorted_columns, sorted_importances, color='royalblue')
    plt.xlabel('Importance')
    plt.title('KAN Feature Importances')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.show()
    return sorted_pairs

In [12]:
def plot_training_ACC(y_true, y_pred, train_loss, val_loss, plot=False):
    clipped = torch.clamp(torch.round(y_pred), min=min_expected, max=max_expected)
    avg_rmse = torch.mean((clipped == y_true).type(dtype))
    print(avg_rmse)
    if plot:
        plt.figure(figsize=(5, 4))
        plt.plot(train_loss)
        plt.plot(val_loss)
        plt.legend(['train', 'val'])
        plt.ylabel('Accuracy')
        plt.xlabel('step')
        plt.yscale('log')
    return avg_rmse

In [13]:
def plot_confusion_matrix(y_true_tensor, y_pred_tensor, title="Confusion Matrix", plot=False):
    # Move tensors to CPU and detach if necessary
    clipped = torch.clamp(torch.round(y_pred_tensor), min=min_expected, max=max_expected)

    y_true = y_true_tensor.detach().cpu().numpy().flatten()
    y_pred = clipped.detach().cpu().numpy().flatten()

    # Round predictions if they are float (e.g., from sigmoid or regression)
    if y_pred.dtype.kind in {'f'}:
        y_pred = y_pred.round()

    # Get sorted list of all unique labels
    all_labels = sorted(set(y_true) | set(y_pred))

    # Compute confusion matrix
    cm = confusion_matrix(y_true, y_pred, labels=all_labels)
    print(cm)
    if plot:
        # Plot
        plt.figure(figsize=(4, 3))
        sns.heatmap(cm, annot=True, fmt='g', cmap='Blues',
                    xticklabels=all_labels, yticklabels=all_labels)
        plt.xlabel("Predicted")
        plt.ylabel("True")
        plt.title(title)
        plt.tight_layout()
        plt.show()
    return cm

In [14]:
def train_acc_kan():
    rounded = torch.round(torch.round(model(dataset['train_input'])[:,0]))
    clipped = torch.clamp(rounded, min=min_expected, max=max_expected)
    return torch.mean((clipped == dataset['train_label'][:,0]).type(dtype))

def val_acc_kan():
    rounded = torch.round(torch.round(model(dataset['val_input'])[:,0]))
    clipped = torch.clamp(rounded, min=min_expected, max=max_expected)
    return torch.mean((clipped == dataset['val_label'][:,0]).type(dtype))

In [15]:
def custom_fit(model, dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., lamb_coefdiff=0., update_grid=True, 
               grid_update_num=10, loss_fn=None, lr=1., start_grid_update_step=-1, stop_grid_update_step=50, batch=-1,
               metrics=None, save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', 
               singularity_avoiding=False, y_th=1000., reg_metric='edge_forward_spline_n', display_metrics=None):
    '''
    training

    Args:
    -----
        dataset : dic
            contains dataset['train_input'], dataset['train_label'], dataset['val_input'], dataset['val_label']
        opt : str
            "LBFGS" or "Adam"
        steps : int
            training steps
        log : int
            logging frequency
        lamb : float
            overall penalty strength
        lamb_l1 : float
            l1 penalty strength
        lamb_entropy : float
            entropy penalty strength
        lamb_coef : float
            coefficient magnitude penalty strength
        lamb_coefdiff : float
            difference of nearby coefficits (smoothness) penalty strength
        update_grid : bool
            If True, update grid regularly before stop_grid_update_step
        grid_update_num : int
            the number of grid updates before stop_grid_update_step
        start_grid_update_step : int
            no grid updates before this training step
        stop_grid_update_step : int
            no grid updates after this training step
        loss_fn : function
            loss function
        lr : float
            learning rate
        batch : int
            batch size, if -1 then full.
        save_fig_freq : int
            save figure every (save_fig_freq) steps
        singularity_avoiding : bool
            indicate whether to avoid singularity for the symbolic part
        y_th : float
            singularity threshold (anything above the threshold is considered singular and is softened in some ways)
        reg_metric : str
            regularization metric. Choose from {'edge_forward_spline_n', 'edge_forward_spline_u', 'edge_forward_sum', 'edge_backward', 'node_backward'}
        metrics : a list of metrics (as functions)
            the metrics to be computed in training
        display_metrics : a list of functions
            the metric to be displayed in tqdm progress bar

    Returns:
    --------
        results : dic
            results['train_loss'], 1D array of training losses (RMSE)
            results['val_loss'], 1D array of val losses (RMSE)
            results['reg'], 1D array of regularization
            other metrics specified in metrics
        best_model:
    '''

    if lamb > 0. and not model.save_act:
        print('setting lamb=0. If you want to set lamb > 0, set model.save_act=True')
        
    old_save_act, old_symbolic_enabled = model.disable_symbolic_in_fit(lamb)

    pbar = tqdm(range(steps), desc='description', ncols=100)

    if loss_fn == None:
        loss_fn = loss_fn_eval = lambda x, y: torch.mean((x - y) ** 2)
    else:
        loss_fn = loss_fn_eval = loss_fn

    grid_update_freq = int(stop_grid_update_step / grid_update_num)

    if opt == "Adam":
        optimizer = torch.optim.Adam(model.get_params(), lr=lr)
    elif opt == "LBFGS":
        optimizer = LBFGS(model.get_params(), lr=lr, history_size=10, 
                          line_search_fn="strong_wolfe", 
                          tolerance_grad=1e-32,
                          tolerance_change=1e-32,
                          tolerance_ys=1e-32)

    results = {}
    results['train_loss'] = []
    results['val_loss'] = []
    results['reg'] = []
    if metrics != None:
        for i in range(len(metrics)):
            results[metrics[i].__name__] = []

    if batch == -1 or batch > dataset['train_input'].shape[0]:
        batch_size = dataset['train_input'].shape[0]
        batch_size_val = dataset['val_input'].shape[0]
    else:
        batch_size = batch
        batch_size_val = batch

    global train_loss, reg_

    def closure():
        global train_loss, reg_
        optimizer.zero_grad()
        pred = model.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th)
        train_loss = loss_fn(pred, dataset['train_label'][train_id])
        if model.save_act:
            if reg_metric == 'edge_backward':
                model.attribute()
            if reg_metric == 'node_backward':
                model.node_attribute()
            reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
        else:
            reg_ = torch.tensor(0.)
        objective = train_loss + lamb * reg_
        objective.backward()
        return objective

    if save_fig:
        if not os.path.exists(img_folder):
            os.makedirs(img_folder)
    
    best_model_state = None
    best_epoch = -1
    best_metric = 0
    val_metric = 0
    for epoch in pbar:
        
        if epoch == steps-1 and old_save_act:
            model.save_act = True
            
        if save_fig and epoch % save_fig_freq == 0:
            save_act = model.save_act
            model.save_act = True
        
        train_id = np.random.choice(dataset['train_input'].shape[0], batch_size, replace=False)
        val_id = np.random.choice(dataset['val_input'].shape[0], batch_size_val, replace=False)

        if epoch % grid_update_freq == 0 and epoch < stop_grid_update_step and update_grid and epoch >= start_grid_update_step:
            model.update_grid(dataset['train_input'][train_id])

        if opt == "LBFGS":
            optimizer.step(closure)

        if opt == "Adam":
            pred = model.forward(dataset['train_input'][train_id], singularity_avoiding=singularity_avoiding, y_th=y_th)
            train_loss = loss_fn(pred, dataset['train_label'][train_id])
            if model.save_act:
                if reg_metric == 'edge_backward':
                    model.attribute()
                if reg_metric == 'node_backward':
                    model.node_attribute()
                reg_ = model.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
            else:
                reg_ = torch.tensor(0.)
            loss = train_loss + lamb * reg_
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        val_loss = loss_fn_eval(model.forward(dataset['val_input'][val_id]), dataset['val_label'][val_id])
        
        if metrics != None:
            for i in range(len(metrics)):
                results[metrics[i].__name__].append(metrics[i]().item())
        
        results['train_loss'].append(torch.sqrt(train_loss).cpu().detach().numpy())
        results['val_loss'].append(torch.sqrt(val_loss).cpu().detach().numpy())
        results['reg'].append(reg_.cpu().detach().numpy())

        if epoch % log == 0:
            if display_metrics == None:
                pbar.set_description("| train_loss: %.2e | val_loss: %.2e | reg: %.2e | " % (torch.sqrt(train_loss).cpu().detach().numpy(), torch.sqrt(val_loss).cpu().detach().numpy(), reg_.cpu().detach().numpy()))
                val_metric = val_loss
            else:
                string = ''
                data = ()
                for metric in display_metrics:
                    val_metric = results[metric][-1]
                    string += f' {metric}: %.2e |'
                    try:
                        results[metric]
                    except:
                        raise Exception(f'{metric} not recognized')
                    data += (results[metric][-1],)
                pbar.set_description(string % data)

        if val_metric > best_metric:
            best_epoch = epoch
            best_metric = val_metric
            best_model_state = copy.deepcopy(model.state_dict())

        if save_fig and epoch % save_fig_freq == 0:
            model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title="Step {}".format(_), beta=beta)
            plt.savefig(img_folder + '/' + str(_) + '.jpg', bbox_inches='tight', dpi=200)
            plt.close()
            model.save_act = save_act

    model.log_history('fit')
    # revert back to original state
    model.symbolic_enabled = old_symbolic_enabled
    print(f"✅ Best validation Accuracy: {best_metric:.4e} at {best_epoch} epoch")
    return best_model_state, results, best_epoch

### Grad CAM Functions

In [16]:
def grad_cam_side_by_side(model, model_state, num_input, img_input, x_col, coordinate,
                          zoom=1, target_index=None, save_path=None, show=True):
    model.load_state_dict(model_state)
    model.eval()
    model.zero_grad()

    if num_input.dim() == 1:
        num_input = num_input.unsqueeze(0)
    if img_input.dim() == 3:
        img_input = img_input.unsqueeze(0)

    num_input = num_input.to(model.device)
    img_input = img_input.to(model.device)

    # Store activations and gradients
    activations = {}
    gradients = {}

    def forward_hook(module, input, output):
        activations["value"] = output.detach()

    def backward_hook(module, grad_input, grad_output):
        gradients["value"] = grad_output[0].detach()

    conv_layer = model.cnn_branch[0]
    h_fwd = conv_layer.register_forward_hook(forward_hook)
    h_bwd = conv_layer.register_full_backward_hook(backward_hook)

    output = model(num_input, img_input)
    target = output if target_index is None else output[:, target_index]
    target.backward()

    act = activations["value"].squeeze(0)
    grad = gradients["value"].squeeze(0)
    weights = grad.mean(dim=(1, 2))
    cam = torch.zeros_like(act[0])
    for i, w in enumerate(weights):
        cam += w * act[i]

    cam = torch.relu(cam)
    if cam.max() > 0:
        cam -= cam.min()
        cam /= cam.max()
    else:
        cam[:] = 0.0


    cam_resized = F.interpolate(cam.unsqueeze(0).unsqueeze(0), size=img_input.shape[-2:], mode='bilinear', align_corners=False)
    #print(cam_resized)
    heatmap = cam_resized.squeeze().cpu()

    img_disp = img_input.squeeze().detach().cpu()
    if img_disp.shape[0] == 1:
        img_disp = img_disp[0]
        cmap = 'gray'
    else:
        img_disp = img_disp.permute(1, 2, 0)
        cmap = None

    h, w = img_disp.shape[:2]

    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    axs[0].imshow(img_disp, cmap=cmap)
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    im = axs[1].imshow(img_disp, cmap=cmap)
    heatmap_img = axs[1].imshow(heatmap, cmap='jet', alpha=0.5)
    axs[1].set_title("Grad-CAM")
    axs[1].axis('off')

    # Add colorbar (legend)
    cbar = fig.colorbar(heatmap_img, ax=axs[1], fraction=0.046, pad=0.04)
    cbar.set_label("Grad-CAM Intensity", fontsize=10)

    # Overlay abbreviated features
    for i, col in enumerate(x_col):
        abbrev = col.split("-")[0][:8]
        if i < len(coordinate[0]):
            r, c = coordinate[0][i], coordinate[1][i]
            ry = r * zoom + zoom // 2
            cx = c * zoom + zoom // 2
            if ry < h and cx < w:
                axs[1].text(cx, ry, abbrev,
                            color='white', fontsize=9, ha='center', va='center',
                            bbox=dict(facecolor='black', edgecolor='none', pad=1.0, alpha=0.4))
                axs[0].text(cx, ry, abbrev,
                            color='white', fontsize=9, ha='center', va='center',
                            bbox=dict(facecolor='black', edgecolor='none', pad=1.0, alpha=0.4))

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)

    if show:
        plt.show()
    else:
        plt.close()

    h_fwd.remove()
    h_bwd.remove()

    return heatmap

In [17]:
def heatmap_to_feature_relevance(heatmap, coordinate, x_col, zoom=1):
    """
    Map heatmap pixel intensities to their corresponding features using coordinate and zoom.
    Returns a dictionary of {feature_name: relevance_score}.
    """
    feature_scores = {}

    for i, col in enumerate(x_col):
        if i < len(coordinate[0]):
            r, c = coordinate[0][i], coordinate[1][i]
            ry = r * zoom + zoom // 2
            cx = c * zoom + zoom // 2
            if ry < heatmap.shape[0] and cx < heatmap.shape[1]:
                feature_scores[col] = heatmap[ry, cx].item()

    return feature_scores

def plot_feature_relevance_bar(feature_scores):
    """
    Plots a horizontal bar chart of feature relevance from Grad-CAM heatmap.
    """
    sorted_scores = sorted(feature_scores.items(), key=lambda item: item[1], reverse=True)
    print(sorted_scores)
    features, scores = zip(*sorted_scores)

    plt.figure(figsize=(6, 3))
    plt.barh(features, scores, color='royalblue')
    plt.xlabel("Grad-CAM Relevance")
    plt.title("Feature Relevance for Test")
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.show()
    return sorted_scores

In [18]:
def compute_avg_feature_relevance_from_val(model, model_state, val_inputs, val_imgs, coordinate, x_col, zoom=1):
    """
    Computes average Grad-CAM feature relevance over all validation instances with a tqdm progress bar.

    Args:
        model: Trained model
        model_state: Trained weights to be loaded
        val_inputs: List or tensor of numerical inputs
        val_imgs: List or tensor of image inputs
        coordinate: IGTD-style coordinate tuple (row array, col array)
        x_col: List of feature names (including extras)
        zoom: Zoom level used when generating the images

    Returns:
        Dictionary of average feature relevance
    """
    accumulated_scores = {feature: [] for feature in x_col}

    for num_input, img_input in tqdm(zip(val_inputs, val_imgs), total=len(val_inputs), desc="Computing Grad-CAM"):
        heatmap = grad_cam_side_by_side(
            model=model,
            model_state=model_state,
            num_input=num_input,
            img_input=img_input,
            coordinate=coordinate,
            x_col=x_col,
            zoom=zoom,
            show=False
        )
        scores = heatmap_to_feature_relevance(heatmap, coordinate, x_col, zoom)
        for feature, value in scores.items():
            accumulated_scores[feature].append(value)

    # Compute average
    avg_scores = {feature: float(np.mean(values)) if values else 0.0
                  for feature, values in accumulated_scores.items()}

    return avg_scores


### CNN Functions

In [19]:
def fit_cnn_only_model(model, dataset, steps=100, lr=1.0, loss_fn=None, batch=-1, opt="LBFGS"):
    """
    Trains a CNN-only model using LBFGS.

    Args:
        model: CNN-only PyTorch model.
        dataset: Dictionary with keys: 'train_img', 'train_label', 'val_img', 'val_label'.
        steps: Number of training iterations.
        lr: Learning rate.
        loss_fn: Loss function. Defaults to MSE.

    Returns:
        results: Dict with lists of train/val losses.
        best_model_state: Best weights based on val loss.
    """
    device = model.device
    if loss_fn is None:
        loss_fn = nn.MSELoss()

    # Optimizer selection
    if opt == "LBFGS":
        optimizer = LBFGS(model.parameters(), lr=lr, history_size=10, 
                          line_search_fn="strong_wolfe", 
                          tolerance_grad=1e-32, 
                          tolerance_change=1e-32, 
                          tolerance_ys=1e-32)
    elif opt == "AdamW":
        optimizer = optim.AdamW(model.parameters(), lr=lr)
    else:
        raise ValueError(f"Unsupported optimizer '{opt}'")

    n_train = dataset["train_img"].shape[0]
    n_val = dataset["val_img"].shape[0]
    batch_size = n_train if batch == -1 or batch > n_train else batch

    results = {'train_loss': [], 'val_loss': []}
    best_model_state = None
    best_loss = float("inf")
    best_epoch = -1

    pbar = tqdm(range(steps), desc="Training CNN Only ({opt})", ncols=100)

    for step in pbar:
        train_idx = np.random.choice(n_train, batch_size, replace=False)
        #train_idx = torch.randperm(n_train)[:min(32, n_train)]  # small batch
        x_train = dataset["train_img"][train_idx].to(device)
        y_train = dataset["train_label"][train_idx].to(device)
        if opt == "LBFGS":
            def closure():
                optimizer.zero_grad()
                output = model(0, x_train)
                loss = loss_fn(output, y_train)
                loss.backward()
                return loss
            optimizer.step(closure)
            train_loss = closure().item()

        else:  # AdamW
            optimizer.zero_grad()
            output = model(0, x_train)
            loss = loss_fn(output, y_train)
            loss.backward()
            optimizer.step()
            train_loss = loss.item()

        with torch.no_grad():
            val_output = model(0, dataset["val_img"].to(device))
            val_loss = loss_fn(val_output, dataset["val_label"].to(device)).item()

        results["train_loss"].append(train_loss)
        results["val_loss"].append(val_loss)

        if val_loss < best_loss:
            best_loss = val_loss
            best_epoch = step
            best_model_state = copy.deepcopy(model.state_dict())

        pbar.set_description(f"| Train: {train_loss:.4e} | Val: {val_loss:.4e} |")

    print(f"✅ Best validation loss: {best_loss:.4e} at {best_epoch} epoch")
    return best_model_state, results, best_epoch

In [20]:
def build_custom_cnn_model(cnn_blocks, dense_layers, imgs_shape, device='cuda'):
    class CustomCNNModel(nn.Module):
        def __init__(self):
            super(CustomCNNModel, self).__init__()
            self.device = device

            cnn_layers = []
            in_channels = imgs_shape[0]
            out_channels = 16
            cnn_blocks_list = [10, 7, 5]
            size_layer_norm = cnn_blocks_list[cnn_blocks-1]
            
            f_layer_size = 10 - cnn_blocks*2
            for i in range(cnn_blocks):
                cnn_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=2))
                
                if i < cnn_blocks - 1:
                    cnn_layers.append(nn.BatchNorm2d(out_channels))
                    cnn_layers.append(nn.ReLU())
                    cnn_layers.append(nn.MaxPool2d(2))
                else:
                    # Last block: LayerNorm + Sigmoid + Flatten
                    cnn_layers.append(nn.LayerNorm([out_channels, size_layer_norm, size_layer_norm]))
                    cnn_layers.append(nn.Sigmoid())
                    cnn_layers.append(nn.Flatten())
                in_channels = out_channels
                out_channels *= 2

            self.cnn_branch = nn.Sequential(*cnn_layers).to(device)
            self.flat_size = self._get_flat_size(imgs_shape)

            # Dense (FC) layers
            fc_layers = []
            input_dim = self.flat_size
            for i in range(dense_layers - 1):
                fc_layers.append(nn.Linear(int(input_dim), int(input_dim // 2)))
                fc_layers.append(nn.ReLU())
                input_dim = input_dim // 2
            fc_layers.append(nn.Linear(int(input_dim), 1))

            self.fc = nn.Sequential(*fc_layers).to(device)

        def _get_flat_size(self, imgs_shape):
            dummy_input = torch.zeros(1, *imgs_shape, device=self.device)
            x = self.cnn_branch(dummy_input)
            return x.shape[1]

        def forward(self, num_input, img_input):
            img_input = img_input.to(self.device)
            features = self.cnn_branch(img_input)
            output = self.fc(features)
            return output

    return CustomCNNModel()

### Write metrics

In [21]:
def create_csv_with_header(filename, columns_opt):
    header=['kan_neurons', 'kan_grid', 'lamb', columns_opt, 'ACC','Conf_Mtx','Best_Epoch','KAN_Relevance','CNN_Relevance','KAN M.R.F.','CNN M.R.F.']
    """Creates a CSV file with a given header."""
    with open(filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(header)

In [22]:
def format_top_3(pairs):
    return '\n'.join(f"{k}: {v:.2f}" for k, v in pairs[:3])

In [23]:
def append_row_to_csv(filename, kan_neurons, kan_grid, lamb, opt_col_val, acc, cm, best_epoch, k_rel, cnn_rel, kan_mrf, cnn_mrf):
    row = [kan_neurons, kan_grid, lamb, opt_col_val, acc, cm, best_epoch, k_rel, cnn_rel, format_top_3(kan_mrf), format_top_3(cnn_mrf)]
    ['Configuration', opt_col_val,'ACC','Conf_Mtx','Best_Epoch','KAN_Relevance','CNN_Relevance','KAN M.R.F.','CNN M.R.F.']
    """Appends a single row to an existing CSV file."""
    if not os.path.isfile(filename):
        raise FileNotFoundError(f"{filename} does not exist. Please create the file first with a header.")
    with open(filename, mode='a', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(row)

### Hybrid Functions

In [24]:
def print_mkan_vs_cnn_relevance(feature_scores, mkan_len):
    if isinstance(feature_scores, torch.Tensor):
        feature_scores = feature_scores.detach().cpu().numpy()

    mkan_relevance = feature_scores[:mkan_len].sum()
    cnn_relevance = feature_scores[mkan_len:].sum()
    m_kan_relevance_perct = float(mkan_relevance/(mkan_relevance+cnn_relevance))
    cnn_relevance_perct = float(cnn_relevance/(mkan_relevance+cnn_relevance))
    print(f"M_KAN Relevance: {m_kan_relevance_perct}")
    print(f"CNN Relevance: {cnn_relevance_perct}")
    return m_kan_relevance_perct, cnn_relevance_perct

In [25]:
def plot_mkan_vs_cnn_relevance(feature_scores, mkan_len=6, title="Feature Relevance Split"):
    """
    Plots a pie chart comparing the total feature relevance from m_kan output vs CNN output.

    Args:
        feature_scores (torch.Tensor or list): 1D tensor of relevance values from final_kan.
        mkan_len (int): Number of dimensions from m_kan (default: 6).
        title (str): Title for the pie chart.
    """
    if isinstance(feature_scores, torch.Tensor):
        feature_scores = feature_scores.detach().cpu().numpy()

    mkan_relevance = feature_scores[:mkan_len].sum()
    cnn_relevance = feature_scores[mkan_len:].sum()

    sizes = [mkan_relevance, cnn_relevance]
    labels = ['m_kan Output', 'CNN Output']
    explode = (0.05, 0)  # Slightly explode m_kan slice for emphasis

    plt.figure(figsize=(3, 3))
    plt.pie(sizes, labels=labels, explode=explode, autopct='%1.1f%%',
            shadow=True, startangle=140, colors=["#66c2a5", "#fc8d62"])
    plt.title(title)
    plt.axis('equal')
    plt.tight_layout()
    plt.show()

In [26]:
def fit_hybrid_dataloaders(model, 
                           dataset,
                           opt="AdamW", 
                           steps=100, 
                           log=1, 
                           lamb=0., 
                           lamb_l1=1., 
                           lamb_entropy=2., 
                           lamb_coef=0., 
                           lamb_coefdiff=0., 
                           update_grid=True, 
                           grid_update_num=10, 
                           loss_fn=None, 
                           lr=1., 
                           start_grid_update_step=-1, 
                           stop_grid_update_step=50, 
                           batch=-1,
                           metrics=None, 
                           save_fig=False, 
                           in_vars=None, 
                           out_vars=None, 
                           beta=3, 
                           save_fig_freq=1, 
                           img_folder='./video', 
                           singularity_avoiding=False, 
                           y_th=1000., 
                           reg_metric='edge_forward_spline_n', 
                           display_metrics=None,
                           sum_f_reg=True):
    """
    Trains the hybrid model (with a KAN branch and a CNN branch) using a steps-based loop
    adapted from KAN.fit(), with grid updates and regularization.
    
    Instead of a single dataset dict, this function accepts three DataLoaders:
        - train_loader: provides (mlp, img, target) for training
        - val_loader: provides (mlp, img, target) for evaluation during training
        - test_loader: provides (mlp, img, target) for validation

    Internally, the function combines each loader into a dataset dictionary.
    
    Returns:
        results: dictionary containing training loss, evaluation loss, regularization values,
                 and any additional metrics recorded during training.
    """
    #device = next(model.parameters()).device

    # Warn if regularization is requested but model's internal flag isn't enabled.
    if lamb > 0. and not getattr(model.m_kan, "save_act", False):
        print("setting lamb=0. If you want to set lamb > 0, set model.m_kan.save_act=True")
    
    # Disable symbolic processing for training if applicable (KAN internal logic)
    if hasattr(model.m_kan, "disable_symbolic_in_fit"):
        old_save_act, old_symbolic_enabled = model.m_kan.disable_symbolic_in_fit(lamb)
        f_old_save_act, f_old_symbolic_enabled = model.final_kan.disable_symbolic_in_fit(lamb)
    else:
        old_save_act, old_symbolic_enabled = None, None

    pbar = tqdm(range(steps), desc='Training', ncols=100)

    # Default loss function (mean squared error) if not provided
    if loss_fn is None:
        loss_fn = lambda x, y: torch.mean((x - y) ** 2)

    # Determine grid update frequency
    grid_update_freq = int(stop_grid_update_step / grid_update_num) if grid_update_num > 0 else 1

    # Determine total number of training examples
    n_train = dataset["train_input"].shape[0]
    n_eval  = dataset["val_input"].shape[0]  # using val set for evaluation during training
    batch_size = n_train if batch == -1 or batch > n_train else batch

    # Set up optimizer: choose between Adam and LBFGS (removed tolerance_ys)
    if opt == "AdamW":
        optimizer = optim.AdamW(model.parameters(), lr=lr)
    elif opt == "LBFGS":        
        optimizer = LBFGS(model.parameters(), lr=lr, history_size=10, 
                          line_search_fn="strong_wolfe", 
                          tolerance_grad=1e-32, 
                          tolerance_change=1e-32, 
                          tolerance_ys=1e-32)
    else:
        raise ValueError("Optimizer not recognized. Use 'Adam' or 'LBFGS'.")

    # Prepare results dictionary.
    results = {'train_loss': [], 'eval_loss': [], 'reg': []}
    
    if metrics is not None:
        for metric in metrics:
            results[metric.__name__] = []

    best_model_state = None
    best_epoch = -1
    best_metric = 0
    val_metric = 0

    for step in pbar:
        # Randomly sample indices for a mini-batch from the training set.
        train_indices = np.random.choice(n_train, batch_size, replace=False)
        # Use full evaluation set for evaluation; you can also sample if desired.
        eval_indices = np.arange(n_eval)
        
        cached_loss = {}
        # Closure for LBFGS
        def closure():
            optimizer.zero_grad()
            mlp_batch = dataset["train_input"][train_indices]
            img_batch = dataset["train_img"][train_indices]
            target_batch = dataset["train_label"][train_indices]
            outputs = model(mlp_batch, img_batch)
            train_loss = loss_fn(outputs, target_batch)
            # Compute regularization term if enabled.
            if hasattr(model.m_kan, "save_act") and model.m_kan.save_act:
                if reg_metric == 'edge_backward':
                    model.m_kan.attribute()
                    model.final_kan.attribute()
                if reg_metric == 'node_backward':
                    model.m_kan.node_attribute()
                    model.final_kan.node_attribute()
                reg_val_inner = model.m_kan.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
                if sum_f_reg:
                    reg_val_inner += model.final_kan.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
            else:
                reg_val_inner = torch.tensor(0., device=device)
            loss_val_inner = train_loss + lamb * reg_val_inner
            loss_val_inner.backward()
            cached_loss['loss'] = loss_val_inner.detach()
            cached_loss['reg'] = reg_val_inner.detach()
            return loss_val_inner

        # Perform grid update if applicable.
        if (step % grid_update_freq == 0 and step < stop_grid_update_step 
            and update_grid and step >= start_grid_update_step):
            
            mlp_batch = dataset['train_input'][train_indices]
            cnn_batch = dataset['train_img'][train_indices]
            
            model.m_kan.update_grid(mlp_batch)
            #cnn_output = model.cnn_branch(cnn_batch)  # Process image input
            concatenated = model.get_concat_output(mlp_batch, cnn_batch)

            model.final_kan.update_grid(concatenated)

        # Perform an optimizer step.
        if opt == "LBFGS":
            optimizer.step(closure)
            loss_val = cached_loss['loss']
            reg_val = cached_loss['reg']
        else:  # AdamW branch
            optimizer.zero_grad()
            mlp_batch = dataset["train_input"][train_indices]
            img_batch = dataset["train_img"][train_indices]
            target_batch = dataset["train_label"][train_indices]
            outputs = model(mlp_batch, img_batch)
            train_loss = loss_fn(outputs, target_batch)
            if hasattr(model.m_kan, "save_act") and model.m_kan.save_act:
                if reg_metric == 'edge_backward':
                    model.m_kan.attribute()
                    model.final_kan.attribute()
                if reg_metric == 'node_backward':
                    model.m_kan.node_attribute()
                    model.final_kan.node_attribute()
                reg_val = model.m_kan.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
                if sum_f_reg:
                    reg_val = reg_val + model.final_kan.get_reg(reg_metric, lamb_l1, lamb_entropy, lamb_coef, lamb_coefdiff)
            else:
                reg_val = torch.tensor(0., device=device)
            loss_val = train_loss + lamb * reg_val
            loss_val.backward()
            optimizer.step()


        with torch.no_grad():
            mlp_eval = dataset["val_input"][eval_indices]
            img_eval = dataset["val_img"][eval_indices]
            target_eval = dataset["val_label"][eval_indices]
            eval_loss = loss_fn(model(mlp_eval, img_eval), target_eval)

        # Record results (using square-root of loss similar to KAN.fit)
        results['eval_loss'].append(torch.sqrt(eval_loss.detach()).item())
        results['reg'].append(reg_val.detach().item())

        if metrics is not None:
            for metric in metrics:
                # Here, we assume each metric returns a tensor.
                results[metric.__name__].append(metric().item())


        # Update progress bar.
        if display_metrics is None:
            pbar.set_description("| train_loss: %.2e | eval_loss: %.2e | reg: %.2e |" %
                                 (torch.sqrt(loss_val.detach()).item(),
                                  torch.sqrt(eval_loss.detach()).item(),
                                  reg_val.detach().item()))
        else:
            string = ''
            data = ()
            for metric in display_metrics:
                val_metric = results[metric][-1]
                string += f' {metric}: %.2e |'
                try:
                    results[metric]
                except:
                    raise Exception(f'{metric} not recognized')
                data += (results[metric][-1],)
            pbar.set_description(string % data)

        if val_metric > best_metric:
            best_epoch = step
            best_metric = val_metric
            best_model_state = copy.deepcopy(model.state_dict())

        # Optionally save a figure snapshot.
        if save_fig and step % save_fig_freq == 0:
            save_act_backup = getattr(model.m_kan, "save_act", False)
            model.m_kan.save_act = True
            model.plot(folder=img_folder, in_vars=in_vars, out_vars=out_vars, title=f"Step {step}", beta=beta)
            plt.savefig(os.path.join(img_folder, f"{step}.jpg"), bbox_inches='tight', dpi=200)
            plt.close()
            model.m_kan.save_act = save_act_backup

    # Restore original settings if applicable.
    if old_symbolic_enabled is not None:
        model.m_kan.symbolic_enabled = old_symbolic_enabled
    if hasattr(model.m_kan, "log_history"):
        model.m_kan.log_history('fit')
    print(f"✅ Best validation Accuracy: {best_metric:.4e} at {best_epoch} epoch")
    return best_model_state, results, best_epoch

In [27]:
import traceback

def try_create_model(model_class, attributes, imgs_shape, kan_neurons, kan_grid, cnn_bottleneck_dim, alpha, hidden_dim, embed_dim, num_heads):
    try:
        model = model_class(attributes, imgs_shape, kan_neurons, kan_grid,
                            cnn_bottleneck_dim=cnn_bottleneck_dim, alpha=alpha, hidden_dim=hidden_dim, embed_dim=embed_dim, num_heads=num_heads)
        # Test the model with a sample input
        num_input = torch.randn(4, attributes)
        img_input = torch.randn(4, *imgs_shape)
        output = model(num_input, img_input)
        
        print(f"Successfully created and tested {model_class.__name__}")
        
        return model
    except Exception as e:
        print(f"Error creating or testing {model_class.__name__}:")
        traceback.print_exc()
        return None

In [28]:
def cnn_branch_relevance(model, best_model_state):
    avg_scores = compute_avg_feature_relevance_from_val(
        model=model,
        model_state=best_model_state,
        val_inputs=dataset["test_input"],
        val_imgs=dataset["test_img"],
        coordinate=completed_coordinate,
        x_col=completed_x_col,
        zoom=2
    )
    return plot_feature_relevance_bar(avg_scores)

In [29]:
def train_and_plot_relevance(model_class, kan_neurons, kan_grid, lamb, steps, cnn_bottleneck_dim=-1, alpha=-1, hidden_dim=-1, embed_dim=-1, num_heads=-1, n_kan_len=None, filename=None, opt_col_val=None):
    torch.cuda.empty_cache()
    gc.collect()
    model = try_create_model(model_class, attributes, imgs_shape, kan_neurons=kan_neurons, kan_grid=kan_grid,
                             cnn_bottleneck_dim=cnn_bottleneck_dim, alpha=alpha, hidden_dim=hidden_dim, embed_dim=embed_dim, num_heads=num_heads)
    def train_acc_hybrid():
        rounded = torch.round(torch.round(model(dataset['train_input'], dataset['train_img'])[:,0]))
        clipped = torch.clamp(rounded, min=min_expected, max=max_expected)
        return torch.mean((clipped == dataset['train_label'][:,0]).type(dtype))
    
    def val_acc_hybrid():
        rounded = torch.round(torch.round(model(dataset['val_input'], dataset['val_img'])[:,0]))
        clipped = torch.clamp(rounded, min=min_expected, max=max_expected)
        return torch.mean((clipped == dataset['val_label'][:,0]).type(dtype))

    model_state, results, best_epoch = fit_hybrid_dataloaders(
        model, dataset, opt="LBFGS", steps=steps, lamb=lamb,
        metrics=(train_acc_hybrid, val_acc_hybrid), display_metrics=['train_acc_hybrid', 'val_acc_hybrid'])

    model.load_state_dict(model_state)
    acc = plot_training_ACC(dataset['test_label'][:,0], model(dataset['test_input'], dataset['test_img'])[:,0],
                            results['train_acc_hybrid'], results['val_acc_hybrid'])

    cm = plot_confusion_matrix(dataset['test_label'][:,0], model(dataset['test_input'], dataset['test_img'])[:,0], title="Confusion Matrix")

    if not n_kan_len:
        n_kan_len = kan_neurons
    k_rel, cnn_rel = print_mkan_vs_cnn_relevance(model.final_kan.feature_score, mkan_len=n_kan_len)
    #plot_mkan_vs_cnn_relevance(model.final_kan.feature_score, mkan_len=kan_neurons)
    kan_mrf = ""#plot_sorted_feature_importance(x_col, model.m_kan.feature_score)
    cnn_mrf = ""#cnn_branch_relevance(model, model_state)

    append_row_to_csv(filename, kan_neurons, kan_grid, lamb, opt_col_val, acc.item(), cm, best_epoch, k_rel, cnn_rel, kan_mrf, cnn_mrf)

### Models Class Hybrids

In [30]:
class Model4_1(nn.Module):
    def __init__(self, attributes, imgs_shape, kan_neurons, kan_grid, cnn_bottleneck_dim=-1, alpha=-1, hidden_dim=-1, embed_dim=-1, num_heads=-1, device=device):
        super(Model4_1, self).__init__()

        self.device = device

        # CNN branch
        self.cnn_branch = nn.Sequential(
            nn.Conv2d(imgs_shape[0], 16, kernel_size=3, padding=2),     # out: 16 x 9 x 9
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),                                            # out: 16 x 4 x 4

            nn.Conv2d(16, 32, kernel_size=3, padding=2),
            nn.LayerNorm([32, 13, 13]),
            nn.Sigmoid(),
            nn.Flatten()
        ).to(device)


        # Dummy pass to get flattened size
        self.flat_size = self._get_flat_size(imgs_shape)

        # Bottleneck layer
        self.cnn_bottleneck = nn.Linear(self.flat_size, cnn_bottleneck_dim).to(device)

        # KAN branch
        self.m_kan = KAN(
            width=[attributes, kan_neurons],
            grid=kan_grid,
            k=3,
            seed=SEED,
            device=device
        )

        # Final KAN layer
        self.final_kan = KAN(
            width=[cnn_bottleneck_dim + kan_neurons, 1],
            grid=kan_grid,
            k=3,
            seed=SEED,
            device=device
        )

    def _get_flat_size(self, imgs_shape):
        dummy_input = torch.zeros(1, *imgs_shape, device=self.device)
        x = self.cnn_branch(dummy_input)
        return x.shape[1]

    def get_concat_output(self, mlp_input, cnn_input):
        kan_input = mlp_input.to(self.device)
        cnn_input = cnn_input.to(self.device)

        conv_out = self.cnn_branch(cnn_input)
        cnn_output = self.cnn_bottleneck(conv_out)

        kan_output = self.m_kan(kan_input)

        return torch.cat((kan_output, cnn_output), dim=1)

    def forward(self, mlp_input, cnn_input):
        concat_output = self.get_concat_output(mlp_input, cnn_input)
        return self.final_kan(concat_output)

In [31]:
class Model4_2(nn.Module):
    def __init__(self, attributes, imgs_shape, kan_neurons, kan_grid, cnn_bottleneck_dim=-1, alpha=-1, hidden_dim=-1, embed_dim=-1, num_heads=-1, device=device):
        super(Model4_2, self).__init__()
        # CNN branch
        self.cnn_branch = nn.Sequential(
            nn.Conv2d(imgs_shape[0], 16, kernel_size=3, padding=2),     # out: 16 x 9 x 9
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),                                            # out: 16 x 4 x 4

            nn.Conv2d(16, 32, kernel_size=3, padding=2),
            nn.LayerNorm([32, 13, 13]),
            nn.Sigmoid(),
            nn.Flatten()
        ).to(device)
        # Final KAN layers
        self.m_kan = KAN(
            width=[attributes, kan_neurons],
            grid=kan_grid,
            k=3,
            seed=SEED,
            device=device
        )

        # Calculate the size of the flattened output
        self.flat_size = self._get_flat_size(imgs_shape)

        # Final MLP layers
        self.final_kan = KAN(
            width=[self.flat_size + kan_neurons, 1],
            grid=kan_grid,
            k=3,
            seed=SEED,
            device=device
        )

        self.device = device
        self.alpha = alpha


    def _get_flat_size(self, imgs_shape):
        # Forward pass with dummy input to calculate flat size
        dummy_input = torch.zeros(4, *imgs_shape, device=device)
        x = self.cnn_branch(dummy_input)
        return x.size(1)

    def get_concat_output(self, mlp_input, cnn_input):
        # Ensure inputs are moved to the correct device
        kan_input = mlp_input.to(self.device)
        cnn_input = cnn_input.to(self.device)
        
        cnn_output = self.cnn_branch(cnn_input)  # Process image input
        cnn_output = cnn_output * self.alpha
        kan_output = self.m_kan(kan_input)  # Process numerical input
        
        return torch.cat((kan_output, cnn_output), dim=1)

    
    def forward(self, mlp_input, cnn_input):
        concat_output = self.get_concat_output(mlp_input, cnn_input)
        return self.final_kan(concat_output)

In [32]:
class Model4_3(nn.Module):
    def __init__(self, attributes, imgs_shape, kan_neurons, kan_grid, cnn_bottleneck_dim=-1, alpha=-1, hidden_dim=-1, embed_dim=-1, num_heads=-1, device=device):
        super(Model4_3, self).__init__()
        self.device = device

        # CNN branch
        self.cnn_branch = nn.Sequential(
            nn.Conv2d(imgs_shape[0], 16, kernel_size=3, padding=2),     # out: 16 x 9 x 9
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),                                            # out: 16 x 4 x 4

            nn.Conv2d(16, 32, kernel_size=3, padding=2),
            nn.LayerNorm([32, 13, 13]),
            nn.Sigmoid(),
            nn.Flatten()
        ).to(device)

        self.flat_size = self._get_flat_size(imgs_shape)

        # KAN branch
        self.m_kan = KAN(
            width=[attributes, kan_neurons],
            grid=kan_grid,
            k=3,
            seed=SEED,
            device=device
        )

        # Gating MLP: inputs are concatenated CNN + KAN representations
        self.gate_net = nn.Sequential(
            nn.Linear(self.flat_size + kan_neurons, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # Output ∈ [0,1]
        ).to(device)

        # Final regressor (KAN layer)
        self.final_kan = KAN(
            width=[kan_neurons + self.flat_size, 1],
            grid=kan_grid,
            k=3,
            seed=SEED,
            device=device
        )

    def _get_flat_size(self, imgs_shape):
        dummy_input = torch.zeros(4, *imgs_shape, device=self.device)
        x = self.cnn_branch(dummy_input)
        return x.size(1)

    def get_concat_output(self, mlp_input, cnn_input):
        mlp_input = mlp_input.to(self.device)
        cnn_input = cnn_input.to(self.device)

        kan_out = self.m_kan(mlp_input)                  # shape: (B, kan_neurons)
        cnn_out = self.cnn_branch(cnn_input)             # shape: (B, cnn_flat)

        concat = torch.cat((kan_out, cnn_out), dim=1)    # For gating
        alpha = self.gate_net(concat)                    # shape: (B, 1)

        gated_kan = (1 - alpha) * kan_out                # shape: (B, kan_neurons)
        gated_cnn = alpha * cnn_out                      # shape: (B, cnn_flat)

        return torch.cat((gated_kan, gated_cnn), dim=1)  # shape: (B, total)

    def forward(self, mlp_input, cnn_input):
        fused = self.get_concat_output(mlp_input, cnn_input)
        return self.final_kan(fused)


In [33]:
class Model4_4(nn.Module):
    def __init__(self, attributes, imgs_shape, kan_neurons, kan_grid, cnn_bottleneck_dim=-1, alpha=-1, hidden_dim=-1, embed_dim=-1, num_heads=-1, device=device):
        super(Model4_4, self).__init__()
        self.device = device

        # CNN branch
        self.cnn_branch = nn.Sequential(
            nn.Conv2d(imgs_shape[0], 16, kernel_size=3, padding=2),     # out: 16 x 9 x 9
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),                                            # out: 16 x 4 x 4

            nn.Conv2d(16, 32, kernel_size=3, padding=2),
            nn.LayerNorm([32, 13, 13]),
            nn.Sigmoid(),
            nn.Flatten()
        ).to(device)

        self.flat_size = self._get_flat_size(imgs_shape)

        # KAN Branch
        self.m_kan = KAN(
            width=[attributes, kan_neurons],
            grid=kan_grid,
            k=3,
            seed=SEED,
            device=device
        )

        # Linear projections for Q, K, V
        self.query_proj = nn.Linear(kan_neurons, embed_dim).to(device)
        self.key_proj = nn.Linear(self.flat_size, embed_dim).to(device)
        self.value_proj = nn.Linear(self.flat_size, embed_dim).to(device)

        # Attention module
        self.attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True).to(device)

        # Final regression layer (KAN again)
        self.final_kan = KAN(
            width=[embed_dim, 1], grid=kan_grid, k=3, seed=SEED, device=device
        )
    
    def get_concat_output(self, mlp_input, cnn_input):
        # Get KAN and CNN outputs
        kan_out = self.m_kan(mlp_input.to(self.device))  # [B, D_kan]
        cnn_out = self.cnn_branch(cnn_input.to(self.device))  # [B, D_cnn]

        # Project into Q, K, V space
        Q = self.query_proj(kan_out).unsqueeze(1)  # [B, 1, E]
        K = self.key_proj(cnn_out).unsqueeze(1)    # [B, 1, E]
        V = self.value_proj(cnn_out).unsqueeze(1)  # [B, 1, E]
        # Cross-attention: KAN attends to CNN
        attn_out, _ = self.attn(Q, K, V)  # [B, 1, E]
        attn_out = attn_out.squeeze(1)   # [B, E]

        return attn_out



    def _get_flat_size(self, imgs_shape):
        dummy_input = torch.zeros(1, *imgs_shape, device=self.device)
        return self.cnn_branch(dummy_input).shape[1]

    def forward(self, mlp_input, cnn_input):
        attn_out = self.get_concat_output(mlp_input, cnn_input)

        return self.final_kan(attn_out)

# Load Dataset and Images

In [34]:
X_train, y_train = load_and_clean('N_train.npy', 'y_train.npy',x_col, target_col)
X_test, y_test   = load_and_clean('N_test.npy',  'y_test.npy', x_col, target_col)
X_val, y_val     = load_and_clean('N_val.npy',   'y_val.npy', x_col, target_col)

In [35]:
# Get the shape of the dataframe
num_columns = X_train.shape[1]

# Calculate number of columns - 1
columns_minus_one = num_columns - 1

# Calculate the square root for image size
image_size = math.ceil(math.sqrt(num_columns))
print(image_size)

4


In [36]:
dataset_name = 'wall-robot-navigation'
#Select the model and the parameters
problem_type = "supervised"
pixel=20
image_model = TINTO(problem=problem_type, blur=False, pixels=pixel, random_seed=SEED)
name = f"TINTO"

#Define the dataset path and the folder where the images will be saved
images_folder = f"HyNNImages/Regression/{dataset_name}/images_{dataset_name}_{name}"

In [37]:
train_loader, val_loader, test_loader, attributes, imgs_shape = load_and_preprocess_data(
    X_train, y_train, X_test, y_test, X_val, y_val,
    image_model=image_model,
    problem_type=problem_type,
    batch_size=16
)

The images are already generated
The images are already generated
The images are already generated
Images shape:  (3, 8, 8)
Attributes:  10


In [39]:
# def get_feature_coordinates_from_model(model, with_names=False):
#     """
#     Extracts the (row, col) positions of features from a fitted REFINED model.
#     Parameters
#     ----------
#     model : REFINED object
#         The fitted REFINED image_model.
#     with_names : bool, optional
#         If True, include feature names in the output.
#     Returns
#     -------
#     feature_to_position : dict
#         If with_names=False:
#             {feature_idx: (row, col)}
#         If with_names=True:
#             {feature_idx: {'name': str, 'position': (row, col)}}
#     """
#     if not hasattr(model, 'map_in_int_MDS') or not hasattr(model, 'gene_names_MDS'):
#         raise RuntimeError("The REFINED algorithm has not been fitted yet. Please call `fit()` first.")
#     feature_to_position = {}
#     grid = model.map_in_int_MDS
#     for row in range(grid.shape[0]):
#         for col in range(grid.shape[1]):
#             feat_idx = grid[row, col]
#             if feat_idx != -1:
#                 if with_names:
#                     feature_to_position[int(feat_idx)] = {
#                         "name": model.gene_names_MDS[int(feat_idx)],
#                         "position": (row, col)
#                     }
#                 else:
#                     feature_to_position[int(feat_idx)] = (row, col)
#     return feature_to_position

# coords = get_feature_coordinates_from_model(image_model, with_names=True)
# for idx, info in coords.items():
#     print(f"Feature {idx} ({info['name']}) → Position {info['position']}")

In [40]:
# x_col

Feature 0 (0) → Position (1, 2)
Feature 1 (1) → Position (2, 3)
Feature 2 (2) → Position (3, 2)
Feature 3 (3) → Position (2, 2)
Feature 4 (4) → Position (1, 0)
Feature 5 (5) → Position (0, 3)
Feature 6 (6) → Position (1, 3)
Feature 7 (7) → Position (3, 3)
Feature 8 (8) → Position (3, 1)
Feature 9 (9) → Position (0, 0)


In [42]:
# Combine dataloaders into tensors.
train_mlp, train_img, train_target = combine_loader(train_loader)
val_mlp, val_img, val_target = combine_loader(val_loader)
test_mlp, test_img, test_target = combine_loader(test_loader)

dataset = {
    "train_input": train_mlp.to(device),
    "train_img": train_img.to(device),
    "train_label": train_target.to(device),
    "val_input": val_mlp.to(device),
    "val_img": val_img.to(device),
    "val_label": val_target.to(device),
    "test_input": test_mlp.to(device),
    "test_img": test_img.to(device),
    "test_label": test_target.to(device),
}

In [43]:
# Print the shapes of the tensors
print("Train data shape:", dataset['train_input'].shape)
print("Train target shape:", dataset['train_label'].shape)
print("Test data shape:", dataset['test_input'].shape)
print("Test target shape:", dataset['test_label'].shape)
print("Validation data shape:", dataset['val_input'].shape)
print("Validation target shape:", dataset['val_label'].shape)

Train data shape: torch.Size([1173, 10])
Train target shape: torch.Size([1173, 1])
Test data shape: torch.Size([367, 10])
Test target shape: torch.Size([367, 1])
Validation data shape: torch.Size([294, 10])
Validation target shape: torch.Size([294, 1])


# Set Files Name

In [45]:
filename_1=f'{dataset_name}_tinto_Concat_Op1.csv'
filename_2=f'{dataset_name}_tinto_Concat_Op2.csv'
filename_3=f'{dataset_name}_tinto_Concat_Op3.csv'
filename_4=f'{dataset_name}_tinto_Concat_Op4.csv'

In [46]:
columns_opt1 = 'cnn_bottleneck_dim'
columns_opt2 = 'alpha'
columns_opt3 = 'hidden_dim'
columns_opt4 = 'embed_dim, num_heads'

# Option 1: Concat KAN with (CNN with dense layer to reduce output_size)

In [None]:
create_csv_with_header(filename_1, columns_opt1)

In [48]:
print("------------------------------ kan_neurons=12, kan_grid=8, lamb=0.001 ------------------------------")
for cnn_bottleneck_dim in [1, 3, 6, 9, 12, 15, 18, 21, 24, 27]:
    print(f"cnn_bottleneck_dim: {cnn_bottleneck_dim}")
    train_and_plot_relevance(Model4_1, kan_neurons=12, kan_grid=8, lamb=0.001, steps=100, 
                             cnn_bottleneck_dim=cnn_bottleneck_dim, filename=filename_1, opt_col_val=cnn_bottleneck_dim)

print("------------------------------ kan_neurons=6, kan_grid=7, lamb=1e-05 ------------------------------")
for cnn_bottleneck_dim in [1, 2, 3, 6, 9, 12, 15]:
    print(f"cnn_bottleneck_dim: {cnn_bottleneck_dim}")
    train_and_plot_relevance(Model4_1, kan_neurons=6, kan_grid=7, lamb=1e-05, steps=100, 
                             cnn_bottleneck_dim=cnn_bottleneck_dim, filename=filename_1, opt_col_val=cnn_bottleneck_dim)


print("------------------------------ kan_neurons=8, kan_grid=8, lamb=0.001 ------------------------------")
for cnn_bottleneck_dim in [1, 2, 3, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24]:
    print(f"cnn_bottleneck_dim: {cnn_bottleneck_dim}")
    train_and_plot_relevance(Model4_1, kan_neurons=8, kan_grid=8, lamb=0.001, steps=100, 
                             cnn_bottleneck_dim=cnn_bottleneck_dim, filename=filename_1, opt_col_val=cnn_bottleneck_dim)


print("------------------------------ kan_neurons=3, kan_grid=7, lamb=0.001 ------------------------------")
for cnn_bottleneck_dim in [1, 2, 3, 4, 6, 8, 9]:
    print(f"cnn_bottleneck_dim: {cnn_bottleneck_dim}")
    train_and_plot_relevance(Model4_1, kan_neurons=3, kan_grid=7, lamb=0.001, steps=100, 
                             cnn_bottleneck_dim=cnn_bottleneck_dim, filename=filename_1, opt_col_val=cnn_bottleneck_dim)

------------------------------ kan_neurons=3, kan_grid=7, lamb=1e-06 ------------------------------
cnn_bottleneck_dim: 1
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 7.89e-01 | val_acc_hybrid: 6.84e-01 |: 100%|█████| 50/50 [00:16<00:00,  3.03it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3129e-01 at 2 epoch
tensor(0.7139, device='cuda:0')
[[126  57]
 [ 48 136]]
M_KAN Relevance: 0.2593177855014801
CNN Relevance: 0.7406821846961975
cnn_bottleneck_dim: 2
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 7.90e-01 | val_acc_hybrid: 6.77e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.23it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 8 epoch
tensor(0.7248, device='cuda:0')
[[132  51]
 [ 50 134]]
M_KAN Relevance: 0.29896867275238037
CNN Relevance: 0.7010313272476196
cnn_bottleneck_dim: 3
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 7.86e-01 | val_acc_hybrid: 6.84e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.24it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 1 epoch
tensor(0.7084, device='cuda:0')
[[125  58]
 [ 49 135]]
M_KAN Relevance: 0.2637682557106018
CNN Relevance: 0.7362317442893982
cnn_bottleneck_dim: 6
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 7.41e-01 | val_acc_hybrid: 6.70e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.20it/s]


saving model version 0.1
✅ Best validation Accuracy: 6.9048e-01 at 40 epoch
tensor(0.6894, device='cuda:0')
[[120  63]
 [ 51 133]]
M_KAN Relevance: 0.30125123262405396
CNN Relevance: 0.698748767375946
cnn_bottleneck_dim: 7
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 7.94e-01 | val_acc_hybrid: 6.84e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.23it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 2 epoch
tensor(0.7302, device='cuda:0')
[[131  52]
 [ 47 137]]
M_KAN Relevance: 0.24447648227214813
CNN Relevance: 0.7555235028266907
cnn_bottleneck_dim: 9
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 7.95e-01 | val_acc_hybrid: 7.24e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.26it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.4830e-01 at 36 epoch
tensor(0.6839, device='cuda:0')
[[131  52]
 [ 64 120]]
M_KAN Relevance: 0.2285248041152954
CNN Relevance: 0.7714751958847046
cnn_bottleneck_dim: 10
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 7.82e-01 | val_acc_hybrid: 6.87e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.25it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2789e-01 at 5 epoch
tensor(0.7057, device='cuda:0')
[[123  60]
 [ 48 136]]
M_KAN Relevance: 0.15092574059963226
CNN Relevance: 0.8490742444992065
cnn_bottleneck_dim: 12
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 7.76e-01 | val_acc_hybrid: 6.87e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.26it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3810e-01 at 7 epoch
tensor(0.6975, device='cuda:0')
[[127  56]
 [ 55 129]]
M_KAN Relevance: 0.229435995221138
CNN Relevance: 0.7705640196800232
------------------------------ kan_neurons=5, kan_grid=7, lamb=0.0001 ------------------------------
cnn_bottleneck_dim: 1
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 8.06e-01 | val_acc_hybrid: 6.53e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.30it/s]


saving model version 0.1
✅ Best validation Accuracy: 6.8367e-01 at 5 epoch
tensor(0.7112, device='cuda:0')
[[136  47]
 [ 59 125]]
M_KAN Relevance: 0.9483941793441772
CNN Relevance: 0.05160585790872574
cnn_bottleneck_dim: 2
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 8.00e-01 | val_acc_hybrid: 7.01e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.26it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.4490e-01 at 12 epoch
tensor(0.7057, device='cuda:0')
[[133  50]
 [ 58 126]]
M_KAN Relevance: 0.31496790051460266
CNN Relevance: 0.6850321292877197
cnn_bottleneck_dim: 4
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 7.78e-01 | val_acc_hybrid: 6.84e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.26it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3129e-01 at 6 epoch
tensor(0.7248, device='cuda:0')
[[132  51]
 [ 50 134]]
M_KAN Relevance: 0.20957835018634796
CNN Relevance: 0.7904216647148132
cnn_bottleneck_dim: 6
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 7.84e-01 | val_acc_hybrid: 6.94e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.19it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3129e-01 at 2 epoch
tensor(0.7302, device='cuda:0')
[[130  53]
 [ 46 138]]
M_KAN Relevance: 0.31841519474983215
CNN Relevance: 0.6815847754478455
cnn_bottleneck_dim: 8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_1


 train_acc_hybrid: 8.19e-01 | val_acc_hybrid: 6.60e-01 |: 100%|█████| 50/50 [00:15<00:00,  3.31it/s]

saving model version 0.1
✅ Best validation Accuracy: 7.1429e-01 at 13 epoch
tensor(0.7084, device='cuda:0')
[[129  54]
 [ 53 131]]
M_KAN Relevance: 0.24121418595314026
CNN Relevance: 0.7587857842445374





# Option 2: Multiply CNN output by factor

In [None]:
create_csv_with_header(filename_2, columns_opt2)

In [49]:
print("------------------------------ kan_neurons=12, kan_grid=8, lamb=0.001 ------------------------------")
for alpha in [.9,.8,.75,.7,.6,.5,.4,.3,.25,.2,.1,.05,.01]:
    print(f"alpha: {alpha}")
    train_and_plot_relevance(Model4_2, kan_neurons=12, kan_grid=8, lamb=0.001, steps=120, 
                             alpha=alpha, filename=filename_2, opt_col_val=alpha)

print("------------------------------ kan_neurons=6, kan_grid=7, lamb=1e-05 ------------------------------")
for alpha in [.9,.8,.75,.7,.6,.5,.4,.3,.25,.2,.1,.05,.01]:
    print(f"alpha: {alpha}")
    train_and_plot_relevance(Model4_2, kan_neurons=6, kan_grid=7, lamb=1e-05, steps=120, 
                             alpha=alpha, filename=filename_2, opt_col_val=alpha)

print("------------------------------ kan_neurons=8, kan_grid=8, lamb=0.001 ------------------------------")
for alpha in [.9,.8,.75,.7,.6,.5,.4,.3,.25,.2,.1,.05,.01]:
    print(f"alpha: {alpha}")
    train_and_plot_relevance(Model4_2, kan_neurons=8, kan_grid=8, lamb=0.001, steps=120, 
                             alpha=alpha, filename=filename_2, opt_col_val=alpha)

print("------------------------------ kan_neurons=3, kan_grid=7, lamb=0.001 ------------------------------")
for alpha in [.9,.8,.75,.7,.6,.5,.4,.3,.25,.2,.1,.05,.01]:
    print(f"alpha: {alpha}")
    train_and_plot_relevance(Model4_2, kan_neurons=3, kan_grid=7, lamb=0.001, steps=120, 
                             alpha=alpha, filename=filename_2, opt_col_val=alpha)

------------------------------ kan_neurons=3, kan_grid=7, lamb=1e-06 ------------------------------
alpha: 0.9
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.70e-01 | val_acc_hybrid: 6.22e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.84it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3469e-01 at 1 epoch
tensor(0.7166, device='cuda:0')
[[126  57]
 [ 47 137]]
M_KAN Relevance: 0.0019358622375875711
CNN Relevance: 0.9980641603469849
alpha: 0.8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.60e-01 | val_acc_hybrid: 6.36e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.82it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3129e-01 at 1 epoch
tensor(0.7248, device='cuda:0')
[[124  59]
 [ 42 142]]
M_KAN Relevance: 0.0020198177080601454
CNN Relevance: 0.9979802370071411
alpha: 0.75
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.56e-01 | val_acc_hybrid: 6.19e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.82it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2789e-01 at 1 epoch
tensor(0.7166, device='cuda:0')
[[123  60]
 [ 44 140]]
M_KAN Relevance: 0.002253951271995902
CNN Relevance: 0.9977460503578186
alpha: 0.7
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.66e-01 | val_acc_hybrid: 6.43e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.84it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2789e-01 at 2 epoch
tensor(0.7384, device='cuda:0')
[[137  46]
 [ 50 134]]
M_KAN Relevance: 0.0033220998011529446
CNN Relevance: 0.9966778755187988
alpha: 0.6
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.76e-01 | val_acc_hybrid: 5.65e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.83it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2789e-01 at 4 epoch
tensor(0.7330, device='cuda:0')
[[134  49]
 [ 49 135]]
M_KAN Relevance: 0.0034475724678486586
CNN Relevance: 0.9965524673461914
alpha: 0.5
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.58e-01 | val_acc_hybrid: 6.29e-01 |: 100%|█████| 60/60 [00:33<00:00,  1.81it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 27 epoch
tensor(0.6894, device='cuda:0')
[[126  57]
 [ 57 127]]
M_KAN Relevance: 0.0015984050696715713
CNN Relevance: 0.9984015822410583
alpha: 0.4
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 9.20e-01 | val_acc_hybrid: 6.02e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.87it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1769e-01 at 7 epoch
tensor(0.7248, device='cuda:0')
[[140  43]
 [ 58 126]]
M_KAN Relevance: 0.0027699486818164587
CNN Relevance: 0.997230052947998
alpha: 0.3
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.66e-01 | val_acc_hybrid: 5.78e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.82it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1429e-01 at 5 epoch
tensor(0.7084, device='cuda:0')
[[133  50]
 [ 57 127]]
M_KAN Relevance: 0.0034978354815393686
CNN Relevance: 0.996502161026001
alpha: 0.25
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.82e-01 | val_acc_hybrid: 5.75e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.84it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1429e-01 at 3 epoch
tensor(0.7166, device='cuda:0')
[[132  51]
 [ 53 131]]
M_KAN Relevance: 0.0038090243469923735
CNN Relevance: 0.9961909651756287
alpha: 0.2
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.96e-01 | val_acc_hybrid: 5.99e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.84it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 7 epoch
tensor(0.7139, device='cuda:0')
[[136  47]
 [ 58 126]]
M_KAN Relevance: 0.004009689204394817
CNN Relevance: 0.9959903359413147
alpha: 0.1
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 9.15e-01 | val_acc_hybrid: 5.61e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.86it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.0748e-01 at 7 epoch
tensor(0.7193, device='cuda:0')
[[133  50]
 [ 53 131]]
M_KAN Relevance: 0.004993771202862263
CNN Relevance: 0.9950062036514282
alpha: 0.05
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 9.01e-01 | val_acc_hybrid: 5.78e-01 |: 100%|█████| 60/60 [00:33<00:00,  1.81it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.0068e-01 at 7 epoch
tensor(0.7248, device='cuda:0')
[[136  47]
 [ 54 130]]
M_KAN Relevance: 0.0055891661904752254
CNN Relevance: 0.9944108128547668
alpha: 0.01
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.73e-01 | val_acc_hybrid: 6.22e-01 |: 100%|█████| 60/60 [00:33<00:00,  1.79it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2109e-01 at 5 epoch
tensor(0.7057, device='cuda:0')
[[130  53]
 [ 55 129]]
M_KAN Relevance: 0.016921790316700935
CNN Relevance: 0.9830781817436218
------------------------------ kan_neurons=5, kan_grid=7, lamb=0.0001 ------------------------------
alpha: 0.9
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.73e-01 | val_acc_hybrid: 5.58e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.82it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3129e-01 at 1 epoch
tensor(0.7330, device='cuda:0')
[[130  53]
 [ 45 139]]
M_KAN Relevance: 0.013979962095618248
CNN Relevance: 0.9860200881958008
alpha: 0.8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.74e-01 | val_acc_hybrid: 6.26e-01 |: 100%|█████| 60/60 [00:33<00:00,  1.80it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2109e-01 at 3 epoch
tensor(0.7221, device='cuda:0')
[[131  52]
 [ 50 134]]
M_KAN Relevance: 0.02135593444108963
CNN Relevance: 0.978644073009491
alpha: 0.75
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.58e-01 | val_acc_hybrid: 6.60e-01 |: 100%|█████| 60/60 [00:33<00:00,  1.80it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1769e-01 at 1 epoch
tensor(0.7030, device='cuda:0')
[[125  58]
 [ 51 133]]
M_KAN Relevance: 0.013213453814387321
CNN Relevance: 0.9867866039276123
alpha: 0.7
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 9.00e-01 | val_acc_hybrid: 6.46e-01 |: 100%|█████| 60/60 [00:31<00:00,  1.88it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1088e-01 at 1 epoch
tensor(0.7139, device='cuda:0')
[[126  57]
 [ 48 136]]
M_KAN Relevance: 0.01624481752514839
CNN Relevance: 0.9837552309036255
alpha: 0.6
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.25e-01 | val_acc_hybrid: 6.63e-01 |: 100%|█████| 60/60 [00:33<00:00,  1.80it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2109e-01 at 11 epoch
tensor(0.6948, device='cuda:0')
[[116  67]
 [ 45 139]]
M_KAN Relevance: 0.014568538405001163
CNN Relevance: 0.9854314923286438
alpha: 0.5
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.93e-01 | val_acc_hybrid: 6.53e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.86it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1769e-01 at 8 epoch
tensor(0.7084, device='cuda:0')
[[130  53]
 [ 54 130]]
M_KAN Relevance: 0.01920514740049839
CNN Relevance: 0.9807949066162109
alpha: 0.4
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.63e-01 | val_acc_hybrid: 6.22e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.82it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1429e-01 at 6 epoch
tensor(0.7112, device='cuda:0')
[[133  50]
 [ 56 128]]
M_KAN Relevance: 0.021625559777021408
CNN Relevance: 0.9783744215965271
alpha: 0.3
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 9.03e-01 | val_acc_hybrid: 5.95e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.83it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.0408e-01 at 4 epoch
tensor(0.7221, device='cuda:0')
[[137  46]
 [ 56 128]]
M_KAN Relevance: 0.03119996003806591
CNN Relevance: 0.9688000679016113
alpha: 0.25
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.58e-01 | val_acc_hybrid: 6.63e-01 |: 100%|█████| 60/60 [00:33<00:00,  1.81it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1429e-01 at 13 epoch
tensor(0.7275, device='cuda:0')
[[131  52]
 [ 48 136]]
M_KAN Relevance: 0.012129932641983032
CNN Relevance: 0.9878700375556946
alpha: 0.2
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.36e-01 | val_acc_hybrid: 6.84e-01 |: 100%|█████| 60/60 [00:33<00:00,  1.80it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 23 epoch
tensor(0.7003, device='cuda:0')
[[128  55]
 [ 55 129]]
M_KAN Relevance: 0.01939028687775135
CNN Relevance: 0.9806097149848938
alpha: 0.1
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.68e-01 | val_acc_hybrid: 6.53e-01 |: 100%|█████| 60/60 [00:33<00:00,  1.81it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1429e-01 at 14 epoch
tensor(0.6975, device='cuda:0')
[[136  47]
 [ 64 120]]
M_KAN Relevance: 0.03770007938146591
CNN Relevance: 0.9622999429702759
alpha: 0.05
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.56e-01 | val_acc_hybrid: 6.02e-01 |: 100%|█████| 60/60 [00:33<00:00,  1.81it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 2 epoch
tensor(0.7057, device='cuda:0')
[[130  53]
 [ 55 129]]
M_KAN Relevance: 0.10377513617277145
CNN Relevance: 0.8962247967720032
alpha: 0.01
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_2


 train_acc_hybrid: 8.97e-01 | val_acc_hybrid: 6.56e-01 |: 100%|█████| 60/60 [00:32<00:00,  1.84it/s]

saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 1 epoch
tensor(0.6948, device='cuda:0')
[[106  77]
 [ 35 149]]
M_KAN Relevance: 0.015258121304214
CNN Relevance: 0.9847419261932373





# Option 3: Dynamic factor

In [None]:
create_csv_with_header(filename_3, columns_opt3)

In [50]:
print("------------------------------ kan_neurons=12, kan_grid=8, lamb=0.001 ------------------------------")
for hidden_dim in [128, 64, 32, 16, 8]:
    print(f"hidden_dim: {hidden_dim}")
    train_and_plot_relevance(Model4_3, kan_neurons=12, kan_grid=8, lamb=0.001, steps=150, 
                             hidden_dim=hidden_dim, filename=filename_3, opt_col_val=hidden_dim)

print("------------------------------ kan_neurons=6, kan_grid=7, lamb=1e-05 ------------------------------")
for hidden_dim in [128, 64, 32, 16, 8]:
    print(f"hidden_dim: {hidden_dim}")
    train_and_plot_relevance(Model4_3, kan_neurons=6, kan_grid=7, lamb=1e-05, steps=150, 
                             hidden_dim=hidden_dim, filename=filename_3, opt_col_val=hidden_dim)
"Hybrid3
cnn_blocks=2"
0.96153
60
"width=[24, 11], 
grid=3, 
lamb=0.001"

print("------------------------------ kan_neurons=8, kan_grid=8, lamb=0.001 ------------------------------")
for hidden_dim in [128, 64, 32, 16, 8]:
    print(f"hidden_dim: {hidden_dim}")
    train_and_plot_relevance(Model4_3, kan_neurons=8, kan_grid=8, lamb=0.001, steps=150, 
                             hidden_dim=hidden_dim, filename=filename_3, opt_col_val=hidden_dim)

print("------------------------------ kan_neurons=3, kan_grid=7, lamb=0.001 ------------------------------")
for hidden_dim in [128, 64, 32, 16, 8]:
    print(f"hidden_dim: {hidden_dim}")
    train_and_plot_relevance(Model4_3, kan_neurons=3, kan_grid=7, lamb=0.001, steps=150, 
                             hidden_dim=hidden_dim, filename=filename_3, opt_col_val=hidden_dim)

------------------------------ kan_neurons=3, kan_grid=7, lamb=1e-06 ------------------------------
hidden_dim: 128
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_3


 train_acc_hybrid: 7.95e-01 | val_acc_hybrid: 6.60e-01 |: 100%|█████| 70/70 [00:41<00:00,  1.67it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1769e-01 at 4 epoch
tensor(0.7139, device='cuda:0')
[[127  56]
 [ 49 135]]
M_KAN Relevance: 2.2801490558777004e-05
CNN Relevance: 0.999977171421051
hidden_dim: 64
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_3


 train_acc_hybrid: 8.08e-01 | val_acc_hybrid: 6.77e-01 |: 100%|█████| 70/70 [00:42<00:00,  1.66it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2109e-01 at 4 epoch
tensor(0.7193, device='cuda:0')
[[133  50]
 [ 53 131]]
M_KAN Relevance: 0.0012987710069864988
CNN Relevance: 0.9987012147903442
hidden_dim: 32
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_3


 train_acc_hybrid: 8.34e-01 | val_acc_hybrid: 6.22e-01 |: 100%|█████| 70/70 [00:41<00:00,  1.69it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2109e-01 at 1 epoch
tensor(0.7057, device='cuda:0')
[[126  57]
 [ 51 133]]
M_KAN Relevance: 0.0003959068562835455
CNN Relevance: 0.9996040463447571
hidden_dim: 16
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_3


 train_acc_hybrid: 8.73e-01 | val_acc_hybrid: 6.70e-01 |: 100%|█████| 70/70 [00:40<00:00,  1.75it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3810e-01 at 10 epoch
tensor(0.7193, device='cuda:0')
[[134  49]
 [ 54 130]]
M_KAN Relevance: 0.001716263359412551
CNN Relevance: 0.9982837438583374
hidden_dim: 8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_3


 train_acc_hybrid: 0.00e+00 | val_acc_hybrid: 0.00e+00 |: 100%|█████| 70/70 [00:42<00:00,  1.64it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 4 epoch
tensor(0.7357, device='cuda:0')
[[132  51]
 [ 46 138]]
M_KAN Relevance: 0.0
CNN Relevance: 1.0
------------------------------ kan_neurons=5, kan_grid=7, lamb=0.0001 ------------------------------
hidden_dim: 128
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_3


 train_acc_hybrid: 8.06e-01 | val_acc_hybrid: 6.90e-01 |: 100%|█████| 70/70 [00:41<00:00,  1.68it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3129e-01 at 23 epoch
tensor(0.7166, device='cuda:0')
[[133  50]
 [ 54 130]]
M_KAN Relevance: 0.00813452061265707
CNN Relevance: 0.9918654561042786
hidden_dim: 64
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_3


 train_acc_hybrid: 8.30e-01 | val_acc_hybrid: 6.63e-01 |: 100%|█████| 70/70 [00:40<00:00,  1.74it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1769e-01 at 11 epoch
tensor(0.7221, device='cuda:0')
[[133  50]
 [ 52 132]]
M_KAN Relevance: 0.005971659906208515
CNN Relevance: 0.9940283298492432
hidden_dim: 32
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_3


 train_acc_hybrid: 7.95e-01 | val_acc_hybrid: 6.97e-01 |: 100%|█████| 70/70 [00:42<00:00,  1.66it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3469e-01 at 33 epoch
tensor(0.7193, device='cuda:0')
[[141  42]
 [ 61 123]]
M_KAN Relevance: 0.006340572610497475
CNN Relevance: 0.9936594367027283
hidden_dim: 16
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_3


 train_acc_hybrid: 8.36e-01 | val_acc_hybrid: 6.53e-01 |: 100%|█████| 70/70 [00:40<00:00,  1.75it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1769e-01 at 4 epoch
tensor(0.7357, device='cuda:0')
[[133  50]
 [ 47 137]]
M_KAN Relevance: 0.006724129896610975
CNN Relevance: 0.9932758808135986
hidden_dim: 8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_3


 train_acc_hybrid: 8.76e-01 | val_acc_hybrid: 6.50e-01 |: 100%|█████| 70/70 [00:40<00:00,  1.72it/s]

saving model version 0.1
✅ Best validation Accuracy: 7.2109e-01 at 10 epoch
tensor(0.7302, device='cuda:0')
[[135  48]
 [ 51 133]]
M_KAN Relevance: 0.008050317876040936
CNN Relevance: 0.9919496774673462





# Opt4: MultiHead Attention

In [None]:
# create_csv_with_header(filename_4, columns_opt4)

In [51]:
print("------------------------------ kan_neurons=12, kan_grid=8, lamb=0.001 ------------------------------")
for embed_dim in [64, 32, 16]:
    for num_head in [2, 4 , 8]:
        print(f"embed_dim: {embed_dim}, num_head:{num_head}")
        train_and_plot_relevance(Model4_4, kan_neurons=12, kan_grid=8, lamb=0.001, steps=180, 
                                 embed_dim=embed_dim, num_heads=num_head, filename=filename_4, opt_col_val=f'{embed_dim}, {num_head}')

for embed_dim in [48, 24, 12]:
    for num_head in [6]:
        print(f"embed_dim: {embed_dim}, num_head:{num_head}")
        train_and_plot_relevance(Model4_4, kan_neurons=12, kan_grid=8, lamb=0.001, steps=180,  
                                 embed_dim=embed_dim, num_heads=num_head, filename=filename_4, opt_col_val=f'{embed_dim}, {num_head}')

print("------------------------------ kan_neurons=6, kan_grid=7, lamb=1e-05 ------------------------------")
for embed_dim in [64, 32, 16]:
    for num_head in [2, 4 , 8]:
        print(f"embed_dim: {embed_dim}, num_head:{num_head}")
        train_and_plot_relevance(Model4_4, kan_neurons=6, kan_grid=7, lamb=1e-05, steps=180, 
                                 embed_dim=embed_dim, num_heads=num_head, filename=filename_4, opt_col_val=f'{embed_dim}, {num_head}')

for embed_dim in [48, 24, 12]:
    for num_head in [6]:
        print(f"embed_dim: {embed_dim}, num_head:{num_head}")
        train_and_plot_relevance(Model4_4, kan_neurons=6, kan_grid=7, lamb=1e-05, steps=180,  
                                 embed_dim=embed_dim, num_heads=num_head, filename=filename_4, opt_col_val=f'{embed_dim}, {num_head}')


print("------------------------------ kan_neurons=8, kan_grid=8, lamb=0.001 ------------------------------")
for embed_dim in [64, 32, 16]:
    for num_head in [2, 4 , 8]:
        print(f"embed_dim: {embed_dim}, num_head:{num_head}")
        train_and_plot_relevance(Model4_4, kan_neurons=8, kan_grid=8, lamb=0.001, steps=180, 
                                 embed_dim=embed_dim, num_heads=num_head, filename=filename_4, opt_col_val=f'{embed_dim}, {num_head}')

for embed_dim in [48, 24, 12]:
    for num_head in [6]:
        print(f"embed_dim: {embed_dim}, num_head:{num_head}")
        train_and_plot_relevance(Model4_4, kan_neurons=8, kan_grid=8, lamb=0.001, steps=180,  
                                 embed_dim=embed_dim, num_heads=num_head, filename=filename_4, opt_col_val=f'{embed_dim}, {num_head}')


print("------------------------------ kan_neurons=3, kan_grid=7, lamb=0.001 ------------------------------")
for embed_dim in [64, 32, 16]:
    for num_head in [2, 4 , 8]:
        print(f"embed_dim: {embed_dim}, num_head:{num_head}")
        train_and_plot_relevance(Model4_4, kan_neurons=3, kan_grid=7, lamb=0.001, steps=180, 
                                 embed_dim=embed_dim, num_heads=num_head, filename=filename_4, opt_col_val=f'{embed_dim}, {num_head}')

for embed_dim in [48, 24, 12]:
    for num_head in [6]:
        print(f"embed_dim: {embed_dim}, num_head:{num_head}")
        train_and_plot_relevance(Model4_4, kan_neurons=3, kan_grid=7, lamb=0.001, steps=180,  
                                 embed_dim=embed_dim, num_heads=num_head, filename=filename_4, opt_col_val=f'{embed_dim}, {num_head}')


------------------------------ kan_neurons=3, kan_grid=7, lamb=1e-06 ------------------------------
embed_dim: 64, num_head:2
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.44e-01 | val_acc_hybrid: 7.07e-01 |: 100%|█████| 90/90 [00:29<00:00,  3.04it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.4150e-01 at 8 epoch
tensor(0.7084, device='cuda:0')
[[121  62]
 [ 45 139]]
M_KAN Relevance: 0.08535192161798477
CNN Relevance: 0.9146481156349182
embed_dim: 64, num_head:4
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.44e-01 | val_acc_hybrid: 7.07e-01 |: 100%|█████| 90/90 [00:29<00:00,  3.05it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.4150e-01 at 8 epoch
tensor(0.7084, device='cuda:0')
[[121  62]
 [ 45 139]]
M_KAN Relevance: 0.08535192161798477
CNN Relevance: 0.9146481156349182
embed_dim: 64, num_head:8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.44e-01 | val_acc_hybrid: 7.07e-01 |: 100%|█████| 90/90 [00:29<00:00,  3.06it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.4150e-01 at 8 epoch
tensor(0.7084, device='cuda:0')
[[121  62]
 [ 45 139]]
M_KAN Relevance: 0.08535192161798477
CNN Relevance: 0.9146481156349182
embed_dim: 32, num_head:2
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.42e-01 | val_acc_hybrid: 6.77e-01 |: 100%|█████| 90/90 [00:29<00:00,  3.05it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.0748e-01 at 3 epoch
tensor(0.6785, device='cuda:0')
[[141  42]
 [ 76 108]]
M_KAN Relevance: 0.09611522406339645
CNN Relevance: 0.903884768486023
embed_dim: 32, num_head:4
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.56e-01 | val_acc_hybrid: 6.97e-01 |: 100%|█████| 90/90 [00:30<00:00,  3.00it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2109e-01 at 44 epoch
tensor(0.7139, device='cuda:0')
[[131  52]
 [ 53 131]]
M_KAN Relevance: 0.15937860310077667
CNN Relevance: 0.8406214118003845
embed_dim: 32, num_head:8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.80e-01 | val_acc_hybrid: 6.84e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.19it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3129e-01 at 4 epoch
tensor(0.7166, device='cuda:0')
[[121  62]
 [ 42 142]]
M_KAN Relevance: 0.04014917090535164
CNN Relevance: 0.9598508477210999
embed_dim: 16, num_head:2
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 8.04e-01 | val_acc_hybrid: 6.77e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.13it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2789e-01 at 3 epoch
tensor(0.6975, device='cuda:0')
[[124  59]
 [ 52 132]]
M_KAN Relevance: 0.27207615971565247
CNN Relevance: 0.7279238104820251
embed_dim: 16, num_head:4
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.61e-01 | val_acc_hybrid: 6.90e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.16it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.4490e-01 at 4 epoch
tensor(0.7166, device='cuda:0')
[[129  54]
 [ 50 134]]
M_KAN Relevance: 0.2094883769750595
CNN Relevance: 0.7905116081237793
embed_dim: 16, num_head:8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 8.04e-01 | val_acc_hybrid: 6.77e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.11it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2789e-01 at 3 epoch
tensor(0.6975, device='cuda:0')
[[124  59]
 [ 52 132]]
M_KAN Relevance: 0.27207615971565247
CNN Relevance: 0.7279238104820251
embed_dim: 48, num_head:6
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.65e-01 | val_acc_hybrid: 7.11e-01 |: 100%|█████| 90/90 [00:29<00:00,  3.04it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 28 epoch
tensor(0.7221, device='cuda:0')
[[133  50]
 [ 52 132]]
M_KAN Relevance: 0.02622860297560692
CNN Relevance: 0.9737713932991028
embed_dim: 24, num_head:6
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.74e-01 | val_acc_hybrid: 6.97e-01 |: 100%|█████| 90/90 [00:29<00:00,  3.07it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3469e-01 at 3 epoch
tensor(0.7030, device='cuda:0')
[[120  63]
 [ 46 138]]
M_KAN Relevance: 0.08549175411462784
CNN Relevance: 0.9145082831382751
embed_dim: 12, num_head:6
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.70e-01 | val_acc_hybrid: 6.97e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.13it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3469e-01 at 2 epoch
tensor(0.7112, device='cuda:0')
[[128  55]
 [ 51 133]]
M_KAN Relevance: 0.02661152556538582
CNN Relevance: 0.9733884930610657
------------------------------ kan_neurons=5, kan_grid=7, lamb=0.0001 ------------------------------
embed_dim: 64, num_head:2
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.67e-01 | val_acc_hybrid: 7.14e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.18it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3810e-01 at 5 epoch
tensor(0.7139, device='cuda:0')
[[121  62]
 [ 43 141]]
M_KAN Relevance: 0.07354620099067688
CNN Relevance: 0.9264537692070007
embed_dim: 64, num_head:4
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.65e-01 | val_acc_hybrid: 7.18e-01 |: 100%|█████| 90/90 [00:29<00:00,  3.08it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3129e-01 at 8 epoch
tensor(0.7112, device='cuda:0')
[[120  63]
 [ 43 141]]
M_KAN Relevance: 0.02184392884373665
CNN Relevance: 0.9781560897827148
embed_dim: 64, num_head:8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.67e-01 | val_acc_hybrid: 7.14e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.17it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3810e-01 at 5 epoch
tensor(0.7139, device='cuda:0')
[[121  62]
 [ 43 141]]
M_KAN Relevance: 0.07354620099067688
CNN Relevance: 0.9264537692070007
embed_dim: 32, num_head:2
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.67e-01 | val_acc_hybrid: 7.07e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.11it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.4490e-01 at 8 epoch
tensor(0.7193, device='cuda:0')
[[130  53]
 [ 50 134]]
M_KAN Relevance: 0.1849890798330307
CNN Relevance: 0.8150109052658081
embed_dim: 32, num_head:4
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.82e-01 | val_acc_hybrid: 7.01e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.12it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3469e-01 at 2 epoch
tensor(0.7193, device='cuda:0')
[[126  57]
 [ 46 138]]
M_KAN Relevance: 0.2563580870628357
CNN Relevance: 0.7436418533325195
embed_dim: 32, num_head:8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.82e-01 | val_acc_hybrid: 7.01e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.14it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3469e-01 at 2 epoch
tensor(0.7193, device='cuda:0')
[[126  57]
 [ 46 138]]
M_KAN Relevance: 0.2563580870628357
CNN Relevance: 0.7436418533325195
embed_dim: 16, num_head:2
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.92e-01 | val_acc_hybrid: 6.84e-01 |: 100%|█████| 90/90 [00:29<00:00,  3.07it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2789e-01 at 7 epoch
tensor(0.7330, device='cuda:0')
[[134  49]
 [ 49 135]]
M_KAN Relevance: 0.23544730246067047
CNN Relevance: 0.7645527124404907
embed_dim: 16, num_head:4
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.60e-01 | val_acc_hybrid: 6.50e-01 |: 100%|█████| 90/90 [00:30<00:00,  2.98it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.0748e-01 at 2 epoch
tensor(0.6703, device='cuda:0')
[[105  78]
 [ 43 141]]
M_KAN Relevance: 0.3717891275882721
CNN Relevance: 0.6282108426094055
embed_dim: 16, num_head:8
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.73e-01 | val_acc_hybrid: 6.87e-01 |: 100%|█████| 90/90 [00:29<00:00,  3.06it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.1769e-01 at 5 epoch
tensor(0.6812, device='cuda:0')
[[114  69]
 [ 48 136]]
M_KAN Relevance: 0.47175654768943787
CNN Relevance: 0.5282434225082397
embed_dim: 48, num_head:6
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 8.06e-01 | val_acc_hybrid: 7.04e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.19it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.3810e-01 at 3 epoch
tensor(0.7221, device='cuda:0')
[[127  56]
 [ 46 138]]
M_KAN Relevance: 0.07518815994262695
CNN Relevance: 0.924811840057373
embed_dim: 24, num_head:6
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.80e-01 | val_acc_hybrid: 6.80e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.18it/s]


saving model version 0.1
✅ Best validation Accuracy: 7.2449e-01 at 3 epoch
tensor(0.7248, device='cuda:0')
[[126  57]
 [ 44 140]]
M_KAN Relevance: 0.3621828854084015
CNN Relevance: 0.6378171443939209
embed_dim: 12, num_head:6
checkpoint directory created: ./model
saving model version 0.0
checkpoint directory created: ./model
saving model version 0.0
Successfully created and tested Model4_4


 train_acc_hybrid: 7.71e-01 | val_acc_hybrid: 6.97e-01 |: 100%|█████| 90/90 [00:28<00:00,  3.15it/s]

saving model version 0.1
✅ Best validation Accuracy: 7.4490e-01 at 5 epoch
tensor(0.7275, device='cuda:0')
[[131  52]
 [ 48 136]]
M_KAN Relevance: 0.42902079224586487
CNN Relevance: 0.5709791779518127



