# Attention-based Autoregressive Structure Generation

This notebook trains a **causal Transformer** (self-attention) that autoregressively generates voxel structures of variable sizes.

## Key idea
We serialize each structure as:
- `<BOS>`
- `SIZE_X_n`, `SIZE_Y_n`, `SIZE_Z_n`
- flattened block-ID tokens in XYZ scan order
- `<EOS>`

At generation time, the model predicts size tokens and then voxel tokens, allowing different output sizes.

In [1]:
from pathlib import Path
from collections import Counter
import random

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from blockgen.utils.data_loader import load_schematic
from blockgen.utils.data import Structure
from blockgen.renderer.render import render_schem
from blockgen.models import VoxelTransformerAR

SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)
print('torch:', torch.__version__)

device: cuda
torch: 2.10.0+cu128


## 1) Load + preprocess structures

In [14]:
repo_root = Path.cwd()
if not (repo_root / 'data' / 'raw').exists():
    repo_root = repo_root.parent

raw_dir = repo_root / 'data' / 'raw'
all_paths = sorted(raw_dir.glob('*.schematic'))

num_examples = 1000
max_dim = 12
max_size_token = 12
air_block_id = 0

paths = all_paths[:num_examples]
print('schematics used:', len(paths))


def preprocess_structure(path, max_dim=12):
    schem = load_schematic(str(path))
    st = Structure.from_schematic(schem, source_path=str(path))
    st = st.crop_to_non_air()
    st = st.downsample(max_dim=max_dim)
    sx, sy, sz = st.shape

    # Clip to representable size tokens
    sx = min(sx, max_size_token)
    sy = min(sy, max_size_token)
    sz = min(sz, max_size_token)

    ids = st.block_ids[:sx, :sy, :sz].astype(np.int32)
    return ids


grids = [preprocess_structure(path, max_dim=max_dim) for path in paths]
print('example shape:', grids[0].shape)

schematics used: 1000
example shape: (11, 4, 10)


## 2) Build tokenizer with size tokens + block tokens

In [15]:
special_tokens = ['<PAD>', '<BOS>', '<EOS>', '<UNK>']
size_tokens = []
for n in range(1, max_size_token + 1):
    size_tokens.extend([f'SIZE_X_{n}', f'SIZE_Y_{n}', f'SIZE_Z_{n}'])

block_counter = Counter()
for g in grids:
    block_counter.update(g.reshape(-1).tolist())

observed_block_ids = sorted(int(x) for x in block_counter.keys())
block_tokens = [f'BID_{bid}' for bid in observed_block_ids]

all_tokens = special_tokens + size_tokens + block_tokens
stoi = {tok: i for i, tok in enumerate(all_tokens)}
itos = {i: tok for tok, i in stoi.items()}

pad_idx = stoi['<PAD>']
bos_idx = stoi['<BOS>']
eos_idx = stoi['<EOS>']
unk_idx = stoi['<UNK>']

air_token = stoi.get('BID_0', unk_idx)
vocab_size = len(stoi)

print('vocab size:', vocab_size)
print('block token count:', len(block_tokens))
print('air token:', air_token)

vocab size: 250
block token count: 210
air token: 129


In [None]:
def size_to_tokens(shape):
    sx, sy, sz = shape
    return [stoi[f'SIZE_X_{sx}'], stoi[f'SIZE_Y_{sy}'], stoi[f'SIZE_Z_{sz}']]


def block_id_to_token(block_id):
    tok = f'BID_{int(block_id)}'
    return stoi.get(tok, unk_idx)


def token_to_block_id(token_idx):
    tok = itos.get(int(token_idx), '<UNK>')
    if isinstance(tok, str) and tok.startswith('BID_'):
        return int(tok.split('_', 1)[1])
    return 0


def serialize_grid(grid):
    sx, sy, sz = grid.shape
    seq = [bos_idx]
    seq.extend(size_to_tokens((sx, sy, sz)))
    seq.extend([block_id_to_token(x) for x in grid.reshape(-1).tolist()])
    seq.append(eos_idx)
    return seq


sequences = [serialize_grid(grid) for grid in grids]
max_seq_len = max(len(s) for s in sequences)
print('max sequence length:', max_seq_len)

max sequence length: 1733


In [16]:
perm = np.random.permutation(len(sequences))
n_train = int(0.85 * len(sequences))
train_ids = perm[:n_train]
test_ids = perm[n_train:]

train_sequences = [sequences[i] for i in train_ids]
test_sequences = [sequences[i] for i in test_ids]

train_hashes = {tuple(seq) for seq in train_sequences}


class ARTokenDataset(Dataset):
    def __init__(self, seqs):
        self.seqs = seqs

    def __len__(self):
        return len(self.seqs)

    def __getitem__(self, idx):
        seq = torch.tensor(self.seqs[idx], dtype=torch.long)
        inp = seq[:-1]
        tgt = seq[1:]
        return inp, tgt


def collate_batch(batch):
    xs, ys = zip(*batch)
    max_len = max(x.shape[0] for x in xs)

    x_pad = torch.full((len(xs), max_len), pad_idx, dtype=torch.long)
    y_pad = torch.full((len(xs), max_len), pad_idx, dtype=torch.long)
    pad_mask = torch.ones((len(xs), max_len), dtype=torch.bool)

    for i, (x, y) in enumerate(zip(xs, ys)):
        n = x.shape[0]
        x_pad[i, :n] = x
        y_pad[i, :n] = y
        pad_mask[i, :n] = False

    return x_pad, y_pad, pad_mask


train_loader = DataLoader(ARTokenDataset(train_sequences), batch_size=12, shuffle=True, collate_fn=collate_batch)
test_loader = DataLoader(ARTokenDataset(test_sequences), batch_size=12, shuffle=False, collate_fn=collate_batch)

print('train examples:', len(train_sequences), 'test examples:', len(test_sequences))

train examples: 187 test examples: 33


## 3) Define and train attention model

In [13]:
model = VoxelTransformerAR(
    vocab_size=vocab_size,
    max_seq_len=max_seq_len,
    d_model=192,
    nhead=6,
    num_layers=6,
    dim_feedforward=768,
    dropout=0.1,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)


def token_loss(logits, targets, ignore_index):
    b, t, v = logits.shape
    return F.cross_entropy(logits.reshape(b * t, v), targets.reshape(b * t), ignore_index=ignore_index)


@torch.no_grad()
def evaluate(loader):
    model.eval()
    losses = []
    for x, y, pad_mask in loader:
        x, y, pad_mask = x.to(device), y.to(device), pad_mask.to(device)
        logits = model(x, pad_mask=pad_mask)
        loss = token_loss(logits, y, ignore_index=pad_idx)
        losses.append(float(loss.item()))

    mean_loss = float(np.mean(losses)) if losses else 0.0
    ppl = float(np.exp(mean_loss))
    return mean_loss, ppl


epochs = 8
for epoch in range(1, epochs + 1):
    model.train()
    train_losses = []

    for x, y, pad_mask in train_loader:
        x, y, pad_mask = x.to(device), y.to(device), pad_mask.to(device)

        logits = model(x, pad_mask=pad_mask)
        loss = token_loss(logits, y, ignore_index=pad_idx)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        train_losses.append(float(loss.item()))

    train_loss = float(np.mean(train_losses))
    val_loss, val_ppl = evaluate(test_loader)
    print(f'epoch {epoch:02d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_ppl={val_ppl:.2f}')

epoch 01 | train_loss=2.3029 | val_loss=1.5088 | val_ppl=4.52
epoch 02 | train_loss=1.4593 | val_loss=1.2139 | val_ppl=3.37
epoch 03 | train_loss=1.2927 | val_loss=1.1099 | val_ppl=3.03
epoch 04 | train_loss=1.1625 | val_loss=1.0407 | val_ppl=2.83
epoch 05 | train_loss=1.0860 | val_loss=0.9894 | val_ppl=2.69
epoch 06 | train_loss=1.0234 | val_loss=0.9548 | val_ppl=2.60
epoch 07 | train_loss=1.0110 | val_loss=0.9326 | val_ppl=2.54
epoch 08 | train_loss=0.9759 | val_loss=0.9185 | val_ppl=2.51


## 4) Autoregressive sampling with variable sizes

In [19]:
def sample_next_token(logits, temperature=1.0, top_k=32):
    logits = logits / max(temperature, 1e-6)
    if top_k is not None and top_k > 0 and top_k < logits.shape[-1]:
        vals, idxs = torch.topk(logits, k=top_k, dim=-1)
        probs = torch.softmax(vals, dim=-1)
        pick = torch.multinomial(probs, num_samples=1)
        return idxs.gather(-1, pick)

    probs = torch.softmax(logits, dim=-1)
    return torch.multinomial(probs, num_samples=1)


def parse_size_token(token_idx, axis):
    tok = itos.get(int(token_idx), '')
    prefix = f'SIZE_{axis}_'
    if isinstance(tok, str) and tok.startswith(prefix):
        return int(tok[len(prefix):])
    return None


@torch.no_grad()
def generate_structure_tokens(model, temperature=0.95, top_k=32):
    model.eval()
    seq = [bos_idx]

    # 1) sample size tokens
    for axis in ['X', 'Y', 'Z']:
        x = torch.tensor([seq], dtype=torch.long, device=device)
        logits = model(x)[:, -1, :]

        # Restrict to valid size tokens for the requested axis
        valid = [stoi[f'SIZE_{axis}_{n}'] for n in range(1, max_size_token + 1)]
        restricted = torch.full_like(logits, -1e9)
        restricted[:, valid] = logits[:, valid]

        tok = int(sample_next_token(restricted, temperature=temperature, top_k=min(top_k, len(valid))).item())
        seq.append(tok)

    sx = parse_size_token(seq[1], 'X') or 1
    sy = parse_size_token(seq[2], 'Y') or 1
    sz = parse_size_token(seq[3], 'Z') or 1

    n_voxels = sx * sy * sz

    # 2) sample voxel tokens
    block_vocab = [stoi[f'BID_{bid}'] for bid in observed_block_ids if f'BID_{bid}' in stoi]

    for _ in range(n_voxels):
        x = torch.tensor([seq], dtype=torch.long, device=device)
        logits = model(x)[:, -1, :]

        restricted = torch.full_like(logits, -1e9)
        restricted[:, block_vocab] = logits[:, block_vocab]
        tok = int(sample_next_token(restricted, temperature=temperature, top_k=min(top_k, len(block_vocab))).item())
        seq.append(tok)

    seq.append(eos_idx)
    return seq


def decode_tokens_to_structure(seq):
    sx = parse_size_token(seq[1], 'X') or 1
    sy = parse_size_token(seq[2], 'Y') or 1
    sz = parse_size_token(seq[3], 'Z') or 1

    voxel_tokens = seq[4:4 + sx * sy * sz]
    block_ids = np.array([token_to_block_id(t) for t in voxel_tokens], dtype=np.int32).reshape((sx, sy, sz))
    block_data = np.zeros_like(block_ids, dtype=np.int32)
    return Structure(block_ids=block_ids, block_data=block_data, source_path='generated')

In [20]:
generated_sequences = []
generated_structures = []

target_samples = 4
attempts = 0
max_attempts = 40

while len(generated_sequences) < target_samples and attempts < max_attempts:
    attempts += 1
    seq = generate_structure_tokens(model, temperature=1.0, top_k=24)

    # novelty filter: skip exact training sequences
    if tuple(seq) in train_hashes:
        continue

    generated_sequences.append(seq)
    generated_structures.append(decode_tokens_to_structure(seq))

print('generated novel samples:', len(generated_structures), 'in attempts:', attempts)

fig = plt.figure(figsize=(16, 8))
for i, st in enumerate(generated_structures, start=1):
    ax = fig.add_subplot(2, 2, i, projection='3d')
    render_schem(st, ax=ax, show=False, crop_non_air=True, max_dim=12)
    ax.set_title(f'Generated {i} | shape={st.shape}')

plt.tight_layout()
plt.show()

AcceleratorError: CUDA error: device-side assert triggered
Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


## 5) Compare generated distribution to training distribution

In [None]:
real_counter = Counter()
for seq in train_sequences:
    for tok in seq:
        if isinstance(tok, int) and itos.get(tok, '').startswith('BID_'):
            real_counter[tok] += 1

gen_counter = Counter()
for seq in generated_sequences:
    for tok in seq:
        if isinstance(tok, int) and itos.get(tok, '').startswith('BID_'):
            gen_counter[tok] += 1

keys = sorted(set(real_counter.keys()) | set(gen_counter.keys()))
eps = 1e-9

real = np.array([real_counter.get(k, 0) for k in keys], dtype=np.float64)
gen = np.array([gen_counter.get(k, 0) for k in keys], dtype=np.float64)

real = (real + eps) / (real.sum() + eps * len(real))
gen = (gen + eps) / (gen.sum() + eps * len(gen))

kl_real_gen = float(np.sum(real * np.log(real / gen)))
print('KL(real || generated):', round(kl_real_gen, 4))

top_real = [k for k, _ in real_counter.most_common(12)]
labels = [itos[k] for k in top_real]
rv = [real_counter[k] for k in top_real]
gv = [gen_counter.get(k, 0) for k in top_real]

x = np.arange(len(labels))
w = 0.42
plt.figure(figsize=(12, 4))
plt.bar(x - w / 2, rv, width=w, label='real')
plt.bar(x + w / 2, gv, width=w, label='generated')
plt.xticks(x, labels, rotation=45)
plt.title('Top block-token counts: training vs generated')
plt.legend()
plt.tight_layout()
plt.show()

## Notes

This model is genuinely autoregressive and attention-based. It can produce variable-size outputs and novel samples, but quality depends on training scale.

Recommended upgrades:
- train longer and on more schematics
- include `block_data` tokens (state/metadata)
- add coordinate factorization or hierarchical generation
- add validity constraints (e.g., support/physics-inspired checks)