In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import safetensors
import torch
from matplotlib.colors import Normalize
from matplotlib.patches import Rectangle
import torch.nn.functional as F

In [9]:
# exp 55
# base_path = "/fsx/phuc/new_workspace/experiments/infini_attention_8b_llama/exp55_8b_llama_16384_ctx_length_and_8192_segment_length_and_1.3m_bs_and_global_lr_1.0e-5_and_balance_factor_lr_0.001/ckp_for_evals/20000/model/model/decoder"
# num_layers = 32
# tp_world_size = 4
# ckp_name = "exp55: 8b llama3 at 20k ckp"

# exp 51
# base_path = "/fsx/phuc/new_workspace/experiments/infini_attention_8b_llama/exp51_200m_infini_llama2_256_ctx_length_and_64_segment_length_and_2m_bs_and_global_lr_0.0000375_and_balance_factor_lr_0.00015/checkpoints/10000/model/model/decoder"
# num_layers = 5
# tp_world_size = 2
# ckp_name = "exp51: 200m llama at 10k ckp"

base_path = "/fsx/phuc/new_workspace/experiments/exp57_8b_llama_1024_ctx_length_and_64_segment_length_and_100k_bs_and_global_lr_1.0e-5_and_balance_factor_lr_0.01_and_balance_factor_0_weight_decay/checkpoints/18000/model/model/decoder"
num_layers = 32
tp_world_size = 4
ckp_name = "exp57: 8b llama3, 1024 ctxlen, 64 segment len, 100k bs, 0.01 balance factor lr, no weight decay for balance factors"

In [10]:
merged_tensors = []

for layer_idx in range(num_layers):  # 0 to 31
    layer_tensors = []
    for tp_rank in range(tp_world_size):
        file_path = os.path.join(base_path, f"{layer_idx}/pp_block/attn/model_balance_factors_pp-rank-0-of-1_tp-rank-{tp_rank}-of-{tp_world_size}.safetensors")
        
        if os.path.exists(file_path):
            tensor = safetensors.safe_open(file_path, framework="pt", device="cpu")
            tensor_data = tensor.get_tensor("data").to(torch.float32).numpy()
            layer_tensors.append(tensor_data)
    
    if layer_tensors:
        merged_tensor = np.concatenate(layer_tensors)
        merged_tensors.append(merged_tensor)

In [11]:
merged_tensors

[array([-2.015625  , -2.        , -2.703125  , -2.75      , -1.7578125 ,
        -1.8515625 , -2.        , -1.6328125 , -1.7421875 , -0.70703125,
        -1.984375  , -0.859375  , -1.328125  , -1.1640625 , -0.9296875 ,
        -0.34375   , -1.3984375 , -0.90234375, -1.3125    , -1.28125   ,
        -1.1953125 , -1.1796875 , -1.296875  , -1.640625  , -2.0625    ,
        -1.5078125 , -1.9453125 , -1.421875  , -1.9375    , -2.        ,
        -2.        , -1.984375  ], dtype=float32),
 array([-1.734375  , -1.59375   , -1.171875  , -1.515625  , -1.984375  ,
        -2.        , -1.625     , -1.140625  , -0.78125   , -0.59765625,
        -1.7109375 , -1.0546875 ,  0.23535156, -1.984375  , -1.890625  ,
        -1.4609375 , -2.        , -1.4453125 , -2.        , -1.6171875 ,
        -2.        , -2.        , -1.03125   , -1.3203125 , -2.        ,
        -1.9765625 , -2.015625  , -1.734375  , -0.40625   , -0.70703125,
        -1.015625  , -1.203125  ], dtype=float32),
 array([-2.046875  , -

In [12]:
num_heads = len(merged_tensors[0])
global_weights = F.sigmoid(torch.tensor(merged_tensors))

In [13]:
num_heads

32

### Distribution

In [14]:
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Assuming you have your n-dimensional tensor as 'tensor'
# Flatten the tensor
flat_tensor = global_weights.flatten()

# Create bins
bins = np.arange(0, 1.1, 0.1)

# Calculate histogram
hist, bin_edges = np.histogram(flat_tensor, bins=bins)

# Create labels for pie chart
labels = [f'{bins[i]:.1f}-{bins[i+1]:.1f}' for i in range(len(bins)-1)]

# Create subplot figure
fig = make_subplots(rows=1, cols=2, specs=[[{'type':'xy'}, {'type':'domain'}]])

# Add histogram
fig.add_trace(go.Histogram(x=flat_tensor, xbins=dict(start=0, end=1, size=0.1), 
                           name='Histogram'), row=1, col=1)

# Add pie chart
fig.add_trace(go.Pie(labels=labels, values=hist, name='Distribution'), row=1, col=2)

# Update layout
fig.update_layout(title_text=f"[{ckp_name}] Global weights's distribution",
                  xaxis_title_text='Value',
                  yaxis_title_text='Frequency',
                  height=600, width=1000)

# Update axes
fig.update_xaxes(range=[0, 1], row=1, col=1)

# Show plot
fig.show()

In [15]:
import plotly.graph_objects as go
import torch
import torch.nn.functional as F
import numpy as np

# Assuming merged_tensors, num_layers, and num_heads are already defined

# Create a 2D array of values
values = np.zeros((num_layers, num_heads))
for layer in range(num_layers):
    for head in range(num_heads):
        values[layer, head] = global_weights[layer, head].item()

# Create the heatmap
fig = go.Figure(data=go.Heatmap(
    z=values,
    x=[f'Head {i}' for i in range(num_heads)],
    y=[f'Layer {i}' for i in range(num_layers)],
    hoverongaps = False,
    colorscale='Viridis',
    colorbar=dict(title='Value'),
    zmin=0,
    zmax=1
))

# Add text annotations
for layer in range(num_layers):
    for head in range(num_heads):
        value = values[layer, head]
        fig.add_annotation(
            x=head,
            y=layer,
            text=f'{value:.3f}',
            showarrow=False,
            font=dict(
                color='white' if value > 0.5 else 'black',
                size=8
            )
        )

# Update layout
fig.update_layout(
    title=f"[{ckp_name}] Global Weights Across Layers and Heads",
    xaxis_title='Head',
    yaxis_title='Layer',
    width=1000,
    height=750,
    yaxis=dict(autorange='reversed')  # To have layer 0 at the top
)

# Show the plot
fig.show()

In [21]:
(global_weights.flatten() < 0.5).float().mean()

tensor(0.8604)

In [23]:
percentage_smaller_0_5 = (global_weights.flatten() < 0.5).float().mean() * 100
percentage_larger_equal_0_5 = (global_weights.flatten() >= 0.5).float().mean() * 100

In [24]:
percentage_larger_equal_0_5

tensor(13.9648)

In [25]:
percentage_smaller_0_5

tensor(86.0352)