In [1]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install tensorflow-cpu -q
!pip install tensorflow_datasets -q
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -q

In [2]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [3]:
# @title Python imports

import enum
import re
import string
import pdb

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from functools import partial

# For LoRA
import lorax

# We will use HuggingFace's dataset, tokenizer, and model classes.
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
import torch

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm


In [4]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

## Fine tuning the Gemma model

In [5]:
import flax
from flax.traverse_util import flatten_dict
from flax.core.meta import unbox


def print_params(params):
    flat_params = flatten_dict(params)    
    for path, param in flat_params.items():
        # Join the path components to create a string name
        name = "/".join(str(x) for x in path)
        print(f"Name: {name}")
        # print(f"Shape: {param.shape}")
        # print(f"dtype: {param.dtype}")
        # print(f"Value: {param}")
        if isinstance(param, flax.core.meta.Partitioned):
            array = unbox(param)
        else:
            array = param
        print(jax.debug.visualize_array_sharding(array))
        print("-" * 40)

## Try LoRA with simpleNN

In [6]:
input_dim = 32 
hidden_dim = 8
output_dim = 1

In [7]:

# Helper function for creating NamedSharding
def create_sharding(pspec):
    return NamedSharding(mesh, pspec)

In [8]:
def mesh_sharding(pspec: PartitionSpec) -> NamedSharding:
  return NamedSharding(mesh, pspec)

NameError: name 'PartitionSpec' is not defined

In [9]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from flax import linen as nn
from flax.training import train_state
import optax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils
import functools
from functools import partial
import jax
import jax.numpy as jnp

class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int
    
    @nn.compact
    def __call__(self, x):
        x = with_sharding_constraint(x, mesh_sharding(PartitionSpec('data', 'model')))

        x = nn.Dense(features=self.hidden_dim, 
                     kernel_init=nn.with_partitioning(nn.initializers.xavier_normal(), ('data', 'model')),
                     use_bias=False)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim, 
                     use_bias=False)(x)
        return x

# Set up the device mesh
devices = jax.devices()
device_mesh = mesh_utils.create_device_mesh((1, 4))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))

print(mesh)

Mesh('data': 1, 'model': 4)


In [164]:
sample_batch = jnp.ones(shape=(1, input_dim))
x_sharding = mesh_sharding(PartitionSpec('data', 'model')) # dimensions: (batch, length)
sample_batch = jax.device_put(sample_batch, x_sharding)
jax.debug.visualize_array_sharding(x)

In [165]:
# Initialize function
def init_fn(key, x, model, optimizer):
    params = model.init(key, x)  # Initialize the model
    state = train_state.TrainState.create(  # Create a `TrainState`
        apply_fn=model.apply,
        params=params['params'],
        tx=optimizer)
    return state

In [166]:
model = SimpleNN(hidden_dim, output_dim)
optimizer = optax.adam(learning_rate=0.001)

In [167]:
abstract_variables = jax.eval_shape(
    functools.partial(
        init_fn, 
        model=model, 
        optimizer=optimizer
    ),
    jax.random.PRNGKey(99),
    sample_batch
)


In [135]:
abstract_variables

TrainState(step=ShapeDtypeStruct(shape=(), dtype=int32), apply_fn=<bound method Module.apply of SimpleNN(
    # attributes
    hidden_dim = 8
    output_dim = 1
)>, params={'Dense_0': {'kernel': Partitioned(value=ShapeDtypeStruct(shape=(32, 8), dtype=float32), names=('data', 'model'), mesh=None)}, 'Dense_1': {'bias': ShapeDtypeStruct(shape=(1,), dtype=float32), 'kernel': ShapeDtypeStruct(shape=(8, 1), dtype=float32)}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f5e0c4291b0>, update=<function chain.<locals>.update_fn at 0x7f5e0c4292d0>), opt_state=(ScaleByAdamState(count=ShapeDtypeStruct(shape=(), dtype=int32), mu={'Dense_0': {'kernel': Partitioned(value=ShapeDtypeStruct(shape=(32, 8), dtype=float32), names=('data', 'model'), mesh=None)}, 'Dense_1': {'bias': ShapeDtypeStruct(shape=(1,), dtype=float32), 'kernel': ShapeDtypeStruct(shape=(8, 1), dtype=float32)}}, nu={'Dense_0': {'kernel': Partitioned(value=ShapeDtypeStruct(shape=(32, 8), dtype=float32), 

In [168]:
state_sharding = nn.get_sharding(abstract_variables, mesh)

In [138]:
state_sharding

TrainState(step=NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec()), apply_fn=<bound method Module.apply of SimpleNN(
    # attributes
    hidden_dim = 8
    output_dim = 1
)>, params={'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec('data', 'model'))}, 'Dense_1': {'bias': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec()), 'kernel': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec())}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7f5e0c4291b0>, update=<function chain.<locals>.update_fn at 0x7f5e0c4292d0>), opt_state=(ScaleByAdamState(count=NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec()), mu={'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec('data', 'model'))}, 'Dense_1': {'bias': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec()), 'kernel': NamedSharding(mesh=Mesh('data': 1, 'mod

In [169]:
jit_init_fn = jax.jit(init_fn, 
                          static_argnums=(2, 3),
                          in_shardings=(mesh_sharding(pspec=()), x_sharding),  # PRNG key and x
                          out_shardings=state_sharding)

In [170]:
initialized_state = jit_init_fn(jax.random.PRNGKey(99), sample_batch, model, optimizer)

In [171]:
initialized_state.params['Dense_1']['kernel']

Array([[-0.14665894],
       [-0.22765416],
       [-0.41170552],
       [-0.47610608],
       [ 0.09983327],
       [-0.55441546],
       [-0.32387885],
       [ 0.44042173]], dtype=float32)

In [181]:
print_params(initialized_state.params)

Name: Dense_0/kernel


None
----------------------------------------
Name: Dense_1/kernel


None
----------------------------------------


In [182]:
def forward_pass(params, state, batch):
  input_images, labels = batch
    
  # call forward pass function.
  logits = state.apply_fn({"params": params}, input_images)

  # compute loss
  loss = optax.squared_error(logits, labels)
  loss = loss.mean()
  return loss, logits

In [183]:
def backward_pass(state, batch):
  # create a function to compute gradients wrt to loss
  # returned by our `forward_pass` function.
  grad_fn = jax.value_and_grad(forward_pass, argnums=(0), has_aux=True)

  # compute gradients.
  (loss, _), grads = grad_fn(state.params, state, batch)

  # apply gradients.
  state = state.apply_gradients(grads=grads)

  return state

In [207]:
@functools.partial(jax.jit, in_shardings=(state_sharding, x_sharding, None),
                   out_shardings=state_sharding)
def train_step(state, inputs, targets):
  return backward_pass(state, (inputs, targets))

In [208]:
batch = (jnp.ones(shape=(1, input_dim)), jnp.zeros(shape=(1, output_dim)))

In [209]:
batch0 = jax.device_put(batch[0], x_sharding)

In [210]:
batch1 = jax.device_put(batch[1], mesh_sharding(PartitionSpec('data', None)))

In [211]:
with mesh:
    new_state = train_step(initialized_state, batch0, batch1)

In [212]:
new_state

TrainState(step=Array(1, dtype=int32, weak_type=True), apply_fn=<bound method Module.apply of SimpleNN(
    # attributes
    hidden_dim = 8
    output_dim = 1
)>, params={'Dense_0': {'kernel': Partitioned(value=Array([[ 0.29087412, -0.05503109, -0.20790817, -0.29750723,  0.07447191,
         0.0784103 ,  0.01066407,  0.49764064],
       [ 0.3668062 , -0.34450766,  0.1546367 ,  0.32535422,  0.12399623,
        -0.29323256,  0.0124924 ,  0.04483453],
       [ 0.03978511,  0.13564567, -0.21708795, -0.20757158, -0.11483077,
         0.06651969,  0.30071765,  0.21988122],
       [-0.14387387,  0.13240269, -0.21598163,  0.18907267,  0.15003328,
         0.383029  ,  0.27404705,  0.34015822],
       [-0.14029394, -0.42128232, -0.0436018 ,  0.3238122 , -0.07844601,
        -0.01554115,  0.06825367, -0.01932324],
       [-0.23955524, -0.01933642,  0.0330421 , -0.0531561 ,  0.1021446 ,
         0.14570239, -0.33155596,  0.1250879 ],
       [ 0.02744705, -0.21270357,  0.10063621, -0.37119505,  0.

In [12]:
def print_params(params):
    flat_params = flatten_dict(params)    
    for path, param in flat_params.items():
        # Join the path components to create a string name
        name = "/".join(str(x) for x in path)
        print(f"Name: {name}")
        print(f"Shape: {param.shape}")
        print(f"dtype: {param.dtype}")
        print(f"Value: {param}")
        # if isinstance(param, flax.core.meta.Partitioned):
        #     array = unbox(param)
        # else:
        #     array = param
        # print(jax.debug.visualize_array_sharding(array))
        print("-" * 40)

In [28]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from flax import linen as nn
from flax.training import train_state
import optax
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils
import functools
from functools import partial
import jax
import jax.numpy as jnp

In [29]:
import jax
import jax.numpy as jnp
from jax.experimental.pjit import pjit
from jax.sharding import Mesh, PartitionSpec as PS
from flax import linen as nn
from flax.training.train_state import TrainState
import optax
from functools import partial
from jax.experimental import mesh_utils
import re


In [30]:
class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int
    
    @nn.compact
    def __call__(self, x):
        x = with_sharding_constraint(x, mesh_sharding(PartitionSpec('data', 'model')))

        x = nn.Dense(features=self.hidden_dim, 
                     kernel_init=nn.with_partitioning(nn.initializers.xavier_normal(), ('data', 'model')),
                     use_bias=False)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim, 
                     use_bias=False)(x)
        return x

In [31]:
# Set up the device mesh
devices = jax.devices()
device_mesh = mesh_utils.create_device_mesh((1, 4))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))

In [25]:
model = SimpleNN(hidden_dim, output_dim)
optimizer = optax.adam(learning_rate=0.001)

In [40]:
# Define the model
class SimpleNN(nn.Module):
    hidden_dim: int
    output_dim: int
    
    @nn.compact
    def __call__(self, x):
        x = nn.Dense(features=self.hidden_dim, use_bias=False)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.output_dim, use_bias=False)(x)
        return x

In [82]:
class JaxRNG(object):
    """ A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
        pure function.
    """

    @classmethod
    def from_seed(cls, seed):
        return cls(jax.random.PRNGKey(seed))

    def __init__(self, rng):
        self.rng = rng

    def __call__(self, keys=None):
        if keys is None:
            self.rng, split_rng = jax.random.split(self.rng)
            return split_rng
        elif isinstance(keys, int):
            split_rngs = jax.random.split(self.rng, num=keys + 1)
            self.rng = split_rngs[0]
            return tuple(split_rngs[1:])
        else:
            split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
            self.rng = split_rngs[0]
            return {key: val for key, val in zip(keys, split_rngs[1:])}


In [146]:
def tree_path_to_string(path, sep=None):
    keys = []
    for key in path:
        if isinstance(key, jax.tree_util.SequenceKey):
            keys.append(str(key.idx))
        elif isinstance(key, jax.tree_util.DictKey):
            keys.append(str(key.key))
        elif isinstance(key, jax.tree_util.GetAttrKey):
            keys.append(str(key.name))
        elif isinstance(key, jax.tree_util.FlattenedIndexKey):
            keys.append(str(key.key))
        else:
            keys.append(str(key))
    if sep is None:
        return tuple(keys)
    return sep.join(keys)


In [143]:
def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
    """ An extended version of jax.tree_util.tree_map, where the mapped function
        f takes both the name (path) and the tree leaf as input.
    """
    return jax.tree_util.tree_map_with_path(
        lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
        tree, *rest,
        is_leaf=is_leaf
    )

In [194]:
import numpy as np
def match_partition_rules(rules, params):
    """ Returns a pytree of PartitionSpec according to rules. Supports handling
        Flax TrainState and Optax optimizer state.
    """
    def get_partition_spec(name, leaf):
        if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
            """ Don't partition scalar values. """
            return PS()
        for rule, ps in rules:
            if re.search(rule, name) is not None:
                return ps
        raise ValueError(f'Partition rule not found for param: {name}')
    return named_tree_map(get_partition_spec, params, sep='/')

In [35]:
def get_jax_mesh(axis_dims, names):
    dims = [int(x) for x in axis_dims.split(',')]
    mesh_shape = jax.numpy.arange(jax.device_count()).reshape(dims).shape
    return 

In [42]:
def with_sharding_constraint(x, partition_specs):
    return jax.lax.with_sharding_constraint(x, partition_specs)

In [44]:
# Configuration
input_dim = 32 
hidden_dim = 8
output_dim = 1

In [192]:
class SimpleNNConfigurator:
    @staticmethod
    def get_partition_rules():
        return (
            ("params/params/Dense_0/kernel", PS("data", "model")),
            ('.*', PS(None)),
        )
    
    @staticmethod
    def get_jax_mesh(mesh_dim):
        return Mesh(mesh_utils.create_device_mesh((1, jax.device_count())), ('data', 'model'))

    @staticmethod
    def rng_keys():
        return ('params', 'dropout', 'other')

In [95]:
def create_trainstate_from_params(params):
    return TrainState.create(params=params, tx=optimizer, apply_fn=None)

In [96]:
def init_fn(rng, input_dim=32, hidden_dim=8, output_dim=1):
    rng_generator = JaxRNG(rng)
    model = SimpleNN(hidden_dim=hidden_dim, output_dim=output_dim)
    params = model.init(
        rng_generator(SimpleNNConfigurator.rng_keys()),
        jnp.zeros((4, input_dim)),
    )
    return TrainState.create(params=params, tx=optimizer, apply_fn=None)

In [None]:
create_train_state(jax.random.PRNGKey(99), 0.1)

In [169]:
def train_step(train_state, rng, batch):
    rng_generator = JaxRNG(rng)
    batch = with_sharding_constraint(batch, PS(('data', 'model')))
    
    def loss_and_accuracy(params):
        pred = model.apply(params, batch['input'],
            rngs=rng_generator(SimpleNNConfigurator.rng_keys()),
        )
        loss = optax.squared_error(pred, batch['target'])
        loss = loss.mean()
        return loss
        
    grad_fn = jax.grad(loss_and_accuracy)
    grads = grad_fn(train_state.params)
    train_state = train_state.apply_gradients(grads=grads)
    return train_state


In [99]:
rng = JaxRNG(0)
model = SimpleNN(hidden_dim=hidden_dim, output_dim=output_dim)

In [100]:
optimizer = optax.adam(1e-3)

In [101]:
train_state_shapes = jax.eval_shape(init_fn, jax.random.PRNGKey(99))

In [193]:
train_state_partition = match_partition_rules(
    SimpleNNConfigurator.get_partition_rules(), train_state_shapes
)

> [0;32m/tmp/ipykernel_4163/2397452676.py[0m(13)[0;36mget_partition_spec[0;34m()[0m
[0;32m     11 [0;31m            [0;32mif[0m [0mre[0m[0;34m.[0m[0msearch[0m[0;34m([0m[0mrule[0m[0;34m,[0m [0mname[0m[0;34m)[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     12 [0;31m                [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 13 [0;31m                [0;32mreturn[0m [0mps[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m        [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;34mf'Partition rule not found for param: {name}'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     15 [0;31m    [0;32mreturn[0m [0mnamed_tree_map[0m[0;34m([0m[0mget_partition_spec[0m[0;34m,[0m [0mparams[0m[0;34m,[0m [0msep[0m[0;34m=[0m[0;34m'/'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  name


'params/params/Dense_0/kernel'


ipdb>  c


> [0;32m/tmp/ipykernel_4163/2397452676.py[0m(13)[0;36mget_partition_spec[0;34m()[0m
[0;32m     11 [0;31m            [0;32mif[0m [0mre[0m[0;34m.[0m[0msearch[0m[0;34m([0m[0mrule[0m[0;34m,[0m [0mname[0m[0;34m)[0m [0;32mis[0m [0;32mnot[0m [0;32mNone[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     12 [0;31m                [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 13 [0;31m                [0;32mreturn[0m [0mps[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m        [0;32mraise[0m [0mValueError[0m[0;34m([0m[0;34mf'Partition rule not found for param: {name}'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     15 [0;31m    [0;32mreturn[0m [0mnamed_tree_map[0m[0;34m([0m[0mget_partition_spec[0m[0;34m,[0m [0mparams[0m[0;34m,[0m [0msep[0m[0;34m=[0m[0;34m'/'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  name


'params/params/Dense_1/kernel'


ipdb>  quit


In [185]:
def create_named_sharding(partition_spec):
    return NamedSharding(mesh, partition_spec)

In [186]:
train_state_named_sharding = jax.tree.map(create_named_sharding, train_state_partition)

In [187]:
train_state_named_sharding.params

{'params': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec('data', 'model'))},
  'Dense_1': {'kernel': NamedSharding(mesh=Mesh('data': 1, 'model': 4), spec=PartitionSpec(None,))}}}

In [153]:
sharded_init_fn = jax.jit(
    init_fn,
    in_shardings=NamedSharding(mesh, PS()),
    # out_shardings=train_state_partition, out_shardings is optional in jax.jit, GSPMD will figure it out.
    static_argnums=(1, 2, 3)
    
)

In [170]:
sharded_create_trainstate_from_params = jax.jit(
    create_trainstate_from_params,
    in_shardings=(train_state_partition.params, ),
    # out_shardings=train_state_partition,
    donate_argnums=(0, ),
)

In [171]:
sharded_train_step = jax.jit(
    train_step,
    in_shardings=(train_state_named_sharding, NamedSharding(mesh, PS()), NamedSharding(mesh, PS())),
    # out_shardings=(train_state_partition),
    donate_argnums=(0, 1),
)

In [156]:
device_mesh = mesh_utils.create_device_mesh((1, 4))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))

In [175]:
with mesh:
    train_state = sharded_init_fn(jax.random.PRNGKey(99))
    batch = {
        'input': jnp.ones((32, input_dim)),
        'target': jnp.zeros((32, 1), dtype=jnp.int32),
    }
    sharded_rng = jax.random.PRNGKey(99)
    train_state = sharded_train_step(
        train_state, sharded_rng, batch
    )
    print(train_state)

TrainState(step=Array(1, dtype=int32, weak_type=True), apply_fn=None, params={'params': {'Dense_0': {'kernel': Array([[ 9.58651379e-02, -5.77238090e-02,  8.71751755e-02,
         2.06664905e-01, -1.54080182e-01,  3.74663025e-01,
         1.21506892e-01, -1.27760872e-01],
       [-1.65383801e-01,  9.92371961e-02, -6.61521703e-02,
         2.75267810e-01,  3.34123582e-01,  1.94328249e-01,
        -2.11523082e-02,  7.46464133e-02],
       [ 6.08712770e-02, -1.41109407e-01, -8.76355544e-02,
         2.75042266e-01, -2.81154457e-02, -1.16207547e-01,
        -2.79863566e-01,  2.71764964e-01],
       [ 1.61056504e-01, -9.62803438e-02,  3.30342233e-01,
        -3.67344916e-01,  4.07360569e-02, -1.63870811e-01,
        -1.81345776e-01,  4.84794006e-03],
       [-9.79540274e-02,  1.76957369e-01,  3.66404593e-01,
        -1.52017111e-02,  5.93522601e-02,  9.57616698e-03,
         9.57029611e-02, -1.66373268e-01],
       [ 7.82506913e-02,  2.19656542e-01,  4.44784090e-02,
         3.93143147e-02, 

In [174]:
train_state

TrainState(step=Array(1, dtype=int32, weak_type=True), apply_fn=None, params={'params': {'Dense_0': {'kernel': Array([[ 9.58651379e-02, -5.77238090e-02,  8.71751755e-02,
         2.06664905e-01, -1.54080182e-01,  3.74663025e-01,
         1.21506892e-01, -1.27760872e-01],
       [-1.65383801e-01,  9.92371961e-02, -6.61521703e-02,
         2.75267810e-01,  3.34123582e-01,  1.94328249e-01,
        -2.11523082e-02,  7.46464133e-02],
       [ 6.08712770e-02, -1.41109407e-01, -8.76355544e-02,
         2.75042266e-01, -2.81154457e-02, -1.16207547e-01,
        -2.79863566e-01,  2.71764964e-01],
       [ 1.61056504e-01, -9.62803438e-02,  3.30342233e-01,
        -3.67344916e-01,  4.07360569e-02, -1.63870811e-01,
        -1.81345776e-01,  4.84794006e-03],
       [-9.79540274e-02,  1.76957369e-01,  3.66404593e-01,
        -1.52017111e-02,  5.93522601e-02,  9.57616698e-03,
         9.57029611e-02, -1.66373268e-01],
       [ 7.82506913e-02,  2.19656542e-01,  4.44784090e-02,
         3.93143147e-02, 