# Démo rapide — Vérification de l’équivariance (Vector Neurons)

**Objectif**

Ce notebook illustre  une propriété clé des *Vector Neurons (VN)* :
> **l’équivariance aux rotations SO(3)**

Autrement dit, pour une couche VN correcte, on doit avoir :

\[
f(VR) = f(V)R
\]

où :
- \( V \) est un ensemble de vecteurs 3D,
- \( R \in SO(3) \) est une rotation,
- \( f \) est une couche de réseau VN (linéaire ou avec non-linéarité).

Nous allons :
1. Implémenter une couche linéaire VN
2. Implémenter la non-linéarité VN-ReLU (version *detached*)
3. Vérifier **numériquement** que l’équivariance est respectée


In [2]:
!pip install plotly ipywidgets

import numpy as np
import torch

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Utilitaires rotations (Rodrigues) et affichage rapide
def random_rotation_matrix(seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    # random unit vector for axis
    v = torch.randn(3)
    v = v / v.norm()

    # On garde le résultat sous forme de Tensor (pas de .item())
    theta = torch.rand(1) * 2 * np.pi

    K = torch.tensor([[0, -v[2], v[1]],
                      [v[2], 0, -v[0]],
                      [-v[1], v[0], 0]], dtype=torch.float32)

    # Maintenant torch.sin(theta) reçoit bien un Tensor
    R = torch.eye(3) + torch.sin(theta)*K + (1 - torch.cos(theta))*(K @ K)
    return R

def apply_rotation(V, R):
    # V shape (..., 3) or (..., C, 3) -> apply rotation on last dim
    return torch.einsum('...c,ij->...i', V, R.T) if V.ndim==1 else V @ R.T

# Implémentation VN-linear & VN-ReLU (vectorized)
eps = 1e-6

def vn_linear(V, W):
    # V shape (batch, C_in, 3)
    # W shape (C_out, C_in)
    # output: (batch, C_out, 3)
    return torch.einsum('oc, b c d -> b o d', W, V)

def vn_relu_detached(V, Wq, Wk):
    # Detached style non-linearity
    # V: (batch, C_in, 3)
    # Wq, Wk: (C_out, C_in)
    q = vn_linear(V, Wq)   # (batch, C_out, 3)
    k = vn_linear(V, Wk)   # (batch, C_out, 3)
    # projection of q onto k: proj = <q,k> / ||k||^2 * k
    denom = (k.norm(dim=-1, keepdim=True)**2 + eps)  # (batch, C_out, 1)
    dot = (q * k).sum(dim=-1, keepdim=True)           # (batch, C_out, 1)
    proj = (dot / denom) * k                          # (batch, C_out, 3)
    # mask where <q,k> > 0
    mask = (dot > 0).float()
    out = mask * q + (1 - mask) * (q - proj)
    return out

# Test d'équivariance
torch.manual_seed(0)
B = 4       # batch
C_in = 5
C_out = 6

# Random input V (batch, C_in, 3)
V = torch.randn(B, C_in, 3, device=device)

# Random linear weights
W = torch.randn(C_out, C_in, device=device)
Wq = torch.randn(C_out, C_in, device=device)
Wk = torch.randn(C_out, C_in, device=device)

# Random rotation
R = random_rotation_matrix()
R = R.to(device)

# Compute f_lin(VR)
V_rot = V @ R.T        # apply rotation to vectors (matches the paper's right multiplication)
out1 = vn_linear(V_rot, W)        # f_lin(VR)
# Compute f_lin(V) then rotate result
out2 = vn_linear(V, W) @ R.T      # f_lin(V)R

diff_lin = (out1 - out2).abs().max().item()
print("Max abs diff linear layer (should be ~0):", diff_lin)

# Now test VN-ReLU (detached variant)
out1 = vn_relu_detached(V_rot, Wq, Wk)
out2 = vn_relu_detached(V, Wq, Wk) @ R.T
diff_relu = (out1 - out2).abs().max().item()
print("Max abs diff VN-ReLU detached (should be ~0):", diff_relu)


Device: cpu
Max abs diff linear layer (should be ~0): 9.5367431640625e-07
Max abs diff VN-ReLU detached (should be ~0): 9.5367431640625e-07
