Dataset

In [1]:
import gzip
import shutil

with open("bird.ndjson", 'rb') as f_in:
    with gzip.open("bird.ndjson.gz", 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)




Train

In [None]:
# ✅ Transformer Decoder + MDN for Stroke Completion (predicting remaining points)

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gzip, json
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# ---------- Dataset (return split: first_half, second_half) ----------

def filter_good_drawings(drawing, min_strokes=3, max_strokes=10, min_points=30, max_points=150, closed_threshold=0.15):
    total_points = 0
    coords = []
    for stroke in drawing:
        xs, ys = stroke
        if len(xs) == 0: continue
        for x, y in zip(xs, ys):
            coords.append((x, y))
        total_points += len(xs)

    if not (min_points <= total_points <= max_points):
        return False
    if not (min_strokes <= len(drawing) <= max_strokes):
        return False

    # Check for closure: distance between start and end
    start = np.array(coords[0])
    end = np.array(coords[-1])
    dist = np.linalg.norm(start - end)
    diagonal = np.linalg.norm(np.ptp(np.array(coords), axis=0)) + 1e-5
    closedness = dist / diagonal
    if closedness > closed_threshold:
        return False

    return True


class QuickDrawDataset(Dataset):
    def __init__(self, file_path, max_len=200, limit=5000):
        self.samples = []
        with gzip.open(file_path, 'rt', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                drawing = data['drawing']
                if not data.get('recognized', False):
                    continue
                if not filter_good_drawings(drawing):
                    continue
                drawing = data['drawing']
                seq = []
                prev_x, prev_y = 0, 0
                for stroke in drawing:
                    xs, ys = stroke
                    for i in range(len(xs)):
                        dx = xs[i] - prev_x
                        dy = ys[i] - prev_y
                        pen = 0 if i < len(xs) - 1 else 1
                        seq.append([dx, dy, pen])
                        prev_x, prev_y = xs[i], ys[i]
                if 20 < len(seq) < max_len:
                    seq = np.array(seq, dtype=np.float32)
                    seq[:, 0] = (seq[:, 0] - np.mean(seq[:, 0])) / (np.std(seq[:, 0]) + 1e-5)
                    seq[:, 1] = (seq[:, 1] - np.mean(seq[:, 1])) / (np.std(seq[:, 1]) + 1e-5)
                    mid = len(seq) // 2
                    first = np.zeros((max_len, 3), dtype=np.float32)
                    second = np.zeros((max_len, 3), dtype=np.float32)
                    first[:mid] = seq[:mid]
                    second[:len(seq)-mid] = seq[mid:]
                    self.samples.append((first, second))
                    if len(self.samples) >= limit:
                        break

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return torch.tensor(self.samples[idx][0]), torch.tensor(self.samples[idx][1])

# ---------- Positional Encoding ----------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

# ---------- Transformer Decoder + MDN ----------
class TransformerDecoderMDN(nn.Module):
    def __init__(self, input_dim=3, d_model=128, nhead=4, num_layers=4, ff_dim=256, num_mixtures=20):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=ff_dim)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(d_model, num_mixtures * 6 + 1)
        self.num_mixtures = num_mixtures

    def forward(self, memory, tgt):
        B, T, _ = tgt.size()
        tgt_emb = self.input_proj(tgt)
        tgt_emb = self.pos_encoder(tgt_emb)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(T).to(tgt.device)
        out = self.transformer(tgt=tgt_emb.transpose(0, 1), memory=memory.transpose(0, 1), tgt_mask=tgt_mask)
        return self.output_layer(out.transpose(0, 1))

# ---------- MDN helper functions ----------



def split_mdn_params(output, num_mixtures):
    eos_hat = output[:, :, 0:1]
    rest = output[:, :, 1:]
    pi_hat, mu1, mu2, sigma1, sigma2, rho = torch.chunk(rest, 6, dim=-1)
    pi = F.softmax(pi_hat, dim=-1)
    sigma1 = torch.clamp(torch.exp(sigma1), 1e-3, 1.0)
    sigma2 = torch.clamp(torch.exp(sigma2), 1e-3, 1.0)
    rho = torch.tanh(rho)
    eos = torch.sigmoid(eos_hat)
    return pi, mu1, mu2, sigma1, sigma2, rho, eos

def mdn_loss(y_true, mdn_params, pen_loss_weight=0.05, sigma_weight=0.01, entropy_weight=0.01):
    x1 = y_true[:, :, 0:1]
    x2 = y_true[:, :, 1:2]
    eos_true = y_true[:, :, 2:3]
    pi, mu1, mu2, sigma1, sigma2, rho, eos = mdn_params

    norm1 = (x1 - mu1) / sigma1
    norm2 = (x2 - mu2) / sigma2
    z = norm1**2 + norm2**2 - 2 * rho * norm1 * norm2
    denom = 2 * (1 - rho**2)
    exponent = -z / denom
    coef = 1 / (2 * np.pi * sigma1 * sigma2 * torch.sqrt(1 - rho**2))
    gauss = coef * torch.exp(exponent)
    prob = torch.sum(pi * gauss, dim=-1, keepdim=True)

    pen_loss = eos_true * torch.log(eos + 1e-8) + (1 - eos_true) * torch.log(1 - eos + 1e-8)
    loss_seq = -torch.log(prob + 1e-8) - pen_loss_weight * pen_loss

    sigma_penalty = sigma_weight * (torch.mean(sigma1) + torch.mean(sigma2))
    entropy = -torch.sum(pi * torch.log(pi + 1e-8), dim=-1, keepdim=True)
    entropy_penalty = -entropy_weight * torch.mean(entropy)

    mu_x_mean = torch.sum(pi * mu1, dim=-1)
    mu_y_mean = torch.sum(pi * mu2, dim=-1)
    true_x = x1.squeeze(-1)
    true_y = x2.squeeze(-1)
    l2_loss = F.mse_loss(mu_x_mean, true_x) + F.mse_loss(mu_y_mean, true_y)

    final_loss = torch.mean(loss_seq) + sigma_penalty + entropy_penalty + 0.05 * l2_loss
    return final_loss


# ---------- Train loop ----------

def sample_completion(model, memory, max_len=100):
    model.eval()
    device = memory.device
    x, y = 0.0, 0.0
    prev = torch.tensor([[[0.0, 0.0, 1.0]]], dtype=torch.float32).to(device)
    seq = [prev.squeeze(0)]
    points = []

    for _ in range(max_len):
        tgt_seq = torch.stack(seq, dim=1).to(device)
        output = model(memory, tgt_seq)
        pi, mu1, mu2, sigma1, sigma2, rho, eos = split_mdn_params(output, model.num_mixtures)

        idx = np.random.choice(model.num_mixtures, p=pi[0, -1].detach().cpu().numpy())
        mean = [mu1[0, -1, idx].item(), mu2[0, -1, idx].item()]
        cov = [[sigma1[0, -1, idx].item()**2, rho[0, -1, idx].item() * sigma1[0, -1, idx].item() * sigma2[0, -1, idx].item()],
               [rho[0, -1, idx].item() * sigma1[0, -1, idx].item() * sigma2[0, -1, idx].item(), sigma2[0, -1, idx].item()**2]]
        dx, dy = np.random.multivariate_normal(mean, cov)
        eos_sample = np.random.binomial(1, eos[0, -1, 0].item())
        x += dx
        y += dy
        points.append(None if eos_sample else (x, y))
        seq.append(torch.tensor([[dx, dy, eos_sample]], dtype=torch.float32).to(device))
        if eos_sample and len(seq) > 10:
            break

    return points

def stroke_to_points(seq):
    """Convert [dx, dy, pen] sequence (Tensor or numpy) to absolute (x, y) points"""
    if torch.is_tensor(seq):
        seq = seq.detach().cpu().numpy()
    x, y = 0.0, 0.0
    points = []
    for dx, dy, p in seq:
        x += dx
        y += dy
        if p < 0.5:
            points.append((x, y))
        else:
            points.append(None)
    return points


def plot_stroke(points, ax, title=""):
    x, y = [], []
    for pt in points:
        if pt is None:
            if x and y:
                ax.plot(x, y, color='black')
                x, y = [], []
        else:
            x.append(pt[0])
            y.append(pt[1])
    if x and y:
        ax.plot(x, y, color='black')
    ax.set_title(title)
    ax.axis("equal")
    ax.invert_yaxis()

def visualize_completion_sample(model, dataset):
    import random
    model.eval()
    i = random.randint(0, len(dataset)-1)
    first_half, second_half = dataset[i]
    first_half = first_half.unsqueeze(0).to(next(model.parameters()).device)

    # Project first_half to memory
    memory = model.input_proj(first_half)
    memory = model.pos_encoder(memory)

    # Sample predicted second half
    pred_points = sample_completion(model, memory)

    # Convert ground truth, input, and prediction to full stroke sequences
    input_points = stroke_to_points(first_half[0].detach().cpu())
    gt_points = stroke_to_points(torch.cat([
    first_half[0].detach().cpu(),
    second_half.detach().cpu()
], dim=0))

    # Append predicted points to the input path
    completed_points = []
    for pt in input_points:
        completed_points.append(pt)
    for pt in pred_points:
        completed_points.append(pt)

    # Plot
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    plot_stroke(gt_points, axs[0], "Ground Truth")
    plot_stroke(input_points, axs[1], "Input (First Half)")
    plot_stroke(completed_points, axs[2], "Model Completion (Full)")
    plt.tight_layout()
    plt.show()



if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = QuickDrawDataset("bird.ndjson.gz")
    loader = DataLoader(dataset, batch_size=32, shuffle=True)
    model = TransformerDecoderMDN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(50):
        model.train()
        for first_half, second_half in loader:
            first_half, second_half = first_half.to(device), second_half.to(device)
            memory = model.input_proj(first_half)
            memory = model.pos_encoder(memory)
            output = model(memory, second_half[:, :-1, :])
            loss = mdn_loss(second_half[:, 1:, :], split_mdn_params(output, model.num_mixtures))
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
        visualize_completion_sample(model, dataset)



Test

In [None]:
from torch.utils.data import Subset

test_indices = torch.load("test_indices.pt")
test_dataset = Subset(dataset, test_indices)

def visualize_mdn_on_test(model, test_dataset):
    device = next(model.parameters()).device
    model.eval()
    for i in range(len(test_dataset)):
        first_half, second_half = test_dataset[i]
        first_half = first_half.unsqueeze(0).to(device)

        # encode memory
        memory = model.input_proj(first_half)
        memory = model.pos_encoder(memory)

        # use your existing MDN sampling function
        mdn_points = sample_completion(model, memory)

        # convert ground truth and input
        input_points = stroke_to_points(first_half[0])
        gt_points = stroke_to_points(torch.cat([first_half[0].cpu(), second_half.cpu()], dim=0))
        completed_points = input_points + mdn_points

        # plot
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        plot_stroke(gt_points, axs[0], "Ground Truth")
        plot_stroke(input_points, axs[1], "Input (First Half)")
        plot_stroke(completed_points, axs[2], "MDN Completion")
        plt.suptitle(f"Sample {i+1}", fontsize=16)
        plt.tight_layout()
        plt.show()


visualize_mdn_on_test(model, test_dataset)
