In [26]:
# From Neel's code
import os
from pathlib import Path
from typing import Callable
import json
import functools
import math


import jax
import jax.numpy as jnp
from jax import random, nn

from optax import adam, rmsprop, sgd

import haiku as hk
from haiku.initializers import Initializer, Constant, RandomNormal, TruncatedNormal, VarianceScaling

import chex


import dataclasses
from meta_transformer import utils
from jax import vmap
from meta_transformer.transformer import Transformer
from meta_transformer.meta_model import MetaModelClassifier, NetEmbedding, ChunkCNN
import numpy as np

from typing import Dict, Any, Optional, Tuple, Callable, Mapping, Sequence, Iterable, Union, List
from jax.typing import ArrayLike


def ctc_net_fn(x: jnp.ndarray,
               n_classes: int,
               n_conv_layers: int = 1, #3,
               kernel_size: tuple = (3, 3),
               n_filters: int = 2, #32,
               n_fc_layers: int = 1, #3,
               fc_width: int = 8, #128,
               activation: Callable = nn.relu,
               w_init: Initializer = TruncatedNormal()) -> jnp.ndarray:  # TODO: Batchnorm?
    convs = [hk.Conv2D(output_channels=n_filters, kernel_shape=kernel_size, padding="SAME", w_init=w_init)
             for _ in range(n_conv_layers)]
    fcs = [hk.Linear(fc_width, w_init=w_init) for _ in range(n_fc_layers - 1)]

    seq = []
    for conv in convs:
        seq.append(conv)
        seq.append(activation)
    seq.append(hk.Flatten())
    for fc in fcs:
        seq.append(fc)
        seq.append(activation)
    seq.append(hk.Linear(n_classes, w_init=w_init))

    net = hk.Sequential(seq)
    return net(x)


key = random.PRNGKey(4)

def tree_shape(tree):
    return jax.tree_map(lambda x: x.shape, tree)

In [27]:
net = hk.without_apply_rng(hk.transform(ctc_net_fn))
key, subkey = random.split(key)
params = net.init(subkey, jnp.ones((1, 32, 32, 1)), 10)
param_shapes = tree_shape(params)
param_shapes

{'conv2_d': {'b': (2,), 'w': (3, 3, 1, 2)},
 'linear': {'b': (10,), 'w': (2048, 10)}}

In [4]:
chunk = ChunkCNN(1024, 4*256)
chunked_params = chunk(params)
#tree_shape(chunked_params)
print(chunked_params.keys())

dict_keys(['conv2_d_chunk_0', 'linear_chunk_0', 'linear_chunk_1', 'linear_chunk_2', 'linear_chunk_3', 'linear_chunk_4', 'linear_chunk_5', 'linear_chunk_6', 'linear_chunk_7', 'linear_chunk_8', 'linear_chunk_9', 'linear_chunk_10', 'linear_chunk_11', 'linear_chunk_12', 'linear_chunk_13', 'linear_chunk_14', 'linear_chunk_15', 'linear_chunk_16', 'linear_chunk_17', 'linear_chunk_18', 'linear_chunk_19', 'linear_chunk_20'])


In [5]:
jax.tree_map(lambda x: x.size, params)

{'conv2_d': {'b': 2, 'w': 18}, 'linear': {'b': 10, 'w': 20480}}

In [21]:
@chex.dataclass(frozen=True)
class Test:
    a: int
    b: int

    def method(self):
        return self.a + self.b

In [22]:
@jax.jit
def test_fn(test: Test):
    return test.method()

In [25]:
t = Test(a=5, b=2)
test_fn(t)

Array(7, dtype=int32, weak_type=True)

In [6]:
def unchunk_layers(chunked_params: Dict[str, ArrayLike]) -> Dict[str, jax.Array]:
    """Unchunk a dictionary of chunked parameters. Both in an output
    are flat (one-level) dictionaries.)"""
    unchunked_params = {}
    for k, v in chunked_params.items():
        layer, _ = k.split("_chunk_")
        if layer not in unchunked_params:
            unchunked_params[layer] = [v]
        else:
            unchunked_params[layer].append(v)
    unchunked_params = {k: jnp.concatenate(v) for k, v in unchunked_params.items()}
    return unchunked_params


def get_layer_sizes(params: Dict[str, Dict[str, ArrayLike]]) -> Dict[str, ArrayLike]:
    """Get the size (weights.size + bias.size) of each layer in params.)"""
    return {k: sum([v.size for v in layer.values()]) 
            for k, layer in params.items()}


def un

def nest_params(
        params: Dict[str, ArrayLike],
        nested_shapes: Dict[str, ArrayLike]) -> Dict[str, Dict[str, jax.Array]]:
    """Nest a flat dictionary of parameters into a nested dictionary."""
    nested_params = {}
    for layer, shapes in nested_shapes.items():
        nested_params[layer] = {}
        i = 0
        for k, v in shapes.items():
            nested_params[layer][k] = params[layer][i:i+v.size].reshape(v)
            i += v.size
    return nested_params


@dataclasses.dataclass
class UnChunkCNN:
    """Inverse of ChunkCNN."""
    linear_chunk_size: int
    conv_chunk_size: int
    param_shapes: Dict[str, ArrayLike]

    def __post_init__(self):
        self.layer_sizes = get_layer_sizes(self.param_shapes)

    def __call__(self, chunked_params: Dict[str, ArrayLike]) -> Dict[str, jax.Array]:
        """Map chunked CNN weights back to original shape."""
        # un-chunk the layers:
        params = unchunk_layers(chunked_params)
        # remove padding:
        params = {k: v[:self.layer_sizes[k]] for k, v in params.items()}
        # convert back to nested dict and reshape layers:
        nested_params = {}
        for layer, shapes in self.param_shapes.items():
            nested_params[layer] = {
                k: params[k][:] for k, v in shapes.items()
            }

        params = {k: v.reshape(self.param_shapes[k]) for k, v in params.items()}
        return params
    

In [15]:
params = unchunk_layers(chunked_params)

In [None]:

# remove padding and reshape:
params = {k: v[:np.prod(self.param_shapes[k])] for k, v in params.items()}
params = {k: v.reshape(self.param_shapes[k]) for k, v in params.items()}

In [7]:
tree_shape(unchunk_layers(chunked_params))

{'conv2_d': (1024,), 'linear': (21504,)}

In [9]:
unchunk = UnChunkCNN(1024, 4*256, param_shapes)
unchunked_params = unchunk(chunked_params)
# assert same shape
print(tree_shape(params) == tree_shape(unchunked_params))
#chex.assert_trees_all_close(params, unchunked_params, rtol=1, atol=1)
jax.tree_map(lambda x, y: jnp.allclose(x, y), params, unchunked_params)

IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(None, {'b': (2,), 'w': (3, 3, 1, 2)}, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).

In [87]:
import jax.numpy as jnp
import dataclasses
from typing import Dict, Tuple

@dataclasses.dataclass
class UnChunkCNN:
    linear_chunk_size: int
    conv_chunk_size: int
    original_param_shapes: Dict[str, Dict[str, Tuple[int]]]

    def __call__(self, chunked_params: dict) -> dict:
        """Merge CNN weight chunks back into the original parameters."""

        def unchunk_layers():
            unchunked_params = {}
            for key, chunk in chunked_params.items():
                layer_key, _, chunk_idx = key.rpartition('_chunk_')
                chunk_idx = int(chunk_idx)
                unchunked_params.setdefault(layer_key, []).extend([None] * (1 + chunk_idx - len(unchunked_params[layer_key])))
                unchunked_params[layer_key][chunk_idx] = chunk
            return unchunked_params

        def concatenate_chunks(unchunked_params):
            return {k: jnp.concatenate(vs) for k, vs in unchunked_params.items()}

        def reshape_layer_params(layer_key, layer_data):
            layer_shapes = self.original_param_shapes[layer_key]
            layer_params = {}
            start_idx = 0
            for param_key, param_shape in layer_shapes.items():
                param_size = jnp.prod(jnp.array(param_shape))
                layer_data_sliced = layer_data[start_idx:start_idx + param_size]
                layer_params[param_key] = jnp.reshape(layer_data_sliced, param_shape)
                start_idx += param_size
            return layer_params

        unchunked_params = unchunk_layers()
        concatenated_params = concatenate_chunks(unchunked_params)
        original_params = {layer_key: reshape_layer_params(layer_key, layer_data) for layer_key, layer_data in concatenated_params.items()}

        return original_params


In [58]:
import jax.numpy as jnp
import dataclasses
from typing import Dict, Tuple
import re

@dataclasses.dataclass
class UnChunkCNN:
    linear_chunk_size: int
    conv_chunk_size: int
    original_param_shapes: Dict[str, Dict[str, Tuple[int]]]

    def __call__(self, chunked_params: dict) -> dict:
        """Merge CNN weight chunks back into the original parameters."""
        # First, unchunk the layers
        unchunked_params = {}
        for key, chunk in chunked_params.items():
            layer_key, _, chunk_idx = key.rpartition('_chunk_')
            chunk_idx = int(chunk_idx)

            if layer_key not in unchunked_params:
                unchunked_params[layer_key] = []

            while len(unchunked_params[layer_key]) <= chunk_idx:
                unchunked_params[layer_key].append(None)

            unchunked_params[layer_key][chunk_idx] = chunk

        # Then, concatenate the chunks
        unchunked_params = {
            k: jnp.concatenate(vs) for k, vs in unchunked_params.items()
        }

        # Finally, reshape the parameters into their original shapes
        original_params = {}
        for layer_key, layer_data in unchunked_params.items():
            if layer_key not in self.original_param_shapes:
                raise ValueError(f"Layer key {layer_key} not found in original_param_shapes")

            layer_shapes = self.original_param_shapes[layer_key]

            layer_params = {}
            start_idx = 0
            for param_key, param_shape in layer_shapes.items():
                param_size = jnp.prod(jnp.array(param_shape))
                layer_params[param_key] = jnp.reshape(
                    layer_data[start_idx:start_idx + param_size], param_shape
                )
                start_idx += param_size

            original_params[layer_key] = layer_params

        return original_params


In [8]:
# chunked_params = chunk_cnn(params, 1024, 256)
# jax.tree_util.tree_map(lambda x: x.shape, chunked_params)

In [9]:
utils.count_params(params) / 1e6

4.23105

In [10]:
stacked = utils.tree_stack([params, params])
jax.tree_map(lambda x: x.shape, stacked)

{'conv2_d': {'b': (2, 32), 'w': (2, 3, 3, 1, 32)},
 'conv2_d_1': {'b': (2, 32), 'w': (2, 3, 3, 32, 32)},
 'conv2_d_2': {'b': (2, 32), 'w': (2, 3, 3, 32, 32)},
 'linear': {'b': (2, 128), 'w': (2, 32768, 128)},
 'linear_1': {'b': (2, 128), 'w': (2, 128, 128)},
 'linear_2': {'b': (2, 10), 'w': (2, 128, 10)}}

In [11]:
%autoreload

In [16]:
def model_fn(params: dict):
    net = MetaModelClassifier(
        model_size=4*32, 
        num_classes=10, 
        transformer=Transformer(
            num_heads=4,
            num_layers=2,
            key_size=32,
            dropout_rate=0.0,
        ))
    return net(params)


model = hk.transform(model_fn)
    
key, subkey = random.split(key)
meta_params = jax.jit(model.init)(subkey, stacked)
model_forward = jax.jit(model.apply)
key, subkey = random.split(key)
out = model_forward(meta_params, subkey, stacked)

In [17]:
out.shape

(2, 10)

In [18]:
out

Array([[-0.21297008, -0.2555064 , -0.22231516,  1.1032469 , -0.36101305,
        -0.2299385 , -0.43089166, -0.07022727, -1.0288874 , -0.95691407],
       [-0.21297008, -0.2555064 , -0.22231516,  1.1032469 , -0.36101305,
        -0.2299385 , -0.43089166, -0.07022727, -1.0288874 , -0.95691407]],      dtype=float32)

In [19]:
out[0] - out[1] 

Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)