In [3]:
import numpy as np
import torch
import tensorflow as tf

In [4]:
VELOCITIES_X = np.array([
    [-1, 0, 1,],
    [-1, 0, 1,],
    [-1, 0, 1,],
]).reshape(-1)
VELOCITIES_Y = np.array([
     [1,  1,  1,],
     [0,  0,  0,],
    [-1, -1, -1,],
]).reshape(-1)
WEIGHTS_MAT = np.array([
    [1/36, 1/9, 1/36,],
    [1/9,  4/9, 1/9,],
    [1/36, 1/9, 1/36,],
]).reshape(-1)

In [5]:
a11 = np.array([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9],
])

a12 = np.array([  # a11 + 10
    [11, 12, 13],
    [14, 15, 16],
    [17, 18, 19],
])

a21 = np.array([
    [21, 22, 23],
    [24, 25, 26],
    [27, 28, 29],
])

a22 = np.array([
    [21, 22, 23],
    [24, 25, 26],
    [27, 28, 29],
])

data = np.array([
    [a11, a12],
    [a21, a22],
])
data = data.astype(np.float32)
data = np.random.rand(2, 2, 3, 3)

In [6]:
print("before reshape", data.shape)
data_flat = data.reshape(*data.shape[:2], -1)
print("after reshape", data.shape)

before reshape (2, 2, 3, 3)
after reshape (2, 2, 3, 3)


In [7]:
data_flat

array([[[0.10565302, 0.32626684, 0.08616004, 0.10136257, 0.77084287,
         0.96432351, 0.48993472, 0.57517283, 0.31116092],
        [0.92045774, 0.53245491, 0.12824768, 0.57679453, 0.35308036,
         0.47117326, 0.82012214, 0.17665592, 0.58657387]],

       [[0.98215709, 0.83071514, 0.79922174, 0.78835093, 0.08582844,
         0.37492559, 0.04919977, 0.19289428, 0.63192522],
        [0.41513291, 0.88426208, 0.210706  , 0.00876332, 0.07120861,
         0.57305402, 0.54778629, 0.63190535, 0.23984528]]])

# Tensorflow block

Interactive session can be replace with [Eager Execution](https://www.tensorflow.org/guide/eager)

In [8]:
def cmp(one, two):
    if not isinstance(one, np.ndarray):
        one = one.numpy()
    if not isinstance(two, np.ndarray):
        two = two.numpy()
    assert np.all(one == two)
    
PRINT = False

def print_v(*data):
    if PRINT:
        print(*data)

In [33]:
def build_graph():
    dtype = tf.float32
    velocities_x_tf = tf.constant(VELOCITIES_X, dtype=dtype)
    velocities_y_tf = tf.constant(VELOCITIES_Y, dtype=dtype)
    weights_tf = tf.constant(WEIGHTS_MAT, dtype=dtype)
    
    def ones_init(shape, dtype=None, partition_info=None):
        kernel = np.zeros(shape)
        kernel[0, 0, :, 0] = 1.0
        return tf.cast(kernel, dtype)

    sum_conv = tf.keras.layers.Conv2D(1, (1, 1), kernel_initializer=ones_init)

    def vel_x_init_many_to_one(shape, dtype=None, partition_info=None):
        kernel = np.zeros(shape)
        kernel[0, 0, :, 0] = VELOCITIES_X
        return tf.cast(kernel, dtype)

    vel_x_conv = tf.keras.layers.Conv2D(1, (1, 1), kernel_initializer=vel_x_init_many_to_one)


    def vel_y_init_many_to_one(shape, dtype=None, partition_info=None):
        kernel = np.zeros(shape)
        kernel[0, 0, :, 0] = VELOCITIES_Y
        return tf.cast(kernel, dtype)

    vel_y_conv = tf.keras.layers.Conv2D(1, (1, 1), kernel_initializer=vel_y_init_many_to_one)
    return velocities_x_tf, velocities_y_tf, weights_tf, sum_conv, vel_x_conv, vel_y_conv

In [78]:
from tensorflow.keras import Input, Model

batch = Input(shape=data_flat.shape)
velocities_x_tf, velocities_y_tf, weights_tf, sum_conv, vel_x_conv, vel_y_conv = build_graph()
rho = sum_conv(batch)
ux_lattices = vel_x_conv(batch) / rho
uy_lattices = vel_y_conv(batch) / rho
ux_elements = tf.math.multiply(ux_lattices, velocities_x_tf)
uy_elements = tf.math.multiply(uy_lattices, velocities_y_tf)
before_weights = (
    1 + 3 * (ux_elements + uy_elements) +
    9 * (ux_elements + uy_elements) ** 2 / 2 - 
    3 * (ux_lattices ** 2 + uy_lattices ** 2) / 2
)
after_weights = tf.math.multiply(before_weights, weights_tf)
F_eq = tf.math.multiply(rho, after_weights)
model = Model(inputs=batch, outputs=F_eq)
graph_model = tf.function(model)
tf.saved_model.save(model, '/tmp/model')
loaded = tf.saved_model.load('/tmp/model')
infer = loaded.signatures["serving_default"]

def predict(data):
    data_batch = data.reshape(1, *data.shape)
    tf_res = infer(tf.constant(data_batch, dtype=tf.float32))
    np_res = tf_res[model.output_names[0]].numpy().squeeze()
    return np_res

INFO:tensorflow:Assets written to: /tmp/model/assets


In [79]:
%%timeit
predict(data_flat)

267 µs ± 36.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [21]:
def calc_tf_eq(data_flat, tf_params):
    velocities_x_tf, velocities_y_tf, weights_tf, sum_conv, vel_x_conv, vel_y_conv = tf_params
#     batch = data_flat.reshape(1, *data_flat.shape)
    batch = data_flat
    rho = sum_conv(batch)
    ux_lattices = vel_x_conv(batch) / rho
    uy_lattices = vel_y_conv(batch) / rho
    ux_elements = tf.math.multiply(ux_lattices, velocities_x_tf)
    uy_elements = tf.math.multiply(uy_lattices, velocities_y_tf)
    before_weights = (
        1 + 3 * (ux_elements + uy_elements) +
        9 * (ux_elements + uy_elements) ** 2 / 2 - 
        3 * (ux_lattices ** 2 + uy_lattices ** 2) / 2
    )
    after_weights = tf.math.multiply(before_weights, weights_tf)
    F_eq = tf.math.multiply(rho, after_weights)
    return F_eq

# We should swithch from old to new indexes somewhere
# F_eq_tf = calc_tf_eq(data_flat, tf_params)
# print("F_eq_tf", F_eq_tf[0][0].reshape(3, 3))

In [25]:
calc_tf_eq_func = tf.function(calc_tf_eq, jit_compile=True)

In [24]:
%%timeit
calc_tf_eq(data_flat.reshape(1, *data_flat.shape), tf_params);

1.23 ms ± 26.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [23]:
%%timeit
calc_tf_eq_func(data_flat.reshape(1, *data_flat.shape), tf_params);

319 µs ± 45.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [27]:
%%timeit
calc_tf_eq_func(data_flat.reshape(1, *data_flat.shape), tf_params);

353 µs ± 20.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [22]:
def calc_np_eq(data_flat):
    F = data_flat
    F_eq_l = np.zeros(F.shape)
    for x_idx in range(F.shape[0]):
        for y_idx in range(F.shape[1]):
            lattice = F[x_idx, y_idx, :]
            rho_l = np.sum(lattice)
            ux_l = np.sum(lattice * VELOCITIES_X) / rho_l
            uy_l = np.sum(lattice * VELOCITIES_Y) / rho_l
            print_v("rho_l", rho_l)
            print_v("ux_l", ux_l)
            print_v("uy_l", uy_l)
            ux_elements = VELOCITIES_X * ux_l
            uy_elements = VELOCITIES_Y * uy_l
            print_v("ux_elements", ux_elements)
            print_v("uy_elements", uy_elements)
            before_weights = (
                1 + 3 * (ux_elements + uy_elements) +
                9 * (ux_elements + uy_elements) ** 2 / 2 - 
                3 * (ux_l ** 2 + uy_l ** 2) / 2
            )
            print_v("before_weights", before_weights)
            after_weights = WEIGHTS_MAT * before_weights
            print_v("after_weights", after_weights)
            F_eq_lat_2 = rho_l * WEIGHTS_MAT * before_weights
            F_eq_lattice = rho_l * WEIGHTS_MAT * (1 +
                3 * (VELOCITIES_X * ux_l + VELOCITIES_Y * uy_l) +
                9 * (VELOCITIES_X * ux_l + VELOCITIES_Y * uy_l) ** 2 / 2 - 
                3 * (ux_l ** 2 + uy_l ** 2) / 2
            )
            assert np.all(F_eq_lat_2 == F_eq_lattice)
            print_v("F_eq_lattice", F_eq_lattice)
            F_eq_l[x_idx, y_idx, :] = F_eq_lattice
    return F_eq_l

F_eq_np = calc_np_eq(data_flat)
print("F_eq_np", F_eq_np[0][0].reshape(3, 3))

F_eq_np [[0.13733372 0.53035208 0.12802397]
 [0.60559401 2.33841034 0.56433773]
 [0.16685153 0.64440603 0.15553313]]


# Olde debug verions

In [21]:
sess = tf.compat.v1.InteractiveSession()

dtype = tf.float32
velocities_x_tf = tf.constant(velocities_x.reshape(-1),dtype=dtype)
velocities_y_tf = tf.constant(velocities_y.reshape(-1),dtype=dtype)
weights_tf = tf.constant(weights_mat.reshape(-1), dtype=dtype)


def ones_init(shape, dtype=None, partition_info=None):
    kernel = np.zeros(shape)
    kernel[0, 0, :, 0] = 1.0
    return tf.cast(kernel, dtype)

sum_conv = tf.keras.layers.Conv2D(1, (1, 1), kernel_initializer=ones_init)

def vel_x_init_many_to_one(shape, dtype=None, partition_info=None):
    kernel = np.zeros(shape)
    kernel[0, 0, :, 0] = velocities_x.reshape(-1)
    return tf.cast(kernel, dtype)

vel_x_conv = tf.keras.layers.Conv2D(1, (1, 1), kernel_initializer=vel_x_init_many_to_one)


def vel_y_init_many_to_one(shape, dtype=None, partition_info=None):
    kernel = np.zeros(shape)
    kernel[0, 0, :, 0] = velocities_y.reshape(-1)
    return tf.cast(kernel, dtype)

vel_y_conv = tf.keras.layers.Conv2D(1, (1, 1), kernel_initializer=vel_y_init_many_to_one)


# batch = tf.constant(data_flat.reshape(1, *data_flat.shape), dtype=dtype)
batch = data_flat.reshape(1, *data_flat.shape)
rho = sum_conv(batch)
ux_lattices = vel_x_conv(batch) / rho
uy_lattices = vel_y_conv(batch) / rho
ux_elements = tf.math.multiply(ux_lattices, velocities_x_tf)
uy_elements = tf.math.multiply(uy_lattices, velocities_y_tf)
before_weights = (
    1 + 3 * (ux_elements + uy_elements) +
    9 * (ux_elements + uy_elements) ** 2 / 2 - 
    3 * (ux_lattices ** 2 + uy_lattices ** 2) / 2
)
after_weights = tf.math.multiply(before_weights, weights_tf)
F_eq = tf.math.multiply(rho, after_weights)

sess.close()

In [22]:
print("rho", rho.numpy().squeeze()[0][0])
print("ux", ux_lattices.numpy().squeeze()[0][0])
print("uy", uy_lattices.numpy().squeeze()[0][0])
print("ux_elements", ux_elements.numpy().squeeze()[0][0])
print("before_weights", before_weights.numpy().squeeze()[0][0])
print("after_weights", after_weights.numpy().squeeze()[0][0])
print("F_eq", F_eq.numpy().squeeze()[0][0].reshape(3, 3))
# after_weights [
#  [0.01148148 0.02814815 0.00703704]
#  [0.04592593 0.32592593 0.13481481]
#  [0.05148148 0.29481481 0.10037037]]

rho 4.4057665
ux -0.051699623
uy -0.2041674
ux_elements [ 0.05169962 -0.         -0.05169962  0.05169962 -0.         -0.05169962
  0.05169962 -0.         -0.05169962]
before_weights [0.5806698  0.50854146 0.4604689  1.1005911  0.9334642  0.7903932
 1.995671   1.733546   1.4954765 ]
after_weights [0.01612972 0.05650461 0.0127908  0.1222879  0.414873   0.08782146
 0.05543531 0.19261622 0.04154101]
F_eq [[0.07106376 0.2489461  0.05635329]
 [0.5387719  1.8278335  0.38692084]
 [0.24423502 0.8486221  0.18302001]]


In [12]:
F_eq.numpy().squeeze()[1][1].reshape(3, 3)

array([[0.09538613, 0.41864648, 0.11592166],
       [0.5784694 , 2.58973   , 0.72433376],
       [0.22581865, 1.0040988 , 0.27821532]], dtype=float32)

In [13]:
F_eq_l[1][1]

array([[0.09538612, 0.41864649, 0.11592166],
       [0.57846934, 2.58972995, 0.72433379],
       [0.22581862, 1.00409876, 0.2782153 ]])

In [9]:
F = data
F_eq_l = np.zeros(F.shape)
for x_idx in range(F.shape[0]):
    for y_idx in range(F.shape[1]):
        lattice = F[x_idx, y_idx, :]
        rho_l = np.sum(lattice)
        ux_l = np.sum(lattice * velocities_x) / rho_l
        uy_l = np.sum(lattice * velocities_y) / rho_l
        print_v("rho_l", rho_l)
        print_v("ux_l", ux_l)
        print_v("uy_l", uy_l)
        ux_elements = velocities_x * ux_l
        uy_elements = velocities_y * uy_l
        print_v("ux_elements", ux_elements)
        print_v("uy_elements", uy_elements)
        before_weights = (
            1 + 3 * (ux_elements + uy_elements) +
            9 * (ux_elements + uy_elements) ** 2 / 2 - 
            3 * (ux_l ** 2 + uy_l ** 2) / 2
        )
        print_v("before_weights", before_weights)
        after_weights = weights_mat * before_weights
        print_v("after_weights", after_weights)
        F_eq_lat_2 = rho_l * weights_mat * before_weights
        F_eq_lattice = rho_l * weights_mat * (1 +
            3 * (velocities_x * ux_l + velocities_y * uy_l) +
            9 * (velocities_x * ux_l + velocities_y * uy_l) ** 2 / 2 - 
            3 * (ux_l ** 2 + uy_l ** 2) / 2
        )
        assert np.all(F_eq_lat_2 == F_eq_lattice)
        print_v("F_eq_lattice", F_eq_lattice)
        F_eq_l[x_idx, y_idx, :] = F_eq_lattice
#         break
#     break

In [141]:
lattice.shape

(9,)

In [213]:
# Save of the initila work

sess = tf.compat.v1.InteractiveSession()



sum_conv = tf.keras.layers.Conv2D(1, (1, 1), kernel_initializer='Ones')

# cmp(rho, data_flat.sum(axis=-1))

def vel_x_init_many_to_one(shape, dtype=None, partition_info=None):
    kernel = np.zeros(shape)
    kernel[0, 0, :, 0] = velocities_x.reshape(-1)
    return kernel

vel_x_conv = tf.keras.layers.Conv2D(1, (1, 1), kernel_initializer=vel_x_init_many_to_one)


def vel_y_init_many_to_one(shape, dtype=None, partition_info=None):
    kernel = np.zeros(shape)
    kernel[0, 0, :, 0] = velocities_y.reshape(-1)
    return kernel

vel_y_conv = tf.keras.layers.Conv2D(1, (1, 1), kernel_initializer=vel_y_init_many_to_one)


def vel_x_init_one_to_many(shape, dtype=None, partition_info=None):
    kernel = np.zeros(shape)
    kernel[0, 0, 0, :] = velocities_x.reshape(-1)
    return kernel

vel_x_conv_to_many = tf.keras.layers.Conv2D(9, (1, 1), kernel_initializer=vel_x_init_one_to_many)


def vel_y_init_one_to_many(shape, dtype=None, partition_info=None):
    kernel = np.zeros(shape)
    kernel[0, 0, 0, :] = velocities_y.reshape(-1)
    return kernel

vel_y_conv_to_many = tf.keras.layers.Conv2D(9, (1, 1), kernel_initializer=vel_y_init_one_to_many)

def weight_init(shape, dtype=None, partition_info=None):
    kernel = np.zeros(shape)
    kernel[0, 0, 0, :] = weights_mat.reshape(-1)
    return kernel

weight_conv = tf.keras.layers.Conv2D(9, (1, 1), kernel_initializer=weight_init)

batch = tf.constant(data_flat.reshape(1, *data_flat.shape), dtype=tf.float32)
rho = sum_conv(batch)
ux_lattices = vel_x_conv(batch) / rho
uy_lattices = vel_y_conv(batch) / rho
# ux_elements = vel_x_conv_to_many(ux_lattices)
# uy_elements = vel_y_conv_to_many(uy_lattices)
ux_elements = tf.math.multiply(ux_lattices, velocities_x.reshape(-1))
uy_elements = tf.math.multiply(uy_lattices, velocities_y.reshape(-1))
before_weights = (
    1 + 3 * (ux_elements + uy_elements) +
    9 * (ux_elements + uy_elements) ** 2 / 2 - 
    3 * (ux_l ** 2 + uy_l ** 2) / 2
)
after_weights = tf.math.multiply(before_weights, weights_mat.reshape(-1))
F_eq = tf.math.multiply(rho_l, after_weights)

sess.close()

# Pytorch block

In [54]:
# convert to batch
def to_batch(data):
    tensor = torch.from_numpy(data.reshape(1, *data.shape))
    tensor = tensor.permute(0, 3, 1, 2)
    return tensor

def from_batch(batch):
    batch =  batch.permute(0, 2, 3, 1)
    return batch.cpu().numpy()

def cmp(pt_array, np_array):
    assert np.all(np_array == pt_array.detach().cpu().numpy())

assert np.all(data == from_batch(to_batch(data)))

In [37]:
batch = to_batch(data)

In [82]:
sum_conv = torch.nn.Conv2d(9, 1, kernel_size=(1, 1), stride=1)
sum_conv.bias.data.fill_(0)
sum_conv.weight.data.fill_(1);

In [56]:
summed_pt = sum_conv(batch)
summed_np = data.sum(axis=-1)
print(summed_pt)
print(summed_np)

cmp(summed_pt, summed_np)

tensor([[[[ 45., 135.],
          [225., 225.]]]], grad_fn=<ThnnConv2DBackward0>)
[[ 45. 135.]
 [225. 225.]]


In [97]:
weight_conv = torch.nn.Conv2d(1, 9, kernel_size=(1, 1), stride=1)
weight_conv.bias.data.fill_(0)
init_weight = weight_conv.weight.data
new_weight = torch.from_numpy(weights_mat.reshape(-1)).reshape(init_weight.shape)
new_weight = new_weight.type(init_weight.dtype)
print(weight_conv.weight.data.dtype)
weight_conv.weight.data = new_weight
print(weight_conv.weight.data.dtype)

torch.float32
torch.float32


Upper cell can be replaced with functional calls
https://discuss.pytorch.org/t/setting-custom-kernel-for-cnn-in-pytorch/27176

In [None]:
weight_conv = torch.nn.Conv2d(9, 1, kernel_size=(1, 1), stride=1)
weight_conv.bias.data.fill_(0)
weight_conv.weight.data = torch.from_numpy(weights_mat.reshape(-1)).reshape(1, 9, 1, 1)