In [1]:
import os

import functools
import jax.numpy as jnp

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import haiku as hk
import jax
import numpy as np

from deeprte.data.pipeline import DataPipeline
from deeprte.model.tf.rte_dataset import np_to_tensor_dict, divide_batch_feat
from deeprte.model.tf.rte_features import _BATCH_FEATURE_NAMES, _COLLOCATION_FEATURE_NAMES
from deeprte.config import get_config

In [2]:
import functools
import haiku as hk
import jax
import jax.numpy as jnp


def parameter_shapes(params):
  """Make printing parameters a little more readable."""
  return jax.tree_util.tree_map(lambda p: p.shape, params)


def transform_and_print_shapes(fn, x_shape=(2, 3)):
  """Print name and shape of the parameters."""
  rng = jax.random.PRNGKey(42)
  x = jnp.ones(x_shape)

  transformed_fn = hk.transform(fn)
  params = transformed_fn.init(rng, x)
  print('\nThe name and shape of the parameters are:')
  print(parameter_shapes(params))

def assert_all_equal(params_1, params_2):
  assert all(jax.tree_util.tree_leaves(
      jax.tree_util.tree_map(lambda a, b: (a == b).all(), params_1, params_2)))

In [3]:
w_init = hk.initializers.TruncatedNormal(stddev=1)

class SimpleModule(hk.Module):
  """A simple module class with one variable."""

  def __init__(self, output_channels, name=None):
    super().__init__(name)
    assert isinstance(output_channels, int)
    self._output_channels = output_channels

  def __call__(self, x):
    w_shape = (x.shape[-1], self._output_channels)
    w = hk.get_parameter("w", w_shape, x.dtype, init=w_init)
    return jnp.dot(x, w)

In [4]:
from deeprte.model.layer_stack_alphafold import layer_stack

In [9]:
def f1(x):
  def f(x):
    simple_one = SimpleModule(output_channels=3)
    return simple_one(x)
  stack = hk.experimental.layer_stack(2)(f)
  # simple_one = SimpleModule(output_channels=3)
  # stack = hk.experimental.layer_stack(2)(simple_one)
  x = stack(x)
  return x

def f2(x):
  def f(x):
    simple_one = SimpleModule(output_channels=3)
    return simple_one(x)
  stack = layer_stack(2)(f)
  # simple_one = SimpleModule(output_channels=3)
  # stack = layer_stack(2)(simple_one)
  x = stack(x)
  return x

# transform_and_print_shapes(f)

In [10]:
rng = jax.random.PRNGKey(42)
x = jnp.ones((3,3))

transformed_fn = hk.transform(f1)
params_1 = transformed_fn.init(rng, x)

transformed_fn = hk.transform(f2)
params_2 = transformed_fn.init(rng, x)

In [11]:
jnp.allclose(params_1['__layer_stack_no_per_layer/simple_module']["w"], params_2['__layer_stack_no_state/simple_module']["w"])

DeviceArray(False, dtype=bool)

In [12]:
params_2

{'__layer_stack_no_state/simple_module': {'w': DeviceArray([[[-1.1022383 , -1.0554173 ,  0.63827467],
                [ 0.17438962, -0.1926215 , -0.52258176],
                [-0.1787666 ,  1.1934166 , -0.30565965]],
  
               [[ 1.463427  , -0.10136782, -0.76178634],
                [ 1.030862  ,  1.0989444 , -0.48557475],
                [-0.01121216, -0.9829121 , -0.76874983]]], dtype=float32)}}