In [17]:
import os
import torch
import numpy as np
from kan import *

In [18]:
def process_torchdataset(dataset, number):
  imgs, labels = zip(*dataset)
  X = torch.cat(imgs, axis=0).flatten(1)
  Y = torch.from_numpy(np.vstack(labels)[:, 1:3]) if number==1 else torch.from_numpy(np.vstack(labels)[:, -2:])
  return X, Y.long()

def create_dataset_from_files(data_dir, data_filenames, max_n=1000):
    dataset = {}
    for name in ["train", "test"]:
      X, Y = [], []
      for filename, number in zip(data_filenames,  [1,9]):
          data = torch.load(os.path.join(data_dir, "{}_{}".format(name, filename)))[:max_n]
          x, y = process_torchdataset(data, number)
          X.extend(x)
          Y.extend(y)
      dataset[f'{name}_input'] = torch.stack(X, axis=0)
      dataset[f'{name}_label'] = torch.stack(Y, axis=0)
    return dataset

In [19]:
data_dir = "/home/carolina/Anansi/MA/KG/MNIST/data/MNIST/excluded_1_9"
data_filenames = ["GAP_1_normalized.pt", "GAP_9_normalized.pt"]
dataset = create_dataset_from_files(data_dir, data_filenames, max_n=50)

In [20]:
model = KAN(width=[784, 128, 64, 32, 16, 2], grid=10, k=3, seed=0, symbolic_enabled=True,
            bias_trainable=False, sp_trainable=False, sb_trainable=False)
model

KAN(
  (biases): ModuleList(
    (0): Linear(in_features=128, out_features=1, bias=False)
    (1): Linear(in_features=64, out_features=1, bias=False)
    (2): Linear(in_features=32, out_features=1, bias=False)
    (3): Linear(in_features=16, out_features=1, bias=False)
    (4): Linear(in_features=2, out_features=1, bias=False)
  )
  (act_fun): ModuleList(
    (0-4): 5 x KANLayer(
      (base_fun): SiLU()
    )
  )
  (base_fun): SiLU()
  (symbolic_fun): ModuleList(
    (0-4): 5 x Symbolic_KANLayer()
  )
)

In [21]:
total_params = 0
for name, param in model.named_parameters():
    print(name, param.shape)
    total_params += param.flatten().shape[0]
print(total_params)

biases.0.weight torch.Size([1, 128])
biases.1.weight torch.Size([1, 64])
biases.2.weight torch.Size([1, 32])
biases.3.weight torch.Size([1, 16])
biases.4.weight torch.Size([1, 2])
act_fun.0.grid torch.Size([100352, 11])
act_fun.0.coef torch.Size([100352, 13])
act_fun.0.scale_base torch.Size([100352])
act_fun.0.scale_sp torch.Size([100352])
act_fun.0.mask torch.Size([100352])
act_fun.1.grid torch.Size([8192, 11])
act_fun.1.coef torch.Size([8192, 13])
act_fun.1.scale_base torch.Size([8192])
act_fun.1.scale_sp torch.Size([8192])
act_fun.1.mask torch.Size([8192])
act_fun.2.grid torch.Size([2048, 11])
act_fun.2.coef torch.Size([2048, 13])
act_fun.2.scale_base torch.Size([2048])
act_fun.2.scale_sp torch.Size([2048])
act_fun.2.mask torch.Size([2048])
act_fun.3.grid torch.Size([512, 11])
act_fun.3.coef torch.Size([512, 13])
act_fun.3.scale_base torch.Size([512])
act_fun.3.scale_sp torch.Size([512])
act_fun.3.mask torch.Size([512])
act_fun.4.grid torch.Size([32, 11])
act_fun.4.coef torch.Size([

In [22]:
def train_acc():
    return torch.mean((torch.round(model(dataset['train_input'])) == dataset['train_label']).float())

def test_acc():
    return torch.mean((torch.round(model(dataset['test_input'])) == dataset['test_label']).float())

results = model.train(dataset, opt="LBFGS", steps=1, batch=-1, loss_fn=None, 
                      lamb=0.01, lamb_entropy=10., update_grid=False,
                      grid_update_num=10, metrics=(train_acc, test_acc));
#results['train_acc'][-1], results['test_acc'][-1]

''' dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., 
lamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., stop_grid_update_step=50, 
batch=-1, small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, 
save_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder='./video', device='cpu')


steps = training_steps = epochs * num_batches = epochs * rpund_up(train_dataset.size()/num_batches)
'''

train loss: 5.71e-01 | test loss: 5.24e-01 | reg: 3.01e+03 : 100%|███| 1/1 [12:28<00:00, 748.48s/it]


' dataset, opt="LBFGS", steps=100, log=1, lamb=0., lamb_l1=1., lamb_entropy=2., lamb_coef=0., \nlamb_coefdiff=0., update_grid=True, grid_update_num=10, loss_fn=None, lr=1., stop_grid_update_step=50, \nbatch=-1, small_mag_threshold=1e-16, small_reg_factor=1., metrics=None, sglr_avoid=False, \nsave_fig=False, in_vars=None, out_vars=None, beta=3, save_fig_freq=1, img_folder=\'./video\', device=\'cpu\')\n\n\nsteps = training_steps = epochs * num_batches = epochs * rpund_up(train_dataset.size()/num_batches)\n'

In [23]:
results['train_acc'][-1], results['test_acc'][-1]

(0.5550000071525574, 0.5299999713897705)

In [42]:
def plot(model, folder="./figures", beta=3, mask=False, mode="supervised", scale=0.5, tick=False, sample=False, in_vars=None, out_vars=None, title=None):
    '''
    plot KAN
    
    Args:
    -----
        folder : str
            the folder to store pngs
        beta : float
            positive number. control the transparency of each activation. transparency = tanh(beta*l1).
        mask : bool
            If True, plot with mask (need to run prune() first to obtain mask). If False (by default), plot all activation functions.
        mode : bool
            "supervised" or "unsupervised". If "supervised", l1 is measured by absolution value (not subtracting mean); if "unsupervised", l1 is measured by standard deviation (subtracting mean).
        scale : float
            control the size of the diagram
        in_vars: None or list of str
            the name(s) of input variables
        out_vars: None or list of str
            the name(s) of output variables
        title: None or str
            title
        
    Returns:
    --------
        Figure
        
    Example
    -------
    >>> # see more interactive examples in demos
    >>> model = KAN(width=[2,3,1], grid=3, k=3, noise_scale=1.0)
    >>> x = torch.normal(0,1,size=(100,2))
    >>> model(x) # do a forward pass to obtain model.acts
    >>> model.plot()
    '''
    if not os.path.exists(folder):
        os.makedirs(folder)
    # matplotlib.use('Agg')
    depth = len(model.width) - 1
    for l in range(depth-1, depth):
        w_large = 2.0
        for i in range(model.width[l]):
            for j in range(model.width[l + 1]):
                rank = torch.argsort(model.acts[l][:, i])
                fig, ax = plt.subplots(figsize=(w_large, w_large))

                num = rank.shape[0]

                symbol_mask = model.symbolic_fun[l].mask[j][i]
                numerical_mask = model.act_fun[l].mask.reshape(model.width[l + 1], model.width[l])[j][i]
                if symbol_mask > 0. and numerical_mask > 0.:
                    color = 'purple'
                    alpha_mask = 1
                if symbol_mask > 0. and numerical_mask == 0.:
                    color = "red"
                    alpha_mask = 1
                if symbol_mask == 0. and numerical_mask > 0.:
                    color = "black"
                    alpha_mask = 1
                if symbol_mask == 0. and numerical_mask == 0.:
                    color = "white"
                    alpha_mask = 0

                if tick == True:
                    ax.tick_params(axis="y", direction="in", pad=-22, labelsize=50)
                    ax.tick_params(axis="x", direction="in", pad=-15, labelsize=50)
                    x_min, x_max, y_min, y_max = model.get_range(l, i, j, verbose=False)
                    plt.xticks([x_min, x_max], ['%2.f' % x_min, '%2.f' % x_max])
                    plt.yticks([y_min, y_max], ['%2.f' % y_min, '%2.f' % y_max])
                else:
                    plt.xticks([])
                    plt.yticks([])
                if alpha_mask == 1:
                    plt.gca().patch.set_edgecolor('black')
                else:
                    plt.gca().patch.set_edgecolor('white')
                plt.gca().patch.set_linewidth(1.5)
                # plt.axis('off')

                plt.plot(model.acts[l][:, i][rank].cpu().detach().numpy(), model.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, lw=5)
                if sample == True:
                    plt.scatter(model.acts[l][:, i][rank].cpu().detach().numpy(), model.spline_postacts[l][:, j, i][rank].cpu().detach().numpy(), color=color, s=400 * scale ** 2)
                plt.gca().spines[:].set_color(color)

                lock_id = model.act_fun[l].lock_id[j * model.width[l] + i].long().item()
                if lock_id > 0:
                    im = plt.imread(f'{folder}/lock.png')
                    newax = fig.add_axes([0.15, 0.7, 0.15, 0.15])
                    plt.text(500, 400, lock_id, fontsize=15)
                    newax.imshow(im)
                    newax.axis('off')

                plt.savefig(f'{folder}/sp_{l}_{i}_{j}.png', bbox_inches="tight", dpi=400)
                plt.close()

    def score2alpha(score):
        return np.tanh(beta * score)

    if mode == "supervised":
        alpha = [score2alpha(score.cpu().detach().numpy()) for score in model.acts_scale]
    elif mode == "unsupervised":
        alpha = [score2alpha(score.cpu().detach().numpy()) for score in model.acts_scale_std]

    # draw skeleton
    width = np.array(model.width)
    A = 1
    y0 = 0.4  # 0.4

    # plt.figure(figsize=(5,5*(neuron_depth-1)*y0))
    neuron_depth = len(width)
    min_spacing = A / np.maximum(np.max(width), 5)

    max_neuron = np.max(width)
    max_num_weights = np.max(width[:-1] * width[1:])
    y1 = 0.4 / np.maximum(max_num_weights, 3)

    fig, ax = plt.subplots(figsize=(10 * scale, 10 * scale * (neuron_depth - 1) * y0))
    # fig, ax = plt.subplots(figsize=(5,5*(neuron_depth-1)*y0))

    # plot scatters and lines
    for l in range(neuron_depth):
        n = width[l]
        spacing = A / n
        for i in range(n):
            plt.scatter(1 / (2 * n) + i / n, l * y0, s=min_spacing ** 2 * 10000 * scale ** 2, color='black')

            if l < neuron_depth - 1:
                # plot connections
                n_next = width[l + 1]
                N = n * n_next
                for j in range(n_next):
                    id_ = i * n_next + j

                    symbol_mask = model.symbolic_fun[l].mask[j][i]
                    numerical_mask = model.act_fun[l].mask.reshape(model.width[l + 1], model.width[l])[j][i]
                    if symbol_mask == 1. and numerical_mask == 1.:
                        color = 'purple'
                        alpha_mask = 1.
                    if symbol_mask == 1. and numerical_mask == 0.:
                        color = "red"
                        alpha_mask = 1.
                    if symbol_mask == 0. and numerical_mask == 1.:
                        color = "black"
                        alpha_mask = 1.
                    if symbol_mask == 0. and numerical_mask == 0.:
                        color = "white"
                        alpha_mask = 0.
                    if mask == True:
                        plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N], [l * y0, (l + 1 / 2) * y0 - y1], color=color, lw=2 * scale, alpha=alpha[l][j][i] * model.mask[l][i].item() * model.mask[l + 1][j].item())
                        plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next], [(l + 1 / 2) * y0 + y1, (l + 1) * y0], color=color, lw=2 * scale, alpha=alpha[l][j][i] * model.mask[l][i].item() * model.mask[l + 1][j].item())
                    else:
                        plt.plot([1 / (2 * n) + i / n, 1 / (2 * N) + id_ / N], [l * y0, (l + 1 / 2) * y0 - y1], color=color, lw=2 * scale, alpha=alpha[l][j][i] * alpha_mask)
                        plt.plot([1 / (2 * N) + id_ / N, 1 / (2 * n_next) + j / n_next], [(l + 1 / 2) * y0 + y1, (l + 1) * y0], color=color, lw=2 * scale, alpha=alpha[l][j][i] * alpha_mask)

        plt.xlim(0, 1)
        plt.ylim(-0.1 * y0, (neuron_depth - 1 + 0.1) * y0)

    # -- Transformation functions
    DC_to_FC = ax.transData.transform
    FC_to_NFC = fig.transFigure.inverted().transform
    # -- Take data coordinates and transform them to normalized figure coordinates
    DC_to_NFC = lambda x: FC_to_NFC(DC_to_FC(x))

    plt.axis('off')

    # plot splines
    for l in range(neuron_depth - 1):
        n = width[l]
        for i in range(n):
            n_next = width[l + 1]
            N = n * n_next
            for j in range(n_next):
                id_ = i * n_next + j
                im = plt.imread(f'{folder}/sp_{l}_{i}_{j}.png')
                left = DC_to_NFC([1 / (2 * N) + id_ / N - y1, 0])[0]
                right = DC_to_NFC([1 / (2 * N) + id_ / N + y1, 0])[0]
                bottom = DC_to_NFC([0, (l + 1 / 2) * y0 - y1])[1]
                up = DC_to_NFC([0, (l + 1 / 2) * y0 + y1])[1]
                newax = fig.add_axes([left, bottom, right - left, up - bottom])
                # newax = fig.add_axes([1/(2*N)+id_/N-y1, (l+1/2)*y0-y1, y1, y1], anchor='NE')
                if mask == False:
                    newax.imshow(im, alpha=alpha[l][j][i])
                else:
                    ### make sure to run model.prune() first to compute mask ###
                    newax.imshow(im, alpha=alpha[l][j][i] * model.mask[l][i].item() * model.mask[l + 1][j].item())
                newax.axis('off')

    if in_vars != None:
        n = model.width[0]
        for i in range(n):
            plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), -0.1, in_vars[i], fontsize=40 * scale, horizontalalignment='center', verticalalignment='center')

    if out_vars != None:
        n = model.width[-1]
        for i in range(n):
            plt.gcf().get_axes()[0].text(1 / (2 * (n)) + i / (n), y0 * (len(model.width) - 1) + 0.1, out_vars[i], fontsize=40 * scale, horizontalalignment='center', verticalalignment='center')

    if title != None:
        plt.gcf().get_axes()[0].text(0.5, y0 * (len(model.width) - 1) + 0.2, title, fontsize=40 * scale, horizontalalignment='center', verticalalignment='center')

In [43]:
#model(dataset["train_input"][:10])
# the last layer plottings work. This is because, as the data is so large, as well as the model,
# I was probably not training enough, and what happens is that the first layers produce the same (x,y) per i,j
# for every image in the training dataser (plot=1 point)
# x = pre-spline activations/input of each layer (batch, in_dim); y= post_spline activations (batch, in_dim, out_dim)
#model.plot()
plot(model)

FileNotFoundError: [Errno 2] No such file or directory: './figures/sp_0_0_0.png'

Error in callback <function flush_figures at 0x7f8af8e8bb00> (for post_execute), with arguments args (),kwargs {}:


KeyboardInterrupt: 

In [26]:
l=0
for i in range(784):
    for j in range(2):
        symbol_mask = model.symbolic_fun[l].mask[j][i]
        numerical_mask = model.act_fun[l].mask.reshape(model.width[l + 1], model.width[l])[j][i]
        if symbol_mask != 0:
            print(symbol_mask)
        if numerical_mask != 1:
            print(numerical_mask)

In [41]:
l=4
i=0
j=0
rank = torch.argsort(model.acts[l][:, i])
model.acts[l][:, i][rank] #x in plot sp_l_i_j; each value of x corresponds to the input (dim=i) values for layer l

tensor([0.9566, 1.0561, 1.5314, 1.7213, 1.9203, 2.0120, 2.8071, 2.8179, 2.8220,
        2.9448, 2.9932, 3.0043, 3.0427, 3.0429, 3.1859, 3.1922, 3.2206, 3.2578,
        3.2589, 3.2703, 3.3166, 3.3275, 3.3315, 3.3497, 3.3661, 3.4060, 3.4147,
        3.4387, 3.4419, 3.4491, 3.4658, 3.4699, 3.4712, 3.4725, 3.4814, 3.4859,
        3.4882, 3.4925, 3.4931, 3.4984, 3.4996, 3.5166, 3.5174, 3.5204, 3.5266,
        3.5276, 3.5372, 3.5391, 3.5476, 3.5545, 3.5588, 3.5639, 3.5644, 3.5659,
        3.5669, 3.5683, 3.5727, 3.5736, 3.5746, 3.5755, 3.5770, 3.5772, 3.5772,
        3.5774, 3.5793, 3.5793, 3.5793, 3.5797, 3.5802, 3.5802, 3.5807, 3.5807,
        3.5808, 3.5808, 3.5809, 3.5810, 3.5810, 3.5811, 3.5811, 3.5812, 3.5812,
        3.5812, 3.5813, 3.5814, 3.5814, 3.5814, 3.5816, 3.5822, 3.5825, 3.5832,
        3.5832, 3.5833, 3.5849, 3.5853, 3.5855, 3.5859, 3.5866, 3.5868, 3.5871,
        3.5931], grad_fn=<IndexBackward0>)

In [38]:
model.spline_postacts[l][:, j, i][rank] #y; each value of y corresponds to a batch image/sample's
#pos_spline values (batch, in_neurons, out_neurons) of layer l

tensor([0.3086, 0.3708, 0.3075, 0.3551, 0.4073, 0.4316, 0.6438, 0.6467, 0.6478,
        0.6804, 0.6932, 0.6961, 0.7063, 0.7063, 0.7440, 0.7457, 0.7532, 0.7629,
        0.7632, 0.7662, 0.7784, 0.7812, 0.7823, 0.7870, 0.7913, 0.8017, 0.8040,
        0.8103, 0.8111, 0.8130, 0.8173, 0.8184, 0.8187, 0.8191, 0.8214, 0.8226,
        0.8232, 0.8243, 0.8244, 0.8258, 0.8261, 0.8306, 0.8308, 0.8316, 0.8332,
        0.8334, 0.8359, 0.8364, 0.8386, 0.8404, 0.8415, 0.8428, 0.8430, 0.8434,
        0.8436, 0.8440, 0.8451, 0.8454, 0.8456, 0.8459, 0.8462, 0.8463, 0.8463,
        0.8464, 0.8468, 0.8468, 0.8469, 0.8469, 0.8471, 0.8471, 0.8472, 0.8472,
        0.8472, 0.8472, 0.8473, 0.8473, 0.8473, 0.8473, 0.8473, 0.8473, 0.8473,
        0.8473, 0.8474, 0.8474, 0.8474, 0.8474, 0.8474, 0.8476, 0.8477, 0.8479,
        0.8479, 0.8479, 0.8483, 0.8484, 0.8485, 0.8486, 0.8488, 0.8488, 0.8489,
        0.8504])

In [24]:
l, i, j = 0, 0, 0
len(model.acts[l][:, :]), model.acts[l][0, :].shape, model.acts[l][0, :]
#model.spline_postacts[l][:, j, i]

(100,
 torch.Size([784]),
 tensor([-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
         -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
  

In [12]:
lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','tan','abs']
model.auto_symbolic(lib=lib)
formula = model.symbolic_formula()[0][0]
formula

fixing (0,0,0) with abs, r2=0.0
fixing (0,0,1) with abs, r2=0.0
fixing (0,1,0) with abs, r2=0.0
fixing (0,1,1) with abs, r2=0.0
fixing (0,2,0) with abs, r2=0.0
fixing (0,2,1) with abs, r2=0.0
fixing (0,3,0) with abs, r2=0.0
fixing (0,3,1) with abs, r2=0.0
fixing (0,4,0) with abs, r2=0.0
fixing (0,4,1) with abs, r2=0.0
fixing (0,5,0) with exp, r2=6.860967089600267e-12
fixing (0,5,1) with abs, r2=0.0
fixing (0,6,0) with abs, r2=0.0
fixing (0,6,1) with abs, r2=0.0
fixing (0,7,0) with abs, r2=0.0
fixing (0,7,1) with abs, r2=0.0
fixing (0,8,0) with abs, r2=0.0
fixing (0,8,1) with abs, r2=0.0
fixing (0,9,0) with abs, r2=0.0
fixing (0,9,1) with abs, r2=0.0
fixing (0,10,0) with abs, r2=0.0
fixing (0,10,1) with abs, r2=0.0
fixing (0,11,0) with abs, r2=0.0
fixing (0,11,1) with abs, r2=0.0
fixing (0,12,0) with abs, r2=0.0
fixing (0,12,1) with abs, r2=0.0
fixing (0,13,0) with abs, r2=0.0
fixing (0,13,1) with abs, r2=0.0
fixing (0,14,0) with abs, r2=0.0
fixing (0,14,1) with abs, r2=0.0
fixing (0,15

KeyboardInterrupt: 

In [107]:
del model

How do they increase number of grid points during training?
*   at each n step, they just reinitialize the model with this method: initialize_from_another_model?