In [None]:
import concurrent.futures
import os
import xml.etree.ElementTree as ET
import requests

def fetch_xml(xml_url):
    response = requests.get(xml_url)
    if response.status_code != 200:
        print(f"[ERROR] Failed to fetch XML: {response.status_code}")
        return None
    return response.content

def parse_xml(xml_content, file_ext, prefix_filter):
    root = ET.fromstring(xml_content)
    namespace = {"s3": "http://doc.s3.amazonaws.com/2006-03-01"}
    base_url = "https://storage.googleapis.com/quickdraw_dataset/"
    file_urls = []

    for content in root.findall(".//s3:Contents", namespace):
        key = content.find("s3:Key", namespace).text
        if prefix_filter and not key.startswith(prefix_filter):
            continue
        if key.endswith(f".{file_ext}"):
            file_urls.append(base_url + key)
    return file_urls

def download_file(file_url, download_folder):
    file_path = os.path.join(download_folder, os.path.basename(file_url))
    if os.path.exists(file_path):
        print(f"[SKIP] Already exists: {file_path}")
        return
    print(f"[DOWNLOAD] {file_url}")
    response = requests.get(file_url)
    if response.status_code == 200:
        with open(file_path, "wb") as f:
            f.write(response.content)
    else:
        print(f"[FAIL] Could not download: {file_url} - Status code: {response.status_code}")

def download_quickdraw_files(xml_url, download_folder, file_type="npy", prefix_filter=""):
    if not os.path.exists(download_folder):
        os.makedirs(download_folder)

    xml_content = fetch_xml(xml_url)
    if xml_content is None:
        return

    file_urls = parse_xml(xml_content, file_ext=file_type, prefix_filter=prefix_filter)

    print(f"[INFO] Found {len(file_urls)} .{file_type} files. Downloading to '{download_folder}'")
    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        executor.map(lambda url: download_file(url, download_folder), file_urls)


download_quickdraw_files(
    xml_url="https://storage.googleapis.com/quickdraw_dataset/",
    download_folder="image_folder",
    file_type="npy",
    prefix_filter=""
)


download_quickdraw_files(
    xml_url="https://storage.googleapis.com/quickdraw_dataset?prefix=sketchrnn/",
    download_folder="strokes_data",
    file_type="npz",
    prefix_filter="sketchrnn/"
)


In [1]:
import os
import numpy as np
import tensorflow as tf
from sklearn.utils import shuffle
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input, LSTM, Dense, Dropout, Bidirectional,
    Conv2D, MaxPooling2D, GlobalAveragePooling2D,
    Concatenate, BatchNormalization
)
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau


MAX_SEQ_LEN = 130
STROKE_FEATURES = 3
IMG_HEIGHT, IMG_WIDTH = 28, 28
IMG_CHANNELS = 1
NUM_CLASSES = 345
SAMPLES_PER_CLASS = 5000
DATA_DIR_STROKES = "strokes_data"
DATA_DIR_IMAGES = "image_folder"
VALIDATION_SPLIT = 0.1
BATCH_SIZE = 128
EPOCHS = 20


tf.keras.mixed_precision.set_global_policy('mixed_float16')


def preprocess_stroke(stroke, max_len=MAX_SEQ_LEN):
    stroke = stroke.astype(np.float32)
    stroke[:, 0] = np.cumsum(stroke[:, 0])
    stroke[:, 1] = np.cumsum(stroke[:, 1])
    stroke[:, 0] -= stroke[:, 0].mean()
    stroke[:, 1] -= stroke[:, 1].mean()

    if len(stroke) > 0:
        max_coord = max(np.abs(stroke[:, 0]).max(), np.abs(stroke[:, 1]).max())
        if max_coord > 0:
            stroke[:, 0] *= (100.0 / max_coord)
            stroke[:, 1] *= (100.0 / max_coord)

    if len(stroke) > max_len:
        return stroke[:max_len]
    pad = np.zeros((max_len - len(stroke), STROKE_FEATURES), dtype=np.float32)
    return np.vstack([stroke, pad])


def load_hybrid_data(N=SAMPLES_PER_CLASS):
    img_files_full = sorted(f for f in os.listdir(DATA_DIR_IMAGES) if f.endswith(".npy"))
    stroke_files_full = sorted(f for f in os.listdir(DATA_DIR_STROKES) if f.endswith(".npz"))

    img_names = {os.path.splitext(f)[0] for f in img_files_full}
    stroke_names = {os.path.splitext(f)[0] for f in stroke_files_full}
    common = sorted(img_names & stroke_names)

    assert len(common) >= NUM_CLASSES, f"Only {len(common)} classes found, need {NUM_CLASSES}"
    common = common[:NUM_CLASSES]

    X_img_list, X_str_list, y_list = [], [], []
    for idx, cls in enumerate(common):
        img_arr = np.load(os.path.join(DATA_DIR_IMAGES, f"{cls}.npy"), allow_pickle=True, encoding="latin1")[:N]
        img_arr = img_arr.reshape(-1, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS).astype("float16") / 255.0
        X_img_list.append(img_arr)

        data = np.load(os.path.join(DATA_DIR_STROKES, f"{cls}.npz"), allow_pickle=True, encoding="latin1")
        strokes = data["train"][:N]
        proc = np.stack([preprocess_stroke(s) for s in strokes], axis=0).astype("float16")
        X_str_list.append(proc)

        y_list.append(np.full((N,), idx, dtype=np.int32))

    X_img = np.concatenate(X_img_list, axis=0)
    X_str = np.concatenate(X_str_list, axis=0)
    y = np.concatenate(y_list, axis=0)
    X_img, X_str, y = shuffle(X_img, X_str, y, random_state=42)

    y_cat = to_categorical(y, num_classes=NUM_CLASSES)
    return (X_str, X_img), y_cat


def tf_dataset(X_img, X_str, y, batch_size):
    ds = tf.data.Dataset.from_tensor_slices(((X_str, X_img), y))
    ds = ds.shuffle(10000).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds


def build_hybrid_model(num_classes):
    inp_str = Input(shape=(MAX_SEQ_LEN, STROKE_FEATURES), name="stroke_input")
    x = Bidirectional(LSTM(128, return_sequences=True))(inp_str)
    x = Dropout(0.3)(x)
    x = Bidirectional(LSTM(128, return_sequences=True))(x)
    x = Dropout(0.3)(x)
    x = Bidirectional(LSTM(64))(x)
    x = Dense(128, activation="relu")(x)

    inp_img = Input(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), name="image_input")
    y = Conv2D(32, 3, activation="relu", padding="same")(inp_img)
    y = MaxPooling2D()(y)
    y = Conv2D(64, 3, activation="relu", padding="same")(y)
    y = BatchNormalization()(y)
    y = MaxPooling2D()(y)
    y = Conv2D(128, 3, activation="relu", padding="same")(y)
    y = BatchNormalization()(y)
    y = MaxPooling2D()(y)
    y = Conv2D(256, 3, activation="relu", padding="same")(y)
    y = GlobalAveragePooling2D()(y)
    y = Dense(128, activation="relu")(y)

    merged = Concatenate()([x, y])
    merged = Dropout(0.5)(merged)
    merged = Dense(256, activation="relu")(merged)
    merged = Dropout(0.3)(merged)
    out = Dense(num_classes, activation="softmax", dtype="float32")(merged) 

    return Model(inputs=[inp_str, inp_img], outputs=out, name="hybrid_model")




In [2]:
model = build_hybrid_model(num_classes=NUM_CLASSES)
model.compile(
    optimizer="adam",
    loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
    metrics=["accuracy"]
)

model.summary()

In [None]:
(X_str, X_img), y = load_hybrid_data(N=SAMPLES_PER_CLASS)
NUM_CLASSES = y.shape[1]
total = X_img.shape[0]
split = int((1 - VALIDATION_SPLIT) * total)

X_str_train, X_str_val = X_str[:split], X_str[split:]
X_img_train, X_img_val = X_img[:split], X_img[split:]
y_train, y_val = y[:split], y[split:]

train_ds = tf_dataset(X_img_train, X_str_train, y_train, batch_size=BATCH_SIZE)
val_ds = tf_dataset(X_img_val, X_str_val, y_val, batch_size=BATCH_SIZE)


callbacks = [
    ModelCheckpoint("best_model_345_classes.keras", monitor="val_accuracy", save_best_only=True),
    EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS,
    callbacks=callbacks,
)
