In [8]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from flax.training import checkpoints
from flax.traverse_util import flatten_dict
from scipy.stats import kurtosis
import optax

from train import TrainConfig, init_train_state, get_default_config


In [9]:
config = get_default_config()
num_heads = config.model.num_heads
hidden_size = config.model.num_embeds
print(config)

TrainConfig(seed=555, out_dir='out', train_pattern='openwebtext/train_??.tfrecord', val_pattern='openwebtext/val_??.tfrecord', shuffle_buffer_size=128, eval_interval=1000, eval_steps=50, eval_only=False, keep_checkpoints=6, batch_size=16, train_steps=150000, weight_decay=0.1, grad_clip=1.0, gradient_accumulation_steps=1, betas=[0.9, 0.95], learning_rate=CosineDecayScheduleConfig(init_value=0.0, peak_value=0.00064, warmup_steps=1000, decay_steps=150000, end_value=6.4e-05), wandb=WandbConfig(entity='jenkspt', project='owt', name='gpt-124m', mode='online', notes=''), model=GPTConfig(block_size=1024, vocab_size=50304, num_layers=12, num_heads=12, num_embeds=768, dropout_rate=0.0, use_bias=True, dtype='bfloat16'), remat=False)


In [10]:
def load_checkpoint(config, ckpt_dir):
    key = jax.random.PRNGKey(config.seed)
    learning_rate = optax.warmup_cosine_decay_schedule(**vars(config.learning_rate))
    train_state = init_train_state(key, config, learning_rate)
    train_state = checkpoints.restore_checkpoint(ckpt_dir, train_state)
    print("Loaded step:", int(train_state.step))
    return train_state

In [11]:
import os
import jax
import jax.numpy as jnp
import numpy as np
import plotly.express as px
from scipy.stats import kurtosis, skew
import plotly.graph_objects as go
from collections import defaultdict

In [12]:
class DistributionMetric:
    def __init__(self, name, compute_fn, requires_matrix=False):
        self.name = name
        self.compute_fn = compute_fn
        self.requires_matrix = requires_matrix


METRIC_REGISTRY = {}

def register_metric(metric: DistributionMetric):
    METRIC_REGISTRY[metric.name] = metric

In [13]:
# ---- Basic distribution metrics ----

def metric_mean(w):
    return float(np.mean(w))

def metric_std(w):
    return float(np.std(w))

def metric_skew(w):
    return float(skew(w))

def metric_kurtosis(w):
    return float(kurtosis(w, fisher=True))


# ---- Norm based ----

def metric_l2_norm(w):
    return float(np.linalg.norm(w))


# ---- Matrix based ----

def metric_spectral_norm(w):
    u, s, v = np.linalg.svd(w, full_matrices=False)
    return float(np.max(s))


def metric_effective_rank(w):
    u, s, v = np.linalg.svd(w, full_matrices=False)
    s = s / np.sum(s)
    entropy = -np.sum(s * np.log(s + 1e-8))
    return float(np.exp(entropy))


# Register defaults
register_metric(DistributionMetric("mean", metric_mean))
register_metric(DistributionMetric("std", metric_std))
register_metric(DistributionMetric("skew", metric_skew))
register_metric(DistributionMetric("kurtosis", metric_kurtosis))
register_metric(DistributionMetric("l2_norm", metric_l2_norm))
register_metric(DistributionMetric("spectral_norm", metric_spectral_norm, requires_matrix=True))
register_metric(DistributionMetric("effective_rank", metric_effective_rank, requires_matrix=True))

In [14]:
import numpy as np
from collections import defaultdict


class WeightStatisticsEngine:
    def __init__(self, params, num_heads, hidden_size,
                 sample_size=100_000):
        self.params = params
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.head_dim = hidden_size // num_heads
        self.sample_size = sample_size

    # ---------------------------------
    # Flatten Flax param tree
    # ---------------------------------
    def _flatten_flax_dict(self, d, parent_key=""):
        items = {}
        for k, v in d.items():
            new_key = f"{parent_key}.{k}" if parent_key else k
            if isinstance(v, dict):
                items.update(self._flatten_flax_dict(v, new_key))
            else:
                items[new_key] = v
        return items

    # ---------------------------------
    # Extract layer index
    # ---------------------------------
    def _extract_layer_id(self, key):
        for part in key.split("."):
            if part.isdigit():
                return int(part)
        return None

    # ---------------------------------
    # Metric computation
    # ---------------------------------
    def _compute_metrics(self, tensor, selected_metrics):
        results = {}

        flat = tensor.flatten()

        if self.sample_size and flat.size > self.sample_size:
            idx = np.random.choice(flat.size, self.sample_size, replace=False)
            flat = flat[idx]

        for metric_name in selected_metrics:
            metric = METRIC_REGISTRY[metric_name]

            if metric.requires_matrix:
                value = metric.compute_fn(tensor)
            else:
                value = metric.compute_fn(flat)

            results[metric_name] = float(value)

        return results

    # ---------------------------------
    # MAIN COMPUTE
    # ---------------------------------
    def compute(self, selected_metrics):

        flat_params = self._flatten_flax_dict(self.params)

        results = {
            "layer_stats": defaultdict(dict),
            "head_stats": defaultdict(lambda: defaultdict(dict))
        }

        for key, tensor in flat_params.items():

            if "kernel" not in key:
                continue

            layer_id = self._extract_layer_id(key)
            if layer_id is None:
                continue

            w = np.array(tensor)

            # ---------------------------------
            # 1️⃣ Layer-wise distribution
            # ---------------------------------
            layer_metrics = self._compute_metrics(w, selected_metrics)
            results["layer_stats"][layer_id][key] = layer_metrics

            # ---------------------------------
            # 2️⃣ Head-wise distribution
            # Only for attention projections
            # ---------------------------------
            if "c_attn" not in key:
                continue

            if w.ndim != 2 or w.shape[1] != 3 * self.hidden_size:
                continue

            q, k, v = np.split(w, 3, axis=1)

            for proj_name, proj_matrix in zip(["Q", "K", "V"], [q, k, v]):

                proj_matrix = proj_matrix.reshape(
                    self.hidden_size,
                    self.num_heads,
                    self.head_dim
                )

                for head in range(self.num_heads):
                    head_tensor = proj_matrix[:, head, :]

                    head_metrics = self._compute_metrics(
                        head_tensor,
                        selected_metrics
                    )

                    results["head_stats"][layer_id][proj_name][
                        f"head_{head}"
                    ] = head_metrics

        return results

In [15]:
import numpy as np
import wandb


class LLMDistributionVisualizer:
    def __init__(self, weight_stats=None,
                 log_scale=False, use_wandb=False, project=None):

        self.weight_stats = weight_stats
        self.log_scale = log_scale
        self.use_wandb = use_wandb

        if use_wandb:
            wandb.init(project=project or "llm-distribution")

    # --------------------------------------------------
    # 1️⃣ LAYER-WISE METRICS
    # --------------------------------------------------
    def plot_weight_metric_layerwise(self, metric_name):

        layer_stats = self.weight_stats["layer_stats"]

        layers = sorted(layer_stats.keys())
        values = []

        for l in layers:
            vals = [
                layer_stats[l][k][metric_name]
                for k in layer_stats[l]
                if metric_name in layer_stats[l][k]
            ]
            values.append(np.mean(vals))

        table = wandb.Table(columns=["Layer", metric_name])

        for layer, value in zip(layers, values):
            table.add_data(layer, value)

        wandb.log({
            f"Layer vs {metric_name}":
                wandb.plot.line(
                    table,
                    x="Layer",
                    y=metric_name,
                    title=f"Layer vs {metric_name}"
                )
        })

    # --------------------------------------------------
    # 2️⃣ HEAD-WISE METRICS
    # --------------------------------------------------
    def plot_weight_metric_headwise(self, metric_name, layer_id):

        head_stats = self.weight_stats["head_stats"]

        if layer_id not in head_stats:
            return

        table = wandb.Table(columns=["Head", metric_name, "Projection"])

        for proj in head_stats[layer_id]:  # Q, K, V
            for head in head_stats[layer_id][proj]:
                if metric_name in head_stats[layer_id][proj][head]:
                    value = head_stats[layer_id][proj][head][metric_name]
                    table.add_data(head, value, proj)

        wandb.log({
            f"Layer {layer_id} Head Distribution ({metric_name})":
                wandb.plot.bar(
                    table,
                    "Head",
                    metric_name,
                    title=f"Layer {layer_id} Head Distribution ({metric_name})"
                )
        })

    def plot_projection_metric_headwise_line(self, metric_name, layer_id, projection="Q"):
    
        head_stats = self.weight_stats["head_stats"]

        if layer_id not in head_stats:
            return

        if projection not in head_stats[layer_id]:
            return

        table = wandb.Table(columns=["Head_Index", metric_name])

        # Sort heads numerically (critical)
        heads = sorted(
            head_stats[layer_id][projection].keys(),
            key=lambda x: int(x.split("_")[1])
        )

        for head in heads:
            if metric_name in head_stats[layer_id][projection][head]:
                value = head_stats[layer_id][projection][head][metric_name]
                head_idx = int(head.split("_")[1])
                table.add_data(head_idx, value)

        wandb.log({
            f"Layer {layer_id} {projection} Head Line Plot ({metric_name})":
                wandb.plot.line(
                    table,
                    "Head_Index",
                    metric_name,
                    title=f"Layer {layer_id} {projection} Heads ({metric_name})"
                )
        })

    # --------------------------------------------------
    # RENDER
    # --------------------------------------------------
    def render(self, weight_metrics=None,
               head_layers=None):

        if self.weight_stats and weight_metrics:
            for m in weight_metrics:
                self.plot_weight_metric_layerwise(m)

            if head_layers:
                for m in weight_metrics:
                    for layer_id in head_layers:
                        # self.plot_weight_metric_headwise(m, layer_id)
                        self.plot_projection_metric_headwise_line(m, layer_id, projection="Q")
                        self.plot_projection_metric_headwise_line(m, layer_id, projection="K")
                        self.plot_projection_metric_headwise_line(m, layer_id, projection="V")


In [16]:
def visualize_model_distributions(
    params,
    num_heads,
    hidden_size,
    stats=("kurtosis", "std"),
    log_scale=False,
    head_layers=None,          # NEW: which layers for head plots
    sample_size=100_000
):
    """
    Visualizes:
        1) Weight distribution across transformer blocks
        2) Weight distribution across attention heads
        3) Activation statistics (optional)

    Args:
        params: Flax model params
        num_heads: number of attention heads
        hidden_size: model hidden dimension
        head_layers: list of layer indices for head-level plots
    """

    # --------------------------------------------------
    # 1️⃣ Weight Statistics
    # --------------------------------------------------
    weight_engine = WeightStatisticsEngine(
        params=params,
        num_heads=num_heads,
        hidden_size=hidden_size,
        sample_size=sample_size
    )

    weight_stats = weight_engine.compute(stats)
    # --------------------------------------------------
    # 3️⃣ Visualization
    # --------------------------------------------------
    viz = LLMDistributionVisualizer(
        weight_stats=weight_stats,
        log_scale=log_scale,
        use_wandb=True
    )

    viz.render(
        weight_metrics=stats,
        head_layers=head_layers
    )

In [17]:
checkpoint_base = "/kaggle/input/notebooks/pankajkumar2002/gpt-flax-openweeb/gpt-jax/out/checkpoints/train_state/"

In [18]:
# for dir_ in  os.listdir(checkpoint_base):
relative_path = "/home/batman/git/gpt-jax/out/checkpoints/checkpoint_101000"
train_state = load_checkpoint(config, relative_path)
params = train_state.params
visualize_model_distributions(
params,
    num_heads=num_heads,
    hidden_size=hidden_size,
    stats=["kurtosis", "std", "spectral_norm"],
    log_scale=True,
    head_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
)

=== Source Location Trace: ===
learning/45eac/tfrc/runtime/libtpu_init_utils.cc:310


Loaded step: 101000


[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /home/batman/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33mimpankaj[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
