In [6]:
from hydra import initialize, compose
import dotenv
import os
import pathlib
import torch

from rigl_torch.utils.checkpoint import Checkpoint
from rigl_torch.models import ModelFactory


In [7]:
def get_mod(run_id: str, device):
    with initialize("../configs", version_base="1.2.0"):
        cfg = compose(
            "config.yaml",
            overrides=[
                "compute.distributed=False",
                "dataset=imagenet",
                "model=vit",
                f"experiment.run_id={run_id}",
                "training.batch_size=2",
            ],
        )
    dotenv.load_dotenv("../.env", override=True)
    os.environ["IMAGE_NET_PATH"]
    checkpoint_dir = pathlib.Path(f"../artifacts/checkpoints/20230601_{run_id}")
    checkpoint = Checkpoint.load_best_checkpoint(checkpoint_dir=checkpoint_dir)
    model_state = checkpoint.model
    model = ModelFactory.load_model(
        model=cfg.model.name, dataset=cfg.dataset.name, diet=cfg.rigl.diet
    )
    model.to(device)
    try:
        model.load_state_dict(model_state)
    except RuntimeError:
        model_state = (
            checkpoint.get_single_process_model_state_from_distributed_state()
        )
        model.load_state_dict(model_state)
    return model.get_submodule("encoder.layers.encoder_layer_11.mlp.0")


__RUN_IDS = {90: "nrblbn15"}

# t_fc = get_mod(__RUN_IDS[90], "cpu") # Run me if you have the artifact on this device

with open("../artifacts/trained_vit_layers/vit16-mlp-layer-90-torch.pkl", "rb") as handle:
    t_fc = torch.load(handle)



In [8]:
import torch

import jax
from typing import Any, Callable, Sequence, Optional, Tuple, Union
from jax import random, vmap, numpy as jnp
import flax
from flax import linen as nn
import numpy as np
from functools import partial


with torch.no_grad():
    kernel = t_fc.weight.detach().cpu().numpy()
    print(kernel.shape)
    bias = t_fc.bias.detach().cpu().numpy()

    # [outC, inC] -> [inC, outC]
    kernel = jnp.transpose(kernel, (1, 0))

    key = random.key(0)
    x = random.normal(key, (64, t_fc.in_features))

    variables = {'params': {'kernel': kernel, 'bias': bias}}
    j_fc = nn.Dense(features=t_fc.out_features)
    j_out = j_fc.apply(variables, x)

    t_x = torch.from_numpy(np.array(x))
    t_out = t_fc(t_x)
    t_out = t_out.detach().cpu().numpy()

    np.testing.assert_almost_equal(j_out, t_out, decimal=3)
    

(3072, 768)


In [9]:
# with open("../artifacts/trained_vit_layers/vit16-mlp-layer-90-torch.pkl", "wb") as handle:
#     torch.save(t_fc, handle)

In [10]:
from numpy.typing import DTypeLike
from jax.typing import ArrayLike 
from flax.core.scope import VariableDict
from copy import deepcopy


def _torch_get_active_neuron_idx(weight: torch.Tensor) -> torch.Tensor:
    # We find all-zero rows in first dimension of weight tensor
    return weight.sum(dim=list(range(1, weight.dim()))) != 0


def _torch_get_fine_grained_idx(
    weight: torch.Tensor, active_neuron_idx
) -> torch.Tensor:
    return (weight[active_neuron_idx] != 0).to(torch.bool)



def _get_active_neuron_idx(kernel: ArrayLike) -> jax.Array:
    # We find all-zero rows in first dimension of weight tensor
    # NOTE: Only works with fc for now, need to test conv later
    # return weight.sum(dim=list(range(1, weight.dim()))) != 0
    # return kernel.sum(axis=0)!=0  # we swap dim with torch
    return kernel.sum(axis=1)!=0  # we swap dim with torch


def _get_fine_grained_idx(
    kernel: ArrayLike, active_neuron_idx: ArrayLike
) -> jax.Array:
    return (kernel[active_neuron_idx] != 0).astype("bool")


kernel, bias = variables["params"]["kernel"].T, variables["params"]["bias"].T
active_neuron_idx = _get_active_neuron_idx(kernel)
fine_grained_idx = _get_fine_grained_idx(kernel, active_neuron_idx)

t_ani = _torch_get_active_neuron_idx(t_fc.weight)
t_fgi = _torch_get_fine_grained_idx(t_fc.weight, t_ani)



assert (active_neuron_idx == t_ani.numpy()).all()
assert (fine_grained_idx == t_fgi.numpy()).all()  # NOTE: transpose here

In [11]:
# struc_kernel = kernel[:, active_neuron_idx]
# struc_kernel[fine_grained_idx].shape
# condensed_kernel = struc_kernel[fine_grained_idx].reshape(-1, struc_kernel.shape[1])
struc_kernel = kernel[active_neuron_idx]
struc_kernel[fine_grained_idx].shape
condensed_kernel = struc_kernel[fine_grained_idx].reshape(struc_kernel.shape[0], -1)

In [12]:
fine_grained_idx.shape

(1145, 768)

In [13]:
fine_grained_idx.nonzero()

(Array([   0,    0,    0, ..., 1144, 1144, 1144], dtype=int32),
 Array([  1,   4,  17, ..., 757, 759, 767], dtype=int32))

In [15]:
from numpy.typing import DTypeLike
from jax.typing import ArrayLike
from flax.core.scope import VariableDict
from copy import deepcopy





def condensed_param_converter(dense_params: VariableDict, dtype: Optional[DTypeLike]=None) -> VariableDict:
    dense_params = deepcopy(dense_params)
    kernel, bias = dense_params["params"]["kernel"].T, dense_params["params"]["bias"].T
    # Without transpose here I found broadcasting issues in original condensed implementation
    if dtype is None:
        dtype = kernel.dtype

    
    active_neuron_idx = _get_active_neuron_idx(kernel)
    fine_grained_idx = _get_fine_grained_idx(kernel, active_neuron_idx)
    struct_kernel = kernel[active_neuron_idx]
    condensed_kernel = struc_kernel[fine_grained_idx].reshape(struct_kernel.shape[0], -1)
    # struct_kernel = kernel[:, active_neuron_idx]
    # condensed_kernel = struc_kernel[fine_grained_idx].reshape(-1, struct_kernel.shape[1])
    
    # TODO: Can speed-up the below, we used torch.nonzero(as_tuple=True) prev
    idxs = []
    for neuron in fine_grained_idx:
        idxs.append(jnp.argwhere(neuron!=0).flatten())
    indx_seqs = jnp.stack(idxs)
    return dict(
        params=dict(
            kernel=condensed_kernel,
            bias=bias[active_neuron_idx],
            indx_seqs=indx_seqs
        )
    )

condensed_params = condensed_param_converter(variables)




In [16]:
class CondensedLinear(nn.Module):
    features: int
    fan_in: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros_init()

    @nn.compact
    def __call__(self, input: ArrayLike) -> jax.Array:
        kernel = self.param("kernel", self.kernel_init, (self.features, self.fan_in))
        bias = self.param("bias", self.bias_init, (self.features,))
        indx_seqs = self.param("indx_seqs", self.kernel_init, (self.features, self.fan_in))
        return jnp.sum(kernel * input[:, indx_seqs], axis=2) + bias
    

class CondensedLinearWithState(nn.Module):
    kernel: ArrayLike
    bias: ArrayLike
    indx_seqs: ArrayLike
        
    def __call__(self, input: ArrayLike) -> jax.Array:
        return jnp.sum(self.kernel * input[:, self.indx_seqs], axis=2) + self.bias

In [17]:
# return jnp.sum(variables["params"]["kernel"] * input[:, variables["params"]["indx_seqs"]], axis=2) + variables["params"]["bias"]
cl = CondensedLinear(features=1145, fan_in=206)
cl_fast = jax.jit(lambda x: cl.apply(condensed_params, x))
cl_fast(x)

Array([[ 0.06890744, -0.1435932 ,  0.18185034, ...,  0.19422379,
        -1.0179393 , -0.8103184 ],
       [-0.5491562 , -0.02691372, -0.9410149 , ...,  0.36850947,
         0.23354463, -0.14654508],
       [ 0.13191406, -0.25492084,  0.6769124 , ...,  0.5262484 ,
        -0.12328374, -0.06665101],
       ...,
       [ 0.78998613, -0.04756172,  0.49412474, ...,  0.3047766 ,
        -0.10613711,  1.3015689 ],
       [ 0.5847777 ,  0.41667846, -0.80086994, ...,  0.18786182,
         0.54880977, -0.97745925],
       [ 0.6309595 , -0.260169  , -1.5061429 , ..., -0.5892743 ,
         0.8418522 , -0.12319788]], dtype=float32)

In [18]:
%%timeit 
cl_fast(x).block_until_ready()

175 µs ± 7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [19]:
j_fc_fast = jax.jit(lambda x: j_fc.apply(variables, x))
j_fc_fast(x)

Array([[ 2.2639807e-12,  6.8915978e-02,  1.9262955e-12, ...,
        -1.2273963e-12, -8.1014007e-01, -8.9613508e-12],
       [ 2.2639807e-12, -5.4868144e-01,  1.9262955e-12, ...,
        -1.2273963e-12, -1.4640963e-01, -8.9613508e-12],
       [ 2.2639807e-12,  1.3192402e-01,  1.9262955e-12, ...,
        -1.2273963e-12, -6.6684671e-02, -8.9613508e-12],
       ...,
       [ 2.2639807e-12,  7.9031599e-01,  1.9262955e-12, ...,
        -1.2273963e-12,  1.3017950e+00, -8.9613508e-12],
       [ 2.2639807e-12,  5.8474576e-01,  1.9262955e-12, ...,
        -1.2273963e-12, -9.7734028e-01, -8.9613508e-12],
       [ 2.2639807e-12,  6.3065708e-01,  1.9262955e-12, ...,
        -1.2273963e-12, -1.2312210e-01, -8.9613508e-12]], dtype=float32)

In [20]:
%%timeit
j_fc_fast(x).block_until_ready()

61.8 µs ± 1.51 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [21]:
cl_2 = CondensedLinearWithState(**condensed_params["params"])
cl_2_fast = jax.jit(lambda x: cl_2(x))
cl_2_fast(x)

Array([[ 0.06890744, -0.1435932 ,  0.18185034, ...,  0.19422379,
        -1.0179393 , -0.8103184 ],
       [-0.5491562 , -0.02691372, -0.9410149 , ...,  0.36850947,
         0.23354463, -0.14654508],
       [ 0.13191406, -0.25492084,  0.6769124 , ...,  0.5262484 ,
        -0.12328374, -0.06665101],
       ...,
       [ 0.78998613, -0.04756172,  0.49412474, ...,  0.3047766 ,
        -0.10613711,  1.3015689 ],
       [ 0.5847777 ,  0.41667846, -0.80086994, ...,  0.18786182,
         0.54880977, -0.97745925],
       [ 0.6309595 , -0.260169  , -1.5061429 , ..., -0.5892743 ,
         0.8418522 , -0.12319788]], dtype=float32)

In [22]:
%%timeit
cl_2_fast(x).block_until_ready()

168 µs ± 5.65 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [23]:
variables["params"]["bias"].shape

(3072,)

In [24]:
variables = {'params': {'kernel': kernel, 'bias': bias}}
j_fc = nn.Dense(features=t_fc.out_features)
j_out = j_fc.apply(variables, x)

ScopeParamShapeError: Initializer expected to generate shape (3072, 768) but got shape (768, 3072) instead for parameter "kernel" in "/". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

In [None]:
condensed_params["params"]["indx_seqs"].shape

(1145, 206)

In [None]:
x.shape

# TODO: Figure out this broadcasting buiness

(64, 768)

In [None]:
x.shape

(64, 768)

In [None]:
condensed_params["params"]["indx_seqs"].shape

(206, 1145)

In [None]:
condensed_params["params"]["indx_seqs"]

Array([[  1,   1,   1, ...,   1,   2,   1],
       [  4,   2,   2, ...,   2,  34,   2],
       [ 17,  17,  16, ...,   4,  42,  18],
       ...,
       [757, 759, 758, ..., 757, 759, 757],
       [759, 761, 764, ..., 762, 764, 759],
       [762, 767, 767, ..., 767, 767, 767]], dtype=int32)

In [None]:
t_fc.weight.shape

torch.Size([3072, 768])

In [None]:
condensed_params["params"]["kernel"].shape

(206, 1145)

In [None]:
jax.vmap(x, in_axes=0)

(64, 768)

In [None]:
print(condensed_params["params"]["kernel"].shape)
print(condensed_params["params"]["indx_seqs"].shape)
print(x.shape)

(206, 1145)
(206, 1145)
(64, 768)


In [None]:

def forward_orig(variables, input) -> ArrayLike:
    return jnp.sum(variables["params"]["kernel"] * input[:, variables["params"]["indx_seqs"]], axis=2) + variables["params"]["bias"]

condensed_out = forward_orig(condensed_params, x)