In [2]:
from hydra import initialize, compose
import dotenv
import os
import pathlib
import torch

from rigl_torch.utils.checkpoint import Checkpoint
from rigl_torch.models import ModelFactory


In [3]:
def get_mod(run_id: str, device):
    with initialize("../configs", version_base="1.2.0"):
        cfg = compose(
            "config.yaml",
            overrides=[
                "compute.distributed=False",
                "dataset=imagenet",
                "model=vit",
                f"experiment.run_id={run_id}",
                "training.batch_size=2",
            ],
        )
    dotenv.load_dotenv("../.env", override=True)
    os.environ["IMAGE_NET_PATH"]
    checkpoint_dir = pathlib.Path(f"../artifacts/checkpoints/20230601_{run_id}")
    checkpoint = Checkpoint.load_best_checkpoint(checkpoint_dir=checkpoint_dir)
    model_state = checkpoint.model
    model = ModelFactory.load_model(
        model=cfg.model.name, dataset=cfg.dataset.name, diet=cfg.rigl.diet
    )
    model.to(device)
    try:
        model.load_state_dict(model_state)
    except RuntimeError:
        model_state = (
            checkpoint.get_single_process_model_state_from_distributed_state()
        )
        model.load_state_dict(model_state)
    return model.get_submodule("encoder.layers.encoder_layer_11.mlp.0")


__RUN_IDS = {90: "nrblbn15"}

# t_fc = get_mod(__RUN_IDS[90], "cpu") # Run me if you have the artifact on this device

with open("../artifacts/trained_vit_layers/vit16-mlp-layer-90-torch.pkl", "rb") as handle:
    t_fc = torch.load(handle)



In [5]:
t_fc.out_features

3072

In [7]:
import torch

import jax
from typing import Any, Callable, Sequence, Optional, Tuple, Union
from jax import random, vmap, numpy as jnp
import flax
from flax import linen as nn
import numpy as np
from functools import partial


with torch.no_grad():
    kernel = t_fc.weight.detach().cpu().numpy()
    print(kernel.shape)
    bias = t_fc.bias.detach().cpu().numpy()

    # [outC, inC] -> [inC, outC]
    kernel = jnp.transpose(kernel, (1, 0))

    key = random.key(0)
    x = random.normal(key, (64, t_fc.in_features))

    variables = {'params': {'kernel': kernel, 'bias': bias}}
    j_fc = nn.Dense(features=t_fc.out_features)
    j_out = j_fc.apply(variables, x)

    t_x = torch.from_numpy(np.array(x))
    t_out = t_fc(t_x)
    t_out = t_out.detach().cpu().numpy()

    np.testing.assert_almost_equal(j_out, t_out, decimal=3)
    

(3072, 768)


In [17]:
# with open("../artifacts/trained_vit_layers/vit16-mlp-layer-90-torch.pkl", "wb") as handle:
#     torch.save(t_fc, handle)

In [32]:
type(variables["params"]["kernel"])

jaxlib.xla_extension.ArrayImpl

In [19]:
print(j_fc.param_dtype)

<class 'jax.numpy.float32'>


In [None]:
jax.

In [24]:
variables["params"]["kernel"].dtype

dtype('float32')

In [30]:
(t_fc.weight.sum(dim=1) !=0).sum()

tensor(1145)

In [31]:
weights = variables["params"]["kernel"]
(weights.sum(axis=0)!=0).sum()

Array(1145, dtype=int32)

In [34]:
weights.shape  # features, num_neurons

(768, 3072)

In [36]:
((weights[:, weights.sum(axis=0)!=0 ]).sum(axis=0)!=0).sum()

Array(1145, dtype=int32)

In [38]:
(weights[:, weights.sum(axis=0)!=0 ]).shape

(768, 1145)

In [58]:
from numpy.typing import DTypeLike
from jax.typing import ArrayLike 
from flax.core.scope import VariableDict
from copy import deepcopy


def _torch_get_active_neuron_idx(weight: torch.Tensor) -> torch.Tensor:
    # We find all-zero rows in first dimension of weight tensor
    return weight.sum(dim=list(range(1, weight.dim()))) != 0


def _torch_get_fine_grained_idx(
    weight: torch.Tensor, active_neuron_idx
) -> torch.Tensor:
    return (weight[active_neuron_idx] != 0).to(torch.bool)



def _get_active_neuron_idx(kernel: ArrayLike) -> jax.Array:
    # We find all-zero rows in first dimension of weight tensor
    # NOTE: Only works with fc for now, need to test conv later
    # return weight.sum(dim=list(range(1, weight.dim()))) != 0
    return kernel.sum(axis=0)!=0  # we swap dim with torch


def _get_fine_grained_idx(
    kernel: ArrayLike, active_neuron_idx: ArrayLike
) -> jax.Array:
    return (kernel[:, active_neuron_idx] != 0).astype("bool")


kernel, bias = variables["params"]["kernel"], variables["params"]["bias"]
active_neuron_idx = _get_active_neuron_idx(kernel)
fine_grained_idx = _get_fine_grained_idx(kernel, active_neuron_idx)

t_ani = _torch_get_active_neuron_idx(t_fc.weight)
t_fgi = _torch_get_fine_grained_idx(t_fc.weight, t_ani)



assert (active_neuron_idx == t_ani.numpy()).all()
assert (fine_grained_idx == t_fgi.T.numpy()).all()  # NOTE: transpose here

In [81]:
struc_kernel = kernel[:, active_neuron_idx]
struc_kernel[fine_grained_idx].shape
condensed_kernel = struc_kernel[fine_grained_idx].reshape(-1, struc_kernel.shape[1])

In [82]:
fine_grained_idx.nonzero()

(Array([  0,   0,   0, ..., 767, 767, 767], dtype=int32),
 Array([   8,   14,   21, ..., 1142, 1143, 1144], dtype=int32))

In [111]:
jnp.flatnonzero(fine_grained_idx).reshape(-1, struc_kernel.shape[1])

Array([[     8,     14,     21, ...,   3058,   3059,   3060],
       [  3063,   3064,   3065, ...,  18359,  18360,  18361],
       [ 18363,  18364,  18365, ...,  20575,  20576,  20577],
       ...,
       [867631, 867633, 867635, ..., 871368, 871373, 871380],
       [871384, 871385, 871388, ..., 875413, 875414, 875415],
       [875417, 875420, 875421, ..., 879357, 879358, 879359]],      dtype=int32)

In [133]:
fine_grained_idx.shape

(768, 1145)

In [134]:
(fine_grained_idx[:, 0]!=0).sum()

Array(206, dtype=int32)

In [None]:
fine_grained_idx[]

In [139]:
idx = jnp.argwhere(fine_grained_idx[:, 0]!=0).flatten()
fine_grained_idx[idx, 0]

Array(True, dtype=bool)

In [155]:
idxs = []
for neuron in fine_grained_idx.T:
    idxs.append(jnp.argwhere(neuron!=0).flatten())

idx_seqs = jnp.stack(idxs).T

In [156]:
idx_seqs.shape

(206, 1145)

In [159]:
struc_kernel.shape

(768, 1145)

In [158]:
idx_seqs

Array([[  1,   1,   1, ...,   1,   2,   1],
       [  4,   2,   2, ...,   2,  34,   2],
       [ 17,  17,  16, ...,   4,  42,  18],
       ...,
       [757, 759, 758, ..., 757, 759, 757],
       [759, 761, 764, ..., 762, 764, 759],
       [762, 767, 767, ..., 767, 767, 767]], dtype=int32)

In [157]:
struc_kernel[idx_seqs].shape

(206, 1145, 1145)

In [144]:
idx = jnp.argwhere(fine_grained_idx!=0)
idx

Array([[   0,    8],
       [   0,   14],
       [   0,   21],
       ...,
       [ 767, 1142],
       [ 767, 1143],
       [ 767, 1144]], dtype=int32)

In [140]:
indx_seqs = jax.vmap(lambda row: jnp.argwhere(row!=0), in_axes=1)(fine_grained_idx)
# struc_kernel[indx_seqs]

# indx_seqs = indx_seqs.reshape(-1, struc_kernel.shape[1])
# indx_seqs.shape

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[].
The size argument of jnp.nonzero must be statically specified to use jnp.nonzero within JAX transformations.
This BatchTracer with object id 140649380001440 was created on line:
  /tmp/ipykernel_87785/1142911197.py:1 (<lambda>)

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [116]:
fine_grained_idx.shape

(768, 1145)

In [107]:
indx_seqs

Array([[   8,   14,   21, ...,  768,  769,  770],
       [ 773,  774,  775, ...,   39,   40,   41],
       [  43,   44,   45, ..., 1110, 1111, 1112],
       ...,
       [ 866,  868,  870, ...,   23,   28,   35],
       [  39,   40,   43, ...,  633,  634,  635],
       [ 637,  640,  641, ..., 1142, 1143, 1144]], dtype=int32)

In [117]:
_, t_input_mask = t_fgi.nonzero(as_tuple=True)
t_input_mask = t_input_mask.reshape(
            shape=(t_fc.weight[t_ani].shape[0], -1)
        )
t_input_mask.shape

# assert (t_input_mask.T.numpy() == indx_seqs).all()

torch.Size([1145, 206])

In [None]:
_, t_input_mask = self.fine_grained_idx.nonzero(as_tuple=True)
self.input_mask = self.input_mask.reshape(
    shape=(module.weight[self.active_neuron_idx].shape[0], -1)
)

In [164]:
from numpy.typing import DTypeLike
from jax.typing import ArrayLike
from flax.core.scope import VariableDict
from copy import deepcopy





def condensed_param_converter(dense_params: VariableDict, dtype: Optional[DTypeLike]=None) -> VariableDict:
    dense_params = deepcopy(dense_params)
    kernel, bias = dense_params["params"]["kernel"], dense_params["params"]["bias"]
    if dtype is None:
        dtype = kernel.dtype

    
    active_neuron_idx = _get_active_neuron_idx(kernel)
    fine_grained_idx = _get_fine_grained_idx(kernel, active_neuron_idx)
    struct_kernel = kernel[:, active_neuron_idx]
    condensed_kernel = struc_kernel[fine_grained_idx].reshape(-1, struct_kernel.shape[1])
    idxs = []
    for neuron in fine_grained_idx.T:
        idxs.append(jnp.argwhere(neuron!=0).flatten())
    indx_seqs = jnp.stack(idxs).T
    return dict(
        params=dict(
            kernel=condensed_kernel,
            bias=bias[active_neuron_idx],
            indx_seqs=indx_seqs
        )
    )

condensed_params = condensed_param_converter(variables)

# class CondensedLinear(nn.Module):
#     dense_params:
#     dtype: Optional[DTypeLike] = None

    
#     def setup(self):
#         if self.dtype is None:
#             self.dtype = self.module.param_dtype
#         self.active_neuron_idx = self.module.weight.sum(dim=1) != 0
#         self.fine_grained_idx = (self.module.weight[self.active_neuron_idx] != 0).to(
#             torch.bool
#         )
#         _, self.input_mask = self.fine_grained_idx.nonzero(as_tuple=True)
#         self.input_mask = self.input_mask.reshape(
#             shape=(module.weight[self.active_neuron_idx].shape[0], -1)
#         )
#         with torch.no_grad():
#             # self.weight = nn.Parameter(
#             #     module.weight[self.active_neuron_idx].contiguous()
#             # )
#             # self.condensed_weight = nn.Parameter(
#             #     self.weight[self.fine_grained_idx]
#             #     .reshape(shape=(self.weight.shape[0], -1))
#             #     .contiguous()
#             # )
#             # self.sparse_weight = nn.Parameter(
#             #     self.weight.to_sparse_csr()
#             # )
#             # if hasattr(module, "bias"):
#             #     self.bias = nn.Parameter(
#             #         module.bias[self.active_neuron_idx].contiguous()
#             #     )
#             # else:
#             #     self.register_parameter("bias", None)
#             self.weight = nn.Parameter(
#                 torch.clone(
#                     module.weight[self.active_neuron_idx].detach().type(dtype)
#                 )
#             )
#             self.condensed_weight = nn.Parameter(
#                 torch.clone(
#                     self.weight[self.fine_grained_idx]
#                     .reshape(shape=(self.weight.shape[0], -1))
#                     .detach()
#                     .type(dtype)
#                 ),
#                 requires_grad=False,
#             )
#             self.sparse_weight = nn.Parameter(
#                 torch.clone(self.weight.detach().type(dtype).to_sparse_csr()),
#                 requires_grad=False,
#             )
#             if hasattr(module, "bias"):
#                 self.bias = nn.Parameter(
#                     torch.clone(
#                         module.bias[self.active_neuron_idx].detach().type(dtype)
#                     ),
#                     requires_grad=False,
#                 )
#             else:
#                 self.register_parameter("bias", None)
        
#     def __call__():
#         pass

In [163]:
variables["params"]["bias"].shape

(3072,)

In [166]:
variables = {'params': {'kernel': kernel, 'bias': bias}}
j_fc = nn.Dense(features=t_fc.out_features)
j_out = j_fc.apply(variables, x)

In [170]:
condensed_params["params"]["indx_seqs"].shape

(206, 1145)

In [171]:
x.shape

# TODO: Figure out this broadcasting buiness

(64, 768)

In [172]:
# def forward_orig(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray) -> jnp.ndarray:
#     return jnp.sum(weights * input[:, indx_seqs], axis=2)

def forward_orig(variables, input) -> ArrayLike:
    return jnp.sum(variables["params"]["kernel"] * input[:, variables["params"]["indx_seqs"].T], axis=2) + variables["params"]["bias"]

condensed_out = forward_orig(condensed_params, x)

ValueError: Incompatible shapes for broadcasting: shapes=[(206, 1145), (64, 1145, 206)]