# Stacking Bilinear Layers
The goal of this notebook is to get experience with stacking bilinear layers.

In [6]:
import torch
from torch import nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import einsum
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
import pickle
from typing import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Bilinear layer similar to the other notebook (left/right via chunk)
class Bilinear(nn.Linear):
    def __init__(self, d_in: int, d_out: int, bias=False) -> None:
        super().__init__(d_in, 2 * d_out, bias=bias)
    def forward(self, x):
        left, right = super().forward(x).chunk(2, dim=-1)
        return left * right
    @property
    def w_l(self):
        return self.weight.chunk(2, dim=0)[0]
    @property
    def w_r(self):
        return self.weight.chunk(2, dim=0)[1]

@dataclass
class BiLayerConfig:
    d_in: int
    d_hid: int
    d_out: int
    bias: bool = False

class BiLayer(nn.Module):
    def __init__(self, cfg: BiLayerConfig):
        super().__init__()
        self.bi_linear = Bilinear(d_in=2*cfg.d_in, d_out=cfg.d_hid, bias=cfg.bias)
        self.projection = nn.Linear(cfg.d_hid, cfg.d_out, bias=cfg.bias)
    def forward(self, x):
        return self.projection(self.bi_linear(x))
    @property
    def w_l(self):
        return self.bi_linear.w_l
    @property
    def w_r(self):
        return self.bi_linear.w_r
    @property
    def w_p(self):
        return self.projection.weight

@dataclass
class BiStackConfig:
    dims: List[int]
    bias: bool = False

class BiStack(nn.Module):
    def __init__(self, cfg: BiStackConfig):
        super().__init__()
        # dims in 0, 3, 5, 7, 9, ...
        assert len(cfg.dims) >= 3 and len(cfg.dims) % 2 == 1
        layers = []
        for i in range(0, len(cfg.dims) - 2, 2):  # Step by 2
            if len(cfg.dims) - i - 2 < 0:
                break
            layer_cfg = BiLayerConfig(
                d_in=cfg.dims[i], d_hid=cfg.dims[i+1], 
                d_out=cfg.dims[i+2], bias=cfg.bias
                )
            layers.append(BiLayer(layer_cfg))
        self.layers = nn.ModuleList(layers)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        for layer in self.layers:
            x = layer(x)
        return x

    @property
    def w_l(self, layer: int):
        return self.layers[layer].bi_linear.w_l
    @property
    def w_r(self, layer: int):
        return self.layers[layer].bi_linear.w_r
    @property
    def w_p(self, layer: int):
        return self.layers[layer].projection.weight

def init_stack(dims: List[int], bias: bool = False) -> BiStack:
    cfg = BiStackConfig(dims=dims, bias=bias)
    return BiStack(cfg)

Device: cuda


In [8]:
init_stack([6, 2, 3, 4, 5])

BiStack(
  (layers): ModuleList(
    (0): BiLayer(
      (bi_linear): Bilinear(in_features=12, out_features=4, bias=False)
      (projection): Linear(in_features=2, out_features=3, bias=False)
    )
    (1): BiLayer(
      (bi_linear): Bilinear(in_features=6, out_features=8, bias=False)
      (projection): Linear(in_features=4, out_features=5, bias=False)
    )
  )
)

In [None]:
# Utilities: data generation, loaders, criterion and training loop (adapted from application.ipynb)
import math
from typing import Tuple
from dataclasses import dataclass

@dataclass
class DataConfig:
    P: int = 113
    train_size: float = 0.8

def generate_modular_addition_data(P: int):
    a_vals = torch.arange(P).repeat_interleave(P)
    b_vals = torch.arange(P).repeat(P)
    x_vals = torch.stack((a_vals, b_vals), dim=-1)

    a_1hot = torch.nn.functional.one_hot(a_vals, num_classes=P).float()
    b_1hot = torch.nn.functional.one_hot(b_vals, num_classes=P).float()
    x_1hot = torch.cat((a_1hot, b_1hot), dim=-1)

    targets = (a_vals + b_vals) % P
    return x_vals, x_1hot, targets

# Loss map and optimizer map (minimal)
_LOSS_MAP = {
    "crossentropy": nn.CrossEntropyLoss,
}
_OPTIMIZER_MAP = {
    "adam": torch.optim.Adam,
    "adamw": torch.optim.AdamW,
    "sgd": torch.optim.SGD,
}

@dataclass
class CriterionConfig:
    name: str = "crossentropy"
    kwargs: dict | None = None

def get_criterion(cfg: CriterionConfig) -> nn.Module:
    key = cfg.name.replace(" ", "").replace("-", "").lower()
    return _LOSS_MAP[key](**(cfg.kwargs or {}))

@dataclass
class OptimizerConfig:
    name: str = "adamw"
    lr: float = 0.003
    weight_decay: float | None = 1e-4

def get_optimizer(params, config: OptimizerConfig) -> torch.optim.Optimizer:
    key = config.name.replace("-", "").replace("_", "").lower()
    cls = _OPTIMIZER_MAP[key]
    kwargs = {"lr": config.lr}
    if config.weight_decay is not None: kwargs['weight_decay'] = config.weight_decay
    return cls(params, **kwargs)

def get_loaders(x: torch.Tensor, t: torch.LongTensor, train_size: float, batch_size: int):
    perm = torch.randperm(x.size(0))
    x, t = x[perm], t[perm]
    split = math.ceil(train_size * x.size(0))
    x_train, x_valid = x.split(split)
    t_train, t_valid = t.split(split)
    train_data = torch.utils.data.TensorDataset(x_train, t_train)
    valid_data = torch.utils.data.TensorDataset(x_valid, t_valid)
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)
    valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=batch_size, shuffle=False)
    return train_loader, valid_loader

def eval_model(model, loader, criterion):
    model.eval()
    losses = []
    accs = []
    with torch.no_grad():
        for x, t in loader:
            x = x.to(device)
            t = t.to(device)
            logit = model(x)
            loss = criterion(logit, t)
            losses.append(loss.item())
            pred = logit.argmax(dim=-1)
            accs.append((pred == t).float().mean().item())
    return sum(losses)/len(losses), sum(accs)/len(accs)

@dataclass
class TrainConfig:
    batch_size: int = 256
    epochs: int = 400
    optimizer = OptimizerConfig
    criterion = CriterionConfig

def train_loop(model, loaders, cfg: TrainConfig):
    criterion = get_criterion(cfg.criterion)
    optimizer = get_optimizer(model.parameters(), cfg.optimizer)
    train_loader, valid_loader = loaders
    train_losses, train_accs, valid_losses, valid_accs = [], [], [], []
    best_valid = float('inf')
    best_state = None
    pbar = tqdm(range(cfg.epochs), desc="Training")
    for epoch in pbar:
        model.train()
        epoch_losses, epoch_accs = [], []
        for x, t in train_loader:
            x = x.to(device)
            t = t.to(device)
            logit = model(x)
            loss = criterion(logit, t)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss.item())
            pred = logit.argmax(dim=-1)
            epoch_accs.append((pred == t).float().mean().item())
        train_losses.append(sum(epoch_losses)/len(epoch_losses))
        train_accs.append(sum(epoch_accs)/len(epoch_accs))
        vloss, vacc = eval_model(model, valid_loader, criterion)
        valid_losses.append(vloss)
        valid_accs.append(vacc)
        if vloss < best_valid:
            best_valid = vloss
            best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        pbar.set_postfix({'train_loss': f'{train_losses[-1]:.4f}', 'valid_loss': f'{vloss:.4f}', 'valid_acc': f'{vacc:.4f}'})
    if best_state is not None:
        model.load_state_dict(best_state)
    return {'train_losses': train_losses, 'train_accs': train_accs, 'valid_losses': valid_losses, 'valid_accs': valid_accs, 'best_valid_loss': best_valid}

def plot_training_results(train_losses, valid_losses, train_accs, valid_accs, arc="Stack", y_scale="linear", cut_off_epoch=None):
    if cut_off_epoch is None: cut_off_epoch = len(train_losses)
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14,5))
    ax1.plot(train_losses[:cut_off_epoch], label='Train Loss')
    ax1.plot(valid_losses[:cut_off_epoch], label='Valid Loss')
    ax1.set_yscale(y_scale)
    ax1.legend()
    ax2.plot(train_accs[:cut_off_epoch], label='Train Acc')
    ax2.plot(valid_accs[:cut_off_epoch], label='Valid Acc')
    ax2.legend()
    plt.tight_layout()
    plt.show()

In [None]:
# Run the experiment with a two-layer BiStack (dims length = 5 creates 2 BiLayers)
P = DataConfig.P
pairs, onehots, labels = generate_modular_addition_data(P=P)
print(f'pairs: {pairs.size()}, onehots: {onehots.size()}, labels: {labels.size()}')
loaders = get_loaders(onehots, labels, DataConfig.train_size, TrainConfig.batch_size)

# dims = [P, hid1, mid, hid2, P] -> creates two BiLayer layers (i=0 and i=2)
dims = [P, P, P, P, P]
stack = init_stack(dims, bias=False).to(device)

# Train
cfg = TrainConfig()
cfg.epochs = 400
cfg.batch_size = 256
results = train_loop(stack, loaders, cfg)

# Plot results
plot_training_results(results['train_losses'], results['valid_losses'], results['train_accs'], results['valid_accs'], arc="BiStack (2 layers)", y_scale="log", cut_off_epoch=50)

# Sanity check on a single example
test_point = 1337
one_hot = onehots[test_point].unsqueeze(0).to(device)
print('label:', labels[test_point].item())
print('pred:', torch.argmax(stack(one_hot), dim=-1).item())
total_params = sum(p.numel() for p in stack.parameters() if p.requires_grad)
print(f'Total parameters: {total_params}')