In [9]:
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ============================================
# Quaternion utilities (PyTorch tensors)
# Quaternions are stored as (..., 4): (w, x, y, z)
# ============================================

def q_normalize(q):
    norm = torch.linalg.norm(q, dim=-1, keepdim=True) + 1e-8
    return q / norm

def q_conj(q):
    w, x, y, z = torch.unbind(q, dim=-1)
    return torch.stack([w, -x, -y, -z], dim=-1)

def q_mul(a, b):
    """Hamilton product of two quaternions.
    a, b: (..., 4)"""
    aw, ax, ay, az = torch.unbind(a, dim=-1)
    bw, bx, by, bz = torch.unbind(b, dim=-1)

    w = aw * bw - ax * bx - ay * by - az * bz
    x = aw * bx + ax * bw + ay * bz - az * by
    y = aw * by - ax * bz + ay * bw + az * bx
    z = aw * bz + ax * by - ay * bx + az * bw

    return torch.stack([w, x, y, z], dim=-1)


# ============================================
# Real-valued baseline network
# ============================================

class RealNet(nn.Module):
    def __init__(self, input_dim=3072, hidden_dim=256, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class RealNet_Large(nn.Module):
    def __init__(self, input_dim=3072, hidden_dim=512, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# ============================================
# Quaternion Linear Layer and Quaternion Nets
# ============================================

class QuaternionLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_features, in_features, 4))
        self.bias = nn.Parameter(torch.empty(out_features, 4))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight, mean=0.0, std=0.1)
        with torch.no_grad():
            self.weight[:] = q_normalize(self.weight)
            nn.init.constant_(self.bias[..., 0], 1.0)
            nn.init.constant_(self.bias[..., 1:], 0.0)

    def forward(self, x):
        B = x.size(0)
        w = self.weight.unsqueeze(0)
        x_exp = x.unsqueeze(1)
        prod = q_mul(w, x_exp)
        out = prod.sum(dim=2) + self.bias
        return out


class QuatNet_4pix(nn.Module):
    """Quaternion network: 4 pixels per quaternion (768 quaternions)"""
    def __init__(self, num_quats=768, hidden_quat=256, num_classes=10):
        super().__init__()
        self.num_quats = num_quats
        
        self.quat_fc1 = QuaternionLinear(num_quats, hidden_quat)
        self.fc2 = nn.Linear(hidden_quat * 4, num_classes)
    
    def image_to_quats(self, x):
        # x: (B, 3, 32, 32) -> (B, 768, 4)
        # Group every 4 pixels into a quaternion
        B = x.size(0)
        flat = x.view(B, -1)  # (B, 3072)
        quats = flat.view(B, self.num_quats, 4)  # (B, 768, 4)
        return quats
    
    def forward(self, x):
        B = x.size(0)
        q_in = self.image_to_quats(x)
        hq = self.quat_fc1(q_in)
        hq = q_normalize(hq)
        hq = torch.tanh(hq)
        h_flat = hq.view(B, -1)
        logits = self.fc2(h_flat)
        return logits


class QuatNet_RGB(nn.Module):
    """Quaternion network: 1 quaternion per pixel using (R,G,B,1) structure (1024 quaternions)"""
    def __init__(self, num_pixels=1024, hidden_quat=256, num_classes=10):
        super().__init__()
        self.num_pixels = num_pixels
        
        self.quat_fc1 = QuaternionLinear(num_pixels, hidden_quat)
        self.fc2 = nn.Linear(hidden_quat * 4, num_classes)
    
    def image_to_quats(self, x):
        # x: (B, 3, 32, 32) -> (B, 1024, 4) where each quaternion is (R,G,B,1)
        B = x.size(0)
        x = x.permute(0, 2, 3, 1)  # (B, 32, 32, 3)
        x = x.reshape(B, self.num_pixels, 3)  # (B, 1024, 3)
        ones = torch.ones(B, self.num_pixels, 1, device=x.device)
        quats = torch.cat([x, ones], dim=-1)  # (B, 1024, 4)
        return quats
    
    def forward(self, x):
        B = x.size(0)
        q_in = self.image_to_quats(x)
        hq = self.quat_fc1(q_in)
        hq = q_normalize(hq)
        hq = torch.tanh(hq)
        h_flat = hq.view(B, -1)
        logits = self.fc2(h_flat)
        return logits


# ============================================
# Training / evaluation helpers
# ============================================

def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = F.cross_entropy(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)


def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)
    return correct / total if total > 0 else 0.0


# ============================================
# Main experiment: RealNet vs QuatNets
# ============================================

def main():
    # Set seeds for reproducibility
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # CIFAR-10 with normalization
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    train_ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
    test_ds  = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

    train_loader = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False, num_workers=2, pin_memory=True)

    # ---------------- RealNet ----------------
    real_net = RealNet(input_dim=3072, hidden_dim=256).to(device)
    real_opt = torch.optim.Adam(real_net.parameters(), lr=1e-3)

    print("\n=== Training RealNet (baseline) ===")
    start = time.time()
    for epoch in range(1, 6):
        loss = train_one_epoch(real_net, train_loader, real_opt, device)
        acc = evaluate(real_net, test_loader, device)
        elapsed = time.time() - start
        print(f"[RealNet] Epoch {epoch} | loss={loss:.4f} | test_acc={acc:.4f} | elapsed={elapsed:.2f}s")
    real_time = time.time() - start
    real_acc = evaluate(real_net, test_loader, device)
    print(f"RealNet final: acc={real_acc:.4f}, time={real_time:.2f}s")

    # ---------------- RealNetLarge ----------------
    realnet_large = RealNet_Large(input_dim=3072, hidden_dim=512).to(device)
    reallarge_opt = torch.optim.Adam(realnet_large.parameters(), lr=1e-3)

    print("\n=== Training RealNetLarge (large capacity) ===")
    start = time.time()
    for epoch in range(1, 6):
        loss = train_one_epoch(realnet_large, train_loader, reallarge_opt, device)
        acc = evaluate(realnet_large, test_loader, device)
        elapsed = time.time() - start
        print(f"[RealNetLarge] Epoch {epoch} | loss={loss:.4f} | test_acc={acc:.4f} | elapsed={elapsed:.2f}s")
    reallarge_time = time.time() - start
    reallarge_acc = evaluate(realnet_large, test_loader, device)
    print(f"RealNetLarge final: acc={reallarge_acc:.4f}, time={reallarge_time:.2f}s")

    # ---------------- QuatNet_4pix ----------------
    quat_net_4pix = QuatNet_4pix(num_quats=768, hidden_quat=256).to(device)
    quat_4pix_opt = torch.optim.Adam(quat_net_4pix.parameters(), lr=1e-3)

    print("\n=== Training QuatNet_4pix (4 pixels per quaternion) ===")
    start = time.time()
    for epoch in range(1, 6):
        loss = train_one_epoch(quat_net_4pix, train_loader, quat_4pix_opt, device)
        acc = evaluate(quat_net_4pix, test_loader, device)
        elapsed = time.time() - start
        print(f"[QuatNet_4pix] Epoch {epoch} | loss={loss:.4f} | test_acc={acc:.4f} | elapsed={elapsed:.2f}s")
    quat_4pix_time = time.time() - start
    quat_4pix_acc = evaluate(quat_net_4pix, test_loader, device)
    print(f"QuatNet_4pix final: acc={quat_4pix_acc:.4f}, time={quat_4pix_time:.2f}s")

    # ---------------- QuatNet_RGB ----------------
    quat_net_rgb = QuatNet_RGB(num_pixels=1024, hidden_quat=256).to(device)
    quat_rgb_opt = torch.optim.Adam(quat_net_rgb.parameters(), lr=1e-3)

    print("\n=== Training QuatNet_RGB (1 quaternion per pixel, RGB structure) ===")
    start = time.time()
    for epoch in range(1, 6):
        loss = train_one_epoch(quat_net_rgb, train_loader, quat_rgb_opt, device)
        acc = evaluate(quat_net_rgb, test_loader, device)
        elapsed = time.time() - start
        print(f"[QuatNet_RGB] Epoch {epoch} | loss={loss:.4f} | test_acc={acc:.4f} | elapsed={elapsed:.2f}s")
    quat_rgb_time = time.time() - start
    quat_rgb_acc = evaluate(quat_net_rgb, test_loader, device)
    print(f"QuatNet_RGB final: acc={quat_rgb_acc:.4f}, time={quat_rgb_time:.2f}s")

    # ---------------- Summary ----------------
    print("\n=== Parameter Counts ===")
    real_params = sum(p.numel() for p in real_net.parameters())
    reallarge_params = sum(p.numel() for p in realnet_large.parameters())
    quat_4pix_params = sum(p.numel() for p in quat_net_4pix.parameters())
    quat_rgb_params = sum(p.numel() for p in quat_net_rgb.parameters())
    print(f"RealNet:        {real_params:,} parameters")
    print(f"RealNetLarge:   {reallarge_params:,} parameters")
    print(f"QuatNet_4pix:   {quat_4pix_params:,} parameters")
    print(f"QuatNet_RGB:    {quat_rgb_params:,} parameters")
    
    print("\n=== Summary (CIFAR-10, seed=42, 5 epochs) ===")
    print(f"RealNet:        acc={real_acc:.4f}, time={real_time:.2f}s")
    print(f"RealNetLarge:   acc={reallarge_acc:.4f}, time={reallarge_time:.2f}s")
    print(f"QuatNet_4pix:   acc={quat_4pix_acc:.4f}, time={quat_4pix_time:.2f}s")
    print(f"QuatNet_RGB:    acc={quat_rgb_acc:.4f}, time={quat_rgb_time:.2f}s")
    
    print("\n=== Comparative Analysis ===")
    print("\nQuatNet_4pix vs RealNet:")
    print(f"  Accuracy: {quat_4pix_acc:.4f} vs {real_acc:.4f} ({(real_acc - quat_4pix_acc)*100:+.2f} points)")
    print(f"  Time: {quat_4pix_time:.2f}s vs {real_time:.2f}s ({quat_4pix_time/real_time:.2f}x)")
    print(f"  Parameters: {quat_4pix_params:,} vs {real_params:,} ({quat_4pix_params/real_params:.2f}x)")
    
    print("\nQuatNet_RGB vs RealNet:")
    print(f"  Accuracy: {quat_rgb_acc:.4f} vs {real_acc:.4f} ({(real_acc - quat_rgb_acc)*100:+.2f} points)")
    print(f"  Time: {quat_rgb_time:.2f}s vs {real_time:.2f}s ({quat_rgb_time/real_time:.2f}x)")
    print(f"  Parameters: {quat_rgb_params:,} vs {real_params:,} ({quat_rgb_params/real_params:.2f}x)")
    
    print("\nQuatNet_RGB vs RealNetLarge:")
    print(f"  Accuracy: {quat_rgb_acc:.4f} vs {reallarge_acc:.4f} ({(reallarge_acc - quat_rgb_acc)*100:+.2f} points)")
    print(f"  Time: {quat_rgb_time:.2f}s vs {reallarge_time:.2f}s ({quat_rgb_time/reallarge_time:.2f}x)")
    print(f"  Parameters: {quat_rgb_params:,} vs {reallarge_params:,} ({quat_rgb_params/reallarge_params:.2f}x)")


if __name__ == "__main__":
    main()

Using device: cuda

=== Training RealNet (baseline) ===
[RealNet] Epoch 1 | loss=1.6356 | test_acc=0.4718 | elapsed=1.95s
[RealNet] Epoch 2 | loss=1.4325 | test_acc=0.4929 | elapsed=3.91s
[RealNet] Epoch 3 | loss=1.3499 | test_acc=0.4987 | elapsed=5.88s
[RealNet] Epoch 4 | loss=1.2790 | test_acc=0.5054 | elapsed=7.85s
[RealNet] Epoch 5 | loss=1.2232 | test_acc=0.5185 | elapsed=9.80s
RealNet final: acc=0.5185, time=9.80s

=== Training RealNetLarge (large capacity) ===
[RealNetLarge] Epoch 1 | loss=1.6478 | test_acc=0.4690 | elapsed=1.91s
[RealNetLarge] Epoch 2 | loss=1.4409 | test_acc=0.4873 | elapsed=3.83s
[RealNetLarge] Epoch 3 | loss=1.3512 | test_acc=0.5104 | elapsed=5.76s
[RealNetLarge] Epoch 4 | loss=1.2718 | test_acc=0.5122 | elapsed=7.75s
[RealNetLarge] Epoch 5 | loss=1.2148 | test_acc=0.5152 | elapsed=9.75s
RealNetLarge final: acc=0.5152, time=9.75s

=== Training QuatNet_4pix (4 pixels per quaternion) ===
[QuatNet_4pix] Epoch 1 | loss=1.7822 | test_acc=0.4222 | elapsed=9.22s
[Q