# Fine-tuning Gemma2 2B model on Roadrunner with JAX, Flax.

We have adopted the Gemma2 notebook from Google Deepmind to use HuggingFace's libraries and and simplified the steps.

## Setup

In [1]:
%%capture
!pip install --upgrade kagglehub -q
!pip install ipywidgets -q
!pip install torch --index-url https://download.pytorch.org/whl/cpu -q
!pip install git+https://github.com/felafax/gemma.git -q
!pip install qax -q
!pip install jax-lorax -q

In [2]:
import os
os.environ['HF_HUB_CACHE'] = '/mnt/persistent-disk/hf/'
os.environ['HF_HOME'] = '/mnt/persistent-disk/hf/'
!export HF_HUB_CACHE="/mnt/persistent-disk/hf/"
!export HF_HOME="/mnt/persistent-disk/hf/"

In [10]:
# @title Python imports

import enum
import re
import string

# We import JAX and some related packages.
import chex
import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
import optax
from functools import partial

# Model partitioning related imports
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.lax import with_sharding_constraint
from jax.experimental import mesh_utils
from flax.traverse_util import flatten_dict
from flax.core.meta import unbox

# For LoRA
import lorax

# We will use HuggingFace's dataset, tokenizer, and model classes.
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
from datasets import Dataset, load_dataset, concatenate_datasets
import torch

# Finally, we import Gemma.
from gemma import params as params_lib
from gemma import sampler as sampler_lib
from gemma import transformer as transformer_lib
import sentencepiece as spm


In [4]:
# HuggingFace username and token to use when downloading.
MODEL_NAME="felafax/gemma-2-2b-it-Flax"
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

model_name=MODEL_NAME
hugging_face_token=HUGGINGFACE_TOKEN

INPUT: Please provide your HUGGINGFACE_USERNAME:  felarof01
INPUT: Please provide your HUGGINGFACE_TOKEN:  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


In [6]:
%%capture
from huggingface_hub import snapshot_download

ckpt_path = snapshot_download(repo_id=MODEL_NAME, token=HUGGINGFACE_TOKEN)
vocab_path = os.path.join(ckpt_path, 'tokenizer.model')

## Fine tuning the Gemma model

In [8]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

In [11]:
# Set up the device mesh
devices = jax.devices()
device_mesh = mesh_utils.create_device_mesh((1, 4))
mesh = Mesh(devices=device_mesh, axis_names=('data', 'model'))

In [7]:
# Load parameters.
params = params_lib.load_and_format_params(os.path.join(ckpt_path, 'gemma2-2b-it'))

In [48]:
# Load model config.
config = transformer_lib.TransformerConfig.gemma2_2b(cache_size=30)
model = transformer_lib.Transformer(config=config)

# You can also infer the model config by using the number of layers in the params.
# config_2b = transformer_lib.TransformerConfig.from_params(params, cache_size=30)

In [49]:
import flax
from flax.traverse_util import flatten_dict

def print_params(params):
    flat_params = flatten_dict(params)    
    for path, param in flat_params.items():
        # Join the path components to create a string name
        name = "/".join(str(x) for x in path)
        print(f"Name: {name}")
        print(f"Shape: {param.shape}")
        print(f"dtype: {param.dtype}")
        # print(f"Value: {param}")
        print("-" * 40)

### print params before

In [50]:
print_params(params)

Name: transformer/embedder/input_embedding
Shape: (256128, 2304)
dtype: bfloat16
----------------------------------------
Name: transformer/final_norm/scale
Shape: (2304,)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/attn/attn_vec_einsum/w
Shape: (8, 256, 2304)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/attn/kv_einsum/w
Shape: (2, 4, 2304, 256)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/attn/q_einsum/w
Shape: (8, 2304, 256)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/mlp/gating_einsum
Shape: (2, 2304, 9216)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/mlp/linear
Shape: (9216, 2304)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/post_attention_norm/scale
Shape: (2304,)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/post_ffw_

In [51]:
from lorax.constants import LORA_FULL, LORA_FREEZE

def decision_fn(path, param):
    if 'embedding' in path:
        print(f'Fully finetuning param {path}')
        return LORA_FULL
    dim = 2
    print(f'Using LoRA with dim={dim} for param {path}')
    return dim

In [52]:
lora_spec = lorax.simple_spec(params, decision_fn=decision_fn, tune_vectors=True)

Using LoRA with dim=2 for param (DictKey(key='transformer'), DictKey(key='embedder'), DictKey(key='input_embedding'))
Using LoRA with dim=2 for param (DictKey(key='transformer'), DictKey(key='layer_0'), DictKey(key='attn'), DictKey(key='attn_vec_einsum'), DictKey(key='w'))
Using LoRA with dim=2 for param (DictKey(key='transformer'), DictKey(key='layer_0'), DictKey(key='attn'), DictKey(key='kv_einsum'), DictKey(key='w'))
Using LoRA with dim=2 for param (DictKey(key='transformer'), DictKey(key='layer_0'), DictKey(key='attn'), DictKey(key='q_einsum'), DictKey(key='w'))
Using LoRA with dim=2 for param (DictKey(key='transformer'), DictKey(key='layer_0'), DictKey(key='mlp'), DictKey(key='gating_einsum'))
Using LoRA with dim=2 for param (DictKey(key='transformer'), DictKey(key='layer_0'), DictKey(key='mlp'), DictKey(key='linear'))
Using LoRA with dim=2 for param (DictKey(key='transformer'), DictKey(key='layer_1'), DictKey(key='attn'), DictKey(key='attn_vec_einsum'), DictKey(key='w'))
Using Lo

In [53]:
lora_spec

{'transformer': {'embedder': {'input_embedding': 2},
  'final_norm': {'scale': -1},
  'layer_0': {'attn': {'attn_vec_einsum': {'w': 2},
    'kv_einsum': {'w': 2},
    'q_einsum': {'w': 2}},
   'mlp': {'gating_einsum': 2, 'linear': 2},
   'post_attention_norm': {'scale': -1},
   'post_ffw_norm': {'scale': -1},
   'pre_attention_norm': {'scale': -1},
   'pre_ffw_norm': {'scale': -1}},
  'layer_1': {'attn': {'attn_vec_einsum': {'w': 2},
    'kv_einsum': {'w': 2},
    'q_einsum': {'w': 2}},
   'mlp': {'gating_einsum': 2, 'linear': 2},
   'post_attention_norm': {'scale': -1},
   'post_ffw_norm': {'scale': -1},
   'pre_attention_norm': {'scale': -1},
   'pre_ffw_norm': {'scale': -1}},
  'layer_10': {'attn': {'attn_vec_einsum': {'w': 2},
    'kv_einsum': {'w': 2},
    'q_einsum': {'w': 2}},
   'mlp': {'gating_einsum': 2, 'linear': 2},
   'post_attention_norm': {'scale': -1},
   'post_ffw_norm': {'scale': -1},
   'pre_attention_norm': {'scale': -1},
   'pre_ffw_norm': {'scale': -1}},
  'layer_

In [54]:
lora_params = lorax.init_lora(params, lora_spec, jax.random.PRNGKey(0))

In [55]:
print_params(lora_params)

Name: transformer/embedder/input_embedding
Shape: (256128, 2304)
dtype: bfloat16
----------------------------------------
Name: transformer/final_norm/scale
Shape: (2304,)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/attn/attn_vec_einsum/w
Shape: (8, 256, 2304)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/attn/kv_einsum/w
Shape: (2, 4, 2304, 256)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/attn/q_einsum/w
Shape: (8, 2304, 256)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/mlp/gating_einsum
Shape: (2, 2304, 9216)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/mlp/linear
Shape: (9216, 2304)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/post_attention_norm/scale
Shape: (2304,)
dtype: bfloat16
----------------------------------------
Name: transformer/layer_0/post_ffw_

In [56]:
print(lora_params)

{'transformer': {'embedder': {'input_embedding': LoraWeight(shape=(256128, 2304), dtype=dtype(bfloat16), w=Array([[0.0351562, -0.0229492, 0.081543, ..., 0.0211182, 0.0527344,
        -0.0351562],
       [-0.0200195, 0.0522461, -0.0302734, ..., 0.0027771, -0.0240479,
        -0.017334],
       [-0.000164032, -0.00592041, 0.0222168, ..., 0.0151978,
        -0.00735474, -0.0119019],
       ...,
       [0.0227051, -0.0375977, 0.0356445, ..., 0.0402832, 0.0117798,
        -0.0308838],
       [0.0319824, -0.0368652, 0.0410156, ..., 0.0385742, 0.0196533,
        -0.0270996],
       [0.0203857, -0.0405273, 0.0368652, ..., 0.0400391, 0.0180664,
        -0.0306396]], dtype=bfloat16), a=Array([[ 0.02199156,  0.0049439 , -0.01322838, ..., -0.01093095,
         0.01156266,  0.01847863],
       [ 0.01022545, -0.01507886, -0.0023296 , ...,  0.00661303,
         0.02698121,  0.00401224]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],

In [57]:
lora_model = lorax.lora(model.apply)

### print params after applying LoRA

## Step 1: prepare the dataset

For this project, we're utilizing the refined **Alpaca dataset**, curated by yahma. This dataset is a carefully filtered selection of 52,000 entries from the original Alpaca collection. Feel free to substitute this section with your own data preparation code if you prefer.

It's crucial to include the EOS_TOKEN (End of Sequence Token) in your tokenized output. Failing to do so may result in endless generation loops.

In [58]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, 
    token=HUGGINGFACE_TOKEN
)
if not tokenizer.pad_token:
    print("Tokenizer doesn't have a pad token.")
    tokenizer.pad_token = tokenizer.eos_token

In [59]:
def get_dataset(*, tokenizer, batch_size=1, max_length=25, debug_mode=True):
    # Define Alpaca prompt template
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
    
    ### Instruction: {}
    
    ### Input: {}
    
    ### Response: {}"""
    
    EOS_TOKEN = tokenizer.eos_token
    
    # Define formatting function.
    def _format_prompts(examples):
        instructions = examples["instruction"]
        inputs = examples["input"]
        outputs = examples["output"]
        texts = []
        for instruction, input, output in zip(instructions, inputs, outputs):
            text = alpaca_prompt.format(instruction, input, output) + EOS_TOKEN
            texts.append(text)
        return {"text": texts}

    def _tokenize(examples, max_length=None):
        tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=25+1 if not max_length else max_length+1)
        tokenized['input_ids'] = [input_id[:-1] for input_id in tokenized['input_ids']]
        return {
            'input_tokens': tokenized['input_ids'],
            'target_mask': tokenized['attention_mask']
        }

    def _custom_collate_fn(batch):
        """Applies default_collate_fn from transformers and converts to JAX NumPy arrays."""
        batch = default_data_collator(batch)
        jax_batch = {}
        for key, value in batch.items():
            if isinstance(value, torch.Tensor):
                jax_batch[key] = jnp.array(value.numpy())
            else:
                jax_batch[key] = value
        
        return jax_batch

    # Load and preprocess the dataset.
    dataset = load_dataset("yahma/alpaca-cleaned", split="train")
    if debug_mode:
        dataset = dataset.select(range(32)) # Use just 32 exampfor faster iteration
    dataset = dataset.map(_format_prompts, batched=True)

    # Create train and test dataset.
    ds = dataset.train_test_split(test_size=0.15)
    ds['train'] = ds['train'].map(_tokenize, batched=True, remove_columns=dataset.column_names)
    ds['test'] = ds['test'].map(_tokenize, batched=True, remove_columns=dataset.column_names)

    # Create DataLoader
    train_dataloader = torch.utils.data.DataLoader(
        ds['train'],
        shuffle=True,
        batch_size=1 if not batch_size else batch_size,
        collate_fn=_custom_collate_fn
    )
    
    test_dataloader = torch.utils.data.DataLoader(
        ds['test'],
        shuffle=True,
        batch_size=1 if not batch_size else batch_size,
        collate_fn=_custom_collate_fn
    )

    return train_dataloader, test_dataloader

In [60]:
# # # Test Dataset
# train_dataloader, _ = get_dataset(tokenizer=tokenizer)
# for i, batch in enumerate(train_dataloader):
#     if i>10:
#         break
#     input_ids, attention_mask = (
#         batch["input_tokens"],
#         batch["target_mask"],
        
#     )
#     print(input_ids)
#     print()
#     print(attention_mask)

In [61]:
def forward_and_loss_fn(params,
                        *,
                        lora_model,
                        input_tokens: jax.Array,            # Shape [B, L]
                        input_mask: jax.Array,              # Shape [B, L]
                        positions: jax.Array,               # Shape [B, L]
                        attention_mask: jax.Array,          # [B, L, L]
                        ) -> jax.Array:
  """Forward pass and loss function.

  Args:
    params: model's input parameters.
    model: gemma transformer model to call.
    input_tokens: input tokens sequence, shape [B, L].
    input_mask: tokens to ignore when computing the loss, shape [B, L].
    positions: relative position of each token, shape [B, L].
    attention_mask: input attention mask, shape [B, L].

  Returns:
    Softmax cross-entropy loss for the next-token prediction task.
  """

  # Forward pass on the input data.
  # No attention cache is needed here.
  logits, _ = lora_model(
        params,
        input_tokens,
        positions,
        None,              # Attention cache is None.
        attention_mask,
    )

  # Exclude the last step as it does not appear in the targets.
  logits = logits[:, :-1]


  # Similarly, the first token cannot be predicteds.
  target_tokens = input_tokens[:, 1:]
  target_mask = input_mask[:, 1:]

  # Convert the target labels into one-hot encoded vectors.
  target_mask = target_mask[...,1:] # TODO
  one_hot = jax.nn.one_hot(target_tokens, logits.shape[-1])

  # Don't update on unwanted tokens.
  one_hot = one_hot * target_mask.astype(one_hot.dtype)[..., None]

  # Normalisation factor.
  norm_factor = 1 / (jnp.sum(target_mask) + 1e-8)

  # Return the nll loss.
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot) * norm_factor

The Gemma transformer requires an attention mask and position vector alongside each input. We can conveniently generate these using the following function:

In [62]:
def get_attention_mask_and_positions(example: jax.Array,
                                     pad_id : int,
                                     )-> tuple[jax.Array, jax.Array]:
  """Builds the position and attention mask vectors from the given tokens."""
  pad_mask = example != pad_id
  current_token_position = transformer_lib.build_positions_from_mask(pad_mask)
  attention_mask = transformer_lib.make_causal_attn_mask(pad_mask)
  return current_token_position, attention_mask

We can now build the train_step function which performs the backward pass and updates the model's parameters accordingly.

In [63]:
def train_step(lora_model,
               params,
               optimizer: optax.GradientTransformation,
               opt_state: optax.OptState,
               pad_id: int,
               example):
  """Train step.

  Args:
    model: gemma transformer model.
    params: model's input parameters.
    optimizer: optax optimizer to use.
    opt_state: input optimizer's state.
    pad_id: id of the pad token.
    example: input batch.

  Returns:
    Training loss, updated parameters, updated optimizer state.
  """
  # Build the position and attention mask vectors.
  positions, attention_mask = get_attention_mask_and_positions(example['input_tokens'], pad_id)


  # Forward and backward passes
  train_loss, grads = jax.value_and_grad(forward_and_loss_fn)(params,
                                                             lora_model=lora_model,
                                                             input_tokens=example['input_tokens'],
                                                             input_mask=example['target_mask'],
                                                             positions=positions,
                                                             attention_mask=attention_mask)
  # Update the parameters
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return train_loss, params, opt_state

Similarly, we build a `validation_step` function without backward pass.

And now the training loop itself.

In [64]:
@chex.dataclass(frozen=True)
class TrainingConfig:
  learning_rate: float
  num_epochs: int
  eval_every_n: int
  batch_size: int
  max_steps: int | None = None

from dataclasses import dataclass
import numpy as np


def train_loop(
    lora_model,
    params,
    train_dataloader,
    tokenizer,
    training_cfg: TrainingConfig,
    lora_spec):


  compiled_train_step = train_step# , static_argnames=['lora_model', 'optimizer'])
  optimizer = optax.sgd(training_cfg.learning_rate)
  optimizer = lorax.wrap_optimizer(optimizer, lora_spec)
  opt_state = optimizer.init(params)

  n_steps = 0
  avg_loss=0

  for i, train_example in enumerate(train_dataloader):
    train_loss, params, opt_state = train_step(lora_model=lora_model,
                                                        params=params,
                                                        optimizer=optimizer,
                                                        opt_state=opt_state,
                                                        pad_id=tokenizer.pad_token_id,
                                                        example=train_example)
    n_steps += 1
    avg_loss += train_loss
    print(f"train_loss {train_loss}")
    if training_cfg.max_steps is not None and n_steps > training_cfg.max_steps:
      break
  return params

We can fine-tune our model on a limited number of steps.

In [65]:
params['transformer']

{'embedder': {'input_embedding': Array([[0.0351562, -0.0229492, 0.081543, ..., 0.0211182, 0.0527344,
          -0.0351562],
         [-0.0200195, 0.0522461, -0.0302734, ..., 0.0027771, -0.0240479,
          -0.017334],
         [-0.000164032, -0.00592041, 0.0222168, ..., 0.0151978,
          -0.00735474, -0.0119019],
         ...,
         [0.0227051, -0.0375977, 0.0356445, ..., 0.0402832, 0.0117798,
          -0.0308838],
         [0.0319824, -0.0368652, 0.0410156, ..., 0.0385742, 0.0196533,
          -0.0270996],
         [0.0203857, -0.0405273, 0.0368652, ..., 0.0400391, 0.0180664,
          -0.0306396]], dtype=bfloat16)},
 'final_norm': {'scale': Array([2.32812, 2.34375, 2.28125, ..., 4.65625, 2.53125, 2.4375],      dtype=bfloat16)},
 'layer_0': {'attn': {'attn_vec_einsum': {'w': Array([[[0.0090332, 0.0100708, 0.0155029, ..., 0.00256348, -0.00537109,
             0.00848389],
            [0.0114136, 0.0202637, 0.00952148, ..., -0.000166893, 0.0108032,
             0.0124512],
     

In [66]:
lora_params['transformer']

{'embedder': {'input_embedding': LoraWeight(shape=(256128, 2304), dtype=dtype(bfloat16), w=Array([[0.0351562, -0.0229492, 0.081543, ..., 0.0211182, 0.0527344,
          -0.0351562],
         [-0.0200195, 0.0522461, -0.0302734, ..., 0.0027771, -0.0240479,
          -0.017334],
         [-0.000164032, -0.00592041, 0.0222168, ..., 0.0151978,
          -0.00735474, -0.0119019],
         ...,
         [0.0227051, -0.0375977, 0.0356445, ..., 0.0402832, 0.0117798,
          -0.0308838],
         [0.0319824, -0.0368652, 0.0410156, ..., 0.0385742, 0.0196533,
          -0.0270996],
         [0.0203857, -0.0405273, 0.0368652, ..., 0.0400391, 0.0180664,
          -0.0306396]], dtype=bfloat16), a=Array([[ 0.02199156,  0.0049439 , -0.01322838, ..., -0.01093095,
           0.01156266,  0.01847863],
         [ 0.01022545, -0.01507886, -0.0023296 , ...,  0.00661303,
           0.02698121,  0.00401224]], dtype=float32), b=Array([[0., 0.],
         [0., 0.],
         [0., 0.],
         ...,
         [0.,

In [68]:
lora_spec

{'transformer': {'embedder': {'input_embedding': 2},
  'final_norm': {'scale': -1},
  'layer_0': {'attn': {'attn_vec_einsum': {'w': 2},
    'kv_einsum': {'w': 2},
    'q_einsum': {'w': 2}},
   'mlp': {'gating_einsum': 2, 'linear': 2},
   'post_attention_norm': {'scale': -1},
   'post_ffw_norm': {'scale': -1},
   'pre_attention_norm': {'scale': -1},
   'pre_ffw_norm': {'scale': -1}},
  'layer_1': {'attn': {'attn_vec_einsum': {'w': 2},
    'kv_einsum': {'w': 2},
    'q_einsum': {'w': 2}},
   'mlp': {'gating_einsum': 2, 'linear': 2},
   'post_attention_norm': {'scale': -1},
   'post_ffw_norm': {'scale': -1},
   'pre_attention_norm': {'scale': -1},
   'pre_ffw_norm': {'scale': -1}},
  'layer_10': {'attn': {'attn_vec_einsum': {'w': 2},
    'kv_einsum': {'w': 2},
    'q_einsum': {'w': 2}},
   'mlp': {'gating_einsum': 2, 'linear': 2},
   'post_attention_norm': {'scale': -1},
   'post_ffw_norm': {'scale': -1},
   'pre_attention_norm': {'scale': -1},
   'pre_ffw_norm': {'scale': -1}},
  'layer_

In [74]:
training_cfg = TrainingConfig(learning_rate=1e-4,
                              num_epochs=1,
                              eval_every_n=20,
                              batch_size=1,
                              max_steps=10)

params = train_loop(lora_model=lora_model,
                    params={'transformer': lora_params},
                    train_dataloader=train_dataloader,
                    tokenizer=tokenizer,
                    training_cfg=training_cfg, 
                   lora_spec=lora_spec)

ValueError: Dict key mismatch; expected keys: ['embedder', 'final_norm', 'layer_0', 'layer_1', 'layer_10', 'layer_11', 'layer_12', 'layer_13', 'layer_14', 'layer_15', 'layer_16', 'layer_17', 'layer_18', 'layer_19', 'layer_2', 'layer_20', 'layer_21', 'layer_22', 'layer_23', 'layer_24', 'layer_25', 'layer_3', 'layer_4', 'layer_5', 'layer_6', 'layer_7', 'layer_8', 'layer_9']; dict: {'transformer': {'embedder': {'input_embedding': LoraWeight(shape=(256128, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.02199156,  0.0049439 , -0.01322838, ..., -0.01093095,
         0.01156266,  0.01847863],
       [ 0.01022545, -0.01507886, -0.0023296 , ...,  0.00661303,
         0.02698121,  0.00401224]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'final_norm': {'scale': Array([2.32812, 2.34375, 2.28125, ..., 4.65625, 2.53125, 2.4375],      dtype=bfloat16)}, 'layer_0': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 1.3663782e-02, -6.0665291e-03,  1.4391938e-02, ...,
        -7.4021826e-03, -2.7578839e-03, -8.8478941e-03],
       [ 1.8342493e-02,  3.5765569e-03,  2.9392069e-05, ...,
        -1.2318562e-02,  2.7163564e-03,  2.3088796e-02]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([-0.53125, -0.515625, -0.490234, ..., -0.53125, 1.42188, -0.519531],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([-0.229492, -0.189453, -0.194336, ..., -0.361328, 0.441406,
       -0.162109], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.116699, 0.134766, 0.192383, ..., 0.636719, 0.0402832, 0.243164],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.227539, 0.208008, 0.208008, ..., 0.992188, 2.15625, 0.197266],      dtype=bfloat16)}}, 'layer_1': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.01047093,  0.01922145, -0.01181239, ...,  0.0187807 ,
        -0.002852  ,  0.01409159],
       [-0.01944142,  0.00886716, -0.00052842, ..., -0.02301286,
         0.00416193, -0.00264629]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([-0.507812, -0.46875, -0.466797, ..., -0.503906, 0.102539,
       -0.498047], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([-0.0354004, 0.0598145, 0.043457, ..., -0.202148, 0.308594,
       0.113281], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.648438, 0.589844, 0.640625, ..., 1.27344, 0.229492, 0.5625],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.498047, 0.570312, 0.535156, ..., 1.22656, 0.361328, 0.462891],      dtype=bfloat16)}}, 'layer_10': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.00627691, -0.01385751, -0.00372067, ..., -0.01245145,
        -0.01636037,  0.00584008],
       [-0.01115024, -0.00439675,  0.00728602, ..., -0.01480319,
         0.01306054,  0.00127973]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([0.244141, 0.330078, 0.421875, ..., 0.183594, 0.165039, 0.129883],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.496094, 0.535156, 0.570312, ..., 0.429688, 0.511719, 0.371094],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.699219, 0.773438, 0.65625, ..., 0.820312, 0.0186768, 0.671875],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.12793, 0.229492, 0.152344, ..., 0.227539, -0.212891, 0.125977],      dtype=bfloat16)}}, 'layer_11': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.01029943, -0.01047805,  0.00627802, ...,  0.00658243,
         0.00416735, -0.00188006],
       [-0.00780879, -0.00588758,  0.00545886, ...,  0.00931677,
        -0.00922768,  0.00010113]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([0.417969, 0.449219, 0.496094, ..., 0.236328, 0.0996094, 0.226562],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.570312, 0.636719, 0.699219, ..., 0.523438, 0.554688, 0.421875],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.648438, 0.875, 0.738281, ..., 0.773438, -0.0170898, 0.667969],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.074707, 0.15625, 0.0830078, ..., 0.185547, -0.205078, 0.0991211],      dtype=bfloat16)}}, 'layer_12': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.00788363,  0.00197184,  0.00437608, ...,  0.01334583,
         0.00926806,  0.01390658],
       [ 0.00132567, -0.00222159, -0.00384167, ..., -0.02493023,
         0.00545663, -0.00319182]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([0.416016, 0.648438, 0.589844, ..., 0.335938, 0.133789, 0.298828],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.644531, 0.695312, 0.757812, ..., 0.589844, 0.589844, 0.507812],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.933594, 1.07812, 0.972656, ..., 0.863281, 0.0952148, 0.957031],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.0192871, 0.0976562, 0.0195312, ..., 0.10498, -0.239258,
       0.0118408], dtype=bfloat16)}}, 'layer_13': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.01005734,  0.01642189,  0.01510733, ...,  0.01328816,
         0.00516436,  0.00271901],
       [ 0.01016035, -0.00417797, -0.00903327, ..., -0.00616623,
        -0.00244164, -0.01308654]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([0.742188, 0.71875, 0.671875, ..., 0.597656, 0.515625, 0.494141],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.875, 0.867188, 0.953125, ..., 0.859375, 0.839844, 0.679688],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.722656, 0.894531, 0.761719, ..., 0.726562, 0.0529785, 0.789062],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([-0.126953, -0.0751953, -0.126953, ..., -0.0358887, -0.269531,
       -0.0932617], dtype=bfloat16)}}, 'layer_14': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.00014633,  0.00200139, -0.00434174, ..., -0.001957  ,
        -0.01323679, -0.00078773],
       [ 0.00267099, -0.01303148,  0.02141839, ..., -0.00115613,
        -0.01118799, -0.00374962]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([0.882812, 0.957031, 0.742188, ..., 0.757812, 0.648438, 0.628906],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([1.14844, 1.125, 1.20312, ..., 1.07031, 1.04688, 0.898438],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.878906, 1.07812, 0.859375, ..., 0.964844, 0.373047, 0.929688],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([-0.185547, -0.140625, -0.193359, ..., -0.111328, -0.271484,
       -0.130859], dtype=bfloat16)}}, 'layer_15': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.00193953,  0.00879681, -0.02188212, ...,  0.00268727,
         0.00692501, -0.02095343],
       [-0.01345678,  0.01583269,  0.00739929, ...,  0.00684354,
        -0.01523105,  0.00528662]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([1.125, 1.09375, 1.17969, ..., 1.22656, 0.921875, 0.847656],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([1.42188, 1.45312, 1.50781, ..., 1.46094, 1.53125, 1.23438],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.589844, 0.570312, 0.498047, ..., 0.628906, 0.259766, 0.570312],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([-0.176758, -0.133789, -0.176758, ..., -0.145508, -0.251953,
       -0.133789], dtype=bfloat16)}}, 'layer_16': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.00321405, -0.01531544,  0.01043977, ...,  0.01724594,
        -0.00019943, -0.00515615],
       [-0.00242979, -0.00892317, -0.00779914, ...,  0.01742655,
         0.00521834,  0.00618159]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([1.25781, 1.17969, 1.125, ..., 0.96875, 1.0625, 0.910156], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([1.8125, 1.82812, 1.90625, ..., 1.71094, 1.84375, 1.57031],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.871094, 0.992188, 0.765625, ..., 0.648438, 0.5625, 0.960938],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([-0.108398, -0.0561523, -0.0947266, ..., -0.074707, -0.150391,
       -0.0505371], dtype=bfloat16)}}, 'layer_17': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.00229674,  0.00096951,  0.00196753, ...,  0.00066601,
         0.00454113, -0.01169481],
       [-0.00571106, -0.00023862, -0.00663244, ...,  0.00123321,
        -0.01802823,  0.00163148]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([1.17188, 1.17188, 1.03125, ..., 1.05469, 0.980469, 1.01562],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([2.125, 2.04688, 2.1875, ..., 2, 2.15625, 1.875], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.535156, 0.675781, 0.472656, ..., 0.484375, 0.423828, 0.667969],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([-0.0554199, -0.00891113, -0.0585938, ..., -0.0483398, -0.0942383,
       -0.000480652], dtype=bfloat16)}}, 'layer_18': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.01076485, -0.01254259,  0.00719575, ..., -0.01088936,
         0.02172294,  0.00355693],
       [-0.00358216, -0.01947643, -0.00234379, ..., -0.00571739,
         0.02209909,  0.01588861]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([1.36719, 1.41406, 1.35938, ..., 1.39062, 1.60938, 1.26562],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([2.46875, 2.42188, 2.48438, ..., 2.34375, 2.53125, 2.3125],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.443359, 0.570312, 0.347656, ..., 0.453125, 0.388672, 0.5625],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.160156, 0.269531, 0.191406, ..., 0.158203, 0.11377, 0.259766],      dtype=bfloat16)}}, 'layer_19': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.01803358, -0.00303927,  0.00691553, ...,  0.00371117,
         0.00372419,  0.00021878],
       [ 0.00078573, -0.00061509, -0.0201269 , ...,  0.01225177,
         0.00519481,  0.00182117]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([1.25, 1.07812, 1.29688, ..., 1.35156, 1.29688, 1.14062], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([2.64062, 2.625, 2.73438, ..., 2.51562, 2.625, 2.3125], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.414062, 0.59375, 0.5, ..., 0.451172, 0.375, 0.554688], dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.175781, 0.236328, 0.203125, ..., 0.168945, 0.128906, 0.248047],      dtype=bfloat16)}}, 'layer_2': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.01155356,  0.00352542, -0.00475057, ...,  0.00461095,
         0.00289454, -0.0215128 ],
       [ 0.00688512, -0.00726845, -0.00535207, ...,  0.00995815,
         0.00537037,  0.00453264]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([-0.232422, -0.229492, -0.160156, ..., -0.414062, 0.0179443,
       -0.265625], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.131836, 0.105469, 0.139648, ..., -0.141602, 0.326172, 0.135742],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.714844, 0.667969, 0.71875, ..., 1.07812, 0.300781, 0.515625],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.5, 0.511719, 0.53125, ..., 0.941406, 0.00320435, 0.40625],      dtype=bfloat16)}}, 'layer_20': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.00383564, -0.00756023, -0.01255484, ...,  0.01498562,
        -0.01903302,  0.02186585],
       [-0.00933835,  0.0013327 ,  0.01779663, ..., -0.00167542,
         0.00606556, -0.00059748]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([1.39062, 1.32031, 1.55469, ..., 1.32031, 1.50781, 1.29688],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([2.84375, 2.875, 2.90625, ..., 2.64062, 2.78125, 2.5625], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.359375, 0.330078, 0.347656, ..., 0.306641, 0.318359, 0.5],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.0220947, 0.0688477, 0.0303955, ..., 0.0279541, 0.0133057,
       0.108887], dtype=bfloat16)}}, 'layer_21': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.00473069, -0.01229254,  0.00118823, ...,  0.00760194,
        -0.01410474, -0.00785806],
       [-0.00138563, -0.00167912,  0.01098308, ...,  0.01268801,
         0.0018889 , -0.00222404]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([1.34375, 1.36719, 1.53906, ..., 1.33594, 1.35156, 1.24219],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([3.1875, 3.34375, 3.375, ..., 3.04688, 3.32812, 3.03125], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.421875, 0.402344, 0.400391, ..., 0.457031, 0.480469, 0.53125],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([-0.0605469, -0.0107422, -0.0512695, ..., -0.0512695, -0.0771484,
       0.0262451], dtype=bfloat16)}}, 'layer_22': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.00068823,  0.00820747,  0.00691753, ..., -0.01032871,
         0.01214039, -0.01452747],
       [-0.00807814, -0.00087587,  0.01156158, ..., -0.00500905,
         0.0198493 ,  0.00917682]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([2.54688, 2.875, 2.70312, ..., 2.375, 2.73438, 2.35938], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([3.73438, 3.75, 3.84375, ..., 3.57812, 3.89062, 3.48438], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.232422, 0.261719, 0.180664, ..., 0.330078, 0.298828, 0.355469],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([-0.090332, -0.0563965, -0.0976562, ..., -0.104004, -0.11084,
       -0.0168457], dtype=bfloat16)}}, 'layer_23': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.00148038,  0.0081275 ,  0.0149825 , ...,  0.00171928,
        -0.0045339 ,  0.00112129],
       [-0.01026383,  0.0025746 , -0.02289991, ..., -0.00131096,
         0.00538384, -0.00893523]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([2.21875, 2.28125, 2.20312, ..., 2.07812, 2.29688, 2.03125],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([4.15625, 4.28125, 4.34375, ..., 4.25, 4.53125, 4], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.306641, 0.337891, 0.300781, ..., 0.367188, 0.396484, 0.429688],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([-0.158203, -0.146484, -0.166016, ..., -0.147461, -0.173828,
       -0.10791], dtype=bfloat16)}}, 'layer_24': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.0118248 ,  0.00209964,  0.00429165, ...,  0.00531876,
         0.0007639 , -0.01389843],
       [ 0.00161521,  0.01745081, -0.0089213 , ...,  0.01427022,
        -0.00016093,  0.01795714]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([2.5625, 2.40625, 2.51562, ..., 2.625, 2.76562, 2.46875], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([4.75, 4.78125, 4.90625, ..., 5.59375, 5.34375, 4.6875], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.248047, 0.223633, 0.21582, ..., 0.345703, 0.259766, 0.339844],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([-0.188477, -0.169922, -0.19043, ..., -0.158203, -0.195312,
       -0.142578], dtype=bfloat16)}}, 'layer_25': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.005503  , -0.00465591,  0.02688703, ...,  0.00167953,
        -0.00383867,  0.01666052],
       [ 0.00788943, -0.01279993, -0.01336973, ...,  0.00526251,
        -0.00046879,  0.01022278]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([2.39062, 2.375, 2.32812, ..., 3.42188, 2.84375, 2.29688], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([4.6875, 4.84375, 4.9375, ..., 5.9375, 4.5625, 4.6875], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.296875, 0.261719, 0.265625, ..., 0.380859, 0.296875, 0.373047],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([-0.136719, -0.134766, -0.133789, ..., -0.0761719, -0.129883,
       -0.105957], dtype=bfloat16)}}, 'layer_3': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.00757848,  0.00208123, -0.0095167 , ...,  0.01346242,
         0.01381325, -0.00154688],
       [-0.01088316, -0.00166531, -0.02113895, ..., -0.00199354,
         0.00065937,  0.00700611]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([-0.166992, -0.209961, -0.139648, ..., -0.363281, -0.0634766,
       -0.261719], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.361328, 0.380859, 0.306641, ..., 0.0698242, 0.480469, 0.386719],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.617188, 0.453125, 0.699219, ..., 0.785156, 0.363281, 0.527344],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.859375, 0.824219, 0.894531, ..., 1.19531, 0.0356445, 0.714844],      dtype=bfloat16)}}, 'layer_4': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[ 0.00232868,  0.00122111, -0.02715557, ...,  0.00275968,
         0.0063306 ,  0.00120457],
       [ 0.01015269, -0.00104287, -0.01713954, ...,  0.01712951,
        -0.02799168, -0.01150994]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([-0.193359, -0.28125, -0.21582, ..., -0.378906, -0.0407715,
       -0.332031], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.285156, 0.245117, 0.246094, ..., 0.0476074, 0.613281, 0.271484],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.652344, 0.507812, 0.613281, ..., 0.824219, -0.0534668, 0.363281],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.554688, 0.570312, 0.554688, ..., 0.851562, 0.124512, 0.453125],      dtype=bfloat16)}}, 'layer_5': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.00584595, -0.00180545, -0.00662199, ..., -0.0220709 ,
        -0.00569964, -0.00861686],
       [ 0.00151854, -0.00964406,  0.00366633, ...,  0.00940564,
        -0.00930284,  0.00364371]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([0.0136108, -0.150391, -0.020874, ..., -0.355469, -0.0186768,
       -0.425781], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.0888672, 0.0622559, 0.12793, ..., -0.00866699, 0.337891,
       0.0756836], dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.490234, 0.255859, 0.330078, ..., 0.367188, -0.0544434, 0.046875],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.388672, 0.382812, 0.402344, ..., 0.597656, -0.0415039, 0.236328],      dtype=bfloat16)}}, 'layer_6': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.0155806 , -0.01371003, -0.01509875, ...,  0.01296713,
         0.00798594,  0.00760426],
       [-0.00061655,  0.01749108, -0.01459254, ...,  0.01223726,
        -0.01537374,  0.00116413]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([0.0039978, -0.124512, -0.0678711, ..., -0.241211, 0.00665283,
       -0.359375], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.106934, 0.0874023, 0.219727, ..., 0.0168457, 0.365234, 0.10498],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([1.10156, 0.796875, 0.84375, ..., 0.988281, 0.0463867, 0.369141],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.435547, 0.447266, 0.443359, ..., 0.636719, -0.0673828, 0.363281],      dtype=bfloat16)}}, 'layer_7': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.00536764,  0.01348992, -0.00598283, ..., -0.01268126,
        -0.01520117, -0.01440818],
       [ 0.00239141,  0.01290342,  0.00408962, ...,  0.00013094,
         0.00600098, -0.00192806]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([0.0143433, -0.0620117, 0.0375977, ..., -0.0898438, 0.110352,
       -0.135742], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.191406, 0.195312, 0.308594, ..., 0.118164, 0.625, 0.196289],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.726562, 0.589844, 0.59375, ..., 0.789062, -0.155273, 0.455078],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.542969, 0.582031, 0.550781, ..., 0.742188, -0.00248718, 0.455078],      dtype=bfloat16)}}, 'layer_8': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.00295627, -0.00559398, -0.00677068, ...,  0.01155958,
        -0.01056946, -0.01427677],
       [ 0.02465803,  0.01395147, -0.01297336, ..., -0.0011971 ,
        -0.00226062, -0.00428502]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([0.120117, 0.132812, 0.211914, ..., 0.0247803, 0.125977, 0.0432129],      dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.277344, 0.273438, 0.369141, ..., 0.199219, 0.644531, 0.271484],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([1.14062, 1.02344, 1.03906, ..., 1.14062, 0.0703125, 0.894531],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.601562, 0.675781, 0.648438, ..., 0.785156, 0.0917969, 0.589844],      dtype=bfloat16)}}, 'layer_9': {'attn': {'attn_vec_einsum': {'w': LoraWeight(shape=(8, 256, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[-0.00212097, -0.00132751],
        [-0.00643921, -0.0130615],
        [-0.00396729, -0.000341415],
        ...,
        [-0.0174561, 0.0014267],
        [-0.00152588, -0.00854492],
        [0.0148315, 0.00283813]],

       [[-0.0050354, 0.00366211],
        [0.00344849, -0.0195312],
        [0.00680542, 0.000835419],
        ...,
        [-0.00692749, -0.0195312],
        [-0.00375366, -0.00375366],
        [0.000246048, 0.0101929]],

       [[0.00732422, -0.0130615],
        [0.0251465, 0.0119629],
        [-0.000341415, 0.00325012],
        ...,
        [-0.000146866, -0.00598145],
        [-0.0125732, -0.00273132],
        [-0.00375366, -0.000341415]],

       ...,

       [[0.00897217, -0.00741577],
        [-0.00312805, -0.00460815],
        [-0.00482178, 0.00325012],
        ...,
        [0.00897217, -0.00273132],
        [0.0115967, -0.00334167],
        [-0.0114136, 0.00515747]],

       [[0.0162354, -0.0166016],
        [0.00408936, 0.0018158],
        [-0.0166016, -0.0233154],
        ...,
        [-0.00668335, -0.00769043],
        [0.0119629, -0.00878906],
        [0.0124512, -0.0100098]],

       [[-0.00668335, -0.00854492],
        [0.0078125, -0.0211182],
        [-0.00692749, 0.0078125],
        ...,
        [-0.00460815, 0.000246048],
        [0.00122833, 0.00610352],
        [0.00680542, 0.000835419]]], dtype=bfloat16), alpha=1.0)}, 'kv_einsum': {'w': LoraWeight(shape=(2, 4, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]],      dtype=bfloat16), b=Array([[[[0.00515747, 0.00585938],
         [0.00866699, 0.0045166],
         [-0.0289307, 0.0133057],
         ...,
         [0.00262451, -0.00439453],
         [-0.00622559, 0.0119629],
         [0.0203857, -0.00273132]],

        [[-0.00334167, 0.00202942],
         [-0.0117798, 0.00122833],
         [0.0045166, -0.00823975],
         ...,
         [0.0124512, 0.0108643],
         [-0.000934601, -0.00909424],
         [0.00387573, 0.00366211]],

        [[0.00430298, -0.00231934],
         [0.00162506, 0.00262451],
         [0.0178223, -0.000146866],
         ...,
         [0.00897217, -0.00552368],
         [0.0128174, -0.00970459],
         [-0.00132751, -0.00172424]],

        [[-0.0050354, 0.0155029],
         [0.00585938, -0.00552368],
         [-0.000341415, -0.000541687],
         ...,
         [0.00634766, -0.00622559],
         [0.00561523, -0.00396729],
         [-0.00692749, -0.0140991]]],


       [[[0.0133057, 0.0119629],
         [0.000246048, -0.0211182],
         [-0.00112915, 0.0115967],
         ...,
         [0.0189209, 0.000835419],
         [0.00430298, 0.00515747],
         [-0.00769043, 0.0148315]],

        [[-0.00823975, -0.0146484],
         [-0.00334167, 0.00344849],
         [0.00732422, -0.0025177],
         ...,
         [0.00610352, -0.0117798],
         [0.00387573, -0.00439453],
         [0.00473022, -0.00172424]],

        [[-0.00292969, -0.00799561],
         [-0.0107422, -0.00552368],
         [-0.00273132, -0.0233154],
         ...,
         [-0.00396729, -0.00854492],
         [-0.00769043, -0.00482178],
         [0.00325012, -0.0050354]],

        [[0.000246048, 0.000637054],
         [-0.00334167, 0.000637054],
         [0.0112305, -0.00439453],
         ...,
         [0.0128174, 0.00927734],
         [0.0030365, -0.00643921],
         [0.00811768, -0.000934601]]]], dtype=bfloat16), alpha=1.0)}, 'q_einsum': {'w': LoraWeight(shape=(8, 2304, 256), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00515747, 0.00585938],
        [0.00866699, 0.0045166],
        [-0.0289307, 0.0133057],
        ...,
        [0.00262451, -0.00439453],
        [-0.00622559, 0.0119629],
        [0.0203857, -0.00273132]],

       [[-0.00334167, 0.00202942],
        [-0.0117798, 0.00122833],
        [0.0045166, -0.00823975],
        ...,
        [0.0124512, 0.0108643],
        [-0.000934601, -0.00909424],
        [0.00387573, 0.00366211]],

       [[0.00430298, -0.00231934],
        [0.00162506, 0.00262451],
        [0.0178223, -0.000146866],
        ...,
        [0.00897217, -0.00552368],
        [0.0128174, -0.00970459],
        [-0.00132751, -0.00172424]],

       ...,

       [[-0.00823975, -0.0146484],
        [-0.00334167, 0.00344849],
        [0.00732422, -0.0025177],
        ...,
        [0.00610352, -0.0117798],
        [0.00387573, -0.00439453],
        [0.00473022, -0.00172424]],

       [[-0.00292969, -0.00799561],
        [-0.0107422, -0.00552368],
        [-0.00273132, -0.0233154],
        ...,
        [-0.00396729, -0.00854492],
        [-0.00769043, -0.00482178],
        [0.00325012, -0.0050354]],

       [[0.000246048, 0.000637054],
        [-0.00334167, 0.000637054],
        [0.0112305, -0.00439453],
        ...,
        [0.0128174, 0.00927734],
        [0.0030365, -0.00643921],
        [0.00811768, -0.000934601]]], dtype=bfloat16), alpha=1.0)}}, 'mlp': {'gating_einsum': LoraWeight(shape=(2, 2304, 9216), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=bfloat16), b=Array([[[0.00262451, -0.00668335],
        [0.0014267, -0.00720215],
        [0.00708008, 0.0142822],
        ...,
        [-0.000341415, 0.00408936],
        [-0.00439453, 0.0108643],
        [0.0220947, 0.0203857]],

       [[-0.00527954, 0.00283813],
        [0.0101929, 0.00927734],
        [0.00836182, -0.0211182],
        ...,
        [-0.000341415, -0.000341415],
        [0.00221252, 0.0078125],
        [-0.0050354, 0.0124512]]], dtype=bfloat16), alpha=1.0), 'linear': LoraWeight(shape=(9216, 2304), dtype=dtype(bfloat16), w=MaskedNode(), a=Array([[-0.0046731 ,  0.00286664,  0.00511242, ...,  0.01160862,
        -0.00942933, -0.01301723],
       [ 0.0052865 ,  0.00223007, -0.01150817, ..., -0.00929926,
        -0.01117802,  0.00230565]], dtype=float32), b=Array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32), alpha=1.0)}, 'post_attention_norm': {'scale': Array([0.107422, 0.113281, 0.164062, ..., 0.0292969, 0.0986328,
       -0.0441895], dtype=bfloat16)}, 'post_ffw_norm': {'scale': Array([0.484375, 0.488281, 0.539062, ..., 0.390625, 0.65625, 0.408203],      dtype=bfloat16)}, 'pre_attention_norm': {'scale': Array([0.953125, 0.917969, 0.871094, ..., 1.10156, 0.15918, 0.871094],      dtype=bfloat16)}, 'pre_ffw_norm': {'scale': Array([0.267578, 0.339844, 0.296875, ..., 0.400391, -0.120117, 0.251953],      dtype=bfloat16)}}}}.