In [1]:
from hydra import initialize, compose
import dotenv
import os
import pathlib
import torch

from rigl_torch.utils.checkpoint import Checkpoint
from rigl_torch.models import ModelFactory


In [2]:
def get_mod(run_id: str, device):
    with initialize("../configs", version_base="1.2.0"):
        cfg = compose(
            "config.yaml",
            overrides=[
                "compute.distributed=False",
                "dataset=imagenet",
                "model=vit",
                f"experiment.run_id={run_id}",
                "training.batch_size=2",
            ],
        )
    dotenv.load_dotenv("../.env", override=True)
    os.environ["IMAGE_NET_PATH"]
    checkpoint_dir = pathlib.Path(f"../artifacts/checkpoints/20230601_{run_id}")
    checkpoint = Checkpoint.load_best_checkpoint(checkpoint_dir=checkpoint_dir)
    model_state = checkpoint.model
    model = ModelFactory.load_model(
        model=cfg.model.name, dataset=cfg.dataset.name, diet=cfg.rigl.diet
    )
    model.to(device)
    try:
        model.load_state_dict(model_state)
    except RuntimeError:
        model_state = (
            checkpoint.get_single_process_model_state_from_distributed_state()
        )
        model.load_state_dict(model_state)
    return model.get_submodule("encoder.layers.encoder_layer_11.mlp.0")


__RUN_IDS = {90: "nrblbn15"}

# t_fc = get_mod(__RUN_IDS[90], "cpu") # Run me if you have the artifact on this device

with open("../artifacts/trained_vit_layers/vit16-mlp-layer-90-torch.pkl", "rb") as handle:  # TODO: try skinnier layer
    t_fc = torch.load(handle)



In [3]:
import torch

import jax
from typing import Any, Callable, Sequence, Optional, Tuple, Union
from jax import random, vmap, numpy as jnp
import flax
from flax import linen as nn
import numpy as np
from functools import partial

# _dtype = jnp.bfloat16 # faster on gpu
_dtype = jnp.float32 # faster on cpu @ batch size 1. slower at 64
# t_fc = t_fc.to(torch.bfloat16) # try bf16, Time to beat (176micro for dense, 137 micro for fastest condensed)
# conversion to jax/flax
with torch.no_grad():
    kernel = t_fc.weight.detach().cpu().numpy()
    bias = t_fc.bias.detach().cpu().numpy()

    # [outC, inC] -> [inC, outC]
    kernel = jnp.transpose(kernel, (1, 0)).astype(_dtype)

    key = random.key(0)
    x = random.normal(key, (64, t_fc.in_features))

    variables = {'params': {'kernel': kernel, 'bias': bias.astype(_dtype)}}
    j_fc = nn.Dense(features=t_fc.out_features)
    j_out = j_fc.apply(variables, x)

    t_x = torch.from_numpy(np.array(x))
    t_out = t_fc(t_x)
    t_out = t_out.detach().cpu().numpy()

    np.testing.assert_almost_equal(j_out, t_out, decimal=2)
    

I0000 00:00:1696012059.441441 3998717 tfrt_cpu_pjrt_client.cc:349] TfrtCpuClient created.
2023-09-29 12:27:39.452337: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:276] failed call to cuInit: CUDA_ERROR_COMPAT_NOT_SUPPORTED_ON_DEVICE: forward compatibility was attempted on non supported HW
2023-09-29 12:27:39.452809: E external/xla/xla/stream_executor/cuda/cuda_diagnostics.cc:312] kernel version 535.86.10 does not match DSO version 535.104.5 -- cannot find working devices in this configuration
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [4]:
input_size = t_fc.in_features
layer_width = t_fc.out_features
batch_size = 16


key = random.PRNGKey(42)
key, subkey = random.split(key)
x = random.normal(subkey, (batch_size, input_size), dtype=_dtype)
x = jax.device_put(x)

dense_layer = nn.Dense(features=layer_width, use_bias=True)
dense_params = variables


In [5]:
%timeit dense_layer.apply(dense_params, x).block_until_ready()

9.87 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
dense_fast = jax.jit(lambda x: dense_layer.apply(dense_params, x))

In [7]:
dense_fast(x).block_until_ready()

Array([[ 2.2639807e-12,  7.5584012e-01,  1.9262955e-12, ...,
        -1.2273963e-12,  8.1965402e-02, -8.9613508e-12],
       [ 2.2639807e-12, -2.0178687e-02,  1.9262955e-12, ...,
        -1.2273963e-12,  1.0754541e+00, -8.9613508e-12],
       [ 2.2639807e-12,  7.9683006e-01,  1.9262955e-12, ...,
        -1.2273963e-12, -4.9104637e-01, -8.9613508e-12],
       ...,
       [ 2.2639807e-12,  3.0003309e-01,  1.9262955e-12, ...,
        -1.2273963e-12, -1.6757858e-01, -8.9613508e-12],
       [ 2.2639807e-12, -1.5575090e-02,  1.9262955e-12, ...,
        -1.2273963e-12, -5.2756774e-01, -8.9613508e-12],
       [ 2.2639807e-12, -2.2859134e-01,  1.9262955e-12, ...,
        -1.2273963e-12,  1.3307133e-01, -8.9613508e-12]], dtype=float32)

In [8]:
%timeit dense_fast(x).block_until_ready()

323 µs ± 7.51 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [9]:
# Flax condensed sparsity

from numpy.typing import DTypeLike
from jax.typing import ArrayLike
from flax.core.scope import VariableDict
from copy import deepcopy

def condensed_param_converter(dense_params: VariableDict, dtype: Optional[DTypeLike]=None) -> VariableDict:
    """Convert dense tensor with sparse weights into condensed version"""
    dense_params = deepcopy(dense_params)
    kernel, bias = dense_params["params"]["kernel"].T, dense_params["params"]["bias"].T
    # Without transpose here I found broadcasting issues in original condensed implementation
    if dtype is None:
        dtype = kernel.dtype
    active_neuron_idx = _get_active_neuron_idx(kernel)
    fine_grained_idx = _get_fine_grained_idx(kernel, active_neuron_idx)
    struct_kernel = kernel[active_neuron_idx]
    condensed_kernel = struct_kernel[fine_grained_idx].reshape(struct_kernel.shape[0], -1)

    # TODO: Can speed-up the below, we used torch.nonzero(as_tuple=True)
    # previously, need to translate the typical 2D tensor output from jax.nonzero into the same
    # format. We don't really care about speed here for our purposes anyways
    idxs = []
    for neuron in fine_grained_idx:
        idxs.append(jnp.argwhere(neuron!=0).flatten())
    indx_seqs = jnp.stack(idxs)
    return dict(
        params=dict(
            kernel=condensed_kernel,
            bias=bias[active_neuron_idx],
            indx_seqs=indx_seqs
        )
    )

def _get_active_neuron_idx(kernel: ArrayLike) -> jax.Array:
  # We find all-zero rows in first dimension of weight tensor
  return kernel.sum(axis=list(range(1, kernel.ndim))) != 0


def _get_fine_grained_idx(
    kernel: ArrayLike, active_neuron_idx: ArrayLike
) -> jax.Array:
    return (kernel[active_neuron_idx] != 0).astype("bool")

class CondensedLinear(nn.Module):
    features: int
    fan_in: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros_init()

    @nn.compact
    def __call__(self, input: ArrayLike) -> jax.Array:
        kernel = self.param("kernel", self.kernel_init, (self.features, self.fan_in))
        bias = self.param("bias", self.bias_init, (self.features,))
        indx_seqs = self.param("indx_seqs", self.kernel_init, (self.features, self.fan_in))
        return jnp.sum(kernel * input[:, indx_seqs], axis=2) + bias


condensed_params = condensed_param_converter(variables)
jax.tree_util.tree_map(lambda x: print(x.shape), variables)
jax.tree_util.tree_map(lambda x: print(x.shape), condensed_params)


(3072,)
(768, 3072)
(1145,)
(1145, 206)
(1145, 206)


{'params': {'bias': None, 'indx_seqs': None, 'kernel': None}}

In [10]:
condensed_params["params"]["kernel"].shape # features, fan_in for condensed linear ctor

(1145, 206)

In [11]:
cl = CondensedLinear(*condensed_params["params"]["kernel"].shape)
cl_fast = jax.jit(lambda x: cl.apply(condensed_params, x))
cl_fast(x).block_until_ready()

Array([[ 0.75584006, -0.17063873,  1.6859026 , ...,  0.18504576,
        -0.00192568,  0.08196551],
       [-0.02017869, -0.71082234,  0.16039723, ...,  0.41155237,
        -0.4476214 ,  1.0754542 ],
       [ 0.7968303 , -0.10603629,  1.2908362 , ..., -0.32634115,
        -0.48441118, -0.49104658],
       ...,
       [ 0.30003327, -0.17096956,  0.10126591, ...,  0.6488617 ,
        -0.1293261 , -0.16757864],
       [-0.01557497, -0.39440617,  0.21288827, ...,  0.32918108,
        -0.12092257, -0.527568  ],
       [-0.22859119, -0.54928684,  0.2176865 , ...,  0.11042877,
         0.5686784 ,  0.13307133]], dtype=float32)

In [12]:
%%timeit
cl_fast(x).block_until_ready()

2.8 ms ± 377 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
# key, subkey = random.split(key)
# input = random.uniform(subkey, (batch_size, input_size), dtype=jnp.float32)
# input = jax.device_put(input)

# # Create mmore realistic indx seqs by randomly shuffling and sampling
# indx_seqs_stack = []
# for i in range(layer_width):
#   key, subkey = random.split(key)
#   key, subkey2 = random.split(key)
#   indx_seqs_stack.append(jax.random.shuffle(subkey, jax.random.choice(subkey2, jnp.arange(input_size), (sparsity,))))
# indx_seqs = jnp.stack(indx_seqs_stack)
# indx_seqs = jax.device_put(indx_seqs)

# key, subkey = random.split(key)
# weights = random.uniform(subkey, (layer_width, sparsity))

weights, bias, indx_seqs = condensed_params['params']['kernel'], condensed_params['params']['bias'], condensed_params['params']['indx_seqs']
weights = jax.device_put(weights)
bias = jax.device_put(bias)
indx_seqs = jax.device_put(indx_seqs)
input = x

In [14]:
for a in [input, weights, indx_seqs]:
  print(a.shape)

(16, 768)
(1145, 206)
(1145, 206)


In [15]:
def forward_orig(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
    return jnp.sum(weights * input[:, indx_seqs], axis=2) + bias

forward_orig_fast = jax.jit(forward_orig)
forward_orig_fast(input, weights, indx_seqs, bias).block_until_ready()

Array([[ 0.75584006, -0.17063873,  1.6859026 , ...,  0.18504576,
        -0.00192568,  0.08196551],
       [-0.02017869, -0.71082234,  0.16039723, ...,  0.41155237,
        -0.4476214 ,  1.0754542 ],
       [ 0.7968303 , -0.10603629,  1.2908362 , ..., -0.32634115,
        -0.48441118, -0.49104658],
       ...,
       [ 0.30003327, -0.17096956,  0.10126591, ...,  0.6488617 ,
        -0.1293261 , -0.16757864],
       [-0.01557497, -0.39440617,  0.21288827, ...,  0.32918108,
        -0.12092257, -0.527568  ],
       [-0.22859119, -0.54928684,  0.2176865 , ...,  0.11042877,
         0.5686784 ,  0.13307133]], dtype=float32)

In [16]:
forward_orig_fast(input, weights, indx_seqs, bias).block_until_ready()

Array([[ 0.75584006, -0.17063873,  1.6859026 , ...,  0.18504576,
        -0.00192568,  0.08196551],
       [-0.02017869, -0.71082234,  0.16039723, ...,  0.41155237,
        -0.4476214 ,  1.0754542 ],
       [ 0.7968303 , -0.10603629,  1.2908362 , ..., -0.32634115,
        -0.48441118, -0.49104658],
       ...,
       [ 0.30003327, -0.17096956,  0.10126591, ...,  0.6488617 ,
        -0.1293261 , -0.16757864],
       [-0.01557497, -0.39440617,  0.21288827, ...,  0.32918108,
        -0.12092257, -0.527568  ],
       [-0.22859119, -0.54928684,  0.2176865 , ...,  0.11042877,
         0.5686784 ,  0.13307133]], dtype=float32)

In [17]:
%timeit forward_orig(input, weights, indx_seqs, bias).block_until_ready()

6.69 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
%timeit forward_orig_fast(input, weights, indx_seqs, bias).block_until_ready()

2.73 ms ± 32.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
forward_orig_faster = jax.jit(partial(forward_orig, weights=weights, indx_seqs=indx_seqs, bias=bias))

In [20]:
forward_orig_faster(input)

Array([[ 0.75584006, -0.17063873,  1.6859026 , ...,  0.18504576,
        -0.00192568,  0.08196551],
       [-0.02017869, -0.71082234,  0.16039723, ...,  0.41155237,
        -0.4476214 ,  1.0754542 ],
       [ 0.7968303 , -0.10603629,  1.2908362 , ..., -0.32634115,
        -0.48441118, -0.49104658],
       ...,
       [ 0.30003327, -0.17096956,  0.10126591, ...,  0.6488617 ,
        -0.1293261 , -0.16757864],
       [-0.01557497, -0.39440617,  0.21288827, ...,  0.32918108,
        -0.12092257, -0.527568  ],
       [-0.22859119, -0.54928684,  0.2176865 , ...,  0.11042877,
         0.5686784 ,  0.13307133]], dtype=float32)

In [21]:
%timeit forward_orig_faster(input).block_until_ready()

2.63 ms ± 30.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


*italicized text*## Method #1: Use slicing/indexing and broadcasting

In [22]:
orig_output = forward_orig(input, weights, indx_seqs, bias)

In [23]:
# Do forward pass for a single neuron from a single batch
def forward_neuron_single(input: jnp.ndarray, weights: jnp.ndarray, indices: jnp.ndarray) -> jnp.ndarray:
    return jnp.sum(input[indices] * weights)

def forward_neuron_v(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
    return vmap(partial(forward_neuron_single, input), in_axes=0, out_axes=0)(weights, indx_seqs) + bias

def forward_neuron(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
    return vmap(partial(forward_neuron_v, weights=weights, indx_seqs=indx_seqs, bias=bias))(input)

# Do forward pass for all neurons over sparsity axis from a single batch
def forward_sparsity_single(input: jnp.ndarray, weights: jnp.ndarray, indices: jnp.ndarray) -> jnp.ndarray:
    return input[indices] * weights

def forward_sparsity_v(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
    output_neurons = vmap(partial(forward_sparsity_single, input), in_axes=1, out_axes=1)(weights, indx_seqs)
    return jnp.sum(output_neurons, axis=1) + bias

def forward_sparsity(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
    return vmap(partial(forward_sparsity_v, weights=weights, indx_seqs=indx_seqs, bias=bias))(input)

forward_neuron_fast = jax.jit(forward_neuron)
forward_neuron_faster = jax.jit(partial(forward_neuron, weights=weights, indx_seqs=indx_seqs, bias=bias))
forward_sparsity_fast = jax.jit(forward_sparsity)
forward_sparsity_faster = jax.jit(partial(forward_sparsity, weights=weights, indx_seqs=indx_seqs, bias=bias))



## Method #2: vmap over neuron/sparsity axes

In [24]:
# call once so JIT happens
fast_sparsity_output = forward_sparsity_fast(input, weights, indx_seqs, bias)
fast_sparsity_output_faster = forward_sparsity_faster(input)
assert jnp.allclose(orig_output, fast_sparsity_output)
assert jnp.allclose(orig_output, fast_sparsity_output_faster)

AssertionError: 

In [None]:
%timeit forward_sparsity(input, weights, indx_seqs, bias).block_until_ready()

In [None]:
%timeit forward_sparsity_fast(input, weights, indx_seqs, bias).block_until_ready()

In [None]:
%timeit forward_sparsity_faster(input).block_until_ready()

## Method #3: vmap over sparsity/neuron axes

In [None]:
# call once so JIT happens
fast_neuron_output = forward_neuron_fast(input, weights, indx_seqs, bias)
faster_neuron_output = forward_neuron_faster(input)
assert jnp.allclose(orig_output, fast_neuron_output)
assert jnp.allclose(orig_output, faster_neuron_output)

In [None]:
%timeit forward_neuron(input, weights, indx_seqs, bias).block_until_ready()

In [None]:
%timeit forward_neuron_fast(input, weights, indx_seqs, bias).block_until_ready()

In [None]:
%timeit forward_neuron_faster(input).block_until_ready()

In [None]:
class CondensedLinearVmapNeuron(nn.Module):
    features: int
    fan_in: int
    kernel_init: Callable = nn.initializers.lecun_normal()
    bias_init: Callable = nn.initializers.zeros_init()

    def forward_neuron(self, input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray, bias: jnp.ndarray) -> jnp.ndarray:
      return vmap(partial(forward_neuron_v, weights=weights, indx_seqs=indx_seqs, bias=bias))(input)

    @nn.compact
    def __call__(self, input: ArrayLike) -> jax.Array:
        kernel = self.param("kernel", self.kernel_init, (self.features, self.fan_in))
        bias = self.param("bias", self.bias_init, (self.features,))
        indx_seqs = self.param("indx_seqs", self.kernel_init, (self.features, self.fan_in))
        return self.forward_neuron(input, weights, indx_seqs, bias)

In [None]:
cl = CondensedLinearVmapNeuron(*condensed_params["params"]["kernel"].shape)
cl_fast = jax.jit(lambda x: cl.apply(condensed_params, x))
cl_fast(x).block_until_ready()

In [None]:
%timeit cl_fast(x).block_until_ready()

In [None]:
%timeit cl.apply(condensed_params, x).block_until_ready()

In [None]:
# Do forward pass for a single neuron from a single batch
def forward_batch_neuron_single(input: jnp.ndarray, weights: jnp.ndarray, indices: jnp.ndarray) -> jnp.ndarray:
    return jnp.sum(input[:, indices] * weights[None, :], axis=1)

def forward_batch_neuron(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray) -> jnp.ndarray:
    return vmap(partial(forward_batch_neuron_single, input), in_axes=0, out_axes=0)(weights, indx_seqs).T

def forward_batch_sparsity(input: jnp.ndarray, weights: jnp.ndarray, indx_seqs: jnp.ndarray) -> jnp.ndarray:
    return vmap(partial(forward_batch_neuron_single, input), in_axes=0, out_axes=1)(weights, indx_seqs)

forward_batch_neuron_fast = jax.jit(forward_batch_neuron)
forward_batch_sparsity_fast = jax.jit(forward_batch_sparsity)

In [None]:
forward_batch_neuron_output = forward_batch_neuron_fast(input, weights, indx_seqs)
forward_batch_sparsity_fast_output = forward_batch_sparsity_fast(input, weights, indx_seqs)
assert jnp.allclose(orig_output, forward_batch_neuron_output)  ## TODO: Add bias to above
assert jnp.allclose(orig_output, forward_batch_sparsity_fast_output)  ## TODO: Add bias to above

In [None]:
%timeit forward_batch_neuron(input, weights, indx_seqs).block_until_ready()

In [None]:
%timeit forward_batch_neuron_fast(input, weights, indx_seqs).block_until_ready()