In [1]:
import torch as th
import numpy as np

#This is inspired by Kolmogorov-Arnold Networks but using 1d fourier coefficients instead of splines coefficients
#It should be easier to optimize as fourier are more dense than spline (global vs local)
#Once convergence is reached you can replace the 1d function with spline approximation for faster evaluation giving almost the same result
#The other advantage of using fourier over spline is that the function are periodic, and therefore more numerically bounded
#Avoiding the issues of going out of grid

class NaiveFourierKANLayer(th.nn.Module):
    def __init__( self, inputdim, outdim, gridsize,addbias=True):
        super(NaiveFourierKANLayer,self).__init__()
        self.gridsize= gridsize
        self.addbias = addbias
        self.inputdim = inputdim
        self.outdim = outdim

        #The normalization has been chosen so that if given inputs where each coordinate is of unit variance,
        #then each coordinates of the output is of unit variance
        #independently of the various sizes
        self.fouriercoeffs = th.nn.Parameter( th.randn(2,outdim,inputdim,gridsize) /
                                             (np.sqrt(inputdim) * np.sqrt(self.gridsize) ) )
        if( self.addbias ):
            self.bias  = th.nn.Parameter( th.zeros(1,outdim))

    #x.shape ( ... , indim )
    #out.shape ( ..., outdim)
    def forward(self,x):
        xshp = x.shape
        outshape = xshp[0:-1]+(self.outdim,)
        x = th.reshape(x,(-1,self.inputdim))
        #Starting at 1 because constant terms are in the bias
        k = th.reshape( th.arange(1,self.gridsize+1,device=x.device),(1,1,1,self.gridsize))
        xrshp = th.reshape(x,(x.shape[0],1,x.shape[1],1) )
        #This should be fused to avoid materializing memory
        c = th.cos( k*xrshp )
        s = th.sin( k*xrshp )
        #We compute the interpolation of the various functions defined by their fourier coefficient for each input coordinates and we sum them
        y =  th.sum( c*self.fouriercoeffs[0:1],(-2,-1))
        y += th.sum( s*self.fouriercoeffs[1:2],(-2,-1))
        if( self.addbias):
            y += self.bias
        #End fuse
        '''
        #You can use einsum instead to reduce memory usage
        #It stills not as good as fully fused but it should help
        #einsum is usually slower though
        c = th.reshape(c,(1,x.shape[0],x.shape[1],self.gridsize))
        s = th.reshape(s,(1,x.shape[0],x.shape[1],self.gridsize))
        y2 = th.einsum( "dbik,djik->bj", th.concat([c,s],axis=0) ,self.fouriercoeffs )
        if( self.addbias):
            y2 += self.bias
        diff = th.sum((y2-y)**2)
        print("diff")
        print(diff) #should be ~0
        '''
        y = th.reshape( y, outshape)
        return y

In [None]:
xshp = x.shape
outshape = xshp[0:-1]+(self.outdim,)
        x = th.reshape(x,(-1,self.inputdim))
        #Starting at 1 because constant terms are in the bias
        k = th.reshape( th.arange(1,self.gridsize+1,device=x.device),(1,1,1,self.gridsize))
        xrshp = th.reshape(x,(x.shape[0],1,x.shape[1],1) )
        #This should be fused to avoid materializing memory
        c = th.cos( k*xrshp )
        s = th.sin( k*xrshp )
        #We compute the interpolation of the various functions defined by their fourier coefficient for each input coordinates and we sum them
        y =  th.sum( c*self.fouriercoeffs[0:1],(-2,-1))
        y += th.sum( s*self.fouriercoeffs[1:2],(-2,-1))
        if( self.addbias):
            y += self.bias
        #End fuse
        '''
        #You can use einsum instead to reduce memory usage
        #It stills not as good as fully fused but it should help
        #einsum is usually slower though
        c = th.reshape(c,(1,x.shape[0],x.shape[1],self.gridsize))
        s = th.reshape(s,(1,x.shape[0],x.shape[1],self.gridsize))
        y2 = th.einsum( "dbik,djik->bj", th.concat([c,s],axis=0) ,self.fouriercoeffs )
        if( self.addbias):
            y2 += self.bias
        diff = th.sum((y2-y)**2)
        print("diff")
        print(diff) #should be ~0
        '''
        y = th.reshape( y, outshape)

In [4]:
def demo():
    bs = 10
    L = 3 #Not necessary just to show that additional dimensions are batched like Linear
    inputdim = 50
    hidden = 200
    outdim = 100
    gridsize = 300

    device = "cpu" #"cuda"

    fkan1 = NaiveFourierKANLayer(inputdim, hidden, gridsize).to(device)
    fkan2 = NaiveFourierKANLayer(hidden, outdim, gridsize).to(device)

    x0 =th.randn(bs,inputdim).to(device)

    h = fkan1(x0)
    y = fkan2(h)
    print("x0.shape")
    print( x0.shape)
    print("h.shape")
    print( h.shape)
    print( "th.mean( h )")
    print( th.mean( h ) )
    print( "th.mean( th.var(h,-1) )")
    print( th.mean( th.var(h,-1)))

    print("y.shape")
    print( y.shape )
    print( "th.mean( y)")
    print( th.mean( y ) )
    print( "th.mean( th.var(y,-1) )")
    print( th.mean( th.var(y,-1)))

    print(" ")
    print(" ")
    print("Sequence example")
    print(" ")
    print(" ")
    xseq =th.randn(bs, L ,inputdim).to(device)

    h = fkan1(xseq)
    y = fkan2(h)
    print("xseq.shape")
    print( xseq.shape)
    print("h.shape")
    print( h.shape)
    print( "th.mean( h )")
    print( th.mean( h ) )
    print( "th.mean( th.var(h,-1) )")
    print( th.mean( th.var(h,-1)))

    print("y.shape")
    print( y.shape )
    print( "th.mean( y)")
    print( th.mean( y ) )
    print( "th.mean( th.var(y,-1) )")
    print( th.mean( th.var(y,-1)))

demo()

x0.shape
torch.Size([10, 50])
h.shape
torch.Size([10, 200])
th.mean( h )
tensor(0.0086, grad_fn=<MeanBackward0>)
th.mean( th.var(h,-1) )
tensor(0.9996, grad_fn=<MeanBackward0>)
y.shape
torch.Size([10, 100])
th.mean( y)
tensor(0.0360, grad_fn=<MeanBackward0>)
th.mean( th.var(y,-1) )
tensor(0.9972, grad_fn=<MeanBackward0>)
 
 
Sequence example
 
 
xseq.shape
torch.Size([10, 3, 50])
h.shape
torch.Size([10, 3, 200])
th.mean( h )
tensor(0.0024, grad_fn=<MeanBackward0>)
th.mean( th.var(h,-1) )
tensor(1.0207, grad_fn=<MeanBackward0>)
y.shape
torch.Size([10, 3, 100])
th.mean( y)
tensor(0.0070, grad_fn=<MeanBackward0>)
th.mean( th.var(y,-1) )
tensor(1.0166, grad_fn=<MeanBackward0>)


In [16]:
bs = 10
L = 3 #Not necessary just to show that additional dimensions are batched like Linear
inputdim = 50
hidden = 200
outdim = 100
gridsize = 300

device = "cpu" #"cuda"

fkan1 = NaiveFourierKANLayer(inputdim, hidden, gridsize).to(device)
fkan2 = NaiveFourierKANLayer(hidden, outdim, gridsize).to(device)

x =th.randn(bs,inputdim).to(device)
print(x.shape)
xshp = x.shape
outshape = xshp[0:-1]+(outdim,)
print(outshape)

x = th.reshape(x,(-1,inputdim))
print(x.shape)

torch.Size([10, 50])
torch.Size([10, 100])
torch.Size([10, 50])


In [20]:
tmp = th.arange(1,gridsize+1,device=x.device)
print(tmp.shape)
k = th.reshape(tmp,(1,1,1,gridsize))
print(k.shape)

xrshp = th.reshape(x,(x.shape[0],1,x.shape[1],1) )
print(xrshp.shape)

c = th.cos( k*xrshp )
s = th.sin( k*xrshp )
print(c.shape)

torch.Size([300])
torch.Size([1, 1, 1, 300])
torch.Size([10, 1, 50, 1])
torch.Size([10, 1, 50, 300])


In [23]:
fouriercoeffs = th.nn.Parameter( th.randn(2,outdim,inputdim,gridsize) / (np.sqrt(inputdim) * np.sqrt(gridsize) ) )
print(fouriercoeffs.shape)

tmp = fouriercoeffs[0:1]
print(tmp.shape)

tmp2 = c*tmp
print(tmp2.shape)
y =  th.sum(tmp2,(-2,-1))
print(y.shape)

torch.Size([2, 100, 50, 300])
torch.Size([1, 100, 50, 300])
torch.Size([10, 100, 50, 300])
torch.Size([10, 100])


In [28]:
a1 = th.randn([1, 2])
a2 = th.randn([2, 1])
a3 = a1*a2
print(a3.shape)
print(a1)
print(a2)
print(a3)

torch.Size([2, 2])
tensor([[1.8056, 0.4584]])
tensor([[ 0.9807],
        [-0.4428]])
tensor([[ 1.7708,  0.4496],
        [-0.7996, -0.2030]])
