# Dion Optimizer Test Notebook

In [1]:
import torch
import torch.nn as nn
import math
from dion import Dion
from tqdm.auto import tqdm

In [2]:
# Set device
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = "cpu"
print(f"Using device: {device}")

Using device: cpu


In [3]:
# Define a simple TransformerBlock
class TransformerBlock(nn.Module):
    def __init__(self, model_dim, num_heads, ff_dim):
        super().__init__()
        self.attn = nn.MultiheadAttention(model_dim, num_heads, batch_first=True)
        self.norm1 = nn.LayerNorm(model_dim)
        self.ff = nn.Sequential(
            nn.Linear(model_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, model_dim)
        )
        self.norm2 = nn.LayerNorm(model_dim)

    def forward(self, x):
        attn_output, _ = self.attn(x, x, x)
        x = self.norm1(x + attn_output)
        ff_output = self.ff(x)
        x = self.norm2(x + ff_output)
        return x

# Define the main TransformerModel
class TransformerModel(nn.Module):
    def __init__(self, vocab_dim, model_dim, num_heads, ff_dim, num_layers):
        super().__init__()
        self.model_dim = model_dim
        self.embedding = nn.Embedding(vocab_dim, model_dim)
        self.blocks = nn.ModuleList([TransformerBlock(model_dim, num_heads, ff_dim) for _ in range(num_layers)])
        self.lm_head = nn.Linear(model_dim, vocab_dim)

    def forward(self, x):
        x = self.embedding(x)
        for block in self.blocks:
            x = block(x)
        x = self.lm_head(x)
        return x

In [4]:
# Hyperparameters
vocab_dim = 100
model_dim = 32
num_heads = 4
ff_dim = 64
num_layers = 2
lr = 1e-3

# Instantiate the model
model = TransformerModel(vocab_dim, model_dim, num_heads, ff_dim, num_layers).to(device)
print("Model created.")

# Separate parameters into groups
matrix_params = [p for n, p in model.blocks.named_parameters() if p.ndim == 2 and 'weight' in n]
vector_params = [p for n, p in model.blocks.named_parameters() if p.ndim != 2]
embed_params  = [p for p in model.embedding.parameters()]
lm_head_params= [p for p in model.lm_head.parameters()]

param_groups = [
    dict(params=matrix_params),  # will default to "dion" algorithm
    dict(params=vector_params, algorithm="lion"),
    dict(params=embed_params, algorithm="lion", weight_decay=0),
    dict(params=lm_head_params, algorithm="lion", lr=lr / math.sqrt(model_dim), weight_decay=0)
]

print(f"Number of parameter groups: {len(param_groups)}")
print(f"Matrix params: {len(matrix_params)}")
print(f"Vector params: {len(vector_params)}")
print(f"Embedding params: {len(embed_params)}")
print(f"LM Head params: {len(lm_head_params)}")

# verify all parameters are included
total_param_count = sum(p.numel() for p in model.parameters())
included_param_count = sum(p.numel() for group in param_groups for p in group['params'])
assert total_param_count == included_param_count, "Some parameters are missing from the optimizer!"

Model created.
Number of parameter groups: 4
Matrix params: 8
Vector params: 16
Embedding params: 1
LM Head params: 2


In [5]:
# Instantiate the optimizer
optimizer = Dion(
    param_groups,
    lr=lr,
    weight_decay=0.1,
)

print("Dion optimizer created.")
print(optimizer)

Dion optimizer created.
Dion (
Parameter Group 0
    algorithm: dion
    beta1: 0.9
    beta2: 0.95
    epsilon: 1e-08
    lr: 0.001
    mu: 0.95
    oversample: 1.25
    rank_fraction: 1.0
    rank_multiple_of: 1
    step: 0
    weight_decay: 0.1

Parameter Group 1
    algorithm: lion
    beta1: 0.9
    beta2: 0.95
    epsilon: 1e-08
    lr: 0.001
    mu: 0.95
    oversample: 1.25
    rank_fraction: 1.0
    rank_multiple_of: 1
    step: 0
    weight_decay: 0.1

Parameter Group 2
    algorithm: lion
    beta1: 0.9
    beta2: 0.95
    epsilon: 1e-08
    lr: 0.001
    mu: 0.95
    oversample: 1.25
    rank_fraction: 1.0
    rank_multiple_of: 1
    step: 0
    weight_decay: 0

Parameter Group 3
    algorithm: lion
    beta1: 0.9
    beta2: 0.95
    epsilon: 1e-08
    lr: 0.00017677669529663688
    mu: 0.95
    oversample: 1.25
    rank_fraction: 1.0
    rank_multiple_of: 1
    step: 0
    weight_decay: 0
)


In [8]:
# Basic training loop
batch_size = 4
seq_len = 32
num_steps = 100

criterion = nn.CrossEntropyLoss()

for step in range(num_steps):
    # Generate dummy data with some order
    data = torch.randint(0, vocab_dim, (batch_size, seq_len + 1), device=device)
    inputs = data[:, :-1]
    targets = data[:, 1:]

    # Forward pass
    outputs = model(inputs)
    
    # Reshape for loss calculation
    loss = criterion(outputs.reshape(-1, vocab_dim), targets.reshape(-1))

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()

    # Clone model state before optimizer step
    model_state_before = {name: param.clone() for name, param in model.named_parameters()}

    optimizer.step()

    # Compare model state after optimizer step
    weights_changed = False
    for name, param in model.named_parameters():
        if not torch.equal(model_state_before[name], param):
            weights_changed = True
            break
    
    if (step + 1) % 10 == 0:
        print(f"Step [{step+1}/{num_steps}], Loss: {loss.item():.4f}, Weights changed: {weights_changed}")

Step [10/100], Loss: 4.7768, Weights changed: True
Step [20/100], Loss: 4.7174, Weights changed: True
Step [30/100], Loss: 4.6965, Weights changed: True
Step [40/100], Loss: 4.7253, Weights changed: True
Step [50/100], Loss: 4.6936, Weights changed: True
Step [60/100], Loss: 4.7259, Weights changed: True
Step [70/100], Loss: 4.7118, Weights changed: True
Step [80/100], Loss: 4.6719, Weights changed: True
Step [90/100], Loss: 4.6928, Weights changed: True
Step [60/100], Loss: 4.7259, Weights changed: True
Step [70/100], Loss: 4.7118, Weights changed: True
Step [80/100], Loss: 4.6719, Weights changed: True
Step [90/100], Loss: 4.6928, Weights changed: True
Step [100/100], Loss: 4.5925, Weights changed: True
Step [100/100], Loss: 4.5925, Weights changed: True


Training test finished.