### Background
Let us define a **stack** as a subset $S$ of a two-dimensional non-negative integer lattice such that
1. if $(0, j)$ is in $S$ with $j > 0$ then $(0, j-1)$ is also in $S$ and
2. if $(i, j)$ is in $S$ with $i,j > 0$ then either $(i-1, j)$ or $(i-1, j-1)$ is in $S$.

Examples can be generated and displayed using the first 3 code cells.

The central fact we are concerned with is that there are exactly $3^{n-1}$ stacks of size $n$.
I became aware of this fact through Peter Kagey who made this related [StackExchange post](https://math.stackexchange.com/questions/3659431/number-of-ways-to-stack-lego-bricks), citing a 1988 paper by [Gouyou-Beauchamps and Viennot](https://doi.org/10.1016/0196-8858(88)90017-6).
This project is an attempt to have a neural network learn an, ideally human-intelligible, bijection between stacks of size $n$ and ternary sequences of length $n-1$. 


### Stack Generator

First we need code to generate the stacks. The algorithm we use to do so operates by repeatedly choosing diagonal lines $d$ and placing a new cell in the greatest position $(i, i+d)$ along the diagonal that creates a valid stack. Consider that placing a cell in diagonal $d_1$ followed by $d_2$ may or may not result in the same stack as the reverse. That is, these additions to the stack may or may not "commute". To mitigate the probabilistic bias towards commuting additions, we choose each non-commuting addition for the next diagonal with probability $1/3$ (including d itself). If none of these diagonals are chosen, then a diagonal is chosen uniformly from one of the remaining valid ones. This scheme generates certain extremal stacks, such as the stack with a single row, with the "correct" probability of $3^{-n+1}$.

One can optionally specify to the generator that stacks be formatted as "unskewed" so that cells $(i,j)$ are "supported" by either cells $(i-1, j)$ or $(i, j-1)$ rather than $(i-1, j)$ or $(i-1, j-1)$. This is more similar to the equivalent definition in the original 1988 paper. This format exhibits the inherent symmetry of the problem and also creates a more compact representation in that we can guarantee $i + j \le n$ for any cell in the stack. I did not find this setting to make a noticible difference in learning.

In [None]:
import torch
import random
import torch.nn.functional as F

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

class StackGenerator(torch.utils.data.IterableDataset):
    """
    A pytorch dataset for generating stacks.
    
    Keyword arguments:
    limit -- the largest size of stack to generate
    exact -- if true, generate only stacks of exactly the give size limit
    seed -- optional random seed
    skew -- if false, format output so that the stack "rows" actually lie along tensor antidiagonals
    """
    
    def __init__(self, limit=40, exact=False, seed=None, skew=True):
        super().__init__()
        self.rand = random.Random(seed)
        self.exact = exact
        self.limit = limit
        self.skew = skew

    def __iter__(self):
        while True:
            stack_set = {(0,0)}
            height = {0: 1, -1: 0, -2: 0, 1: 0, 2: 0}
            min_diag = -1
            max_diag = 1
            dependencies = [0, -1 ,1]
            size = 1
            while self.continue_condition(size):
                size += 1
                r = self.rand.randint(0, 2)
                if r < len(dependencies):
                    diag = dependencies[r]
                else:
                    diag = self.rand.randint(min_diag, max_diag)
                    while diag in dependencies:
                        diag = self.rand.randint(min_diag, max_diag)
                row = max(height[diag], height[diag+1])
                col = diag + row
                stack_set.add((row,col))
                height[diag] = row + 1
                min_diag = min(diag - 1, min_diag)
                max_diag = max(diag + 1, max_diag)
                if min_diag - 1 not in height:
                    height[min_diag - 1] = 0
                if max_diag + 1 not in height:
                    height[max_diag + 1] = 0
                dependencies = [diag]
                if height[diag - 1] < height[diag]:
                    dependencies.append(diag-1)
                if max(height[diag + 1], height[diag + 2]) + 1 > height[diag]:
                    dependencies.append(diag+1)
            
            # Flip the stack left to right with probability 1/2
            base_length = max([col for (row, col) in stack_set if row == 0])
            if self.rand.random() < 1/2:
                stack_set = {(row, base_length + row - col) for (row, col) in stack_set}

            stack = torch.zeros([size, size], dtype=torch.bool)
            if self.skew:
                for row, col in stack_set:
                    stack[row, col] = True
            else:
                for row, col in stack_set:
                    stack[row - col + base_length, col] = True
            yield stack

    def continue_condition(self, count):
        """Determines whether to add additional cells."""
        if count >= self.limit:
            return False
        return self.exact or self.rand.random() < 1 - (count - 1) / count / 4
        
def collate_fn(data):
    max_size = max(len(item) for item in data)
    data = [F.pad(item, (0, max_size-len(item), 0, max_size-len(item)))[None] for item in data]
    return torch.cat(data, 0)

def worker_init(multiplier):
    def fn(worker_id):
        torch.utils.data.get_worker_info().dataset.rand.seed(multiplier*(worker_id+1))
    return fn

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(StackGenerator(limit=40), batch_size=1, num_workers=0, collate_fn=collate_fn, worker_init_fn=worker_init(571))
it = iter(train_dataloader)

In [None]:
import matplotlib.pyplot as plt

plt.imshow(next(it)[0])

### Architecture

In order to make a bijection representable as a neural network, we use the overall structure of a [variational autoencoder](https://en.wikipedia.org/wiki/Variational_autoencoder). That is, we have an **encoder** neural network that maps the input stack representation of size $n$ into a latent space of $n-1$ random samples drawn from [relaxed onehot categorical](https://pytorch.org/docs/stable/distributions.html#relaxedonehotcategorical) distributions, and a **decoder** neural network that maps the latent space back to a two-dimensional grid space. The **autoencoder** is the composition of these two modules. If the autoencoder achieves effectively perfect reconstruction while the latent distributions approach non-random values, then the encoder and decoder compute bijections between stacks and ternary sequences.

In theory, achieving perfect reconstruction requires perfectly non-random distributions so the distributions need not be explicitly penalized in the loss function; we can simply use a measure of reconstruction error. Alternatively the autoencoder can be initialized with the `use_ELBO` flag set to `True` which will cause the autoencoder to return the [evidence lower bound](https://en.wikipedia.org/wiki/Evidence_lower_bound) during training which is intended to be used directly as the loss function.

Internally, the encoder repeatedly applies the same residual block (i.e. with the same kernels) $n$ times, each time decreasing the grid size by 1. The $n$ vectors in the $(0,0)$ corner of the grid are then transformed into the distribution parameters. The decoder is very similar but in reverse. Unfortunately this makes the modules quite serial, but it ensures that the model scales to any input size $n$ without additional parameters, and does so homogeneously. Heuristically one might expect that this leads to a model that is more human-intelligible and *logically* scalable. However, this may be a premise worth revisiting in the hope that a denser, more parallel, but less homogeneous architecture can more easily find a solution that is still logically scalable. Alternatively the repeated portions of the modules could be made deeper, but this would slow training to even more of a crawl.

In [None]:
import torch.nn.functional as F
from torch import nn
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical
from torch.distributions.bernoulli import Bernoulli
import math
from mamba_ssm import Mamba

class BidirectionalMamba(nn.Module):
    """Splits input channel-wise and applies Mamba module in both directions."""
    
    def __init__(self, **kwargs):
        super().__init__()
        self.mamba = Mamba(**kwargs)

    def forward(self, x):
        (x1, x2) = torch.split(x, x.size(1) // 2, dim=1)
        x1 = torch.flip(x1, [2])
        x = torch.cat((x1, x2), dim=1)
        x = self.mamba(x.movedim(1,2)).movedim(2,1)
        (x1, x2) = torch.split(x, x.size(1) // 2, dim=1)
        x1 = torch.flip(x1, [2])
        x = torch.cat((x1, x2), dim=1)
        return x

class DepthwiseResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, internal_dim, kernel_size=3, padding=1, glu=False, prenormalize=True):
        super().__init__()
        self.norm = nn.LayerNorm(in_channels) if prenormalize else lambda x: x
        self.proj_up = nn.Conv2d(in_channels, internal_dim, 1)
        self.depth_conv = nn.Conv2d(internal_dim, internal_dim, kernel_size, groups=internal_dim, padding=padding)
        self.gate = nn.GLU(dim=1) if glu else nn.SiLU()
        self.proj_down = nn.Conv2d(internal_dim // (1 + glu), out_channels, 1)
        self.id_proj = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
        size_change = 2 * padding - kernel_size + 1
        if size_change == 0:
            self.resize = lambda x: x
        elif size_change < 0:
            self.resize = lambda x: x[:, :, :size_change, :size_change]
        else:
            self.resize = nn.ZeroPad2d((0, size_change, 0, size_change))

    def forward(self, x):
        res = self.norm(x.movedim(1, 3)).movedim(3, 1)
        res = self.proj_up(res)
        res = F.silu(res)
        res = self.depth_conv(res)
        res = self.gate(res)
        res = self.proj_down(res)
        return self.resize(self.id_proj(x)) + res

class Encoder(nn.Module):
    def __init__(self, use_ELBO=False):
        super().__init__()
        self.state_size = 8
        self.internal_dim = 64
        
        self.conv_block = DepthwiseResidualBlock(1, self.state_size, 2 * self.state_size, prenormalize=False)

        self.repeat_block = DepthwiseResidualBlock(self.state_size, self.state_size, self.internal_dim, kernel_size=2, padding=0, glu=True)

        self.flatnorm = nn.LayerNorm(self.state_size)
        self.mamba = BidirectionalMamba(d_model=self.state_size, d_conv=3)
        self.batchnorm = nn.BatchNorm1d(self.state_size)
        self.shorten = nn.Conv1d(self.state_size, 3, 1)
        self.temperature = nn.parameter.Parameter(torch.tensor([2/3]), requires_grad=False)
        self.use_ELBO = use_ELBO
    
    def forward(self, x):
        num_bricks = x.int().sum((1,2))
        lower = torch.ones([x.size(1), x.size(1)], dtype=torch.bool, device=device).tril()
        index = (num_bricks - 1).unsqueeze(1).expand(-1, x.size(1))
        mask = torch.gather(lower, 0, index).unsqueeze(1)[:, :, 1:]

        x = self.conv_block(x.unsqueeze(1))

        logits = torch.zeros((x.size(0), self.state_size, x.size(2)-1), device=device)
        for i in range(x.size(2)-1):
            x = self.repeat_block(x)
            logits[:, :, i] = x[:, :, 0, 0]
        
        logits = logits + self.mamba(self.flatnorm(logits.transpose(1, 2)).transpose(1, 2))
        logits = self.batchnorm(logits)
        logits = self.shorten(logits)

        if self.training:
            posterior = RelaxedOneHotCategorical(self.temperature, logits=logits.transpose(1, 2))
            if self.use_ELBO:
                return posterior, mask
            out = posterior.rsample()
        else:
            out = torch.argmax(logits, dim=1)
            out = F.one_hot(out, num_classes=3).float()
        out = out.transpose(1, 2)
        return out * mask

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.state_size = 8
        self.internal_dim = 64
        
        self.lengthen = nn.Conv1d(3, self.state_size, 1)
        self.batchnorm = nn.BatchNorm1d(self.state_size)
        self.mamba = BidirectionalMamba(d_model=self.state_size, d_conv=3)

        self.initial = nn.parameter.Parameter(torch.rand(1, self.state_size, 1, 1))

        self.repeat_block = DepthwiseResidualBlock(2 * self.state_size, self.state_size, self.internal_dim, kernel_size=2, padding=1, glu=True)
        
        self.conv_block = DepthwiseResidualBlock(self.state_size, 1, 2 * self.state_size)

    def forward(self, x):
        num_bricks = torch.max(x != 0, dim=1).values.sum(1) + 1
        lower = torch.ones([x.size(2) + 1, x.size(2) + 1], dtype=torch.bool, device=device).tril()
        index = (num_bricks - 1).unsqueeze(1).expand(-1, x.size(2) + 1)
        mask = torch.gather(lower, 0, index)
        square_mask = torch.einsum('bh,bw->bhw', mask, mask).unsqueeze(1)
        
        x = self.lengthen(x)
        x = x + self.mamba(self.batchnorm(x)) * mask.unsqueeze(1)[:,:,1:]

        out = self.initial.expand(x.size(0), self.state_size, 1, 1)
        for i in reversed(range(x.size(2))):
            out = F.pad(out, (0, 0, 0, 0, self.state_size, 0))
            out[:, :self.state_size, :, :] = x[:, :, i, None, None]
            out = self.repeat_block(out)

        out = self.conv_block(out)
        
        out = torch.where(square_mask, out, torch.finfo().min)
        out = out.squeeze(1)
        if self.training:
            return out
        return F.sigmoid(out)

class AutoEncoder(nn.Module):
    def __init__(self, use_ELBO=False):
        super().__init__()
        self.enc = Encoder(use_ELBO=use_ELBO)
        self.dec = Decoder()
        self.prior_temp = nn.parameter.Parameter(torch.tensor([1/10]), requires_grad=False)
        self.prior_logits = nn.parameter.Parameter(torch.tensor([0.,0.,0.]), requires_grad=False)
        self.use_ELBO = use_ELBO

    def forward(self, x):
        if self.training and self.use_ELBO:
            posterior, mask = self.enc(x)
            code = posterior.rsample()
            likelihood_logits = self.dec(code.transpose(1, 2) * mask)
            likelihood = Bernoulli(logits=likelihood_logits)
            prior = RelaxedOneHotCategorical(self.prior_temp, logits=self.prior_logits)

            log_prior = torch.sum(prior.log_prob(code) * mask[:,0,:], 1)
            log_likelihood = torch.sum(likelihood.log_prob(x), (1, 2))
            log_posterior = torch.sum(posterior.log_prob(code) * mask[:,0,:], 1)
            unnormalized_ELBO = log_prior + log_likelihood - log_posterior
            return -(unnormalized_ELBO / x.sum((1,2))).sum() / len(x)
        
        code = self.enc(x)
        return code, self.dec(code)

### Training

There isn't too much to say about the training procedure. As stated above, we can use either evidence lower bound as a loss function, in which case the training loop below must be slightly modified, or a measure of how well the input is reconstructed. For training, this measure is the binary cross entropy, and for testing we count the number of grid cells that round to the incorrect value. In both cases we normalize by the input size $n$.

If we are using reconstruction loss rather than the evidence lower bound, then we can use Newton's method as a root finding algorithm since we want to achieve **zero** reconstruction loss. Unfortunately, this has not yielded much success.

In [None]:
import torch.nn.functional as F
from torch.distributions.dirichlet import Dirichlet
import matplotlib.pyplot as plt
import time
import einops

def reconstruction_loss(X, out):
    unnormalized = F.binary_cross_entropy_with_logits(out, X, reduction='none').sum((1,2))
    return (unnormalized / X.sum((1,2))).sum() / len(X)

def reconstruction_error(X, reconstruction):
    unnormalized = torch.round(torch.abs(X - reconstruction)).sum((1,2))
    return (unnormalized / X.sum((1,2))).sum() / len(X)

def train_loop(data_iter, model, num_batches=1000, reports=5):
    model.train()
    batches_per_report = int(num_batches / reports)
    for i in range(num_batches):
        X = next(data_iter).to(device, dtype=torch.float)
        #loss = model(X)
        _, out = model(X)
        loss = reconstruction_loss(X, out)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i % batches_per_report == 0:
            print(f"Loss: {loss.item():>10f}  [{(1 + i)*len(X)} samples]")

def test_loop(data_iter, model, num_batches=100):
    model.eval()
    recon_error = 0
    with torch.no_grad():
        for _ in range(num_batches):
            X = next(data_iter).to(device, dtype=torch.float)
            code, reconstruction = model(X)
            recon_error += reconstruction_error(X, reconstruction)
    recon_error /= num_batches
    print(f"Average Error: {recon_error:>10f} \n")
    return recon_error

class NewtonOptimizer(torch.optim.Optimizer): 

    def __init__(self, params): 
        super().__init__(params, defaults={}) 

    def step(self, loss):
        for group in self.param_groups:
            params = [p for p in group['params'] if p.requires_grad]
            grad, ps = einops.pack([p.grad.data for p in params], '*')
            direction = - 1 * loss * grad / (torch.linalg.vector_norm(grad) ** 2 + 1e-4 )
            direction = einops.unpack(direction, ps, '*')
            for i, p in enumerate(params): 
                p.data += direction[i]

The current settings have managed to achieve a test error of around 0.16.

In [None]:
from torch.utils.data import DataLoader
import time
model = AutoEncoder(use_ELBO=False)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
#optimizer = NewtonOptimizer(model.parameters())
dataloader = DataLoader(StackGenerator(limit=40), batch_size=64, num_workers=16, collate_fn=collate_fn, worker_init_fn=worker_init(51))
data_iter = iter(dataloader)
for i in range(40):
    t = time.time()
    train_loop(data_iter, model, num_batches=1000)
    print(f"Time Elapsed: {time.time()-t:>8f}")
    test_error = test_loop(data_iter, model, num_batches=100)


In [None]:
model.dec.eval()
with torch.no_grad():
    out = model.dec(torch.tensor([[[1,0,0],[0,0,1],[0,0,1],[0,1,0],[0,0,1]]], device=device).float().movedim(2,1)).cpu()
    plt.imshow(out.squeeze())