# Binary Group Composition on $D_n$ (Dihedral Group)

**Group:** Dihedral group $D_n$ of order $2n$ (rotations and reflections of a regular $n$-gon).  
**Task:** Given encodings of two group elements $g_1, g_2 \in D_n$, predict the encoding of their product $g_1 \cdot g_2$.  
**Sequence length:** $k = 2$ (binary composition).  
**Architecture:** `TwoLayerNet` with square nonlinearity.  
**Key result:** The network learns one irreducible representation at a time, producing a staircase in the training loss.

## Imports

In [None]:
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from escnn.group import DihedralGroup
from torch.utils.data import DataLoader, TensorDataset

import src.dataset as dataset
import src.model as model
import src.optimizer as optimizer
import src.power as power
import src.template as template
import src.train as train_mod
import src.viz as viz

## Configuration

In [None]:
TEST_MODE = os.environ.get("NOTEBOOK_TEST_MODE", "0") == "1"

seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

n = 5  # D_5 has order 2*5 = 10
group = DihedralGroup(n)
group_size = group.order()

hidden_size = 20 if TEST_MODE else 180
epochs = 2 if TEST_MODE else 2000
lr = 0.01
init_scale = 1e-3

FIGURES_DIR = "figures"
os.makedirs(FIGURES_DIR, exist_ok=True)

print(f"Group: D_{n}, order {group_size}")

## Template and Dataset

In [None]:
# Build a template with known Fourier structure on D_n
# D_5 has 4 irreps: two 1D and two 2D
# fourier_coef_diag_values: one value per irrep
fourier_coef_diag_values = [0.0, 5.0, 30.0, 300.0]
tpl = template.fixed_group(group, fourier_coef_diag_values)

# Build exhaustive dataset: all group_size^2 pairs
X, Y = dataset.group_dataset(group, tpl)
X_tensor, Y_tensor, device = dataset.move_dataset_to_device_and_flatten(X, Y)

ds = TensorDataset(X_tensor, Y_tensor)
dataloader = DataLoader(ds, batch_size=len(ds), shuffle=False)

print(f"Dataset: {len(ds)} samples (all {group_size}x{group_size} pairs)")
print(f"X shape: {X_tensor.shape}, Y shape: {Y_tensor.shape}")

In [None]:
# Visualize template and its group power spectrum
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.bar(range(group_size), tpl, color="black")
ax1.set_xlabel("Group element")
ax1.set_ylabel("Template value")
ax1.set_title(f"Template $t$ on $D_{{{n}}}$")

gp = power.GroupPower(tpl, group)
pwr = gp.group_power_spectrum()
ax2.bar(range(len(pwr)), pwr, color="steelblue")
ax2.set_xlabel("Irrep index")
ax2.set_ylabel("Power")
ax2.set_title("Power spectrum (by irrep)")

plt.tight_layout()
plt.savefig(f"{FIGURES_DIR}/dihedral_template.pdf", bbox_inches="tight")
plt.show()

## Model and Optimizer

In [None]:
net = model.TwoLayerNet(
    group_size=group_size,
    hidden_size=hidden_size,
    nonlinearity="square",
    init_scale=init_scale,
)
net = net.to(device)

criterion = nn.MSELoss()
opt = optimizer.PerNeuronScaledSGD(net, lr=lr, degree=3)

print(f"Model: TwoLayerNet(group_size={group_size}, hidden={hidden_size}, init_scale={init_scale})")
print(f"Optimizer: PerNeuronScaledSGD(lr={lr}, degree=3)")
print(f"Training for {epochs} epochs")

## Training

In [None]:
loss_history, val_loss_history, param_history, param_save_epochs, final_epoch = train_mod.train(
    net,
    dataloader,
    criterion,
    opt,
    epochs=epochs,
    verbose_interval=max(1, epochs // 10),
    save_param_interval=1,
)

## Training Loss

In [None]:
# Compute theoretical loss plateau levels
theory_levels = gp.loss_plateau_predictions()

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(loss_history, lw=4)

for level in theory_levels:
    ax.axhline(y=level, color="black", linestyle="--", linewidth=2, zorder=-2)

ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel("Epochs", fontsize=18)
ax.set_ylabel("Train Loss", fontsize=18)
ax.set_title(f"Training loss on $D_{{{n}}}$", fontsize=20)
viz.style_axes(ax)
ax.grid(False)

plt.tight_layout()
plt.savefig(f"{FIGURES_DIR}/dihedral_loss.pdf", bbox_inches="tight")
plt.show()

## Power Spectrum Over Training

In [None]:
# Use model_power_over_time from src/power.py
powers_over_time, power_steps = power.model_power_over_time(
    group_name="dihedral",
    model=net,
    param_history=param_history,
    model_inputs=X_tensor,
    group=group,
)

# Reference: template power per irrep
template_pwr = gp.group_power_spectrum()

# Plot
colors = ["tab:blue", "tab:orange", "tab:red", "tab:green", "tab:brown", "tab:purple"]
fig, ax = plt.subplots(figsize=(8, 6))

n_irreps = powers_over_time.shape[1]
for k in range(n_irreps):
    color = colors[k] if k < len(colors) else f"C{k}"
    ax.plot(power_steps, powers_over_time[:, k], color=color, lw=4, label=rf"$\rho_{{{k}}}$")
    ax.axhline(template_pwr[k], color=color, linestyle="dotted", linewidth=2, alpha=0.5, zorder=-10)

ax.set_xscale("log")
ax.set_ylabel("Power", fontsize=18)
ax.set_xlabel("Epochs", fontsize=18)
ax.set_title(f"Power spectrum over training on $D_{{{n}}}$", fontsize=20)
ax.legend(fontsize=12, title="Irrep", title_fontsize=14, loc="upper left", labelspacing=0.25)
viz.style_axes(ax)
ax.grid(False)

plt.tight_layout()
plt.savefig(f"{FIGURES_DIR}/dihedral_power_spectrum.pdf", bbox_inches="tight")
plt.show()

## Irreducible Representations

In [None]:
fig = viz.plot_irreps(group, show=False)
plt.savefig(f"{FIGURES_DIR}/dihedral_irreps.pdf", bbox_inches="tight")
plt.show()