In [25]:
import torch
from math import log, ceil


In [26]:
batch_size, seq_len, classes = 2, 4, 15

In [27]:
batch = []
for _ in range(batch_size):
    seq = torch.randint(0, classes, (seq_len,))
    batch.append({'x': seq})

In [28]:
x = [item['x'] for item in batch]

In [29]:
min_length = min(seq.shape[0] for seq in x)

In [30]:
x = [tensor[:min_length] for tensor in x]

In [31]:
x

[tensor([13, 13,  0,  9]), tensor([12,  4, 14,  8])]

In [32]:
base = 3
assert base >= 2, "Base must be at least 2"

In [33]:
dims = ceil(log((-(1-base) * classes - 1), base))

In [34]:
powers = base ** torch.arange(dims - 1, -1, -1)

In [35]:
powers

tensor([27,  9,  3,  1])

In [36]:
values = torch.stack(x)

In [37]:
values

tensor([[13, 13,  0,  9],
        [12,  4, 14,  8]])

In [38]:
values = values.view(-1, 1)

In [39]:
flat_rep = (values.view(-1, 1) // powers) % base

In [40]:
flat_rep.shape

torch.Size([8, 4])

In [41]:
flat_rep

tensor([[0, 1, 1, 1],
        [0, 1, 1, 1],
        [0, 0, 0, 0],
        [0, 1, 0, 0],
        [0, 1, 1, 0],
        [0, 0, 1, 1],
        [0, 1, 1, 2],
        [0, 0, 2, 2]])

In [42]:
from torch.nn import functional as F
from einops import rearrange

In [43]:
vector_rep = F.one_hot(flat_rep, num_classes=base)

In [44]:
vector_rep.shape

torch.Size([8, 4, 3])

In [45]:
# we are able to encode the original classes sequence into a different base "vocab-size" by adding an additional encoding
# sequence dimension. However, we only want the shape (seq_len, base), so that the overall shape is still (batch, seq_len, base)
# as that is what our models expect.
vector_rep = rearrange(vector_rep, '(batch_size classes_seq_len) encoding_seq_len base_vocab_size -> batch_size (classes_seq_len encoding_seq_len) base_vocab_size', batch_size=batch_size)

In [46]:
vector_rep.shape

torch.Size([2, 16, 3])

In [47]:
vector_rep

tensor([[[1, 0, 0],
         [0, 1, 0],
         [0, 1, 0],
         [0, 1, 0],
         [1, 0, 0],
         [0, 1, 0],
         [0, 1, 0],
         [0, 1, 0],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [0, 1, 0],
         [1, 0, 0],
         [1, 0, 0]],

        [[1, 0, 0],
         [0, 1, 0],
         [0, 1, 0],
         [1, 0, 0],
         [1, 0, 0],
         [1, 0, 0],
         [0, 1, 0],
         [0, 1, 0],
         [1, 0, 0],
         [0, 1, 0],
         [0, 1, 0],
         [0, 0, 1],
         [1, 0, 0],
         [1, 0, 0],
         [0, 0, 1],
         [0, 0, 1]]])

# Inverse Transformation
Now we will perform the inverse transformation to get `x` back from `vector_rep`.

In [48]:
# Reshape back to the state before the dimension merge
unarranged_vector_rep = rearrange(vector_rep, 'batch_size (classes_seq_len encoding_seq_len) base_vocab_size -> (batch_size classes_seq_len) encoding_seq_len base_vocab_size', classes_seq_len=seq_len, encoding_seq_len=dims)

# Get the indices from the one-hot encoding
flat_rep_inverse = torch.argmax(unarranged_vector_rep, dim=-1)

# Multiply by powers and sum to get the original values
values_flat_inverse = (flat_rep_inverse * powers).sum(dim=1)

# Reshape to the original batch structure
x_inverse = values_flat_inverse.view(batch_size, -1)

# Verify that the inverse transformation is correct
assert torch.equal(torch.stack(x), x_inverse)

print("Inverse transformation successful!")
print("Original x:\n", torch.stack(x))
print("Reconstructed x:\n", x_inverse)

Inverse transformation successful!
Original x:
 tensor([[13, 13,  0,  9],
        [12,  4, 14,  8]])
Reconstructed x:
 tensor([[13, 13,  0,  9],
        [12,  4, 14,  8]])
