In [1]:
#@title LICENSE
# Licensed under the Apache License, Version 2.0

## JaxPruner Quick Start
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/jaxpruner/blob/main/colabs/quick_start.ipynb)

This interactive colab provides a short overview of some of the key features of the `jaxpruner` library:

- One-shot Pruning
- Pruning during Optimization (Integration w/ optax)
- ConfigDict Integration
- Compatibility with JAX parallelization via `pmap` and `pjit`

In [2]:
import functools
import flax
import jax
import jax.numpy as jnp
from jax.sharding import PartitionSpec
import jax.experimental.pjit
import numpy as np
import optax
import pprint

In [3]:
# !pip3 install git+https://github.com/google-research/jaxpruner
import jaxpruner
import ml_collections

Collecting git+https://github.com/google-research/jaxpruner
  Cloning https://github.com/google-research/jaxpruner to /private/var/folders/sq/g9fv1ssn3yqg_27cnkj1_c100000gr/T/pip-req-build-zhu9jesb
  Running command git clone --filter=blob:none --quiet https://github.com/google-research/jaxpruner /private/var/folders/sq/g9fv1ssn3yqg_27cnkj1_c100000gr/T/pip-req-build-zhu9jesb
  Resolved https://github.com/google-research/jaxpruner to commit f133ee50f31a03d0152b8b272edb534065152f5d
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone


  index_match = re.match("(.*)\[([0-9]+)\]", key)



# One-shot Pruning
Pruning a given matrix to a desired level of sparsity is the building block of any pruning algorithm. Therefore jaxpruner provides a common API for one-shot
pruning. This is achieved by calling the `instant_sparsify` method.

In [4]:
matrix_size = 5
learning_rate = 0.01
matrix = jax.random.uniform(jax.random.PRNGKey(8), shape=(matrix_size, matrix_size))
print(matrix)

[[0.28248537 0.0030005  0.5755601  0.80048716 0.24104941]
 [0.4520849  0.30063164 0.11872995 0.9262264  0.02943528]
 [0.77339077 0.9098681  0.30932128 0.9311595  0.7131189 ]
 [0.76767194 0.60364604 0.54187894 0.54719424 0.68919384]
 [0.6382148  0.9619373  0.9574717  0.21418309 0.21543407]]


In [5]:
sparsity_distribution = functools.partial(
    jaxpruner.sparsity_distributions.uniform, sparsity=0.8)
pruner = jaxpruner.MagnitudePruning(sparsity_distribution_fn=sparsity_distribution)
pruned_matrix, mask = pruner.instant_sparsify(matrix)

print(pruned_matrix)
print(mask.dtype)
print(mask)

[[0.        0.        0.        0.        0.       ]
 [0.        0.        0.        0.9262264 0.       ]
 [0.        0.9098681 0.        0.9311595 0.       ]
 [0.        0.        0.        0.        0.       ]
 [0.        0.9619373 0.9574717 0.        0.       ]]
uint8
[[0 0 0 0 0]
 [0 0 0 1 0]
 [0 1 0 1 0]
 [0 0 0 0 0]
 [0 1 1 0 0]]


We can quickly change the sparsity structure using `sparsity_type` flag. 

In [6]:
pruner = jaxpruner.MagnitudePruning(sparsity_distribution_fn=sparsity_distribution,
                                    sparsity_type=jaxpruner.sparsity_types.NByM(1, 5))
pruned_matrix, mask = pruner.instant_sparsify(matrix)

print(pruned_matrix)
print(mask.dtype)
print(mask)

[[0.         0.         0.         0.80048716 0.        ]
 [0.         0.         0.         0.9262264  0.        ]
 [0.         0.         0.         0.9311595  0.        ]
 [0.76767194 0.         0.         0.         0.        ]
 [0.         0.9619373  0.         0.         0.        ]]
uint8
[[0 0 0 1 0]
 [0 0 0 1 0]
 [0 0 0 1 0]
 [1 0 0 0 0]
 [0 1 0 0 0]]


`instant sparsify` also supports parameter collections, which are commonly used in deep learning. 

In [7]:
# params = [matrix, 1 - matrix]
params = {'pos': matrix, 'inv': 1 - matrix}
pruned_params, masks = pruner.instant_sparsify(params)
pprint.pprint(pruned_params)

{'inv': Array([[0.        , 0.9969995 , 0.        , 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.        , 0.9705647 ],
       [0.        , 0.        , 0.6906787 , 0.        , 0.        ],
       [0.        , 0.        , 0.45812106, 0.        , 0.        ],
       [0.        , 0.        , 0.        , 0.7858169 , 0.        ]],      dtype=float32),
 'pos': Array([[0.        , 0.        , 0.        , 0.80048716, 0.        ],
       [0.        , 0.        , 0.        , 0.9262264 , 0.        ],
       [0.        , 0.        , 0.        , 0.9311595 , 0.        ],
       [0.76767194, 0.        , 0.        , 0.        , 0.        ],
       [0.        , 0.9619373 , 0.        , 0.        , 0.        ]],      dtype=float32)}


It is common to choose different sparsities for different layers or keep them dense entirely. We provide some basic functions to distribute sparsity across different layers such as `uniform` (default) and `erk` under `jaxpruner.sparsity_distributions`. Users can also define their own distributions easily. Here we define a custom distribution function to set different sparsities for each variable.

In [8]:
def custom_distribution(params, sparsity=0.8):
  return {key: 0.4 if 'pos' in key else sparsity for key in params}

pruner = jaxpruner.MagnitudePruning(sparsity_distribution_fn=custom_distribution)
pruned_params, masks = pruner.instant_sparsify(params)
pprint.pprint(jaxpruner.summarize_sparsity(pruned_params))

{'_nparams': Array(50., dtype=float32),
 '_nparams_nnz': Array(20., dtype=float32),
 '_total_sparsity': Array(0.6, dtype=float32),
 'inv': Array(0.8, dtype=float32),
 'pos': Array(0.39999998, dtype=float32)}


Masks used for enforcing sparsity use the same tree structure as the parameters pruned. We use `None` values to indicate dense parameters. We don't create masks for dense variables. 

In [9]:
def custom_distribution2(params, sparsity=0.8):
  return {key: None if 'pos' in key else sparsity for key in params}

pruner = jaxpruner.MagnitudePruning(sparsity_distribution_fn=custom_distribution2)
_, masks = pruner.instant_sparsify(params)
pprint.pprint(masks)

{'inv': Array([[0, 1, 0, 0, 0],
       [0, 0, 1, 0, 1],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 1, 1]], dtype=uint8),
 'pos': None}


Changing the pruning algorithm is easy as they all inherit from the same `BaseUpdater`. We have the following baseline pruning and sparse training algorithms included in our library.

In [10]:
for k in jaxpruner.ALGORITHM_REGISTRY:
  print(k, jaxpruner.ALGORITHM_REGISTRY[k])

no_prune <class 'jaxpruner.base_updater.NoPruning'>
magnitude <class 'jaxpruner.algorithms.pruners.MagnitudePruning'>
random <class 'jaxpruner.algorithms.pruners.RandomPruning'>
saliency <class 'jaxpruner.algorithms.pruners.SaliencyPruning'>
magnitude_ste <class 'jaxpruner.algorithms.ste.SteMagnitudePruning'>
random_ste <class 'jaxpruner.algorithms.ste.SteRandomPruning'>
global_magnitude <class 'jaxpruner.algorithms.global_pruners.GlobalMagnitudePruning'>
global_saliency <class 'jaxpruner.algorithms.global_pruners.GlobalSaliencyPruning'>
static_sparse <class 'jaxpruner.algorithms.sparse_trainers.StaticRandomSparse'>
rigl <class 'jaxpruner.algorithms.sparse_trainers.RigL'>
set <class 'jaxpruner.algorithms.sparse_trainers.SET'>


Next we use gradient based saliency score for pruning. `SaliencyPruning` requires gradients to be passed to `pruner.instant_sparsify`. Gradients are multipled with parameter values to obtain a first order Taylor approximation of the change in loss.

In [11]:
# Gradient based pruning
pruner = jaxpruner.SaliencyPruning(sparsity_distribution_fn=sparsity_distribution)
print(pruner.instant_sparsify(matrix, grads=(1 - matrix))[0])

[[0.         0.         0.5755601  0.         0.        ]
 [0.4520849  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.        ]
 [0.         0.60364604 0.54187894 0.54719424 0.        ]
 [0.         0.         0.         0.         0.        ]]


# Pruning as optimization (jaxpruner + optax)

Often state-of-the-art pruning algorithms require iterative adjustments to the sparsity masks used. Such iterative approaches are stateful, i.e. they require some additional variables like masks, counters and initial values. This is similar to common optimization algorithms such as Adam and SGD+Momentum which require moving averages.

The observation that *most iterative pruning and sparse training algoritms can be implemented as an optimizer*, played a key role when designing `jaxpruner` and led us to integrate `jaxpruner` with the `optax` optimization library.

Here is an example training loop where we find an orthogonal matrix using gradient descent:

In [12]:
matrix_size = 5

def loss_fn(params):
  matrix = params['w']
  loss = jnp.sum((matrix @ matrix.T - jnp.eye(matrix_size))**2)
  return loss

grad_fn = jax.value_and_grad(loss_fn)

@functools.partial(jax.jit, static_argnames='optimizer')
def update_fn(params, opt_state, optimizer):
  loss, grads = grad_fn(params)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss

def run_experiment(init_matrix):
  optimizer = optax.sgd(0.05)
  params = {'w': init_matrix}
  opt_state = optimizer.init(params)

  for i in range(20):
    params, opt_state, loss = update_fn(params, opt_state, optimizer)
    if i % 4 == 0:
      print(f'Step: {i}, loss: {loss}')
  return params['w']

First run the baseline training with a dense matrix. 

In [13]:
params = jax.random.uniform(jax.random.PRNGKey(8),
                            shape=(matrix_size, matrix_size))
run_experiment(params)

Step: 0, loss: 50.740779876708984
Step: 4, loss: 1.0821640491485596
Step: 8, loss: 0.5029404163360596
Step: 12, loss: 0.09175926446914673
Step: 16, loss: 0.0036935126408934593


Array([[-0.5452558 , -0.68374544,  0.39526057,  0.27897137, -0.01504929],
       [ 0.18433541, -0.26916775, -0.45707935,  0.31126106, -0.76552254],
       [-0.6820641 ,  0.08165327, -0.69178694, -0.14950337,  0.15661548],
       [ 0.33259526, -0.6524972 , -0.23232982, -0.6030979 ,  0.20593995],
       [-0.2942041 ,  0.16224197,  0.3196011 , -0.66244096, -0.5882758 ]],      dtype=float32)

Adding a pruner to an existing training loop requires just 2 lines. First we wrap an existing optimizer using the `pruner.wrap_optax` method. This wrapped optimizer ensures the masks are updated during the training. Second, we add a `pruner.post_gradient_update` call after our gradient step. This function defines algorithm specific parameter updates (like applying a mask to parameters) and provides flexibility when implementing various algorithms.

In [14]:
def run_pruning_experiment(init_matrix, pruner):
  optimizer = optax.sgd(0.05)
  # Modification #1
  optimizer = pruner.wrap_optax(optimizer)

  params = {'w': init_matrix}
  opt_state = optimizer.init(params)

  for i in range(20):
    params, opt_state, loss = update_fn(params, opt_state, optimizer)
    # Modification #2
    params = pruner.post_gradient_update(params, opt_state)

    if i % 4 == 0:
      print(f'Step: {i}, loss: {loss}')
      print(jaxpruner.summarize_sparsity(params, only_total_sparsity=True))
  return params['w']

Now, prune the matrix in one step (step=15).




In [15]:
pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=sparsity_distribution,
    scheduler=jaxpruner.sparsity_schedules.OneShotSchedule(target_step=10)
    )
params = jax.random.uniform(jax.random.PRNGKey(8),
                            shape=(matrix_size, matrix_size))
run_pruning_experiment(params, pruner)

Step: 0, loss: 50.740779876708984
{'_total_sparsity': Array(0., dtype=float32), '_nparams_nnz': Array(25., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 4, loss: 1.0821640491485596
{'_total_sparsity': Array(0., dtype=float32), '_nparams_nnz': Array(25., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 8, loss: 0.5029404163360596
{'_total_sparsity': Array(0., dtype=float32), '_nparams_nnz': Array(25., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 12, loss: 1.8950453996658325
{'_total_sparsity': Array(0.8, dtype=float32), '_nparams_nnz': Array(5., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 16, loss: 1.3515889644622803
{'_total_sparsity': Array(0.8, dtype=float32), '_nparams_nnz': Array(5., dtype=float32), '_nparams': Array(25., dtype=float32)}


Array([[-0.        , -0.99310285,  0.        ,  0.        , -0.        ],
       [ 0.        , -0.        , -0.        ,  0.        , -0.9267426 ],
       [-0.        ,  0.        , -0.9929704 , -0.        ,  0.        ],
       [ 0.        , -0.        , -0.        , -0.        ,  0.        ],
       [-0.        ,  0.        ,  0.        , -0.9202637 , -0.2308394 ]],      dtype=float32)

Alternatively we can prune it iteratively using the [polynomial schedule](https://arxiv.org/abs/1710.01878).

In [16]:
pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=sparsity_distribution,
    scheduler=jaxpruner.sparsity_schedules.PolynomialSchedule(
        update_freq=4, update_start_step=2, update_end_step=14)
)
params = jax.random.uniform(jax.random.PRNGKey(8),
                            shape=(matrix_size, matrix_size))
run_pruning_experiment(params, pruner)

Step: 0, loss: 50.740779876708984
{'_total_sparsity': Array(0., dtype=float32), '_nparams_nnz': Array(25., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 4, loss: 1.0821640491485596
{'_total_sparsity': Array(0.48000002, dtype=float32), '_nparams_nnz': Array(13., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 8, loss: 1.2138574123382568
{'_total_sparsity': Array(0.76, dtype=float32), '_nparams_nnz': Array(6., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 12, loss: 1.3473515510559082
{'_total_sparsity': Array(0.8, dtype=float32), '_nparams_nnz': Array(5., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 16, loss: 1.0177024602890015
{'_total_sparsity': Array(0.8, dtype=float32), '_nparams_nnz': Array(5., dtype=float32), '_nparams': Array(25., dtype=float32)}


Array([[-0.        , -0.9976892 ,  0.        ,  0.        , -0.        ],
       [ 0.        , -0.        , -0.        ,  0.        , -0.9915333 ],
       [-0.        , -0.        , -0.99801445, -0.        ,  0.        ],
       [ 0.        , -0.        , -0.        , -0.67936283,  0.        ],
       [-0.        ,  0.        ,  0.        , -0.7284256 , -0.        ]],      dtype=float32)

# ml_collections.ConfigDict Integration

Many popular jax libraries like [scenic](https://github.com/google-research/scenic) and [big_vision](https://github.com/google-research/big_vision) use `ml_collections.ConfigDict` to configure experiments. `jaxpruner` provides a helper function (`jaxpruner.create_updater_from_config`) to make it easy to use a `ConfigDict` to generate pruner objects. 

In [17]:
sparsity_config = ml_collections.ConfigDict()
sparsity_config.algorithm = 'magnitude'
sparsity_config.update_freq = 2
sparsity_config.update_end_step = 15
sparsity_config.update_start_step = 5
sparsity_config.sparsity = 0.6
sparsity_config.dist_type = 'uniform'

In [18]:
# Create a dense layer and sparsify.
pruner = jaxpruner.create_updater_from_config(sparsity_config)
params = jax.random.uniform(jax.random.PRNGKey(8),
                            shape=(matrix_size, matrix_size))
run_pruning_experiment(params, pruner)

Step: 0, loss: 50.740779876708984
{'_total_sparsity': Array(0., dtype=float32), '_nparams_nnz': Array(25., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 4, loss: 1.0821640491485596
{'_total_sparsity': Array(0., dtype=float32), '_nparams_nnz': Array(25., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 8, loss: 0.7265740036964417
{'_total_sparsity': Array(0.48000002, dtype=float32), '_nparams_nnz': Array(13., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 12, loss: 1.3116353750228882
{'_total_sparsity': Array(0.6, dtype=float32), '_nparams_nnz': Array(10., dtype=float32), '_nparams': Array(25., dtype=float32)}
Step: 16, loss: 1.129997730255127
{'_total_sparsity': Array(0.6, dtype=float32), '_nparams_nnz': Array(10., dtype=float32), '_nparams': Array(25., dtype=float32)}


Array([[-0.5102305 , -0.6804417 ,  0.        ,  0.        , -0.        ],
       [ 0.        , -0.        , -0.        ,  0.5152243 , -0.8074678 ],
       [-0.44380936,  0.        , -0.8579842 , -0.        ,  0.        ],
       [ 0.        , -0.55724764, -0.        , -0.38684854,  0.        ],
       [-0.        ,  0.        ,  0.        , -0.72127473, -0.58689326]],      dtype=float32)

# Parallelization with `pmap` and `pjit`

The `jaxpruner` library is in general compatible with JAX parallelization mechanisms like `pmap` and `pjit`. There are some minor points to watch out for,
which we will now demonstrate using parallelized versions of the previously introduced orthogonal matrix optimization example.

## `pmap`

First, we demonstrate compatibility with `pmap` where a model is replicated to run different shards of a batch on different devices. Note that this example
has no actual model "inputs" apart from the parameter matrix and the replication is thus not directly useful, but the general mechanisms are the same as for real training.

The main point to watch out for is to make sure that the optimizer state is replicated **after** wrapping it with the `jaxpruner`.

In [19]:
matrix_size = 8

def loss_fn(params):
  matrix = params['w']
  loss = jnp.sum((matrix @ matrix.T - jnp.eye(matrix_size))**2)
  return loss

grad_fn = jax.value_and_grad(loss_fn)

@functools.partial(
    jax.pmap, out_axes=(0, 0, None), axis_name='batch',
    static_broadcasted_argnums=(2,)
)
def update_fn(params, opt_state, optimizer):
  loss, grads = grad_fn(params)
  loss = jax.lax.pmean(loss, 'batch')
  grads = jax.lax.pmean(grads, 'batch')
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss


sparsity_distribution = functools.partial(
    jaxpruner.sparsity_distributions.uniform, sparsity=0.8)

pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=sparsity_distribution,
    scheduler=jaxpruner.sparsity_schedules.OneShotSchedule(target_step=0)
)

optimizer = optax.sgd(0.001)
optimizer = pruner.wrap_optax(optimizer)
params = {
    'w': jax.random.normal(jax.random.PRNGKey(0), (matrix_size, matrix_size))
}
opt_state = optimizer.init(params)
# The key step for using pmap with the jaxpruner is to replicate the optimizer
# state **after** wrapping it.
opt_state = flax.jax_utils.replicate(opt_state)
params = flax.jax_utils.replicate(params)

for i in range(100):
  params, opt_state, loss = update_fn(params, opt_state, optimizer)
  params = pruner.post_gradient_update(params, opt_state)
  if i % 5 == 0:
    print(f'Step: {i}, loss: {loss}')
params = flax.jax_utils.unreplicate(params)
print(params['w'])

Step: 0, loss: 343.5727844238281
Step: 5, loss: 70.89682006835938
Step: 10, loss: 47.94889831542969
Step: 15, loss: 34.66495132446289
Step: 20, loss: 26.278722763061523
Step: 25, loss: 20.651508331298828
Step: 30, loss: 16.700218200683594
Step: 35, loss: 13.826349258422852
Step: 40, loss: 11.676509857177734
Step: 45, loss: 10.030875205993652
Step: 50, loss: 8.746789932250977
Step: 55, loss: 7.728399276733398
Step: 60, loss: 6.909374713897705
Step: 65, loss: 6.242647647857666
Step: 70, loss: 5.694087505340576
Step: 75, loss: 5.238475799560547
Step: 80, loss: 4.856863975524902
Step: 85, loss: 4.534801483154297
Step: 90, loss: 4.261125087738037
Step: 95, loss: 4.027107238769531
[[ 0.73600227 -0.         -0.         -0.         -0.         -0.
  -0.         -0.8174586 ]
 [-0.          0.88034064 -0.         -0.          0.         -0.
  -0.         -0.4709545 ]
 [-0.7798368  -0.         -0.         -0.          0.          0.
  -0.         -0.7573714 ]
 [ 0.         -0.         -1.0685847 

## `pjit`

Next, we demonstrate tensor sharded training with `pjit`. Here the key is that the partition specifications of the wrapped optimizer state have to incoporate also the `jaxpruner.base_update.SparseState` produced by the pruning wrapper.

In [21]:
matrix_size = 8
if jax.device_count() % 8 == 0:
  MESH_SHAPE = (2, 4)
else:
  MESH_SHAPE = (1, 1)

def loss_fn(params):
  matrix = params['w']
  loss = jnp.sum((matrix @ matrix.T - jnp.eye(matrix_size))**2)
  return loss

grad_fn = jax.value_and_grad(loss_fn)

# Define the partition-specs for pjit; in most libraries for real models this
# is done somewhat automatically, yet this will likely require a small
# adjustment as shown below.

params_partition = {
    'w': PartitionSpec('X', 'Y')
}

# The main step required to run the jaxpruner together with pjit is defining
# a partition-spec for the wrapped `SparseState` as shown below.
opt_partition = jaxpruner.base_updater.SparseState(
    masks=params_partition,
    inner_state=(None, None),  # other optimizers may require sharding
    target_sparsities=None,
    count=None
)

resources = (params_partition, opt_partition)

@functools.partial(
    jax.experimental.pjit.pjit,
    in_shardings=resources,
    out_shardings=resources + (None,),
    static_argnames='optimizer'
)
def update_fn(params, opt_state, optimizer):
  loss, grads = grad_fn(params)
  updates, opt_state = optimizer.update(grads, opt_state, params)
  params = optax.apply_updates(params, updates)
  return params, opt_state, loss


sparsity_distribution = functools.partial(
    jaxpruner.sparsity_distributions.uniform, sparsity=0.8)
pruner = jaxpruner.MagnitudePruning(
    sparsity_distribution_fn=sparsity_distribution,
    scheduler=jaxpruner.sparsity_schedules.OneShotSchedule(target_step=0)
)

optimizer = optax.sgd(0.001)
optimizer = pruner.wrap_optax(optimizer)
params = {
    'w': jax.random.normal(jax.random.PRNGKey(0), (matrix_size, matrix_size))
}
opt_state = optimizer.init(params)

devices = np.asarray(jax.devices()).reshape(MESH_SHAPE)
mesh = jax.sharding.Mesh(devices, ('X', 'Y'))

with mesh:
  for i in range(100):
    params, opt_state, loss = update_fn(params, opt_state, optimizer)
    params = pruner.post_gradient_update(params, opt_state)
    if i % 5 == 0:
      print(f'Step: {i}, loss: {loss}')
  print(params['w'])
  jax.debug.visualize_array_sharding(params['w'])

Step: 0, loss: 343.5727844238281
Step: 5, loss: 70.89682006835938
Step: 10, loss: 47.94889831542969
Step: 15, loss: 34.66495132446289
Step: 20, loss: 26.278722763061523
Step: 25, loss: 20.651508331298828
Step: 30, loss: 16.700218200683594
Step: 35, loss: 13.826349258422852
Step: 40, loss: 11.676509857177734
Step: 45, loss: 10.030875205993652
Step: 50, loss: 8.746789932250977
Step: 55, loss: 7.728399276733398
Step: 60, loss: 6.909374713897705
Step: 65, loss: 6.242647647857666
Step: 70, loss: 5.694087505340576
Step: 75, loss: 5.238475799560547
Step: 80, loss: 4.856863975524902
Step: 85, loss: 4.534801483154297
Step: 90, loss: 4.261125087738037
Step: 95, loss: 4.027107238769531
[[ 0.73600227 -0.         -0.         -0.         -0.         -0.
  -0.         -0.8174586 ]
 [-0.          0.88034064 -0.         -0.          0.         -0.
  -0.         -0.4709545 ]
 [-0.7798368  -0.         -0.         -0.          0.          0.
  -0.         -0.7573714 ]
 [ 0.         -0.         -1.0685847 