# Neural Additive Growths?

Investigatory notebook for playing around with different ideas for learning additivity in neural networks (or possibly Kolmogorov-Arnorld Networks)

In [None]:
import os

from tqdm.notebook import tqdm

from itertools import combinations

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.tensorboard import SummaryWriter

from entmax import sparsemax, entmax15, entmax_bisect, normmax_bisect, budget_bisect

from torcheval.metrics.functional import mean_squared_error

%matplotlib inline

In [None]:
# Map x, y coordinates to a set of features z, just so we can visualize things easier
def xy_to_z(x, y):
    return torch.cat([torch.sin(2*np.pi*x), torch.sin(2*np.pi*y), torch.cos(2*np.pi*x), 
                      torch.cos(2*np.pi*y), x+y, -x*y, torch.exp(x), torch.exp(y), torch.pow(x+y, 2), torch.pow(x-y, 2),
                      -torch.pow(x+y, 2), torch.pow(y-x, 2)], dim=1)

def plot_data(xy, z, ms=20, title=""):
    fig, ax = plt.subplots(figsize=(8, 6))
    
    scatter = ax.scatter(xy[:,0].cpu().detach().numpy(), xy[:,1].cpu().detach().numpy(), c=z.cpu().detach().numpy(), cmap='viridis', s=ms)
    colorbar = fig.colorbar(scatter)

    ax.set_title(title)

    plt.show()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(42)

xy = torch.rand((15000, 2)).to(device)

z = xy_to_z(xy[:,0].unsqueeze(1), xy[:,1].unsqueeze(1))

num_features=12

In [None]:
for i in range(num_features):
    plot_data(xy, z[:,i], title="z_{}".format(i))

We then generate the remaining data from a random set of functions of each of the 6 variables in a relatively simple/naive process, which we call $g(z)$:

1. For each variable, select a random function from among $\sin(2*\pi*z_i), \cos(2*\pi*z_i), \exp(z_i), \sqrt{|z_i|})$
2. For each (disjoint) pair of variables (i.e. $(z_0, z_1), (z_2, z_3), (z_4, z_5)$), select a random function from among $z_i*z_j, \sin(2*\pi*(z_i+z_j)), \cos(2*\pi*(z_i+z_j))$
3. Randomly generate a bias value $\beta \in [a, b]$ fpr some $a, b$
4. Add the resulting values together with some noise $\epsilon$ to generate a final value

In [None]:
class SingleFunction:
    def __init__(self, i):
        self.i = i

    def calculate(self, z):
        raise Exception("Abstract Class should not be called")

    def equation(self):
        raise Exception("Abstract Class should not be called")

class Exp(SingleFunction):
    def __init__(self, i):
        super().__init__(i)

    def calculate(self, z):
        return torch.exp(z[:,self.i])

    def equation(self):
        return "e^z{}".format(self.i)

class Sqrt(SingleFunction):
    def __init__(self, i):
        super().__init__(i)

    def calculate(self, z):
        return torch.sqrt(torch.abs(z[:,self.i]))

    def equation(self):
        return "z{}^(1/2)".format(self.i)

class Sin(SingleFunction):
    def __init__(self, i):
        super().__init__(i)

    def calculate(self, z):
        return torch.sin(2*np.pi*z[:,self.i])

    def equation(self):
        return "sin(2*pi*z{})".format(self.i)

class Cos(SingleFunction):
    def __init__(self, i):
        super().__init__(i)

    def calculate(self, z):
        return torch.cos(2*np.pi*z[:,self.i])

    def equation(self):
        return "cos(2*pi*z{})".format(self.i)

single_funcs = [Exp, Sqrt, Sin, Cos]

In [None]:
class MultFunction:
    def __init__(self, i, j):
        self.i = i
        self.j = j

    def calculate(self, z):
        raise Exception("Abstract Class should not be called")

    def equation(self):
        raise Exception("Abstract Class should not be called")

class Cos2(MultFunction):
    def __init__(self, i, j):
        super().__init__(i, j)

    def calculate(self, z):
        return torch.cos(2*np.pi*(z[:,self.i] + z[:,self.j]))

    def equation(self):
        return "cos(2*pi*(z{} + z{}))".format(self.i, self.j)

class Sin2(MultFunction):
    def __init__(self, i, j):
        super().__init__(i, j)

    def calculate(self, z):
        return torch.sin(2*np.pi*(z[:,self.i] + z[:,self.j]))

    def equation(self):
        return "sin(2*pi*(z{} + z{}))".format(self.i, self.j)

class Mult(MultFunction):
    def __init__(self, i, j):
        super().__init__(i, j)

    def calculate(self, z):
        return z[:,self.i] * z[:,self.j]

    def equation(self):
        return "z{}*z{}".format(self.i, self.j)

multi_funcs = [Cos2, Sin2, Mult]

In [None]:
class DataGenerator:
    def __init__(self, rng, a=0, b=10, noise_str=0.2, num_features=12):
        self.bias = rng.random() * (b - a) + a
        self.eqs = ["{:.3f}".format(self.bias)]

        self.noise_str = noise_str

        # Add single variable functions
        self.f1s = []
        for i in range(num_features):
            func = rng.choice(single_funcs)(i)
        
            self.f1s.append(func)
            self.eqs.append(func.equation())

        # Add two variable functions
        self.f2s = []
        for i in range(0, num_features, 2):
            func = rng.choice(multi_funcs)(i, i+1)

            self.f2s.append(func)
            self.eqs.append(func.equation())

    def equation(self):
        return " + ".join(self.eqs) + " + eps"

    def calculate(self, z):
        v = torch.ones(z.shape[0]).to(device) * self.bias

        # Add single variable functions
        for f1 in self.f1s:
            v += f1.calculate(z)

        # Add two variable functions
        for f2 in self.f2s:
            v += f2.calculate(z)

        noises = torch.randn_like(v) * self.noise_str

        v += noises
        
        return v

In [None]:
rng = np.random.default_rng(seed=1234)

data_generator = DataGenerator(rng)

In [None]:
print("labels = {}".format(data_generator.equation()))

labels = data_generator.calculate(z)

In [None]:
plot_data(xy, labels, title="True Data")

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, z, labels, device):
        self.z = torch.Tensor(z).to(device)
        self.labels = torch.Tensor(labels).unsqueeze(1).to(device)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        z = self.z[idx]
        labels = self.labels[idx]

        return z, labels

z_train = z[:10000]
labels_train = labels[:10000]

z_valid = z[10000:12500]
labels_valid = labels[10000:12500]

z_test = z[12500:]
labels_test = labels[12500:]

train_ds = Dataset(z_train, labels_train, device)
valid_ds = Dataset(z_valid, labels_valid, device)
test_ds = Dataset(z_test, labels_test, device)

train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=64)
valid_loader = torch.utils.data.DataLoader(valid_ds, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64)

## Models

Model constructions for the synthetic data

### Base Model

The Base Model just consists of a basic feed forward network, no bells or whistles

In [None]:
class BaseModel(torch.nn.Module):
    def __init__(self, num_features, num_labels=1):
        super().__init__()

        self.model = torch.nn.Sequential(
            torch.nn.Linear(num_features, 64),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.05),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.05),
            torch.nn.Linear(128, 128),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.05),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.05),
            torch.nn.Linear(64, num_labels)
        )

    def forward(self, x):
        return self.model(x)

    def loss(self, output, y):
        return {
            "mse": F.mse_loss(output, y)
        }

    def get_output(self, output):
        return output

    def plot_data(self):
        return
    
    def add_epoch_info(self, writer, epoch):
        return

    def name(self):
        return "Base"

### "Cheating" Additive Model

This model consists of a separate neural network for each invidiual feature (like NAM) and each pairwise combination of features (this would not scale for any reasonable number of features, but is a good test comparison)

In [None]:
class AdditiveModel(torch.nn.Module):
    def __init__(self, num_features, num_labels=1):
        super().__init__()

        self.num_features = num_features

        self.first_order = torch.nn.ModuleList(
            [torch.nn.Sequential(
                torch.nn.Linear(1, 32),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.05),
                torch.nn.Linear(32, 32),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.05),
                torch.nn.Linear(32, 1)
            ) for _ in range(num_features)]
        )

        combs = list(combinations(range(num_features), 2))

        self.second_order = torch.nn.ModuleList(
            [torch.nn.Sequential(
                torch.nn.Linear(2, 32),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.05),
                torch.nn.Linear(32, 32),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.05),
                torch.nn.Linear(32, 1)
            ) for _ in range(len(combs))]
        )

        self.final = torch.nn.Linear(len(combs) + num_features, num_labels)

    def forward(self, x):
        f = torch.cat([self.first_order[i](x[:,i].unsqueeze(1)) for i in range(self.num_features)], dim=1)
        s = torch.cat([self.second_order[i](x[:,p]) for i, p in enumerate(combinations(range(self.num_features), 2))], dim=1)
        yhat = self.final(torch.cat([f, s], dim=1))

        return yhat

    def loss(self, output, y):
        return {
            "mse": F.mse_loss(output, y)
        }

    def get_output(self, output):
        return output

    def plot_data(self):
        return
    
    def add_epoch_info(self, writer, epoch):
        return

    def name(self):
        return "Additive"

## Attention Selection Model

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim * num_heads)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        batch_size, num_features, embed_dim = x.size()
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, num_features, self.num_heads, 3 * self.embed_dim)
        qkv = qkv.permute(0, 2, 1, 3) # (batch_size, num_heads, num_features, 3 * head_dim)
        q, k, v = qkv.chunk(3, dim=-1)
        
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / (self.embed_dim ** 0.5)
        
        return attn_weights


class GraphFragmentationLoss(nn.Module):
    def __init__(self):
        super(GraphFragmentationLoss, self).__init__()
    
    def connectivity_loss(self, A):
        # Sum of off-diagonal elements of adjacency matrix
        return torch.sum(A) - torch.sum(torch.diag(A))
    
    def spectral_loss(self, L, k=1):
        # Compute eigenvalues of Laplacian
        eigenvalues = torch.linalg.eigvalsh(L)
        # Sum of smallest non-zero eigenvalues
        return torch.sum(eigenvalues[1:k+1])
    
    def sparsity_loss(self, A):
        # L1 norm of adjacency matrix
        return torch.norm(A, p=1)
    
    def forward(self, A):
        # Compute degree matrix
        D = torch.diag(A.sum(1))
        # Compute Laplacian matrix
        L = D - A
        
        # Loss components
        conn_loss = self.connectivity_loss(A)
        spec_loss = self.spectral_loss(L)
        sparse_loss = self.sparsity_loss(A)
        
        return conn_loss, spec_loss, sparse_loss


import torch
import torch.nn.functional as F

def laplacian_matrix(adj):
    """Compute the Laplacian matrix of the adjacency matrix."""
    degree = torch.diag(torch.sum(adj, dim=1))
    laplacian = degree - adj
    return laplacian

def spectral_loss(laplacian, k=1):
    """Compute the spectral loss based on the smallest non-zero eigenvalues of the Laplacian."""
    eigenvalues, _ = torch.linalg.eigh(laplacian)
    non_zero_eigenvalues = eigenvalues[eigenvalues > 1e-5]  # Filter out zero eigenvalues
    smallest_non_zero_eigenvalues = non_zero_eigenvalues[:k]
    return torch.sum(smallest_non_zero_eigenvalues)

def spectral_loss_incl_zero(laplacian, k=1):
    """
    Same as above, but just penalize the smallest k eigenvalues regardless of whether or not they are zero.
    """
    eigenvalues, _ = torch.linalg.eigh(laplacian)  # TODO are these sorted?
    print("Eigenvalues", eigenvalues)
    smallest_eigenvalues = eigenvalues[:k]
    return torch.sum(smallest_eigenvalues)

def edge_sparsity_loss(adj):
    """Compute the edge sparsity loss."""
    return torch.sum(adj)

def connectivity_penalty(adj, k=3):
    """Compute the connectivity penalty using powers of the adjacency matrix."""
    adj_k = torch.matrix_power(adj, k)
    return torch.trace(adj_k)

def fragmentation_loss(adj, alpha=1.0, beta=1.0, gamma=1.0, k=1):
    """Compute the total fragmentation loss."""
    laplacian = laplacian_matrix(adj)
    loss_spectral = spectral_loss(laplacian, k)
    loss_sparsity = edge_sparsity_loss(adj)
    loss_connectivity = connectivity_penalty(adj)

    return loss_spectral, loss_sparsity, loss_connectivity

class FeatureAttentionModel(nn.Module):
    def __init__(self, num_features=4, embedding_dim=64, num_heads=8, include_spectral_loss=False, include_connectivity_loss=False, 
                 include_sparsity_loss=False, mse_weight=50, spectral_weight=10):
        super(FeatureAttentionModel, self).__init__()

        self.include_spectral_loss = include_spectral_loss
        self.include_connectivity_loss = include_connectivity_loss
        self.include_sparsity_loss = include_sparsity_loss
        self.mse_weight = mse_weight
        self.spectral_weight = spectral_weight
        if include_spectral_loss or include_connectivity_loss or include_sparsity_loss:
            self.include_graph_losses = True
            self.GFL = GraphFragmentationLoss()
        else:
            self.include_graph_losses = False

        self.num_features = num_features
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        
        self.embedding_layer = nn.Embedding(num_features, embedding_dim)
        self.attention_layer = MultiHeadAttention(embedding_dim, num_heads)

        self.linear1 = nn.ModuleList([
            nn.Sequential(
                nn.Linear(1, 16),
                nn.ReLU(),
                nn.Dropout(0.05),
                nn.Linear(16, 32),
                nn.ReLU(),
                nn.Dropout(0.05),
                nn.Linear(32, embedding_dim),
                nn.ReLU(),
                nn.Dropout(0.05)
            ) for _ in range(num_features)
        ])

        self.linear2 = nn.ModuleList([
            nn.ModuleList([
                nn.Sequential(
                    nn.Linear(num_features, 16),
                    nn.ReLU(),
                    nn.Dropout(0.05),
                    nn.Linear(16, 32),
                    nn.ReLU(),
                    nn.Dropout(0.05),
                    nn.Linear(32, embedding_dim),
                    nn.ReLU(),
                    nn.Dropout(0.05)
                ) for _ in range(num_heads)
            ]) for _ in range(num_features)
        ])

        self.proj = nn.Linear(num_features*embedding_dim + num_features * num_heads * embedding_dim, 1)
    
    def forward(self, x):
        embs = self.embedding_layer(torch.tensor(range(self.num_features)).to(device))

        mask = torch.eye(self.num_features).to(device)

        attn = self.attention_layer(embs.unsqueeze(0))
        # selection = sparsemax(attn - torch.diag(torch.inf * torch.ones(self.num_features)).to(device)) + mask
        selection = sparsemax(attn) * (1 - mask) + mask

        A = (selection[0].mT @ selection[0]).sum(dim=0)
        degree = A.sum(dim=1)
        D_inv_sqrt = torch.diag(torch.pow(degree, -0.5))
        A = D_inv_sqrt @ A @ D_inv_sqrt
        
        selectionx = torch.einsum("ijk,bk->bjik", selection[0], x)

        h1 = torch.cat([l(x[:,i].unsqueeze(1)) for i, l in enumerate(self.linear1)], dim=1)

        h2 = torch.cat([l(selectionx[:,i,j]) for i, l2 in enumerate(self.linear2) for j, l in enumerate(l2)], dim=1)

        y = self.proj(torch.cat([h1, h2], dim=1))
        
        return embs, mask, attn, selection, selectionx, h1, h2, A, y

    def loss(self, output, y):
        A = output[-2]

        mse_loss = F.mse_loss(output[-1], y)

        losses = {
            "mse": mse_loss
        }

        if self.include_graph_losses:
            # loss_spectral, loss_sparsity, loss_connectivity = fragmentation_loss(A)
            conn_loss, spec_loss, sparse_loss = self.GFL(A)

            if self.include_spectral_loss:
                losses["spectral"] = self.spectral_weight * spec_loss

            if self.include_sparsity_loss:
                losses["sparsity"] = sparse_loss

            if self.include_connectivity_loss:
                losses["connectivity"] = conn_loss
                
            losses["mse"] = self.mse_weight * losses["mse"]

        return losses

    def get_output(self, output):
        return output[-1]

    def plot_data(self, save=False, epoch=0, writer=None):
        embs = self.embedding_layer(torch.tensor(range(self.num_features)).to(device))

        mask = torch.eye(self.num_features).to(device)

        attn = self.attention_layer(embs.unsqueeze(0))
        # selection = sparsemax(attn - torch.diag(torch.inf * torch.ones(self.num_features)).to(device)) + mask
        selection = sparsemax(attn) * (1 - mask) + mask

        A = (selection[0].mT @ selection[0]).sum(dim=0)

        A = A.cpu().detach()
        selection = selection.mean(dim=(0,1)).cpu().detach()
        
        fig, ax = plt.subplots(figsize=(20, 16))
        heatmap = sns.heatmap(A, ax=ax, annot=True, cmap='coolwarm', fmt=".2f",
            xticklabels=[f"Feature {i}" for i in range(self.num_features)],
            yticklabels=[f"Feature {i}" for i in range(self.num_features)])
        ax.set_title("Attention Matrix")

        if save:
            heatmap.figure.savefig("A.png")

            # Load the heatmap image and convert it to a tensor
            image = plt.imread("A.png")
            image_tensor = torch.tensor(image).permute(2, 0, 1).unsqueeze(0)
    
            writer.add_image("A", image_tensor[0], epoch)

            plt.close()

        fig, ax = plt.subplots(figsize=(20, 16))
        heatmap = sns.heatmap(selection, ax=ax, annot=True, cmap='coolwarm', fmt=".2f",
            xticklabels=[f"Feature {i}" for i in range(self.num_features)],
            yticklabels=[f"Feature {i}" for i in range(self.num_features)])
        ax.set_title("Selection Heads (Aggregated)")

        if save:
            heatmap.figure.savefig("selection.png")
    
            # Load the heatmap image and convert it to a tensor
            image = plt.imread("selection.png")
            image_tensor = torch.tensor(image).permute(2, 0, 1).unsqueeze(0)
    
            writer.add_image("selection", image_tensor[0], epoch)

            plt.close()

        if not save:
            plt.show()
    
    def add_epoch_info(self, writer, epoch):
        self.plot_data(True, epoch, writer)

    def name(self):
        return "FAM{}{}{}".format("S" if self.include_spectral_loss else "", "E" if self.include_sparsity_loss else "", 
                                  "C" if self.include_connectivity_loss else "")

### KAN-based models

Consider a model with a KAN-based encoder, which takes in features $x_1, \dots, x_D$, and outputs hidden variables $z_1, \dots, z_H$, ideally in a modular group structure that is clusterable (the network can be partitioned into multiple components, each with some inputs and outputs, such that the weight of edges between clusters is low). Then, the final layer produces output $y_j$, either by:

1) Dot product (linear layer): 
$y_j = \mathbf{w}^T \mathbf{z} + b = b + \sum_{i=1}^H w_i h_i$
where $w_i, b$ are learnable parameters. NOT SUPPORTED YET.

2) Generalized additive model (1-layer KAN):
$y_j = b + \sum_{i=1}^H \phi_i(h_i; \theta_i)$
where the $\phi_i$ are learnable activation functions (parameterized with B-splines). In this case, we include this as the last layer in the KAN, but do "clusterability regularization" excluding the last layer.

In [None]:
IN_VARS = [f"z{i}" for i in range(12)]
OUT_VARS = ["y"]
PLOT_DIR = "."

import kan
class ModularKAN(nn.Module):
    def __init__(self, num_features=4, num_outputs=1, num_layers=2, embedding_dim=4, final_layer="gam",
                 include_spectral_loss=False, include_connectivity_loss=False, include_sparsity_loss=False,
                 mse_weight=50, spectral_weight=10,
                 kan_grid=3, kan_grid_margin=1.0, kan_noise=0.3, kan_base_fun="silu", kan_affine_trainable=True,
                 kan_absolute_deviation=False, kan_flat_entropy=True, kan_update_grid_until=20):
        """
        num_features: Number of input features.
        num_outputs: Number of final outputs (1 if single-output)
        num_layers: Number of layers EXCLUDING the final layer
        embedding_dim: Number of neurons in hidden layers
        final_layer: `gam` if final layer is a normal KAN layer, `linear` if we want a final dot-product
        include_spectral_loss: True if you want to include spectral loss (penalizing the smallest eigenvalues ->
            encouraging the graph to be more bottlenecked / easily clusterable)
        include_connectivity_loss: True if you want to include connectivity penalty based on matrix powers
        include_sparsity_loss: True to include an L1 penalty on the edge scores
        

        kan_update_grid_until: update grid every epoch until this number of epochs
        """
        super(ModularKAN, self).__init__()

        self.include_spectral_loss = include_spectral_loss
        self.include_connectivity_loss = include_connectivity_loss
        self.include_sparsity_loss = include_sparsity_loss
        self.mse_weight = mse_weight
        self.spectral_weight = spectral_weight
        if include_spectral_loss or include_connectivity_loss or include_sparsity_loss:
            self.include_graph_losses = True
            self.GFL = GraphFragmentationLoss()
        else:
            self.include_graph_losses = False

        self.num_features = num_features
        self.num_outputs = num_outputs
        self.embedding_dim = embedding_dim
        self.final_layer = final_layer
        self.kan_flat_entropy = kan_flat_entropy
        self.update_grid_until = kan_update_grid_until

        # Size of each layer: [inputs, hidden nodes, ..., outputs]
        self.layer_sizes = [num_features] + [embedding_dim] * (num_layers) + [num_outputs]

        # For the adjacency matrix, we have a node for each neuron (across all layers).
        # Precompute which index each layer's neurons start from.
        # For example, if layer 0 (input) neurons are 0-3, layer 1's neurons are 4-11, and layer 2 (output) neurons is 12, 
        # self.layer_starts would contain: [0, 4, 12, 13] - 13 is the end of the last layer (exclusive)
        self.layer_starts = [0]
        curr_idx = 0
        for i, layer_size in enumerate(self.layer_sizes):
            curr_idx += layer_size
            self.layer_starts.append(curr_idx)
        self.total_nodes = sum(self.layer_sizes)
        assert self.total_nodes == curr_idx

        # Initialize the KAN model
        # TODO: Dot-product final layer is not supported yet
        self.kan = kan.KAN(width=self.layer_sizes, grid=kan_grid, k=3, seed=torch.initial_seed(), device=device,
                           noise_scale=kan_noise, base_fun=kan_base_fun, affine_trainable=kan_affine_trainable, grid_eps=1.0, 
                           grid_margin=kan_grid_margin, absolute_deviation=kan_absolute_deviation)

    def forward(self, x):
        self.x = x  # Cache input
        output = self.kan(x)

        # KAN sparsity losses
        # NOTE: the lamb values passed are completely unused, as we direclty obtain the individual loss components and weight them later.
        # For default weights see https://github.com/KindXiaoming/pykan/blob/master/kan/MultKAN.py#L1411
        kan_l1_loss, kan_entropy_loss, kan_coef_loss, kan_coefdiff_loss, kan_coefdiff2_loss, conn_cost = self.kan.reg(
            reg_metric='edge_backward', lamb_l1=1., lamb_entropy=1., lamb_coef=1., lamb_coefdiff=1.,
            return_indiv=True, flat_entropy=self.kan_flat_entropy)

        # Count the number of edges per layer
        A = torch.zeros((self.total_nodes, self.total_nodes))
        for i, layer_edges in enumerate(self.kan.edge_scores):
            # print("Layer", i, layer_edges.shape)  # self.kan.edge_scores is a list. Each element is of shape [out_dim, in_dim]
            input_nodes = slice(self.layer_starts[i], self.layer_starts[i+1])
            output_nodes = slice(self.layer_starts[i+1], self.layer_starts[i+2])
            A[input_nodes, output_nodes] = layer_edges.T
            A[output_nodes, input_nodes] = layer_edges
        self.A = A
        self.conn_cost = conn_cost
        return A, output

    def loss(self, output, y):
        A = output[-2]
        
        mse_loss = F.mse_loss(output[-1], y)

        losses = {
            "mse": mse_loss
        }

        if self.include_graph_losses:
            # loss_spectral, loss_sparsity, loss_connectivity = fragmentation_loss(A)
            conn_loss, spec_loss, sparse_loss = self.GFL(A)

            if self.include_spectral_loss:
                losses["spectral"] = self.spectral_weight * spec_loss

            if self.include_sparsity_loss:
                losses["sparsity"] = sparse_loss

            if self.include_connectivity_loss:
                losses["connectivity"] = conn_loss

            if self.include_
            losses["mse"] = self.mse_weight * losses["mse"]

        return losses

    def get_output(self, output):
        return output[-1]

    def plot_data(self, save=False, epoch=0, writer=None):

        # Plot the KAN
        # Produce edge/node importance scores
        self.kan.attribute()
        self.kan.node_attribute()

        # Plot the unpruned model
        self.kan.plot(folder=os.path.join(PLOT_DIR, "splines"), in_vars=IN_VARS, out_vars=OUT_VARS, scale=5, varscale=0.13)
        if save:
            plt.savefig(os.path.join(PLOT_DIR, f"epoch{epoch}_kan_plot.png"))
            plt.close()

        # Plot the adjacency matrix
        A = self.A.cpu().detach()
        # selection = selection.mean(dim=(0,1)).cpu().detach()
        
        fig, ax = plt.subplots(figsize=(20, 16))
        heatmap = sns.heatmap(A, ax=ax, annot=True, cmap='coolwarm', fmt=".2f",
            xticklabels=[f"Feature {i}" for i in range(self.num_features)],
            yticklabels=[f"Feature {i}" for i in range(self.num_features)])
        ax.set_title("Attention Matrix")

        if save:
            heatmap.figure.savefig("A.png")

            # Load the heatmap image and convert it to a tensor
            image = plt.imread("A.png")
            image_tensor = torch.tensor(image).permute(2, 0, 1).unsqueeze(0)
    
            writer.add_image("A", image_tensor[0], epoch)

            plt.close()

    
    def add_epoch_info(self, writer, epoch):
        """
        joshuafan: I'm overloading this method so that it also performs KAN
        operations that should be run infrequently, such as (1) updating the spline
        grids, and (2) swapping to reduce connection cost (encourage modularity)
        """
        print("Add epoch info!")
        # Update grid
        if epoch < self.update_grid_until:
            with torch.no_grad():
                self.kan.update_grid(self.x)

        # Swap to reduce connection cost
        self.kan.auto_swap()

        # Create plots
        self.plot_data(True, epoch, writer)

        

    def name(self):
        return "KAN{}{}{}".format("S" if self.include_spectral_loss else "", "E" if self.include_sparsity_loss else "", 
                                  "C" if self.include_connectivity_loss else "")

## Model Training

Next we train the different models we're working with and plotting the outputs

In [None]:
def plot_model_results(model, losses, metrics, data_generator):
    # Generate new data
    xy = torch.rand((15000, 2)).to(device)
    
    z = xy_to_z(xy[:,0].unsqueeze(1), xy[:,1].unsqueeze(1))

    labels = data_generator.calculate(z)

    with torch.no_grad():
        lhat = model.get_output(model(z))

    plot_data(xy, labels, title="True Data")
    plot_data(xy, lhat, title="{} Predictions".format(model.name()))
    plot_data(xy, torch.log(torch.abs(labels.unsqueeze(1) - lhat)), title="{} Predictions vs. True Data".format(model.name()))

    fig, axs = plt.subplots(1, 3, figsize=(20, 4))

    # NOTE: Start at 1 because the initial loss is very high
    axs[0].plot(range(1, len(losses['train'])), losses["train"][1:], linewidth=3)
    axs[0].set_title("Training Loss")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("Loss")

    axs[1].plot(range(1, len(losses['valid'])), losses["valid"][1:], linewidth=3)
    axs[1].set_title("Validation Loss")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("Loss")

    axs[2].plot(range(1, len(losses['test'])), losses["test"][1:], linewidth=3)
    axs[2].set_title("Test Loss")
    axs[2].set_xlabel("Epoch")
    axs[2].set_ylabel("Loss")

    fig, axs = plt.subplots(1, 3, figsize=(20, 4))

    # NOTE: Start at 1 because the MSE is very high
    axs[0].plot(range(1, len(metrics['train'])), metrics["train"][1:], linewidth=3)
    axs[0].set_title("Training MSE")
    axs[0].set_xlabel("Epoch")
    axs[0].set_ylabel("MSE")

    axs[1].plot(range(1, len(metrics['valid'])), metrics["valid"][1:], linewidth=3)
    axs[1].set_title("Validation MSE")
    axs[1].set_xlabel("Epoch")
    axs[1].set_ylabel("MSE")

    axs[2].plot(range(1, len(metrics['test'])), metrics["test"][1:], linewidth=3)
    axs[2].set_title("Test MSE")
    axs[2].set_xlabel("Epoch")
    axs[2].set_ylabel("MSE")

    model.plot_data()

def model_summary(model):
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters for {model.name()}: {trainable_params}")

def train(model, num_epochs, train_loader, valid_loader, test_loader, data_generator, run=None):
    optimizer = torch.optim.AdamW(model.parameters())

    name = model.name()

    if run is not None:
        name = "{}_{}".format(name, run)

    # Create a SummaryWriter instance
    writer = SummaryWriter(log_dir='logs/{}'.format(name))

    print("Training {}".format(name))
    model_summary(model)

    best_model = model.parameters()
    best_mse = None

    os.makedirs("models", exist_ok=True)

    loss_epochs = {
        "train": [],
        "valid": [],
        "test": []
    }

    metric_epochs = {
        "train": [],
        "valid": [],
        "test": []
    }

    model.train()
    for epoch in tqdm(range(num_epochs), leave=True):
        all_probs = []
        all_labels = []

        all_losses = {}
        
        for x, y in train_loader:
            output = model(x)

            losses = model.loss(output, y)

            loss = 0
            for v, l in losses.items():
                loss += l
                if v not in all_losses:
                    all_losses[v] = []
                all_losses[v].append(l.detach())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            all_probs.append(model.get_output(output).detach())
            all_labels.append(y.detach())

        all_probs = torch.cat(all_probs, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        train_losses = {k: torch.cat([c.unsqueeze(0) for c in l], dim=0).mean() for k, l in all_losses.items()}

        train_mse = mean_squared_error(all_probs.flatten(), all_labels.flatten())

        loss = 0
        for l in train_losses.values():
            loss += l.cpu().numpy()
        
        loss_epochs["train"].append(loss)
        metric_epochs["train"].append(train_mse.cpu().detach().numpy())

        model.eval()
        
        all_probs = []
        all_labels = []

        all_losses = {}
        
        for x, y in valid_loader:
            with torch.no_grad():
                output = model(x)
                
                losses = model.loss(output, y)
                loss = 0
                for v, l in losses.items():
                    loss += l
                    if v not in all_losses:
                        all_losses[v] = []
                    all_losses[v].append(l.detach())

            all_probs.append(model.get_output(output).detach())
            all_labels.append(y.detach())

        all_probs = torch.cat(all_probs, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        valid_losses = {k: torch.cat([c.unsqueeze(0) for c in l], dim=0).mean() for k, l in all_losses.items()}

        valid_mse = mean_squared_error(all_probs.flatten(), all_labels.flatten())

        loss = 0
        for l in valid_losses.values():
            loss += l.cpu().numpy()
        
        loss_epochs["valid"].append(loss)
        metric_epochs["valid"].append(valid_mse.cpu().detach().numpy())

        if best_mse is None or valid_mse < best_mse:
            best_mse = valid_mse
            torch.save(model.state_dict(), "models/{}.pt".format(name))
            best_model = model.state_dict()

        all_probs = []
        all_labels = []

        all_losses = {}
        
        for x, y in test_loader:
            with torch.no_grad():
                output = model(x)
                
                losses = model.loss(output, y)
                loss = 0
                for v, l in losses.items():
                    loss += l
                    if v not in all_losses:
                        all_losses[v] = []
                    all_losses[v].append(l.detach())

            all_probs.append(model.get_output(output).detach())
            all_labels.append(y.detach())

        all_probs = torch.cat(all_probs, dim=0)
        all_labels = torch.cat(all_labels, dim=0)
        test_losses = {k: torch.cat([c.unsqueeze(0) for c in l], dim=0).mean() for k, l in all_losses.items()}

        test_mse = mean_squared_error(all_probs.flatten(), all_labels.flatten())

        loss = 0
        for l in test_losses.values():
            loss += l.cpu().numpy()
        
        loss_epochs["test"].append(loss)
        metric_epochs["test"].append(test_mse.cpu().detach().numpy())

        writer.add_scalar("MSE/train", train_mse, epoch)
        writer.add_scalar("MSE/valid", valid_mse, epoch)
        writer.add_scalar("MSE/test", test_mse, epoch)

        for k, loss in train_losses.items():
            writer.add_scalar("Loss/{}/train".format(k), loss, epoch)

        for k, loss in valid_losses.items():
            writer.add_scalar("Loss/{}/valid".format(k), loss, epoch)

        for k, loss in test_losses.items():
            writer.add_scalar("Loss/{}/test".format(k), loss, epoch)

        model.add_epoch_info(writer, epoch)
        
        if epoch % 10 == 0 or epoch == num_epochs - 1:
            print("{} TRAINING MSE: {:.3f} VALID MSE: {:.3f} TEST MSE: {:.3f}".format(epoch, train_mse, valid_mse, test_mse))
            print("\tTRAINING LOSSES: {}".format(" ".join(["{}: {:.3f}".format(k, l) for k, l in train_losses.items()])))
            print("\tVALIDATION LOSSES: {}".format(" ".join(["{}: {:.3f}".format(k, l) for k, l in valid_losses.items()])))
            print("\tTEST LOSSES: {}".format(" ".join(["{}: {:.3f}".format(k, l) for k, l in test_losses.items()])))

    model.load_state_dict(best_model)

    print("Plotting Results")

    plot_model_results(model, loss_epochs, metric_epochs, data_generator)

    return model

In [None]:
num_epochs = 100

#### Base Model

Base Model run

In [None]:
model = BaseModel(num_features).to(device)

train(model, num_epochs, train_loader, valid_loader, test_loader, data_generator)

#### "Cheating" Additive Model

Additive Model run

In [None]:
model = AdditiveModel(num_features).to(device)

train(model, num_epochs, train_loader, valid_loader, test_loader, data_generator)

#### KAN-based model

In [None]:
embedding_dim = 8

model = ModularKAN(num_features=num_features, num_outputs=1, num_layers=2, embedding_dim=embedding_dim,
                   include_spectral_loss=True, include_connectivity_loss=True, include_sparsity_loss=True).to(device)

train(model, num_epochs, train_loader, valid_loader, test_loader, data_generator, 0)

#### Feature Attention Model

Feature Attention Model with just MSE loss

In [None]:
embedding_dim = 32
num_heads = 6

model = FeatureAttentionModel(num_features, embedding_dim, num_heads).to(device)

train(model, num_epochs, train_loader, valid_loader, test_loader, data_generator, 0)

Feature Attention Model with graph-based losses

In [None]:
embedding_dim = 32
num_heads = 6

model = FeatureAttentionModel(num_features, embedding_dim, num_heads, True, False, True, 500).to(device)

train(model, num_epochs, train_loader, valid_loader, test_loader, data_generator)