# Introduction
In a recent video lecture (https://youtu.be/DP454c1K_vQ?t=4841), Professor Jürgen Schmidhuber claimed that Mamba, a type of State Space Model (SSM), cannot solve the parity problem. This claim sparked my curiosity and led me to investigate a simpler SSM that could potentially solve this problem.
# The Parity Problem
The parity problem involves a string of 1s and 0s. For each point in the string, we need to output whether the number of 1s encountered up to that point is even or odd.
# A Simple SSM Solution
I propose a simple SSM that can solve the parity problem using the following formula:
$$
\begin{aligned}
A &= [\pi \cdot i] \,
,\delta = x \
,S_0 = 1 \
,S_{t+1} = e^{A \cdot \delta} \cdot S_t + 0
\end{aligned}
$$
This formulation allows the state to flip only when we encounter a 1, effectively solving the parity problem.
# Challenges for Complex Mamba
While exploring this solution, I identified three reasons why it might be challenging for a complex Mamba model to implement this:

Mamba uses addition in its state update, which makes it difficult to preserve the state within normalized bounds without decay, and we don't want decay to take place(if there is continous decay at some point what you entered will be erased).

If $\delta$ is not exactly 1, there will be a bias over time that gradually sways the model until it loses accuracy.

Mamba cannot assign a $\delta$ of zero for an input due to the softplus function in its $\delta$ calculation, and that means it won't be able to completely ignore the zeroes(which is neccesary for the parity problem).

# Proposed Architecture
To address these challenges, I developed a new architecture that implements the simple SSM described above. The key insight is that we only need to track the complex angle/phase of the state, as the magnitude/radius is always one (we're moving on the unit circle).

The updated formula for this architecture is:
$$
\begin{aligned}
S_0 = 0 \text{ (radians/}\pi\text{)} \
,S_{t+1} = (S_t + \text{angle}_t) \bmod 1 \
,\text{out1}_t = \cos(S_t \cdot \pi) \
,\text{out2}_t = \sin(S_t \cdot \pi)
\end{aligned}
$$
To ensure the model can achieve a state flip when encountering a 1, I divided the unit circle into 100 pieces. The input angle is floored to the nearest division, pushing us forward by 1-100 divisions on the unit circle. To maintain gradient flow, I employed the gradient skipping method from VQ-VAE.

So the final formula is:
$$
\begin{aligned}
S_0 = 0 \text{ (radians/}\pi\text{)} \
,S_{t+1} = (S_t + Round(\text{angle}_t , Num\_digits = 2)) \bmod 1 \
,\text{out1}_t = \cos(S_t \cdot \pi) \
,\text{out2}_t = \sin(S_t \cdot \pi)
\end{aligned}
$$
# Implications and Limitations
This architecture provides a flip-flop-like capability for infinite-context-length state manipulation. However, it's important to note that the model cannot view its current state to decide whether to flip it. This means it can't imitate every type of automaton.

For the parity problem, we only need the current input (action) to determine the state transition, not the current state. This is why our model can solve the parity problem effectively.

While I'm not certain about the full range of tasks this architecture can solve, I'm pleased to report that it generalizes from training on sequence lengths of 100 to testing on infinite sequence lengths for the parity problem.

# imports

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim, round, tensor

pi = torch.acos(torch.zeros(1)).item() * 2

# Model

In [None]:
class parityModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.n_digits = 2
        self.input_multiplier = nn.Parameter(torch.normal(tensor(0.3), tensor(0.1)))
        self.W = nn.Parameter(torch.normal(tensor([0.3]*2), tensor([0.1]*2)))

    def forward(self, x):
        delta = torch.relu(torch.tanh(self.input_multiplier * x))
        quantized_delta = delta + (round(delta, decimals = self.n_digits) - delta).detach()
        states = torch.cumsum(quantized_delta, dim=-1) % 1
        states = states * 2 * pi
        angles = torch.cat([torch.cos(states)[:,:,None], torch.sin(states)[:,:,None]], dim = -1)
        preds = torch.sigmoid(2*torch.einsum('bsd,d->bs', angles, self.W))
        return preds

In [None]:
p = parityModel()

In [None]:
out = p(torch.ones((2,15)))
print(out)

tensor([[0.5684, 0.2825, 0.6508, 0.5822, 0.2804, 0.6402, 0.5957, 0.2790, 0.6289,
         0.6086, 0.2784, 0.6169, 0.6209, 0.2785, 0.6044],
        [0.5684, 0.2825, 0.6508, 0.5822, 0.2804, 0.6402, 0.5957, 0.2790, 0.6289,
         0.6086, 0.2784, 0.6169, 0.6209, 0.2785, 0.6044]],
       grad_fn=<SigmoidBackward0>)


# data generation

In [None]:
def generate_batch(batch_size, seq_len, precentage_of_zeroes = 0.5):
    X = (torch.rand((batch_size, seq_len)) > precentage_of_zeroes)
    Y = X[:,0].unsqueeze(-1)
    for i in range(1,X.shape[-1]):
        Y = torch.cat([Y, torch.logical_xor(Y[:,-1].unsqueeze(-1), X[:,i].unsqueeze(-1))], dim = -1)
    return X.float(), Y.float()

In [None]:
generate_batch(3,5)

(tensor([[1., 1., 0., 0., 0.],
         [1., 0., 1., 0., 0.],
         [0., 1., 0., 1., 0.]]),
 tensor([[1., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [0., 1., 1., 0., 0.]]))

# training loop

In [None]:
def training_loop(num_epochs, model, optimizer, criterion, batch_size, seq_len):
    for i in range(num_epochs):
        optimizer.zero_grad()
        X, Y = generate_batch(batch_size, seq_len)
        preds = model(X)

        loss = criterion(Y, preds)

        if i % 100 == 99:

            X_test, Y_test = generate_batch(batch_size, seq_len)
            test_preds = model(X_test)

            accuracy = ((test_preds > 0.5).float() == Y_test).float().mean()

            print(f'accuracy: {accuracy}, loss: {loss}')

        loss.backward()

        optimizer.step()

In [None]:
p = parityModel()

optimizer = optim.Adam(p.parameters(), lr = 1e-3)

training_loop(10000, p, optimizer, nn.MSELoss(), 64, 100)

accuracy: 0.5284374952316284, loss: 0.25648486614227295
accuracy: 0.5262500047683716, loss: 0.2531069219112396
accuracy: 0.5254687666893005, loss: 0.2507232129573822
accuracy: 0.5310937762260437, loss: 0.250151127576828
accuracy: 0.5298437476158142, loss: 0.250132292509079
accuracy: 0.4921875, loss: 0.24996539950370789
accuracy: 0.5029687285423279, loss: 0.24997840821743011
accuracy: 0.5159375071525574, loss: 0.24989449977874756
accuracy: 0.5107812285423279, loss: 0.2498777210712433
accuracy: 0.4925000071525574, loss: 0.25000011920928955
accuracy: 0.5218750238418579, loss: 0.24993100762367249
accuracy: 0.4868749976158142, loss: 0.2500340938568115
accuracy: 0.514843761920929, loss: 0.24987514317035675
accuracy: 0.5087500214576721, loss: 0.25004783272743225
accuracy: 0.48531249165534973, loss: 0.2500492036342621
accuracy: 0.49031248688697815, loss: 0.2498769909143448
accuracy: 0.48515623807907104, loss: 0.2498980313539505
accuracy: 0.4778124988079071, loss: 0.2500070631504059
accuracy: 0

# checking performance on longer sequences

In [None]:
X_test, Y_test = generate_batch(64, 40000)
test_preds = p(X_test)

accuracy = ((test_preds > 0.5).float() == Y_test).float().mean()

print(f'accuracy: {accuracy}')

accuracy: 1.0


checking the delta of 1 (meaning how much a input of 1 push us forward on the circle unit) and the result is as expected exactly 0.5 .

In [None]:
delta = torch.relu(torch.tanh(p.input_multiplier * 1.0))
quantized_delta = delta + (round(delta, decimals = 2) - delta).detach()
print(quantized_delta)

tensor(0.5000, grad_fn=<AddBackward0>)
