In [1]:
import functools
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as axes3d
import keras
from sympy import *
init_printing()

Using TensorFlow backend.


In [2]:
theta, phi = symbols('theta, phi', real=True)

### recursive procedure of generating associated legendre polynomials
- optimized with dynamic programming

In [3]:
@functools.lru_cache(maxsize=4096)
def Pmm(m, x):
    if m <= 0:
        return Number(1)
    fact = Number(2*m - 1)
    pmm = Number(-1)**m * fact * sqrt(1-x**2) * Pmm(m-1, x)
    return pmm.simplify()

@functools.lru_cache(maxsize=4096)
def Pmmp1(m, x):
    pmmp1 = x * (2*m + 1) * Pmm(m, x)
    return pmmp1.simplify()

@functools.lru_cache(maxsize=4096)
def P(l, m, x):
    if l == m:
        return Pmm(m, x)
    if l == m+1:
        return Pmmp1(m, x)
    pll = ((2*l-1) * x * P(l-1, m, x) - (l+m-1) * P(l-2, m, x)) / (l-m)
    return pll.simplify()

### 3d real spherical harmonic functions
- optimized with dynamic programming

In [4]:
@functools.lru_cache(maxsize=4096)
def K(l, m):
    return sqrt(((2*l+1)/(4*pi))*(factorial(l-m)/factorial(l+m))).simplify()

@functools.lru_cache(maxsize=4096)
def SH(l, m, theta, phi):
    if m > 0:
        sh = sqrt(2)*K(l,m)*cos(m*phi)*P(l,m,cos(theta))
    elif m < 0:
        sh = sqrt(2)*K(l,-m)*sin(-m*phi)*P(l,-m,cos(theta))
    else:
        sh = K(l,m)*P(l,m,cos(theta))
    return sh.simplify()

### mnist dataset for testing

In [5]:
(X_train, Y_train), (X_test, Y_test) = keras.datasets.mnist.load_data()
X_train, X_test = X_train/255, X_test/255

### a custom keras 3d directional neural layer
- it reuse the spherical harmonic generation procedure before
- it also demonstrate how to combine sympy and tensorflow together

In [6]:
class Directional(keras.layers.Layer):
    
    def __init__(self, units, channels=4, bands=4, ** kwargs):
        self.units = units
        self.channels = channels
        self.bands = bands
        super(Directional, self).__init__(**kwargs)
    
    def build(self, input_shape):
        input_shape, input_angle_shape = input_shape
        if input_angle_shape[-1] != 2:
            raise Exception("last dimension of the input must be (THETA, PHI)")
        self.coeff = self.add_weight(
            shape=(self.units, self.channels, self.bands*self.bands),
            initializer='uniform', name='coff')
        self.bias = self.add_weight(
            shape=(self.units, self.channels),
            initializer='zeros', name='bias')
        def lambdify_keep_dims(args, func):
            if func.is_constant():
                return lambda T,P: keras.backend.constant(np.full(T.shape, float(func)))
            return lambdify(args, func, modules=['tensorflow','numpy'])
        self.__SHs = [
            lambdify_keep_dims((theta, phi), SH(l, m, theta, phi).evalf())
            for l in range(self.bands)
            for m in range(-l, l+1)]
        super(Directional, self).build(input_shape)
    
    def call(self, inputs):
        inputs, input_angles = inputs
        THETA, PHI = input_angles[:,:,0], input_angles[:,:,1]
        logits = keras.backend.sum([
            keras.backend.repeat_elements(
                keras.backend.expand_dims(
                    self.__SHs[i](THETA, PHI), axis=-1), self.channels, axis=-1) * \
            keras.backend.repeat_elements(
                keras.backend.expand_dims(
                    self.coeff[:,:,i], axis=0), inputs.shape[1], axis=0)
            for i in range(0, len(self.__SHs))], axis=0)
        logits = keras.backend.reshape(
            inputs @ keras.backend.reshape(logits, (-1,self.units*self.channels)),
            (-1,self.units,self.channels))
        logits += self.bias
        activation = keras.backend.sum(keras.backend.tanh(logits), axis=-1)
        return activation
    
    def compute_output_shape(self, input_shape):
        return (input_shape[0][0], self.units)

In [7]:
X = X_input = keras.layers.Input((28,28))
X = keras.layers.Reshape((28*28,))(X)
X_angle = keras.layers.Lambda(lambda _: keras.backend.constant(
    np.random.rand(28*28,64,2)*(np.pi,np.pi*2)))([])
X = Directional(64, channels=3, bands=4)([X, X_angle])
X_angle = keras.layers.Lambda(lambda _: keras.backend.constant(
    np.random.rand(64,32,2)*(np.pi,np.pi*2)))([])
X = Directional(32, channels=2, bands=5)([X, X_angle])
X_angle = keras.layers.Lambda(lambda _: keras.backend.constant(
    np.random.rand(32,10,2)*(np.pi,np.pi*2)))([])
X = Directional(10, channels=1, bands=6)([X, X_angle])
X = keras.layers.Activation('softmax')(X)
M = keras.Model(X_input, X)
M.compile('adam', 'sparse_categorical_crossentropy', ['acc'])
M.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 28, 28)       0                                            
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 784)          0           input_1[0][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               (784, 64, 2)         0                                            
__________________________________________________________________________________________________
directional_1 (Directional)     (None, 64)           3264        reshape_1[0][0]                  
                                                                 lambda_1[0][0]                   
__________

In [8]:
%%time
M.fit(X_train, Y_train, validation_data=(X_test, Y_test),
    batch_size=128, epochs=50, callbacks=[
        keras.callbacks.ReduceLROnPlateau(monitor='loss', patience=2, verbose=True),
        keras.callbacks.EarlyStopping(monitor='loss'),
    ])

Train on 60000 samples, validate on 10000 samples
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
CPU times: user 5min 56s, sys: 28.9 s, total: 6min 24s
Wall time: 5min 33s


<keras.callbacks.History at 0x7ff7a6995e48>