<a href="https://colab.research.google.com/github/marcmontb/Symmetric-Power-Transformers/blob/main/3_0_Notebook_Adaptive_Symmetric_Power_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Adaptive Symmetric Power Transformers

The Adaptive Symmetric Power Transformer (ASPT) is a theoretical twist on the Symmetric Power Transformer (SPT) idea. The original SPT, introduced by Manifest AI, made transformers more efficient by using a power function instead of the usual softmax, and later formulating them as Linear Transformers using Symmetric Power Embeddings. Our ASPT concept takes this a step further. We're exploring what happens if we let the power change based on the input, rather than keeping it fixed. This could help the model adapt better to different parts of a sequence, potentially capturing more complex patterns. In this write-up, we'll look at how this idea might work, what challenges it brings, and what it could mean for language processing tasks.

# 1. Adaptive Power Mechanism


The Adaptive Symmetric Power Transformer introduces a dynamic, input-dependent power parameter through an adaptive power network. For an input vector $x \in \mathbb{R}^{d_{model}}$, the adaptive power network $f_\theta: \mathbb{R}^{d_{model}} \to \mathbb{R}$ is defined as:

$$f_\theta(x) = \sigma(W_2 \cdot \text{ReLU}(W_1x + b_1) + b_2)$$

where $W_1 \in \mathbb{R}^{d_{model} \times d_{model}}$, $W_2 \in \mathbb{R}^{d_{model} \times 1}$, $b_1 \in \mathbb{R}^{d_{model}}$, $b_2 \in \mathbb{R}$ are learnable parameters, ReLU is the rectified linear unit function, and $\sigma$ is the sigmoid function. The final adaptive power $p(x)$ is computed as:

$$p(x) = p_{min} + (p_{max} - p_{min}) \cdot f_\theta(x)$$

This formulation ensures that $p(x) \in (p_{min}, p_{max})$, allowing the model to dynamically adjust its behavior within a predefined range. Typically, $p_{min}$ and $p_{max}$ are set to even integers, maintaining consistency with the theoretical foundations of Symmetric Power Transformers.

In the current implementation, this adaptive power is computed only for the query vectors in the attention mechanism, introducing an asymmetry that diverges from traditional attention formulations. This design choice may lead to interesting dynamics in how the model processes and attends to different parts of the input sequence.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import math
import numpy as np
from tqdm import tqdm

In [None]:
class AdaptivePowerNetwork(nn.Module):
    def __init__(self, d_model, p_min=2, p_max=8):
        super().__init__()
        self.p_min = p_min
        self.p_max = p_max
        self.p_network = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        p_scale = self.p_network(x).squeeze(-1)
        p = self.p_min + (self.p_max - self.p_min) * p_scale
        return p


The `p_network` corresponds to $f_\theta$, and the final computation of `p` matches the equation for $p(x)$. The use of `nn.Sequential` allows for a compact representation of the two-layer network with ReLU activation.

# 2. Symmetric Power Embedding

The cornerstone of the Symmetric Power Transformer is the symmetric power embedding, which we extend to incorporate the adaptive power mechanism. For a vector $v \in \mathbb{R}^d$ and a power $p \in \mathbb{R}^+$, the symmetric power embedding $\phi_p(v)$ is implemented as:

$$\phi_p(v) = (\sqrt{c_\alpha} \prod_{i=1}^{\lfloor p \rfloor} v_{\alpha_i})_{\alpha \in I_{d,\lfloor p \rfloor}}$$

where $I_{d,\lfloor p \rfloor}$ is the set of non-decreasing multi-indices of length $\lfloor p \rfloor$ with entries in $\{1, ..., d\}$, and $c_\alpha$ is the multinomial coefficient accounting for repeated indices in $\alpha$. The floor function $\lfloor p \rfloor$ is used to handle non-integer power values produced by the adaptive mechanism, effectively rounding down to the nearest integer.

The implementation computes this embedding for each vector in the input sequence individually, applying the corresponding adaptive power. This approach allows for fine-grained control over the embedding dimension for each input element, potentially capturing varying levels of higher-order interactions across the sequence.

The dimension of the symmetric power embedding for each vector is $\binom{d+\lfloor p \rfloor-1}{\lfloor p \rfloor}$, which grows polynomially with $p$ for fixed $d$. This growth in dimensionality presents computational challenges for large $d$ or $p$, and may require approximation techniques or hardware-specific optimizations for practical large-scale applications.

In [None]:
class SymmetricPowerEmbedding(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

    def forward(self, x, p):
        batch_size, seq_len, _ = x.shape
        # Pre-compute the output dimension for a single embedding
        sample_embed = self.symmetric_power_embedding(x[0, 0], p[0, 0].item())
        embed_dim = len(sample_embed)

        # Initialize the output tensor with the correct shape
        embedding = torch.zeros(
            (batch_size * seq_len, embed_dim),
            device=x.device,
            dtype=x.dtype
        )

        # Fill the tensor
        for i in range(batch_size):
            for j in range(seq_len):
                idx = i * seq_len + j
                v = x[i, j]
                p_val = p[i, j].item()
                embedding[idx] = self.symmetric_power_embedding(v, p_val)

        # Reshape to the desired output shape
        return embedding.view(batch_size, seq_len, embed_dim)

    def symmetric_power_embedding(self, v, p):
        d = v.shape[0]
        x = []
        for midx in self.non_decreasing_multiindices(int(p), d):
            c = self.count(midx, d)
            xi = math.sqrt(self.multinomial(c))
            for j in range(int(p)):
                xi *= v[midx[j]]
            x.append(xi)
        return torch.tensor(x, device=v.device)

    @staticmethod
    def non_decreasing_multiindices(n, max_idx, starting_from=0):
        if n == 1:
            return [[i] for i in range(starting_from, max_idx)]
        seqs = []
        for i in range(starting_from, max_idx):
            seqs += [[i] + remainder for remainder in
                     SymmetricPowerEmbedding.non_decreasing_multiindices(n-1, max_idx, starting_from=i)]
        return seqs

    @staticmethod
    def multinomial(lst):
        res, i = 1, 1
        for a in lst:
            for j in range(1, a + 1):
                res *= i
                res //= j
                i += 1
        return res

    @staticmethod
    def count(midx, d):
        c = [0] * d
        for i in midx:
            c[i] += 1
        return c

In the code we compute the symmetric power embedding as described in the theoretical explanation. The `symmetric_power_embedding` method corresponds to $\phi_p(v)$, with the `non_decreasing_multiindices` method generating $I_{d,\lfloor p \rfloor}$, and the `multinomial` method computing $c_\alpha$. The rounding of $p$ to the nearest integer is done using `int(p)`.

# 3. Adaptive Symmetric Power Attention

The adaptive symmetric power attention mechanism integrates the adaptive power computation with the symmetric power embedding in a multi-head attention framework. Given query, key, and value matrices $Q, K, V \in \mathbb{R}^{n \times d_{model}}$, where $n$ is the sequence length, the attention computation proceeds as follows:

1. Project inputs: $Q' = W_Q Q, K' = W_K K, V' = W_V V$, where $W_Q, W_K, W_V \in \mathbb{R}^{d_{model} \times d_{model}}$ are learnable projection matrices.

2. Compute adaptive powers: $p_i = p(Q'_i)$ for $i = 1, ..., n$, where $Q'_i$ is the $i$-th row of $Q'$.

3. Apply symmetric power embedding: $\tilde{Q} = \phi_p(Q'), \tilde{K} = \phi_p(K')$, where the embedding is applied row-wise with the corresponding adaptive power.

4. Reshape for multi-head attention: Split $\tilde{Q}, \tilde{K}, V'$ into $h$ heads.

5. Compute attention scores: $S = \tilde{Q}\tilde{K}^T / \sqrt{d_k}$, where $d_k$ is the dimension per head.

6. Apply mask (if provided): $S_{ij} = -\infty$ where mask$_{ij} = 0$.

7. Apply even power normalization:
   $$A_{ij} = \frac{(ReLU(S_{ij}) + \epsilon)^{p_i}}{\sum_k (ReLU(S_{ik}) + \epsilon)^{p_i}}$$
   where $\epsilon$ is a small constant for numerical stability.

8. Compute attention output: $O = AV'$

9. Concatenate heads and project: $Y = W_O [O_1; ...; O_h]$, where $W_O \in \mathbb{R}^{hd_k \times d_{model}}$ is a learnable projection matrix.

This formulation differs from standard multi-head attention in several key aspects. First, the use of symmetric power embeddings allows the model to capture higher-order interactions between query and key elements. Second, the adaptive power mechanism enables the model to adjust the "sharpness" of attention for each query position. Finally, the even power normalization replaces the traditional softmax, potentially altering the attention dynamics in beneficial ways.

An important note is that the adaptive power is only computed for the queries, not the keys. This asymmetry in the attention mechanism is a departure from the usual formulation of attention and may lead to interesting behavioral properties of the model.



In [None]:
class AdaptiveSymmetricPowerAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads

        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

        self.adaptive_power = AdaptivePowerNetwork(d_model)
        self.symmetric_power_embedding = SymmetricPowerEmbedding(d_model)

    def forward(self, q, k, v, mask=None):
        batch_size, seq_len, _ = q.shape

        # Project inputs
        q = self.q_proj(q)
        k = self.k_proj(k)
        v = self.v_proj(v)

        # Compute adaptive p
        p = self.adaptive_power(q)  # [batch_size, seq_len]

        # Compute symmetric power embeddings
        q_embed = self.symmetric_power_embedding(q, p)
        k_embed = self.symmetric_power_embedding(k, p)

        # Reshape for multi-head attention
        q_embed = q_embed.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        k_embed = k_embed.view(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        attn_weights = torch.matmul(q_embed, k_embed.transpose(-2, -1))

        if mask is not None:
            attn_weights = attn_weights.masked_fill(mask == 0, float('-inf'))

        # Apply even power normalization
        attn_weights = torch.pow(torch.relu(attn_weights) + 1e-6, p.unsqueeze(1).unsqueeze(1))
        attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + 1e-6)

        # Apply attention to values
        output = torch.matmul(attn_weights, v)

        # Reshape and project output
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.out_proj(output)

        return output, attn_weights

The code uses the `AdaptivePowerNetwork` and `SymmetricPowerEmbedding` classes defined earlier. The even power normalization is implemented using `torch.pow` and `torch.relu`, with a small epsilon (1e-6) added for numerical stability.

# 4. Transformer Layer and Full Model

The Adaptive Symmetric Power Transformer is composed of multiple layers, each containing the adaptive symmetric power attention mechanism followed by a feed-forward network. The full model also includes embedding layers and a final output projection.

The transformer layer consists of:
1. Multi-head Adaptive Symmetric Power Attention
2. Layer Normalization
3. Feed-forward Network
4. Another Layer Normalization

The full model includes:
1. Input Embedding
2. Positional Encoding
3. Multiple Transformer Layers
4. Final Layer Normalization
5. Output Projection

In [None]:
class AdaptiveSymmetricPowerTransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = AdaptiveSymmetricPowerAttention(d_model, num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = self.norm1(x + self.dropout(attn_output))

        # Feed-forward
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout(ff_output))

        return x

class AdaptiveSymmetricPowerTransformer(nn.Module):
    def __init__(self, d_model, num_heads, num_layers, d_ff, max_seq_length, vocab_size, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = self.positional_encoding(max_seq_length, d_model)
        self.layers = nn.ModuleList([
            AdaptiveSymmetricPowerTransformerLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.final_norm = nn.LayerNorm(d_model)
        self.output_proj = nn.Linear(d_model, vocab_size)

    def forward(self, x, mask=None):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_encoding[:seq_len, :].to(x.device)

        for layer in self.layers:
            x = layer(x, mask)

        x = self.final_norm(x)
        x = self.output_proj(x)
        return x

    @staticmethod
    def positional_encoding(max_seq_length, d_model):
        pos = torch.arange(max_seq_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pos_encoding = torch.zeros(max_seq_length, d_model)
        pos_encoding[:, 0::2] = torch.sin(pos * div_term)
        pos_encoding[:, 1::2] = torch.cos(pos * div_term)
        return pos_encoding

This implementation includes both the `AdaptiveSymmetricPowerTransformerLayer` and the full `AdaptiveSymmetricPowerTransformer` model. The layer implements the structure described in the theoretical explanation, while the full model adds input embedding, positional encoding, and output projection.

# 5. Synthetic dataset example


In [None]:
# Hyperparameters
d_model = 64
num_heads = 4
num_layers = 2
d_ff = 128
max_seq_length = 50
vocab_size = 1000
batch_size = 32
num_epochs = 10
learning_rate = 0.001

# Create a small synthetic dataset
def create_synthetic_dataset(num_samples, seq_length, vocab_size):
    X = torch.randint(0, vocab_size, (num_samples, seq_length))
    y = torch.randint(0, vocab_size, (num_samples,))
    return X, y

# Create train and test datasets
train_X, train_y = create_synthetic_dataset(1000, max_seq_length, vocab_size)
test_X, test_y = create_synthetic_dataset(200, max_seq_length, vocab_size)

train_dataset = TensorDataset(train_X, train_y)
test_dataset = TensorDataset(test_X, test_y)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Initialize the model
model = AdaptiveSymmetricPowerTransformer(d_model, num_heads, num_layers, d_ff, max_seq_length, vocab_size)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for batch_X, batch_y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)

        optimizer.zero_grad()
        output = model(batch_X)
        loss = criterion(output[:, -1, :], batch_y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

    # Evaluation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_X, batch_y in test_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            output = model(batch_X)
            _, predicted = torch.max(output[:, -1, :], 1)
            total += batch_y.size(0)
            correct += (predicted == batch_y).sum().item()

    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy:.2f}%")

print("Training completed!")

# Test the model
model.eval()
test_input = torch.randint(0, vocab_size, (1, max_seq_length)).to(device)
with torch.no_grad():
    output = model(test_input)
    _, predicted = torch.max(output[:, -1, :], 1)
    print(f"Input sequence: {test_input}")
    print(f"Predicted next token: {predicted.item()}")

Epoch 1/10:   0%|          | 0/32 [00:00<?, ?it/s]


## 6. Conclusion and Future Directions

The Adaptive Symmetric Power Transformer, as implemented, provides a ground for exploring more flexible models for sequence processing tasks. However, bridging the gap between theoretical formulation and practical implementation is still a big challenge, and this is just a first shot at it.

Future research directions might include:

1. Investigating the impact of computing adaptive powers for both queries and keys, and comparing it with the current query-only approach.
2. Explore the implementation of efficient algorithms (chunked algorithm) or approximations for computing symmetric power embeddings, particularly for large $p$ or $d_{model}$.
3. Continuous Power Approximations: Analyzing the effects of the power discretization in the symmetric power embedding, and potentially exploring continuous approximations to avoid discontinuities.
4. Developing methods to address the computational and memory challenges associated with the symmetric power embeddings, possibly through techniques like low-rank approximations or sparse representations.



# Annex

# 1. Dimension Changes and Symmetry Properties

When making the power parameter p adaptive, a natural question arises about how the symmetry properties of the embeddings are maintained across different values of p. The key insight is that while the dimension of the embedding space changes with p, the fundamental symmetry properties are preserved within each p-specific space.

## 1.1 Mathematical Background

For a given power p and input dimension d, the symmetric power embedding lives in a space of dimension $\binom{d+p-1}{p}$. This follows from the theory of symmetric tensors described in Section 3.1 from the original SPT paper. The embedding $\phi^p_\text{SYM}(v)$ maps vectors to this space while preserving all relevant symmetries.

When p changes, we move between spaces of different dimensions. For instance:
$$\dim(\phi^2_\text{SYM}) = \frac{d^2 + d}{2}$$
$$\dim(\phi^4_\text{SYM}) = \frac{(d+3)(d+2)(d+1)d}{24}$$

## 1.2 Preservation of Symmetries

Let $T^p$ be a symmetric tensor of order p. The defining property of such tensors is that for all multi-indices $\alpha$ and permutations $\rho \in G_p$:
$$T^p_\alpha = T^p_{\rho(\alpha)}$$

This property holds independently for each p-specific embedding space. When the adaptive mechanism changes p from $p_1$ to $p_2$, we move from one symmetric tensor space to another, each maintaining its own complete set of symmetries. No cross-space symmetry preservation is required because:

1. Each p defines its own complete symmetric tensor space
2. The symmetry properties are defined within each space
3. The dimensions change, but the fundamental symmetric structure remains intact

## 1.3 Implications

This mathematical structure has several important implications:

1. The attention scores maintain their theoretical properties locally (within each p-specific computation):
   $$A_{ij} = \frac{(Q_i^T K_j)^{p(x)}}{\sum_{k=1}^i (Q_i^T K_j)^{p(x)}}$$

2. The state size becomes dynamic:
   $$\text{StateSize}(p) = \text{layer\_n} \cdot \text{head\_count} \cdot \binom{d+p-1}{p} \cdot \text{value\_size}$$

3. The efficient chunked formulation of linear transformers requires modification to handle varying dimensions.

## 1.4 Open Questions

Several theoretical questions remain:

1. How to ensure smooth transitions between spaces of different dimensions?
2. What is the optimal trajectory of p values during training?
3. Can we establish theoretical bounds on the computational cost given a distribution of p values?