In [None]:
# Cell 1: Initialize ViT-B/16 from scratch, optionally convert LN -> DynamicTanh (DyT)

import os
from types import SimpleNamespace

import torch
from timm.models import create_model

from datasets import build_dataset  # uses repo's timm create_transform + ImageNet-style norms
from dynamic_tanh import convert_ln_to_dyt

# --- user knobs ---
MODEL_NAME = "vit_base_patch16_224"   # timm model name
USE_DYT = True                        # toggle DyT on/off
BATCH_SIZE = 4
DATA_SET = "IMNET"                    # 'IMNET' | 'CIFAR' | 'image_folder'
DATA_PATH = os.environ.get("IMAGENET_PATH", "/path/to/imagenet")  # expects train/val subfolders for IMNET
EVAL_DATA_PATH = None                 # only used for data_set='image_folder'

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# match repo init style (main.py): create_model(..., pretrained=False, num_classes, global_pool='avg', drop_path_rate)
model = create_model(
    MODEL_NAME,
    pretrained=False,
    num_classes=1000,
    global_pool="avg",
    drop_path_rate=0.0,
)

if USE_DYT:
    model = convert_ln_to_dyt(model)

model = model.to(DEVICE)
model


In [None]:
# Cell 2: Build dataset with the repo's transforms and get one batch (B=4)

# Minimal args needed by datasets.build_dataset/build_transform
args = SimpleNamespace(
    # dataset selection
    data_set=DATA_SET,
    data_path=DATA_PATH,
    eval_data_path=EVAL_DATA_PATH,
    nb_classes=1000,
    # transform params (match defaults in main.py)
    input_size=224,
    imagenet_default_mean_and_std=True,
    color_jitter=0.4,
    aa="rand-m9-mstd0.5-inc1",
    train_interpolation="bicubic",
    reprob=0.25,
    remode="pixel",
    recount=1,
    crop_pct=None,
)

# This uses the same transform codepath as training in this repo
train_dataset, nb_classes = build_dataset(is_train=True, args=args)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=False,
    drop_last=True,
)

samples, targets = next(iter(train_loader))
print("samples:", samples.shape, samples.dtype)
print("targets:", targets.shape, targets.dtype)
(samples, targets)


In [None]:
# Cell 3: Forward pass (same structure as repo training/validation)

# training/eval code in engine.py does: samples = samples.to(device) ; output = model(samples)
model.eval()

with torch.no_grad():
    samples_device = samples.to(DEVICE, non_blocking=False)
    outputs = model(samples_device)

print("outputs:", outputs.shape, outputs.dtype, "device:", outputs.device)
outputs


In [None]:
# Cell 4: Inspect patch embeddings, positional embeddings, and their combination (matches timm logic)

from timm.layers import resample_abs_pos_embed

model.eval()

with torch.no_grad():
    x_img = samples.to(DEVICE, non_blocking=False)

    # 1) Patch embeddings (same as: x = self.patch_embed(x))
    x_patch = model.patch_embed(x_img)

    # 2) Positional embedding tensor used (same as inside VisionTransformer._pos_embed)
    if model.pos_embed is None:
        pos_embed_used = None
    else:
        if getattr(model, "dynamic_img_size", False):
            # dynamic_img_size path expects NHWC patch output
            B, H, W, C = x_patch.shape
            prev_grid_size = model.patch_embed.grid_size
            pos_embed_used = resample_abs_pos_embed(
                model.pos_embed,
                new_size=(H, W),
                old_size=prev_grid_size,
                num_prefix_tokens=0 if model.no_embed_class else model.num_prefix_tokens,
            )
        else:
            pos_embed_used = model.pos_embed

    # 3) Combined tokens (same as: x = self._pos_embed(x_patch); x includes prefix tokens and pos add)
    if model.pos_embed is None:
        # timm returns flattened tokens when no pos_embed
        if x_patch.ndim == 4:
            B, H, W, C = x_patch.shape
            x_tokens = x_patch.view(B, -1, C)
        else:
            x_tokens = x_patch
    else:
        # handle NHWC -> NLC flatten if needed (matches timm dynamic_img_size branch)
        if getattr(model, "dynamic_img_size", False):
            B, H, W, C = x_patch.shape
            x_tokens = x_patch.view(B, -1, C)
        else:
            x_tokens = x_patch

        to_cat = []
        if getattr(model, "cls_token", None) is not None:
            to_cat.append(model.cls_token.expand(x_tokens.shape[0], -1, -1))
        if getattr(model, "reg_token", None) is not None:
            to_cat.append(model.reg_token.expand(x_tokens.shape[0], -1, -1))

        if model.no_embed_class:
            # add pos to patch tokens only, then concat prefix
            x_tokens = x_tokens + pos_embed_used
            if to_cat:
                x_tokens = torch.cat(to_cat + [x_tokens], dim=1)
        else:
            # concat prefix first, then add pos_embed (pos_embed includes prefix positions)
            if to_cat:
                x_tokens = torch.cat(to_cat + [x_tokens], dim=1)
            x_tokens = x_tokens + pos_embed_used

        # timm returns self.pos_drop(x)
        x_tokens = model.pos_drop(x_tokens)

# Report shapes
print("x_patch shape:", tuple(x_patch.shape))
print("pos_embed_used shape:", None if pos_embed_used is None else tuple(pos_embed_used.shape))
print("x_tokens (after pos add + prefix concat + pos_drop) shape:", tuple(x_tokens.shape))

# Expose all three
x_patch, pos_embed_used, x_tokens


In [None]:
# Cell 5: Re-initialize all DyT alphas to a new alpha_init_value

from dynamic_tanh import DynamicTanh

NEW_ALPHA_INIT_VALUE = 0.8  # <-- set whatever you want

# Re-init DyT alpha parameters in-place
num_dyt = 0
with torch.no_grad():
    for m in model.modules():
        if isinstance(m, DynamicTanh):
            # keep the module's metadata consistent
            m.alpha_init_value = float(NEW_ALPHA_INIT_VALUE)
            # overwrite the learnable parameter value
            m.alpha.data.fill_(float(NEW_ALPHA_INIT_VALUE))
            num_dyt += 1

print(f"Reinitialized {num_dyt} DynamicTanh modules with alpha={NEW_ALPHA_INIT_VALUE}")

# sanity check: show a few alpha values
alphas = [m.alpha.item() for m in model.modules() if isinstance(m, DynamicTanh)]
print("first 5 alphas:", alphas[:5])


In [None]:
# Cell 6: Cache inputs to each DynamicTanh in forward order + cosine similarity/angle vs previous

import math
import torch
import matplotlib.pyplot as plt

from dynamic_tanh import DynamicTanh

model.eval()

# We'll store DyT inputs in the order they're executed
_dyt_inputs = []          # list[torch.Tensor] on CPU
_dyt_names = []           # list[str]
_dyt_kinds = []           # list[str]  'tokens' or 'pooled'
_handles = []

# forward_pre_hook captures module inputs

def _make_dyt_pre_hook(name: str):
    def _hook(mod, inputs):
        x = inputs[0]
        # token-level DyTs see [B, N, C]; fc_norm DyT sees [B, C]
        if x.ndim == 3:
            kind = "tokens"
            x_use = x.detach().cpu()
        elif x.ndim == 2:
            kind = "pooled"
            # normalize to [B, 1, C] so we can store consistently
            x_use = x.detach().cpu().unsqueeze(1)
        else:
            raise ValueError(f"Unexpected DyT input ndim={x.ndim} for {name}")

        _dyt_inputs.append(x_use)
        _dyt_names.append(name)
        _dyt_kinds.append(kind)

    return _hook

for name, m in model.named_modules():
    if isinstance(m, DynamicTanh):
        _handles.append(m.register_forward_pre_hook(_make_dyt_pre_hook(name)))

# Run forward
with torch.no_grad():
    _ = model(samples.to(DEVICE, non_blocking=False))

# Remove hooks
for h in _handles:
    h.remove()

# Compute cosine similarities vs previous DyT (exclude CLS token position for token-level tensors)
EPS = 1e-8
B = _dyt_inputs[0].shape[0]
max_S = max(t.shape[1] for t in _dyt_inputs)  # includes cls for token-level; pooled has S=1
L = len(_dyt_inputs)

# We'll store padded [B, L, max_S] (NaN where not applicable)
cos_sims = torch.full((B, L, max_S), float('nan'))
angles = torch.full((B, L, max_S), float('nan'))

# Also keep per-layer unpadded dicts (keyed by DyT execution index)
cos_sims_by_dyt = {}
angles_by_dyt = {}

for i in range(L):
    cur = _dyt_inputs[i]  # [B, S_cur, C]

    if i == 0:
        continue

    prev = _dyt_inputs[i - 1]

    # If current is pooled (S=1), pool previous token activations to match
    if _dyt_kinds[i] == "pooled":
        # previous might be tokens; if so, exclude CLS then avg pool
        if prev.shape[1] > 1:
            prev_vec = prev[:, 1:, :].mean(dim=1, keepdim=True)  # [B, 1, C]
        else:
            prev_vec = prev
        cur_vec = cur  # already [B, 1, C]
    else:
        # token-level: align token grids and exclude CLS position
        # current/prev token-level inputs should both include CLS at position 0
        cur_vec = cur[:, 1:, :]  # [B, S-1, C]
        prev_vec = prev[:, 1:, :] if prev.shape[1] > 1 else prev.expand_as(cur_vec)

        # in case something changes shape, trim to min token length
        S_min = min(cur_vec.shape[1], prev_vec.shape[1])
        cur_vec = cur_vec[:, :S_min, :]
        prev_vec = prev_vec[:, :S_min, :]

    # cosine over embedding dim -> [B, S]
    dot = (cur_vec * prev_vec).sum(dim=-1)
    denom = cur_vec.norm(dim=-1) * prev_vec.norm(dim=-1) + EPS
    c = dot / denom
    c = c.clamp(-1.0, 1.0)
    a = torch.acos(c)

    # store
    S = c.shape[1]
    cos_sims[:, i, :S] = c
    angles[:, i, :S] = a
    cos_sims_by_dyt[i] = c
    angles_by_dyt[i] = a

# Plot: average over batch and sequence (excluding NaNs)
mean_cos = torch.nanmean(cos_sims, dim=(0, 2)).numpy()   # [L]
mean_ang = torch.nanmean(angles, dim=(0, 2)).numpy()     # [L]

plt.figure(figsize=(8, 3))
plt.plot(mean_cos, marker='o')
plt.title('Average cosine similarity vs DyT execution index')
plt.xlabel('DyT idx')
plt.ylabel('avg cosine')
plt.grid(True)
plt.show()

plt.figure(figsize=(8, 3))
plt.plot(mean_ang, marker='o')
plt.title('Average angle (arccos) vs DyT execution index')
plt.xlabel('DyT idx')
plt.ylabel('avg angle (radians)')
plt.grid(True)
plt.show()

print(f"Captured {L} DynamicTanh inputs")
print("First few DyTs:")
for i in range(min(5, L)):
    print(i, _dyt_names[i], _dyt_kinds[i], tuple(_dyt_inputs[i].shape))

# Expose arrays + dicts
cos_sims, angles, cos_sims_by_dyt, angles_by_dyt, _dyt_names, _dyt_kinds


In [None]:
# Cell 7: Cache output of each transformer block and plot average activation norm vs block

import torch
import matplotlib.pyplot as plt

model.eval()

_block_outputs = []   # list[torch.Tensor] on CPU, each [B, N, C]
_block_names = []
_handles = []

def _make_block_hook(name: str):
    def _hook(mod, inputs, output):
        # output is [B, N, C]
        _block_outputs.append(output.detach().cpu())
        _block_names.append(name)
    return _hook

for i, blk in enumerate(model.blocks):
    _handles.append(blk.register_forward_hook(_make_block_hook(f"blocks.{i}")))

with torch.no_grad():
    _ = model(samples.to(DEVICE, non_blocking=False))

for h in _handles:
    h.remove()

# Compute average token-vector norm per block, excluding CLS token
avg_norms = []
for out in _block_outputs:
    out_tok = out[:, 1:, :]  # [B, S, C] exclude CLS
    token_norm = out_tok.norm(dim=-1)  # [B, S]
    avg_norms.append(token_norm.mean().item())

plt.figure(figsize=(8, 3))
plt.plot(avg_norms, marker='o')
plt.title('Average activation norm vs transformer block')
plt.xlabel('block idx')
plt.ylabel('avg ||x||')
plt.grid(True)
plt.show()

avg_norms, _block_names


In [None]:
# Cell 8: Backward pass (training-style) + cache grads of each block output; plot grad norms

import torch
import matplotlib.pyplot as plt

model.train()  # training-style backward

# Capture block outputs (with gradients)
_block_out_refs = []
_handles = []

def _make_block_hook_retain():
    def _hook(mod, inputs, output):
        # retain grads on block outputs
        output.retain_grad()
        _block_out_refs.append(output)
    return _hook

for blk in model.blocks:
    _handles.append(blk.register_forward_hook(_make_block_hook_retain()))

# Forward + loss + backward (same structure as engine.py full precision)
criterion = torch.nn.CrossEntropyLoss()

model.zero_grad(set_to_none=True)

samples_device = samples.to(DEVICE, non_blocking=False)
targets_device = targets.to(DEVICE, non_blocking=False)

outputs = model(samples_device)
loss = criterion(outputs, targets_device)
loss.backward()

for h in _handles:
    h.remove()

# Compute average grad norm per block output (exclude CLS token)
avg_grad_norms = []
avg_log_grad_norms = []

for out in _block_out_refs:
    g = out.grad.detach().cpu()          # [B, N, C]
    g_tok = g[:, 1:, :]                  # exclude CLS
    token_gnorm = g_tok.norm(dim=-1)     # [B, S]

    mean_g = token_gnorm.mean().item()
    avg_grad_norms.append(mean_g)

    # log(grad_norm) with small epsilon for stability
    avg_log_grad_norms.append(torch.log(token_gnorm + 1e-12).mean().item())

plt.figure(figsize=(8, 3))
plt.plot(avg_grad_norms, marker='o')
plt.title('Average gradient norm vs transformer block output')
plt.xlabel('block idx')
plt.ylabel('avg ||dL/dx||')
plt.grid(True)
plt.show()

plt.figure(figsize=(8, 3))
plt.plot(avg_log_grad_norms, marker='o')
plt.title('Average log gradient norm vs transformer block output')
plt.xlabel('block idx')
plt.ylabel('avg log(||dL/dx||)')
plt.grid(True)
plt.show()

loss.item(), avg_grad_norms, avg_log_grad_norms


In [None]:
# Cell 11: Train via main.py (no re-implementation of training loop)

from pathlib import Path

import main as dyt_main

# --- important knobs (similar to your block around lines 24-42) ---
MODEL_NAME = "vit_base_patch16_224"
EPOCHS = 10

DATA_SET = "CIFAR"
DATA_PATH = "/tmp/cifar100"      # CIFAR will download here
NB_CLASSES = 100
INPUT_SIZE = 224

DYNAMIC_TANH = True
DYT_ALPHA_INIT_VALUE = 0.8

# optimizer / schedule knobs
BATCH_SIZE = 64
LR = 4e-3
MIN_LR = 1e-6
WEIGHT_DECAY = 0.05
WARMUP_EPOCHS = 0                # for 10-epoch runs, warmup=0 is usually more sensible than 20
WARMUP_STEPS = -1

# checkpoint/log dirs
OUTPUT_DIR = Path("/tmp/vit_dyt_main_cifar100")
LOG_DIR = OUTPUT_DIR / "tb"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

# --- run ---
# Build args via the repo's argparse, then call main.main(args)
parser = dyt_main.get_args_parser()
args = parser.parse_args([
    "--model", MODEL_NAME,
    "--epochs", str(EPOCHS),
    "--batch_size", str(BATCH_SIZE),
    "--lr", str(LR),
    "--min_lr", str(MIN_LR),
    "--weight_decay", str(WEIGHT_DECAY),
    "--warmup_epochs", str(WARMUP_EPOCHS),
    "--warmup_steps", str(WARMUP_STEPS),

    "--data_set", DATA_SET,
    "--data_path", DATA_PATH,
    "--nb_classes", str(NB_CLASSES),
    "--input_size", str(INPUT_SIZE),

    "--output_dir", str(OUTPUT_DIR),
    "--log_dir", str(LOG_DIR),

    "--dynamic_tanh", "true" if DYNAMIC_TANH else "false",
    "--dyt_alpha_init_value", str(DYT_ALPHA_INIT_VALUE),

    # Save only init + final; no best checkpoints
    "--save_ckpt", "true",
    "--save_init_ckpt", "true",
    "--save_best_ckpt", "false",
    "--save_best_ema_ckpt", "false",
    "--save_ckpt_freq", str(EPOCHS),
    "--save_ckpt_num", "1",

    # keep W&B off by default
    "--enable_wandb", "false",
])

# Single-process notebook run (no torchrun)
dyt_main.main(args)

print("Done. Logs + checkpoints are in:", OUTPUT_DIR)
print("- log.txt:", OUTPUT_DIR / "log.txt")
print("- init ckpt:", OUTPUT_DIR / "checkpoint-init.pth")
print("- final ckpt:", OUTPUT_DIR / f"checkpoint-{EPOCHS - 1}.pth")


In [None]:
# Cell 12: Plot training loss + validation accuracy from main.py's output_dir/log.txt

from pathlib import Path
import json

import matplotlib.pyplot as plt

OUTPUT_DIR = Path("/tmp/vit_dyt_main_cifar100")  # must match Cell 11
log_path = OUTPUT_DIR / "log.txt"

epochs = []
train_loss = []
val_acc1 = []

with open(log_path, "r", encoding="utf-8") as f:
    for line in f:
        rec = json.loads(line)
        epochs.append(int(rec["epoch"]) + 1)
        train_loss.append(rec.get("train_loss"))
        val_acc1.append(rec.get("test_acc1"))

plt.figure(figsize=(7, 3))
plt.plot(epochs, train_loss, marker='o')
plt.title('Train loss vs epoch')
plt.xlabel('epoch')
plt.ylabel('train loss')
plt.grid(True)
plt.tight_layout()
plt.savefig(OUTPUT_DIR / "train_loss_curve.png", dpi=150)
plt.show()

plt.figure(figsize=(7, 3))
plt.plot(epochs, val_acc1, marker='o')
plt.title('Validation Acc@1 vs epoch')
plt.xlabel('epoch')
plt.ylabel('val acc@1')
plt.grid(True)
plt.tight_layout()
plt.savefig(OUTPUT_DIR / "val_acc1_curve.png", dpi=150)
plt.show()

print("Saved plots to:")
print("-", OUTPUT_DIR / "train_loss_curve.png")
print("-", OUTPUT_DIR / "val_acc1_curve.png")

train_loss, val_acc1
