In [None]:
import numpy as np
import os
import struct
import itertools
import matplotlib.pyplot as plt

import jax, optax
from jax import random, numpy as jnp
from flax import linen as nn
from flax.training import train_state
from flax.core.frozen_dict import freeze, unfreeze

from google.colab import drive
drive.mount('/content/drive')
os.chdir('/content/drive/My Drive/jax_cnn/')

from read import fw_to_np

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
u, v, w, p, t, icount, x, y, z = fw_to_np('data')

In [None]:
# check size
print(np.shape(u))
print(np.shape(v))
print(np.shape(w))
print(np.shape(p))
print(np.shape(t))
print(np.shape(icount))
print(np.shape(x))
print(np.shape(y))
print(np.shape(z))

(80, 105, 15)
(79, 106, 15)
(79, 105, 16)
(79, 105, 15)
(79, 105, 15)
()
(80,)
(106,)
(16,)


In [None]:
filter0 = np.random.standard_normal((5,5,5))
filter1 = np.random.standard_normal((5,5,5))

filter_array = np.zeros([5, 5, 5, 2])
filter_array[:, :, :, 0] = filter0
filter_array[:, :, :, 1] = filter1

In [None]:
class FixedConvFilterModel(nn.Module):
    @nn.compact
    def __call__(self, x, apply_pooling=False):
        Nx = 10
        Ny = 10
        Nz = 10
        x = x.reshape([-1, Nx, Ny, Nz, 1])
        x = nn.Conv(features=2, kernel_size=(5, 5, 5), use_bias=False)(x)
        x = jnp.abs(x)
        x = nn.relu(x-0.2)
        if apply_pooling:
            x = nn.max_pool(x, window_shape=(2, 2, 2), strides=(2, 2, 2))
        return x

In [None]:
key, key1 = random.split(random.PRNGKey(0))
variables = FixedConvFilterModel().init(key1, images[0:1])

jax.tree_util.tree_map(lambda x: x.shape, variables['params'])

In [None]:
params = unfreeze(variables['params'])
params['Conv_0']['kernel'] = jnp.asarray(filter_array)
new_params = freeze(params)

state = train_state.TrainState.create(
    apply_fn=FixedConvFilterModel().apply,
    params=new_params,
    tx=optax.adam(learning_rate=0.001))

In [None]:
conv_output = jax.device_get(
    state.apply_fn({'params': state.params}, images[:]))
filter_vals = jax.device_get(state.params['Conv_0']['kernel'])