In [1]:
import jax
import jax.numpy as jnp

In [74]:
key = jax.random.PRNGKey(0)
n_params = 100
params_vec = jax.random.normal(key, (n_params,))
mask = jax.random.normal(key, (n_params,))
def model(params):
    return jnp.sum(params)
    # subnet = jax.lax.top_k(jax.nn.softmax(params), k=10)[0]
    # return jnp.sum(subnet)

def subnet(mask, params, k=10):
    subnet_params = params * jax.nn.softmax(mask)
    other_params = params * (1 - jax.nn.softmax(mask))
    other_params = jax.lax.stop_gradient(other_params)
    return model(subnet_params + other_params)

In [78]:
subnet(params_vec, mask)

Array(8.212992, dtype=float32)

In [79]:
model(params_vec)

Array(8.212992, dtype=float32)

In [75]:
jax.jacobian(model)(params_vec)

Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32)

In [77]:
jax.jacobian(subnet, argnums=0)(mask, params_vec, 10)

Array([-0.00239839,  0.00218263, -0.00044763, -0.00549583, -0.00434076,
       -0.00097209, -0.00549857, -0.00531801, -0.00550981, -0.00527788,
       -0.00375996, -0.003722  ,  0.00107607, -0.00153689, -0.00529077,
        0.00084627, -0.00484503, -0.00464658, -0.00241067, -0.00535627,
       -0.00503356,  0.01949507, -0.00506478, -0.00320252, -0.00547537,
       -0.00466237, -0.00548037, -0.00481282, -0.00304566, -0.00268466,
        0.00204134, -0.0024138 , -0.00553483, -0.0034703 , -0.00532534,
       -0.0049394 , -0.00507513, -0.00321392, -0.00065511, -0.00464129,
       -0.00475285, -0.0052299 , -0.0053873 ,  0.00556816,  0.02801784,
        0.00109676, -0.00552725, -0.00385907, -0.00432169,  0.00095799,
       -0.00522872, -0.0042306 ,  0.02544134, -0.00527138, -0.00503143,
       -0.00531495, -0.00548444,  0.05548716, -0.00534412, -0.00483855,
       -0.0042225 ,  0.02990714, -0.00553267,  0.00618682, -0.00385848,
        0.02178323, -0.00436594, -0.00552285, -0.0053798 ,  0.02

In [43]:
params_vec * jax.nn.softmax(mask)

jax.lax.top_k(params_vec * jax.nn.softmax(mask), k=10)[0]

Array([0.15037847, 0.09958701, 0.06265382, 0.06230341, 0.05946638,
       0.05557143, 0.05227708, 0.0499831 , 0.0464476 , 0.02494846],      dtype=float32)

In [46]:
idx = jax.lax.top_k(jax.nn.softmax(mask), k=10)[1]

In [64]:
idx = jax.nn.softmax(mask) > jax.lax.top_k(jax.nn.softmax(mask), k=10)[0][-1]
params_small = jnp.where(idx, params_vec, 0)
params_rest = jnp.where(1 - idx, params_vec, 0)


In [66]:
params_small

Array([0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 1.5863018, 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 0.       ,
       0.       , 0.       , 1.7405769, 0.       , 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 1.6977485, 0.       ,
       0.       , 0.       , 0.       , 2.0786808, 0.       , 0.       ,
       0.       , 1.7702677, 0.       , 0.       , 0.       , 1.6315418,
       0.       , 0.       , 0.       , 1.6594526, 0.       , 0.       ,
       0.       , 0.       , 0.       , 0.       , 0.       , 1.7738531,
       0.       , 0.       , 0.       , 0.       , 

In [53]:
jax.lax.index_in_dim(params_vec, idx[0])

Array([2.3627229], dtype=float32)

In [49]:
params_vec.shape

(100,)

In [45]:
params_vec[97]

Array(2.3627229, dtype=float32)

In [39]:
    
model(params_vec)

Array(0.3607679, dtype=float32)

In [41]:
jax.jacobian(model)(params_vec)

Array([-0.00029821, -0.00616735, -0.00526413, -0.00223952, -0.0008384 ,
       -0.005065  , -0.00223076, -0.00146435, -0.00219088, -0.00141988,
       -0.00062973, -0.00061796, -0.00580364, -0.00484066, -0.00143371,
       -0.00572543, -0.00108675, -0.00097828, -0.00030054, -0.00252562,
       -0.00120954,  0.01871699, -0.00287492, -0.00409321, -0.00229743,
       -0.00321893, -0.00172268, -0.00106792, -0.00417103, -0.00035522,
       -0.006122  , -0.00030114, -0.00201749, -0.00395534, -0.00147296,
       -0.00299252, -0.00286463, -0.00047665, -0.00518631, -0.00097561,
       -0.00314896, -0.00137156, -0.00247586, -0.00718041,  0.02183921,
       -0.00581063, -0.00210435, -0.00066139, -0.00083051, -0.00576358,
       -0.00269724, -0.00351711,  0.02092362, -0.00264431, -0.00120801,
       -0.00146082, -0.00227317,  0.03062482, -0.00254399, -0.00108292,
       -0.00352226,  0.02249735, -0.00205662,  0.01302888, -0.00374186,
        0.0195832 , -0.00342888, -0.00186539, -0.00154289,  0.02

In [12]:
jax.lax.approx_max_k(jax.nn.softmax(mask), k=10)[0]

Array([0.06364626, 0.04790876, 0.03532075, 0.03519434, 0.03416475,
       0.03273243, 0.0315026 , 0.0306355 , 0.02928043, 0.02038208],      dtype=float32)