Dataset

In [1]:
import gzip
import shutil

with open("bird.ndjson", 'rb') as f_in:
    with gzip.open("bird.ndjson.gz", 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)




Train

In [None]:
# ✅ Clean MSE-based Transformer Decoder for Stroke Completion

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gzip, json
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# ---------- Dataset (first half, second half) ----------

def filter_good_drawings(drawing, min_strokes=3, max_strokes=10, min_points=30, max_points=150, closed_threshold=0.15):
    total_points = 0
    coords = []
    for stroke in drawing:
        xs, ys = stroke
        if len(xs) == 0: continue
        for x, y in zip(xs, ys):
            coords.append((x, y))
        total_points += len(xs)
    if not (min_points <= total_points <= max_points): return False
    if not (min_strokes <= len(drawing) <= max_strokes): return False
    start = np.array(coords[0])
    end = np.array(coords[-1])
    dist = np.linalg.norm(start - end)
    diagonal = np.linalg.norm(np.ptp(np.array(coords), axis=0)) + 1e-5
    if (dist / diagonal) > closed_threshold: return False
    return True

class QuickDrawDataset(Dataset):
    def __init__(self, file_path, max_len=200, limit=5000):
        self.samples = []
        with gzip.open(file_path, 'rt', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                if not data.get('recognized', False): continue
                if not filter_good_drawings(data['drawing']): continue
                drawing = data['drawing']
                seq, prev_x, prev_y = [], 0, 0
                for stroke in drawing:
                    xs, ys = stroke
                    for i in range(len(xs)):
                        dx, dy = xs[i] - prev_x, ys[i] - prev_y
                        pen = 0 if i < len(xs)-1 else 1
                        seq.append([dx, dy, pen])
                        prev_x, prev_y = xs[i], ys[i]
                if 20 < len(seq) < max_len:
                    seq = np.array(seq, dtype=np.float32)
                    seq[:, 0] = (seq[:, 0] - np.mean(seq[:, 0])) / (np.std(seq[:, 0]) + 1e-5)
                    seq[:, 1] = (seq[:, 1] - np.mean(seq[:, 1])) / (np.std(seq[:, 1]) + 1e-5)
                    mid = len(seq) // 2
                    first = np.zeros((max_len, 3), dtype=np.float32)
                    second = np.zeros((max_len, 3), dtype=np.float32)
                    first[:mid] = seq[:mid]
                    second[:len(seq)-mid] = seq[mid:]
                    self.samples.append((first, second))
                    if len(self.samples) >= limit: break

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        return torch.tensor(self.samples[idx][0]), torch.tensor(self.samples[idx][1])

# ---------- Positional Encoding ----------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

# ---------- Transformer Decoder (MSE version) ----------
class TransformerDecoderMSE(nn.Module):
    def __init__(self, input_dim=3, d_model=128, nhead=4, num_layers=4, ff_dim=256):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, ff_dim)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_layer = nn.Linear(d_model, 3)  # dx, dy, pen

    def forward(self, memory, tgt):
        B, T, _ = tgt.size()
        tgt_emb = self.input_proj(tgt)
        tgt_emb = self.pos_encoder(tgt_emb)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(T).to(tgt.device)
        out = self.transformer(tgt_emb.transpose(0,1), memory.transpose(0,1), tgt_mask=tgt_mask)
        return self.output_layer(out.transpose(0,1))

# ---------- MSE Loss ----------
def mse_loss(pred, target, pen_weight=0.01):
    l2_loss = F.mse_loss(pred[:, :, :2], target[:, :, :2])
    pen_loss = F.binary_cross_entropy_with_logits(pred[:, :, 2], target[:, :, 2])
    return l2_loss + pen_weight * pen_loss

# ---------- Plotting ----------
def stroke_to_points(seq):
    if torch.is_tensor(seq): seq = seq.cpu().numpy()
    x, y, points = 0.0, 0.0, []
    for dx, dy, p in seq:
        x += dx
        y += dy
        points.append(None if p >= 0.5 else (x, y))
    return points

def plot_stroke(points, ax, title=""):
    x, y = [], []
    for pt in points:
        if pt is None:
            if x: ax.plot(x, y, color='black'); x, y = [], []
        else:
            x.append(pt[0]); y.append(pt[1])
    if x: ax.plot(x, y, color='black')
    ax.set_title(title)
    ax.axis("equal")
    ax.invert_yaxis()

def visualize_completion_sample(model, dataset):
    import random
    model.eval()
    device = next(model.parameters()).device
    first_half, second_half = dataset[random.randint(0, len(dataset)-1)]
    first_half = first_half.unsqueeze(0).to(device)
    second_half = second_half.unsqueeze(0).to(device)
    memory = model.input_proj(first_half)
    memory = model.pos_encoder(memory)
    pred = model(memory, second_half[:, :-1, :])
    pred_points = stroke_to_points(pred[0].detach())
    input_points = stroke_to_points(first_half[0])
    gt_points = stroke_to_points(torch.cat([first_half[0], second_half[0]], dim=0))
    completed_points = input_points + pred_points

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    plot_stroke(gt_points, axs[0], "Ground Truth")
    plot_stroke(input_points, axs[1], "Input (First Half)")
    plot_stroke(completed_points, axs[2], "Model Completion")
    plt.tight_layout()
    plt.show()

# ---------- Train Loop ----------
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset = QuickDrawDataset("bird.ndjson.gz")
    loader = DataLoader(dataset, batch_size=32, shuffle=True)
    model = TransformerDecoderMSE().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

    for epoch in range(50):
        model.train()
        for first_half, second_half in loader:
            first_half, second_half = first_half.to(device), second_half.to(device)
            memory = model.input_proj(first_half)
            memory = model.pos_encoder(memory)
            output = model(memory, second_half[:, :-1, :])
            loss = mse_loss(output, second_half[:, 1:, :])
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

        print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
        visualize_completion_sample(model, dataset)

In [7]:
# 假设你用的是 QuickDrawDataset("bird.ndjson.gz")
from torch.utils.data import Subset
import torch

# 固定随机种子保证每次选的一样
torch.manual_seed(42)

# 加载完整数据集
full_dataset = QuickDrawDataset("bird.ndjson.gz")

# 随机抽取 100 个样本
indices = torch.randperm(len(full_dataset))[:100]
test_subset = Subset(full_dataset, indices)

# 保存到本地（推荐用 .pt 文件）
torch.save(indices, "test_indices.pt")


In [None]:
from torch.utils.data import Subset

# 加载完整数据集
full_dataset = QuickDrawDataset("bird.ndjson.gz")

# 加载固定 index（确保两个模型都用一样的 subset）
test_indices = torch.load("test_indices.pt")
test_dataset = Subset(full_dataset, test_indices)


Test

In [8]:
# ✅ Clean MSE-based Transformer Decoder for Stroke Completion

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gzip, json
from torch.utils.data import Dataset, DataLoader, Subset
import matplotlib.pyplot as plt

# ---------- Dataset (first half, second half) ----------

def filter_good_drawings(drawing, min_strokes=3, max_strokes=10, min_points=30, max_points=150, closed_threshold=0.15):
    total_points = 0
    coords = []
    for stroke in drawing:
        xs, ys = stroke
        if len(xs) == 0: continue
        for x, y in zip(xs, ys):
            coords.append((x, y))
        total_points += len(xs)
    if not (min_points <= total_points <= max_points): return False
    if not (min_strokes <= len(drawing) <= max_strokes): return False
    start = np.array(coords[0])
    end = np.array(coords[-1])
    dist = np.linalg.norm(start - end)
    diagonal = np.linalg.norm(np.ptp(np.array(coords), axis=0)) + 1e-5
    if (dist / diagonal) > closed_threshold: return False
    return True

class QuickDrawDataset(Dataset):
    def __init__(self, file_path, max_len=200, limit=5000):
        self.samples = []
        with gzip.open(file_path, 'rt', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line)
                if not data.get('recognized', False): continue
                if not filter_good_drawings(data['drawing']): continue
                drawing = data['drawing']
                seq, prev_x, prev_y = [], 0, 0
                for stroke in drawing:
                    xs, ys = stroke
                    for i in range(len(xs)):
                        dx, dy = xs[i] - prev_x, ys[i] - prev_y
                        pen = 0 if i < len(xs)-1 else 1
                        seq.append([dx, dy, pen])
                        prev_x, prev_y = xs[i], ys[i]
                if 20 < len(seq) < max_len:
                    seq = np.array(seq, dtype=np.float32)
                    seq[:, 0] = (seq[:, 0] - np.mean(seq[:, 0])) / (np.std(seq[:, 0]) + 1e-5)
                    seq[:, 1] = (seq[:, 1] - np.mean(seq[:, 1])) / (np.std(seq[:, 1]) + 1e-5)
                    mid = len(seq) // 2
                    first = np.zeros((max_len, 3), dtype=np.float32)
                    second = np.zeros((max_len, 3), dtype=np.float32)
                    first[:mid] = seq[:mid]
                    second[:len(seq)-mid] = seq[mid:]
                    self.samples.append((first, second))
                    if len(self.samples) >= limit: break

    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        return torch.tensor(self.samples[idx][0]), torch.tensor(self.samples[idx][1])

# ---------- Positional Encoding ----------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)].to(x.device)

# ---------- Transformer Decoder (MSE version) ----------
class TransformerDecoderMSE(nn.Module):
    def __init__(self, input_dim=3, d_model=128, nhead=4, num_layers=4, ff_dim=256):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, ff_dim)
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_layer = nn.Linear(d_model, 3)  # dx, dy, pen

    def forward(self, memory, tgt):
        B, T, _ = tgt.size()
        tgt_emb = self.input_proj(tgt)
        tgt_emb = self.pos_encoder(tgt_emb)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(T).to(tgt.device)
        out = self.transformer(tgt_emb.transpose(0,1), memory.transpose(0,1), tgt_mask=tgt_mask)
        return self.output_layer(out.transpose(0,1))

# ---------- MSE Loss ----------
def mse_loss(pred, target, pen_weight=0.01):
    l2_loss = F.mse_loss(pred[:, :, :2], target[:, :, :2])
    pen_loss = F.binary_cross_entropy_with_logits(pred[:, :, 2], target[:, :, 2])
    return l2_loss + pen_weight * pen_loss

# ---------- Plotting ----------
def stroke_to_points(seq):
    if torch.is_tensor(seq): seq = seq.cpu().numpy()
    x, y, points = 0.0, 0.0, []
    for dx, dy, p in seq:
        x += dx
        y += dy
        points.append(None if p >= 0.5 else (x, y))
    return points

def plot_stroke(points, ax, title=""):
    x, y = [], []
    for pt in points:
        if pt is None:
            if x: ax.plot(x, y, color='black'); x, y = [], []
        else:
            x.append(pt[0]); y.append(pt[1])
    if x: ax.plot(x, y, color='black')
    ax.set_title(title)
    ax.axis("equal")
    ax.invert_yaxis()

def visualize_mdn_vs_mse(model_mdn, model_mse, dataset):
    device = next(model_mdn.parameters()).device
    for i in range(len(dataset)):
        first_half, second_half = dataset[i]
        first_half = first_half.unsqueeze(0).to(device)
        second_half = second_half.unsqueeze(0).to(device)

        memory = model_mdn.input_proj(first_half)
        memory = model_mdn.pos_encoder(memory)

        # MDN sampling (you need to define this separately)
        mdn_points = sample_completion(model_mdn, memory)

        # MSE deterministic
        mse_pred = model_mse(memory, second_half[:, :-1, :])
        mse_points = stroke_to_points(mse_pred[0].detach())

        input_points = stroke_to_points(first_half[0])
        gt_points = stroke_to_points(torch.cat([first_half[0], second_half[0]], dim=0))

        fig, axs = plt.subplots(1, 4, figsize=(16, 4))
        plot_stroke(gt_points, axs[0], "Ground Truth")
        plot_stroke(input_points, axs[1], "Input")
        plot_stroke(mdn_points, axs[2], "MDN Completion")
        plot_stroke(mse_points, axs[3], "MSE Completion")
        plt.tight_layout()
        print(f"Sample {i+1}")
        plt.show()

# (sample_completion should be implemented if you're using this block)
# (you can now call visualize_mdn_vs_mse(model_mdn, model_mse, test_dataset))


In [None]:
def visualize_mse_on_test(model, test_dataset):
    device = next(model.parameters()).device
    model.eval()
    for i in range(len(test_dataset)):
        first_half, second_half = test_dataset[i]
        first_half = first_half.unsqueeze(0).to(device)
        second_half = second_half.unsqueeze(0).to(device)

        memory = model.input_proj(first_half)
        memory = model.pos_encoder(memory)

        # Predict second half deterministically
        pred = model(memory, second_half[:, :-1, :])
        pred_points = stroke_to_points(pred[0].detach())
        input_points = stroke_to_points(first_half[0])
        gt_points = stroke_to_points(torch.cat([first_half[0], second_half[0]], dim=0))
        completed_points = input_points + pred_points

        # Plot
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        plot_stroke(gt_points, axs[0], "Ground Truth")
        plot_stroke(input_points, axs[1], "Input (First Half)")
        plot_stroke(completed_points, axs[2], "MSE Completion")
        plt.suptitle(f"Sample {i+1}", fontsize=16)
        plt.tight_layout()
        plt.show()

# 如果你还没 split，可以先选一个固定的测试集
torch.manual_seed(99)
test_indices = torch.randperm(len(dataset))[:100]
test_dataset = torch.utils.data.Subset(dataset, test_indices)
torch.save(test_indices, "test_indices.pt")
# 然后可视化所有
visualize_mse_on_test(model, test_dataset)
