In [1]:

import dataclasses
import chex
import jax
import jax.numpy as jnp
import functools
import optax
import pprint

In [2]:
from jaxpruner.sparsity_types import SparsityType, Unstructured, Block, NByM
from typing import Any, NamedTuple, Tuple, Union, List
from jaxpruner.base_updater import SparsityDistributionFnType
from jaxpruner.mask_calculator import get_topk_fn
from jaxpruner import mask_calculator

class ConstFanIn(NamedTuple):
    const_fan_in: List[int]
    
SparsityType = Union[Unstructured, Block, NByM, ConstFanIn]

In [37]:
@functools.partial(jax.jit, static_argnames="const_fan_in")
def const_fan_in_top_k(scores, const_fan_in: int):
    updated_mask = jnp.zeros(shape=scores.shape, dtype=mask_calculator.MASK_DTYPE)
    updated_mask.get[:, jnp.argsort(scores)[-const_fan_in:]].set(1)
    return updated_mask
    
    
    

def _const_fan_in_inner(scores, const_fan_in: int):
    return jnp.argmax(scores)
    

In [30]:
x = jax.random.uniform(jax.random.PRNGKey(42), shape=(3,3), minval=-1, maxval=1)
x

Array([[ 0.2878754 , -0.35496855, -0.61301255],
       [ 0.77298665,  0.44948697, -0.6161399 ],
       [-0.30973816, -0.49523377,  0.26380277]], dtype=float32)

In [38]:
jnp.arange(len(x))

Array([0, 1, 2], dtype=int32)

In [40]:
x.argsort(axis=0)

Array([[2, 2, 1],
       [0, 0, 0],
       [1, 1, 2]], dtype=int32)

In [36]:
x[jnp.arange(len(x)),x.argsort(axis=1)]

Array([[-0.61301255,  0.44948697, -0.30973816],
       [-0.61301255,  0.44948697, -0.30973816],
       [-0.35496855,  0.77298665,  0.26380277]], dtype=float32)

In [29]:
x[x.argsort(axis=-1) > 1]

Array([0.2878754 , 0.77298665, 0.26380277], dtype=float32)

In [19]:
x[0,x.argsort(axis=1)[0]]

Array([-0.61301255, -0.35496855,  0.2878754 ], dtype=float32)

In [20]:
x[:,x.argsort(axis=1)]

Array([[[-0.61301255, -0.35496855,  0.2878754 ],
        [-0.61301255, -0.35496855,  0.2878754 ],
        [-0.35496855,  0.2878754 , -0.61301255]],

       [[-0.6161399 ,  0.44948697,  0.77298665],
        [-0.6161399 ,  0.44948697,  0.77298665],
        [ 0.44948697,  0.77298665, -0.6161399 ]],

       [[ 0.26380277, -0.49523377, -0.30973816],
        [ 0.26380277, -0.49523377, -0.30973816],
        [-0.49523377, -0.30973816,  0.26380277]]], dtype=float32)

In [15]:
jnp.take_along_axis(x, x.argsort(order="ascending"), axis=1)

ValueError: 'order' argument to argsort is not supported.

In [14]:
x

Array([[ 0.2878754 , -0.35496855, -0.61301255],
       [ 0.77298665,  0.44948697, -0.6161399 ],
       [-0.30973816, -0.49523377,  0.26380277]], dtype=float32)

In [5]:
from jaxpruner.algorithms import RigL
import dataclasses

@dataclasses.dataclass
class ConstFanInPruner(RigL):
  sparsity_type: ConstFanIn
  sparsity_distribution_fn: SparsityDistributionFnType
  
  
  def __post_init__(self):
    if self.rng_seed is None:
      self.rng_seed = jax.random.PRNGKey(8)
    self.topk_drop_fn = get_topk_fn(Unstructured)
    self.topk_grow_fn = const_fan_in_top_k
  
  def get_initial_masks(self, params, target_sparsities):
    ...
  
  def _update_masks(self, old_mask, drop_score, grow_score, drop_fraction):
    density = jnp.sum(old_mask) / old_mask.size
    sparsity = 1 - density
    const_fan_in = jnp.sum(old_mask, axis=1)
    
    intermediate_density = density * (1-drop_fraction)
    intermediate_sparsity  = 1 - intermediate_density
    
    # Explicitly set inactive param scores to lowest possible
    lowest_score = jnp.min(drop_score) - self.eps
    new_drop_score = jnp.where(old_mask == 0, lowest_score, drop_score)
    dropped_mask = self.topk_drop_fn(new_drop_score, intermediate_sparsity)
    if self.is_debug:
      # All ones in the dropped mask should exist in the original mask.
      chex.assert_trees_all_close(
          jnp.sum(dropped_mask * (1 - old_mask)), jnp.array(0)
      )
    
    # Raise active connections to avoid considering for regrowth
    highest_score = jnp.max(grow_score) + self.eps
    new_grow_scores = jnp.where(dropped_mask == 1, highest_score, grow_score)
    updated_mask = self.topk_grow_fn(new_grow_scores, const_fan_in)
    
  def update_state(self, sparse_state, params, grads):
    drop_scores = self._get_drop_scores(sparse_state, params, grads)
    grow_scores = self._get_grow_scores(sparse_state, params, grads)
    current_drop_fraction = self.drop_fraction_fn(sparse_state.count)
    update_masks_fn = functools.partial(
        self._update_masks, drop_fraction=current_drop_fraction
    )
    

In [22]:
x = jax.random.uniform(jax.random.PRNGKey(42), shape=(10,10), minval=-1, maxval=1)
jnp.argsort(x, axis=1)
# jax.tree_map(functools.partial(jnp.abs, axis=1), x)

Array([[9, 4, 6, 5, 8, 3, 0, 2, 7, 1],
       [9, 7, 1, 8, 6, 4, 0, 2, 5, 3],
       [8, 0, 1, 7, 5, 9, 3, 4, 2, 6],
       [4, 3, 9, 7, 8, 0, 2, 1, 6, 5],
       [4, 3, 1, 2, 9, 7, 0, 5, 8, 6],
       [9, 5, 0, 6, 4, 8, 1, 2, 7, 3],
       [7, 0, 3, 5, 6, 2, 9, 1, 8, 4],
       [5, 0, 2, 9, 6, 4, 3, 8, 1, 7],
       [8, 2, 7, 9, 0, 4, 5, 3, 6, 1],
       [3, 7, 5, 2, 6, 9, 4, 1, 0, 8]], dtype=int32)

In [24]:
jnp.argsort(x, axis=1)[0]

Array([9, 4, 6, 5, 8, 3, 0, 2, 7, 1], dtype=int32)

In [31]:
jnp.argsort(x, axis=1)

Array([[9, 4, 6, 5, 8, 3, 0, 2, 7, 1],
       [9, 7, 1, 8, 6, 4, 0, 2, 5, 3],
       [8, 0, 1, 7, 5, 9, 3, 4, 2, 6],
       [4, 3, 9, 7, 8, 0, 2, 1, 6, 5],
       [4, 3, 1, 2, 9, 7, 0, 5, 8, 6],
       [9, 5, 0, 6, 4, 8, 1, 2, 7, 3],
       [7, 0, 3, 5, 6, 2, 9, 1, 8, 4],
       [5, 0, 2, 9, 6, 4, 3, 8, 1, 7],
       [8, 2, 7, 9, 0, 4, 5, 3, 6, 1],
       [3, 7, 5, 2, 6, 9, 4, 1, 0, 8]], dtype=int32)

In [30]:
x[0]

Array([ 0.36908197,  0.8857446 ,  0.79026294, -0.1892128 , -0.69729066,
       -0.3294983 , -0.5691731 ,  0.8822291 , -0.20968175, -0.9315412 ],      dtype=float32)

In [34]:
x[:, jnp.argsort(x, axis=1)]

Array([[[-0.9315412 , -0.69729066, -0.5691731 , -0.3294983 ,
         -0.20968175, -0.1892128 ,  0.36908197,  0.79026294,
          0.8822291 ,  0.8857446 ],
        [-0.9315412 ,  0.8822291 ,  0.8857446 , -0.20968175,
         -0.5691731 , -0.69729066,  0.36908197,  0.79026294,
         -0.3294983 , -0.1892128 ],
        [-0.20968175,  0.36908197,  0.8857446 ,  0.8822291 ,
         -0.3294983 , -0.9315412 , -0.1892128 , -0.69729066,
          0.79026294, -0.5691731 ],
        [-0.69729066, -0.1892128 , -0.9315412 ,  0.8822291 ,
         -0.20968175,  0.36908197,  0.79026294,  0.8857446 ,
         -0.5691731 , -0.3294983 ],
        [-0.69729066, -0.1892128 ,  0.8857446 ,  0.79026294,
         -0.9315412 ,  0.8822291 ,  0.36908197, -0.3294983 ,
         -0.20968175, -0.5691731 ],
        [-0.9315412 , -0.3294983 ,  0.36908197, -0.5691731 ,
         -0.69729066, -0.20968175,  0.8857446 ,  0.79026294,
          0.8822291 , -0.1892128 ],
        [ 0.8822291 ,  0.36908197, -0.1892128 , -0.3

In [7]:
x = jnp.arange(10.1-)
print(x)

SyntaxError: invalid syntax (2158789446.py, line 1)

In [3]:
type(x)

jaxlib.xla_extension.ArrayImpl

In [4]:
x

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)