## Designing boundary conditions for CNNs

In [127]:
import jax.numpy as jnp
import jax
import numpy as np
import haiku as hk

In [11]:
def findPadding(kernel):
    
    padding = []
    for kernel_length in jnp.shape(kernel):
        if kernel_length % 2 == 0:
            raise AssertionError('Kernel must have odd lengths in each dimension')
        padding.append(kernel_length //2)
    
    return padding

def createPaddedMesh(mesh,padding):

    rowPad = padding[0]
    colPad = padding[1]
        
    (rows,cols) = jnp.shape(mesh)

    paddedMesh = jnp.zeros((rows + 2*padding[0],
                           cols + 2*padding[1]))
    return paddedMesh.at[rowPad:-rowPad,colPad:-colPad].set(mesh)

def createPaddedMesh_jit(mesh,kernel):
    
    padding = findPadding(kernel)
    rowPad = padding[0]
    colPad = padding[1]
        
    (rows,cols) = jnp.shape(mesh)

    paddedMesh = jnp.zeros((rows + 2*padding[0],
                           cols + 2*padding[1]))
    return paddedMesh.at[rowPad:-rowPad,colPad:-colPad].set(mesh)


createPaddedMesh_jit = jax.jit(createPaddedMesh_jit)

In [5]:
## Now lets apply boundary conditions: 
#     - periodic
#     - dirichlet
#     - neumann?

In [255]:
field = np.linspace(1,25,25)
field = field.reshape(5,5)

field = jnp.array(field)

kernel = jnp.ones((3,3))

padding=findPadding(kernel)

test_mesh = createPaddedMesh(field,padding)
# test_mesh = createPaddedMesh_jit(field,kernel)

## periodic
rowPad = padding[0]
colPad = padding[1]

pad = colPad

def periodicPadding(data,pad,axis=0):
    """
    implements periodic padding to both ends of given dimension
    
    axis=0 -> left and right
    axis=1 -> top and bottom (transpose then left and right then transpose)
    """
    if axis == 1:
        data = data.T

    data = data.at[:,:pad].set(data.at[:,-2*pad:-pad].get())

    data = data.at[:,-pad:].set(data.at[:,pad:2*pad].get())

    if axis == 1:
        data = data.T
        
    return data

def dirichletPadding(data,pad,leftPad,rightPad,axis=0):
    """
    implements dirichlet padding to both ends of given dimension
    
    axis=0 -> left and right
    axis=1 -> top and bottom (transpose then left and right then transpose)
    """
    if axis == 1:
        data = data.T

    data = data.at[:,:pad].set(leftPad)

    data = data.at[:,-pad:].set(rightPad)

    if axis == 1:
        data = data.T
        
    return data

def padCorners(data,pad,value):
    
    data = data.at[0,0].set(value)
    data = data.at[0,-1].set(value)
    
    data = data.at[-1,0].set(value)
    data = data.at[-1,-1].set(value)
    
    return data

def channelFlowPadding(data,kernel):
    (padRow,padCol) = findPadding(kernel)
    
    data = periodicPadding(data,padRow,axis = 1)
    
    data = dirichletPadding(data,padCol,1.1,1.1,axis = 0)
    
    return padCorners(data,pad,0)

print(test_mesh)
channelFlowPadding(test_mesh,kernel)



[[ 0.  0.  0.  0.  0.  0.  0.]
 [ 0.  1.  2.  3.  4.  5.  0.]
 [ 0.  6.  7.  8.  9. 10.  0.]
 [ 0. 11. 12. 13. 14. 15.  0.]
 [ 0. 16. 17. 18. 19. 20.  0.]
 [ 0. 21. 22. 23. 24. 25.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.]]


DeviceArray([[ 0. , 21. , 22. , 23. , 24. , 25. ,  0. ],
             [ 1.1,  1. ,  2. ,  3. ,  4. ,  5. ,  1.1],
             [ 1.1,  6. ,  7. ,  8. ,  9. , 10. ,  1.1],
             [ 1.1, 11. , 12. , 13. , 14. , 15. ,  1.1],
             [ 1.1, 16. , 17. , 18. , 19. , 20. ,  1.1],
             [ 1.1, 21. , 22. , 23. , 24. , 25. ,  1.1],
             [ 0. ,  1. ,  2. ,  3. ,  4. ,  5. ,  0. ]], dtype=float32)

In [124]:
test = test_mesh

(rowPad,colPad)=findPadding(kernel)
test = test.at[rowPad:-rowPad,colPad:-colPad].get()

In [221]:
dfield = jnp.dstack((field,))
jnp.shape(dfield)

(5, 5, 1)

In [227]:
class CNN(hk.Module):
    def __init__(self):
        super().__init__(name="CNN")
        self.conv1 = hk.Conv2D(output_channels=1, kernel_shape=(3,3), padding="SAME")

    def __call__(self, x):
        x = self.conv1(x)
        return x

def ConvNet(x):
    cnn = CNN()
    return cnn(x)

conv_net = hk.transform(ConvNet)

rng = jax.random.PRNGKey(42)

params = conv_net.init(rng,dfield)


In [229]:
pred = conv_net.apply(params,rng,dfield)

In [230]:
np.shape(pred)

(5, 5, 1)