In [3]:
import torch
import math

In [31]:
def calculate_kan_bottlneck(dmodel, expert_size, kan_type, latent_factor=1):
    # this function calculates kan_bottlneck, so that KAN architecture is parameter matched with MLP
    # kan types:
        # kan_squared   KAN (dmodel, dmodel)
        # kan:          KAN -> KAN (dmodel, kan_bottlneck, dmodel)                                      *ReLU
        # mlp_kan:      MLP -> KAN (dmodel, kan_bottlneck, dmodel)                                      *no ReLU
        # kan_latent:   MLP -> KAN -> MLP (dmodel, kan_bottlneck, kan_bottlneck*latent_factor, dmodel)  *no ReLU
    kan_bottlneck = None

    if kan_type == 'kan':
        kan_bottlneck = int((1/10) * expert_size)
    elif kan_type == 'mlp_kan':
        kan_bottlneck = int((2/11) * expert_size)
    elif kan_type == 'kan_latent':
        kan_bottlneck = int((math.sqrt(dmodel*(80*expert_size*latent_factor + dmodel*(latent_factor + 1)**2)) - dmodel*(latent_factor + 1))/(20*latent_factor))

    return kan_bottlneck


In [32]:
dmodel = 256
dff = 1024

kan_types = ['kan', 'mlp_kan', 'kan_latent']

In [33]:
for ktype in kan_types:
    params = calculate_kan_bottlneck(dmodel, dff, ktype)
    print(f'params: {params}\ttype: {ktype}')

params: 102	type: kan
params: 186	type: mlp_kan
params: 204	type: kan_latent


In [34]:
lfs = [1, 1.2, 1.3, 1.5, 2]
for lf in lfs:
    params = calculate_kan_bottlneck(dmodel, dff, 'kan_latent', lf)
    print(f'params: {params}\ttype: {ktype}')

params: 204	type: kan_latent
params: 186	type: kan_latent
params: 179	type: kan_latent
params: 166	type: kan_latent
params: 143	type: kan_latent


In [28]:
def calculate_kan_params(dmodel, bottlneck, kan_type, latent_factor=1):
    # this function calculates kan_bottlneck, so that KAN architecture is parameter matched with MLP
    # kan types:
        # kan_squared   KAN (dmodel, dmodel)
        # kan:          KAN -> KAN (dmodel, kan_bottlneck, dmodel)                                      *ReLU
        # mlp_kan:      MLP -> KAN (dmodel, kan_bottlneck, dmodel)                                      *no ReLU
        # kan_latent:   MLP -> KAN -> MLP (dmodel, kan_bottlneck, kan_bottlneck*latent_factor, dmodel)  *no ReLU
    params = None

    if kan_type == 'mlp':
        params = 2*dmodel*bottlneck
    elif kan_type == 'kan':
        params = 2*10*dmodel*bottlneck
    elif kan_type == 'mlp_kan':
        params = dmodel*bottlneck + 10*dmodel*bottlneck
    elif kan_type == 'kan_latent':
        params = dmodel*bottlneck + 10*int(latent_factor*bottlneck)*bottlneck + dmodel*int(latent_factor*bottlneck)

    return params

In [29]:
params = calculate_kan_params(dmodel, dff, 'mlp', 1)
print(f'mlp params: {params}')

mlp params: 524288


In [30]:
params = calculate_kan_params(dmodel, 180, 'kan_latent', 1.3)
print(params)

527184
