# Large PyG Graph Generator (Scalable)

This notebook trains a larger torch_geometric model that:
- encodes full structure graphs with attention (`TransformerConv`)
- learns a latent distribution (VAE-style)
- autoregressively decodes block-token sequences
- samples **new** graphs not memorized from training

In [2]:
from pathlib import Path
from collections import Counter
import random
import tqdm

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch_geometric.data import Batch

from blockgen.utils.graph_data import SchematicGraphDataset, NODE_BLOCK, structure_to_pyg_data
from blockgen.utils.data import Structure
from blockgen.models import LargePyGGraphGenerator, LargePyGGraphGeneratorConfig
from blockgen.renderer.render import render_schem

SEED = 2026
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device:', device)
print('torch:', torch.__version__)

device: cuda
torch: 2.10.0+cu128


## 1) Load full dataset as graphs

In [3]:
repo_root = Path.cwd()
if not (repo_root / 'data' / 'raw').exists():
    repo_root = repo_root.parent

raw_dir = repo_root / 'data' / 'raw'
all_paths = sorted(raw_dir.glob('*.schematic'))
print('total schematics found:', len(all_paths))

# Use the entire dataset by default.
paths = [str(p) for p in all_paths]

dataset = SchematicGraphDataset(
    paths,
    include_air=False,
    crop_non_air=True,
    max_dim=20,
)

print('graph dataset size:', len(dataset))

total schematics found: 10963
graph dataset size: 10963


## 2) Build block-token vocabulary from all graphs

In [4]:
special_tokens = ['<PAD>', '<BOS>', '<UNK>']
block_counter = Counter()
num_blocks_list = []

for i in range(len(dataset)):
    g = dataset[i]
    block_ids = g.block_id[g.node_type == NODE_BLOCK].tolist()
    block_counter.update(int(x) for x in block_ids)
    num_blocks_list.append(len(block_ids))

observed_block_ids = sorted(block_counter.keys())

stoi = {tok: i for i, tok in enumerate(special_tokens)}
for bid in observed_block_ids:
    stoi[f'BID_{int(bid)}'] = len(stoi)

itos = {v: k for k, v in stoi.items()}

pad_idx = stoi['<PAD>']
bos_idx = stoi['<BOS>']
unk_idx = stoi['<UNK>']

def block_id_to_token(block_id):
    return stoi.get(f'BID_{int(block_id)}', unk_idx)

def token_to_block_id(token):
    label = itos.get(int(token), '<UNK>')
    if isinstance(label, str) and label.startswith('BID_'):
        return int(label.split('_', 1)[1])
    return 0

num_block_tokens = len(stoi)
max_block_nodes = int(np.percentile(num_blocks_list, 95))
max_block_nodes = max(max_block_nodes, 32)

print('num block tokens:', num_block_tokens)
print('max_block_nodes (95th pct):', max_block_nodes)

BadGzipFile: Not a gzipped file (b'Ra')

## 3) Prepare train/test splits and batch collation

In [None]:
indices = np.arange(len(dataset))
np.random.shuffle(indices)
split = int(0.9 * len(indices))
train_indices = indices[:split]
test_indices = indices[split:]

print('train graphs:', len(train_indices), 'test graphs:', len(test_indices))


def graph_to_block_token_sequence(g):
    block_mask = g.node_type == NODE_BLOCK
    block_ids = g.block_id[block_mask].tolist()
    return [block_id_to_token(bid) for bid in block_ids]


class GraphTokenDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, id_list):
        self.base_dataset = base_dataset
        self.id_list = [int(i) for i in id_list]

    def __len__(self):
        return len(self.id_list)

    def __getitem__(self, idx):
        g = self.base_dataset[self.id_list[idx]]
        tokens = graph_to_block_token_sequence(g)
        n_blocks = len(tokens)

        if n_blocks > max_block_nodes:
            tokens = tokens[:max_block_nodes]
            n_blocks = max_block_nodes

        decoder_input = [bos_idx] + tokens[:-1] if n_blocks > 0 else [bos_idx]
        decoder_target = tokens if n_blocks > 0 else [unk_idx]

        return {
            'graph': g,
            'decoder_input': torch.tensor(decoder_input, dtype=torch.long),
            'decoder_target': torch.tensor(decoder_target, dtype=torch.long),
            'num_blocks': torch.tensor(n_blocks, dtype=torch.long),
            'sequence_tuple': tuple(tokens),
        }


def collate_graph_tokens(batch):
    graphs = [item['graph'] for item in batch]
    batched_graph = Batch.from_data_list(graphs)

    decoder_inputs = [item['decoder_input'] for item in batch]
    decoder_targets = [item['decoder_target'] for item in batch]
    num_blocks = torch.stack([item['num_blocks'] for item in batch], dim=0)
    seq_tuples = [item['sequence_tuple'] for item in batch]

    max_len = max(x.shape[0] for x in decoder_inputs)
    x_pad = torch.full((len(batch), max_len), pad_idx, dtype=torch.long)
    y_pad = torch.full((len(batch), max_len), pad_idx, dtype=torch.long)

    for i, (x, y) in enumerate(zip(decoder_inputs, decoder_targets)):
        n = x.shape[0]
        x_pad[i, :n] = x
        y_pad[i, :n] = y

    return batched_graph, x_pad, y_pad, num_blocks, seq_tuples


train_ds = GraphTokenDataset(dataset, train_indices)
test_ds = GraphTokenDataset(dataset, test_indices)

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=collate_graph_tokens)
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, collate_fn=collate_graph_tokens)

## 4) Train large model

In [None]:
config = LargePyGGraphGeneratorConfig(
    num_block_tokens=num_block_tokens,
    max_block_nodes=max_block_nodes,
    hidden_dim=384,
    latent_dim=192,
    encoder_layers=6,
    decoder_layers=3,
    num_heads=6,
    dropout=0.1,
)

model = LargePyGGraphGenerator(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)


def kl_divergence(mu, logvar):
    return -0.5 * torch.mean(1.0 + logvar - mu.pow(2) - logvar.exp())


def compute_losses(output, token_targets, size_targets, pad_idx, beta=0.02):
    logits = output['token_logits']
    b, t, v = logits.shape

    token_loss = F.cross_entropy(
        logits.reshape(b * t, v),
        token_targets.reshape(b * t),
        ignore_index=pad_idx,
    )

    size_targets = torch.clamp(size_targets, min=0, max=config.max_block_nodes)
    size_loss = F.cross_entropy(output['size_logits'], size_targets)

    kl = kl_divergence(output['mu'], output['logvar'])
    total = token_loss + 0.5 * size_loss + beta * kl

    return total, token_loss, size_loss, kl


@torch.no_grad()
def evaluate(loader):
    model.eval()
    losses = []
    for graph_batch, dec_inp, dec_tgt, num_blocks, _ in loader:
        graph_batch = graph_batch.to(device)
        dec_inp = dec_inp.to(device)
        dec_tgt = dec_tgt.to(device)
        num_blocks = num_blocks.to(device)

        out = model(graph_batch, dec_inp)
        total, _, _, _ = compute_losses(out, dec_tgt, num_blocks, pad_idx)
        losses.append(float(total.item()))

    return float(np.mean(losses)) if losses else 0.0


epochs = 6
for epoch in range(1, epochs + 1):
    model.train()
    total_losses = []

    for graph_batch, dec_inp, dec_tgt, num_blocks, _ in train_loader:
        graph_batch = graph_batch.to(device)
        dec_inp = dec_inp.to(device)
        dec_tgt = dec_tgt.to(device)
        num_blocks = num_blocks.to(device)

        out = model(graph_batch, dec_inp)
        total, token_loss, size_loss, kl = compute_losses(out, dec_tgt, num_blocks, pad_idx)

        optimizer.zero_grad()
        total.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_losses.append(float(total.item()))

    train_loss = float(np.mean(total_losses))
    val_loss = evaluate(test_loader)
    print(f'epoch {epoch:02d} | train={train_loss:.4f} | val={val_loss:.4f}')

## 5) Generate novel graphs from prior

In [None]:
train_token_sequences = set()
for _, _, _, _, seq_tuples in train_loader:
    for seq in seq_tuples:
        train_token_sequences.add(tuple(seq))


def sequence_to_structure(token_sequence):
    block_ids = np.array([token_to_block_id(tok) for tok in token_sequence], dtype=np.int32)
    n = max(1, len(block_ids))

    side = int(np.ceil(n ** (1 / 3)))
    grid = np.zeros((side, side, side), dtype=np.int32)
    grid.reshape(-1)[:n] = block_ids
    block_data = np.zeros_like(grid, dtype=np.int32)
    return Structure(block_ids=grid, block_data=block_data, source_path='generated_prior')


valid_block_tokens = [stoi[f'BID_{int(bid)}'] for bid in observed_block_ids if f'BID_{int(bid)}' in stoi]

novel_structures = []
novel_sequences = []
attempts = 0
max_attempts = 80
target_novel = 4

while len(novel_structures) < target_novel and attempts < max_attempts:
    attempts += 1

    z = torch.randn(1, config.latent_dim, device=device)
    sampled_n = int(model.sample_num_blocks(z, temperature=1.0, min_blocks=8).item())
    sampled_n = max(8, min(sampled_n, config.max_block_nodes))

    token_seq = model.sample_block_tokens(
        z,
        bos_token_id=bos_idx,
        num_tokens=sampled_n,
        valid_token_ids=valid_block_tokens,
        temperature=1.0,
        top_k=32,
    )

    if tuple(token_seq) in train_token_sequences:
        continue

    novel_sequences.append(tuple(token_seq))
    novel_structures.append(sequence_to_structure(token_seq))

print('novel graphs generated:', len(novel_structures), 'after attempts:', attempts)

fig = plt.figure(figsize=(16, 8))
for i, st in enumerate(novel_structures, start=1):
    ax = fig.add_subplot(2, 2, i, projection='3d')
    render_schem(st, ax=ax, show=False, crop_non_air=True, max_dim=20)
    ax.set_title(f'Novel Sample {i} | shape={st.shape}')

plt.tight_layout()
plt.show()

## 6) Compare generated vs training token distribution

In [None]:
real_counter = Counter()
for seq in train_token_sequences:
    real_counter.update(seq)

gen_counter = Counter()
for seq in novel_sequences:
    gen_counter.update(seq)

keys = sorted(set(real_counter.keys()) | set(gen_counter.keys()))
eps = 1e-8

real = np.array([real_counter.get(k, 0) for k in keys], dtype=np.float64)
gen = np.array([gen_counter.get(k, 0) for k in keys], dtype=np.float64)

real = (real + eps) / (real.sum() + eps * len(real))
gen = (gen + eps) / (gen.sum() + eps * len(gen))

kl = float(np.sum(real * np.log(real / gen)))
print('KL(train || generated):', round(kl, 4))

top = [k for k, _ in real_counter.most_common(12)]
labels = [itos[k] for k in top]
rv = [real_counter[k] for k in top]
gv = [gen_counter.get(k, 0) for k in top]

x = np.arange(len(top))
w = 0.42
plt.figure(figsize=(12, 4))
plt.bar(x - w/2, rv, width=w, label='train')
plt.bar(x + w/2, gv, width=w, label='generated')
plt.xticks(x, labels, rotation=45)
plt.title('Top block-token counts: train vs generated')
plt.legend()
plt.tight_layout()
plt.show()

## Notes

Scaling knobs:
- Increase `max_dim` in dataset
- Increase `hidden_dim`, `encoder_layers`, `num_heads`
- Train for more epochs
- Increase batch size if GPU memory allows

This model is graph-based, attention-based, and generates novel samples by sampling from the latent prior.