In [1]:
# @title Download .npy & .npz Dataset

import concurrent.futures
import os
import xml.etree.ElementTree as ET

import requests

# Shared class names (without extension, to be extended based on type)
base_classes = [
    "backpack",
    "banana",
    "bat",
    "beard",
    "bicycle",
    "bird",
    "book",
    "bread",
    "bridge",
    "bucket",
    "bush",
    "butterfly",
    "cactus",
    "camel",
    "camera",
    "candle",
    "cow",
    "crab",
    "crown",
    "cup",
    "donut",
    "dumbbell",
    "elbow",
    "eye",
    "fish",
    "flashlight",
    "flip flops",
    "flower",
    "foot",
    "hat",
    "helicopter",
    "hot air balloon",
    "leaf",
    "leg",
    "light bulb",
    "lightning",
    "motorbike",
    "mouth",
    "nail",
    "pencil",
    "pillow",
    "river",
    "school bus",
    "sock",
    "spoon",
    "table",
    "telephone",
    "tooth",
    "tree",
    "umbrella",
]


def fetch_xml(xml_url):
    response = requests.get(xml_url)
    if response.status_code != 200:
        print(f"[ERROR] Failed to fetch XML: {response.status_code}")
        return None
    return response.content


def parse_xml(xml_content, selected_files, prefix_filter):
    root = ET.fromstring(xml_content)
    namespace = {"s3": "http://doc.s3.amazonaws.com/2006-03-01"}
    base_url = "https://storage.googleapis.com/quickdraw_dataset/"
    file_urls = []

    for content in root.findall(".//s3:Contents", namespace):
        key = content.find("s3:Key", namespace).text
        file_name = os.path.basename(key)

        if prefix_filter and not key.startswith(prefix_filter):
            continue
        if file_name in selected_files:
            file_urls.append(base_url + key)
    return file_urls


def download_file(file_url, download_folder):
    file_path = os.path.join(download_folder, os.path.basename(file_url))
    if os.path.exists(file_path):
        print(f"[SKIP] Already exists: {file_path}")
        return
    print(f"[DOWNLOAD] {file_url}")
    response = requests.get(file_url)
    if response.status_code == 200:
        with open(file_path, "wb") as f:
            f.write(response.content)
    else:
        print(
            f"[FAIL] Could not download: {file_url} - Status code: {response.status_code}"
        )


def download_quickdraw_files(
    xml_url, download_folder, file_type="npy", prefix_filter=""
):
    selected_files = [name + f".{file_type}" for name in base_classes]

    if not os.path.exists(download_folder):
        os.makedirs(download_folder)

    xml_content = fetch_xml(xml_url)
    if xml_content is None:
        return

    file_urls = parse_xml(xml_content, selected_files, prefix_filter)

    print(f"[INFO] Downloading {len(file_urls)} files to '{download_folder}'")
    with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
        executor.map(lambda url: download_file(url, download_folder), file_urls)


# Download image .npy files
download_quickdraw_files(
    xml_url="https://storage.googleapis.com/quickdraw_dataset/",
    download_folder="image_folder",
    file_type="npy",
    prefix_filter="",
)

# Download stroke .npz files
download_quickdraw_files(
    xml_url="https://storage.googleapis.com/quickdraw_dataset?prefix=sketchrnn/",
    download_folder="strokes_data",
    file_type="npz",
    prefix_filter="sketchrnn/",
)

[INFO] Downloading 50 files to 'image_folder'
[SKIP] Already exists: image_folder/backpack.npy
[SKIP] Already exists: image_folder/banana.npy
[SKIP] Already exists: image_folder/beard.npy
[SKIP] Already exists: image_folder/bird.npy
[SKIP] Already exists: image_folder/bread.npy
[SKIP] Already exists: image_folder/bridge.npy
[SKIP] Already exists: image_folder/bush.npy
[SKIP] Already exists: image_folder/cactus.npy
[SKIP] Already exists: image_folder/candle.npy
[SKIP] Already exists: image_folder/camel.npy
[SKIP] Already exists: image_folder/bat.npy
[SKIP] Already exists: image_folder/butterfly.npy
[SKIP] Already exists: image_folder/bicycle.npy
[SKIP] Already exists: image_folder/book.npy
[SKIP] Already exists: image_folder/bucket.npy
[SKIP] Already exists: image_folder/camera.npy
[SKIP] Already exists: image_folder/cow.npy
[SKIP] Already exists: image_folder/crab.npy
[SKIP] Already exists: image_folder/cup.npy
[SKIP] Already exists: image_folder/dumbbell.npy
[SKIP] Already exists: ima

In [2]:
# all import necesary
import os
import numpy as np
from tensorflow.keras.utils import to_categorical
from sklearn.utils import shuffle
import tensorflow as tf
from tensorflow.keras.layers import (
    Input,
    LSTM,
    Conv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Dropout,
    Concatenate,
)
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.layers import (
    Bidirectional,
    Input,
    LSTM,
    Conv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Dropout,
    Concatenate,
)
from tensorflow.keras.models import Model

In [3]:
# Global Settings

MAX_SEQ_LEN = 130
STROKE_FEATURES = 3
IMG_HEIGHT, IMG_WIDTH = 28, 28
IMG_CHANNELS = 1
NUM_CLASSES = 50
SAMPLES_PER_CLASS = 5000
DATA_DIR_STROKES = "strokes_data"
DATA_DIR_IMAGES = "image_folder"
VALIDATION_SPLIT = 0.1
BATCH_SIZE = 128
EPOCHS = 20

In [4]:

def preprocess_stroke(stroke, max_len=MAX_SEQ_LEN):
    """
    Improved stroke preprocessing with consistent normalization
    Centers to (0,0) and scales to [-100, 100] range
    """
    stroke = stroke.astype(np.float32)
    
    # Convert to absolute coordinates
    stroke[:, 0] = np.cumsum(stroke[:, 0])
    stroke[:, 1] = np.cumsum(stroke[:, 1])
    
    # Center to (0, 0)
    stroke[:, 0] -= stroke[:, 0].mean()
    stroke[:, 1] -= stroke[:, 1].mean()
    
    # Scale to [-100, 100] range
    if len(stroke) > 0:
        # Find the maximum absolute coordinate value
        max_coord = max(
            np.abs(stroke[:, 0]).max() if len(stroke) > 0 else 1,
            np.abs(stroke[:, 1]).max() if len(stroke) > 0 else 1
        )
        
        # Avoid division by zero
        if max_coord > 0:
            # Scale to [-100, 100] range
            scale_factor = 100.0 / max_coord
            stroke[:, 0] *= scale_factor
            stroke[:, 1] *= scale_factor
    
    # Truncate or pad as before
    if len(stroke) > max_len:
        return stroke[:max_len]
    
    pad = np.zeros((max_len - len(stroke), STROKE_FEATURES), dtype=np.float32)
    return np.vstack([stroke, pad])


In [5]:
def load_hybrid_data(N=SAMPLES_PER_CLASS):
    img_files_full = sorted(
        f for f in os.listdir(DATA_DIR_IMAGES) if f.endswith(".npy")
    )
    stroke_files_full = sorted(
        f for f in os.listdir(DATA_DIR_STROKES) if f.endswith(".npz")
    )
    
    img_names = {os.path.splitext(f)[0] for f in img_files_full}
    stroke_names = {os.path.splitext(f)[0] for f in stroke_files_full}
    common = sorted(img_names & stroke_names)[:NUM_CLASSES]

    X_img_list, X_str_list, y_list = [], [], []
    for idx, cls in enumerate(common):
        img_arr = np.load(
            os.path.join(DATA_DIR_IMAGES, f"{cls}.npy"),
            allow_pickle=True,
            encoding="latin1",
        )[:N]
        img_arr = (
            img_arr.reshape(-1, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS).astype("float32")
            / 255.0
        )
        X_img_list.append(img_arr)

        data = np.load(
            os.path.join(DATA_DIR_STROKES, f"{cls}.npz"),
            allow_pickle=True,
            encoding="latin1",
        )
        strokes = data["train"][:N]
        proc = np.stack([preprocess_stroke(s) for s in strokes], axis=0)
        X_str_list.append(proc)

        y_list.append(
            np.full((N,), idx, dtype=np.int32)
        )  #  putting labels same for each class

    X_img = np.concatenate(X_img_list, axis=0)
    X_str = np.concatenate(X_str_list, axis=0)
    y = np.concatenate(y_list, axis=0)
    X_img, X_str, y = shuffle(
        X_img, X_str, y, random_state=42
    )  #  mix everything randomly

    true_num_classes = len(common)
    y_cat = to_categorical(y, num_classes=true_num_classes)
    return (X_str, X_img), y_cat


In [6]:
def build_hybrid_model():
    # Stroke input branch: captures temporal dependency of pen stokes using  Bidirectional LSTMs
    inp_str = Input(shape=(MAX_SEQ_LEN, STROKE_FEATURES), name="stroke_input")
    x = Bidirectional(LSTM(128, return_sequences=True))(inp_str)
    x = Bidirectional(LSTM(64))(x)
    x = Dense(64, activation="relu")(
        x
    )  # Dense layer to compact learned stroke features

    # Image input branch: processes sketch image using a CNN to extract spatial patterns
    inp_img = Input(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS), name="image_input")
    y = Conv2D(32, 3, activation="relu", padding="same")(inp_img)
    y = MaxPooling2D()(y)
    y = Conv2D(64, 3, activation="relu", padding="same")(y)
    y = MaxPooling2D()(y)
    y = Conv2D(128, 3, activation="relu", padding="same")(y)
    y = MaxPooling2D()(y)
    y = Flatten()(y)
    y = Dense(128, activation="relu")(y)

    # Feature fusion: concatenate outputs from stroke and image branches
    merged = Concatenate()([x, y])
    merged = Dropout(0.5)(merged)  # dropout
    merged = Dense(128, activation="relu")(merged)
    out = Dense(NUM_CLASSES, activation="softmax")(merged)

    return Model(inputs=[inp_str, inp_img], outputs=out, name="hybrid_model")


In [7]:
# load data
(X_str, X_img), y = load_hybrid_data(N=SAMPLES_PER_CLASS)
NUM_CLASSES = y.shape[1]
total = X_img.shape[0]
split = int((1 - VALIDATION_SPLIT) * total)

# trian test split
X_str_train, X_str_val = X_str[:split], X_str[split:]
X_img_train, X_img_val = X_img[:split], X_img[split:]
y_train, y_val = y[:split], y[split:]


In [10]:
from tensorflow.keras.models import load_model
model = load_model("/Users/aryanagarwal/Desktop/cv/project/main_paper/Doodle-vision/inference/models/best_hybrid_model_strokes_scaled.keras")

In [37]:
model.summary()

In [17]:
(24704+6450)/579186

0.05378928358074953

In [12]:
from tensorflow.keras.models import Model
feature_model = Model(inputs=model.input, outputs=model.get_layer("concatenate").output)

In [14]:
X_train_feats = feature_model.predict([X_str_train, X_img_train], batch_size=BATCH_SIZE)
X_val_feats = feature_model.predict([X_str_val, X_img_val], batch_size=BATCH_SIZE)

y_train_labels = np.argmax(y_train, axis=1)
y_val_labels = np.argmax(y_val, axis=1)

from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score



[1m1758/1758[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m176s[0m 100ms/step
[1m196/196[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m20s[0m 103ms/step
Random Forest Validation Accuracy: 0.9366


In [29]:
rf = RandomForestClassifier(n_estimators=30,max_depth=5, random_state=42, n_jobs=-1)
rf.fit(X_train_feats, y_train_labels)
y_pred = rf.predict(X_val_feats)
print("Random Forest Validation Accuracy:", accuracy_score(y_val_labels, y_pred))

y_pred_train = rf.predict(X_train_feats)
print("Random Forest Training Accuracy:", accuracy_score(y_train_labels, y_pred_train))

import joblib
import os
joblib.dump(rf, "rf_model.pkl")
print(f"Model size: {os.path.getsize('rf_model.pkl') / (1024*1024):.2f} MB")

Random Forest Validation Accuracy: 0.7964
Random Forest Training Accuracy: 0.8232533333333333
Model size: 0.86 MB


In [35]:
from sklearn.ensemble import GradientBoostingClassifier
gb_model = GradientBoostingClassifier(
    n_estimators=3,       # reduce number of boosting rounds
    learning_rate=0.1,     # step size
    max_depth=5,           # depth of each tree
    # subsample=0.8,         # stochastic gradient boosting
    random_state=42,
    verbose=True
)

gb_model.fit(X_train_feats, y_train_labels)
y_pred = gb_model.predict(X_val_feats)
print("XGBoost Validation Accuracy:", accuracy_score(y_val_labels, y_pred))

y_pred_train = gb_model.predict(X_train_feats)
print("XGBoost Training Accuracy:", accuracy_score(y_train_labels, y_pred_train))

import joblib
import os
joblib.dump(gb_model, "gb_model.pkl")
print(f"Model size: {os.path.getsize('gb_model.pkl') / (1024*1024):.2f} MB")

      Iter       Train Loss   Remaining Time 
         1           1.2991           12.50m
         2           1.1455            6.26m
         3           1.0171            0.00s
XGBoost Validation Accuracy: 0.83344
XGBoost Training Accuracy: 0.8822
Model size: 0.91 MB


In [16]:
from sklearn.metrics import classification_report
print('train cf')
print(classification_report(y_train_labels, y_pred_train, target_names=base_classes))


print('validation cf')
print(classification_report(y_val_labels, y_pred, target_names=base_classes))

train cf
                 precision    recall  f1-score   support

       backpack       1.00      1.00      1.00      4476
         banana       1.00      1.00      1.00      4484
            bat       1.00      1.00      1.00      4520
          beard       1.00      1.00      1.00      4457
        bicycle       1.00      1.00      1.00      4474
           bird       1.00      1.00      1.00      4496
           book       1.00      1.00      1.00      4521
          bread       1.00      1.00      1.00      4492
         bridge       1.00      1.00      1.00      4509
         bucket       1.00      1.00      1.00      4505
           bush       1.00      1.00      1.00      4524
      butterfly       1.00      1.00      1.00      4491
         cactus       1.00      1.00      1.00      4543
          camel       1.00      1.00      1.00      4482
         camera       1.00      1.00      1.00      4532
         candle       1.00      1.00      1.00      4497
            cow      

In [36]:
# trying knowledge distillation and quantisation

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow_model_optimization.quantization.keras import vitis_quantize

# Build a smaller student model
def build_student_model():
    inp_str = keras.Input(shape=(130, 3), name="stroke_input")
    x = keras.layers.Bidirectional(keras.layers.LSTM(64, return_sequences=True))(inp_str)
    x = keras.layers.Bidirectional(keras.layers.LSTM(32))(x)
    x = keras.layers.Dense(32, activation="relu")(x)

    inp_img = keras.Input(shape=(28, 28, 1), name="image_input")
    y = keras.layers.Conv2D(16, 3, activation="relu", padding="same")(inp_img)
    y = keras.layers.MaxPooling2D()(y)
    y = keras.layers.Conv2D(32, 3, activation="relu", padding="same")(y)
    y = keras.layers.MaxPooling2D()(y)
    y = keras.layers.Flatten()(y)
    y = keras.layers.Dense(64, activation="relu")(y)

    merged = keras.layers.Concatenate()([x, y])
    merged = keras.layers.Dropout(0.3)(merged)
    merged = keras.layers.Dense(64, activation="relu")(merged)
    out = keras.layers.Dense(NUM_CLASSES, activation="softmax")(merged)

    return keras.Model(inputs=[inp_str, inp_img], outputs=out)

student_model = build_student_model()

# Distillation class
class Distiller(keras.Model):
    def __init__(self, student, teacher):
        super().__init__()
        self.student = student
        self.teacher = teacher
        self.teacher.trainable = False

    def compile(self, optimizer, metrics, student_loss_fn, distill_loss_fn, alpha=0.5, temperature=5):
        super().compile(optimizer=optimizer, metrics=metrics)
        self.student_loss_fn = student_loss_fn
        self.distill_loss_fn = distill_loss_fn
        self.alpha = alpha
        self.temperature = temperature

    def train_step(self, data):
        x, y = data
        teacher_preds = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            student_preds = self.student(x, training=True)
            student_loss = self.student_loss_fn(y, student_preds)
            distill_loss = self.distill_loss_fn(
                tf.nn.softmax(teacher_preds / self.temperature, axis=1),
                tf.nn.softmax(student_preds / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distill_loss

        grads = tape.gradient(loss, self.student.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.student.trainable_variables))
        self.compiled_metrics.update_state(y, student_preds)
        return {m.name: m.result() for m in self.metrics}

# Compile distiller
distiller = Distiller(student=student_model, teacher=teacher_model)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=["accuracy"],
    student_loss_fn=keras.losses.CategoricalCrossentropy(),
    distill_loss_fn=keras.losses.KLDivergence(),
    alpha=0.5,
    temperature=5,
)

# Train
distiller.fit(
    [X_str_train, X_img_train],
    y_train,
    validation_data=([X_str_val, X_img_val], y_val),
    batch_size=128,
    epochs=10,
)

# Quantization (optional)
quantizer = vitis_quantize.VitisQuantizer(student_model)
quantized_model = quantizer.quantize_model(calib_dataset=([X_str_train[:200], X_img_train[:200]]))

# Save quantized model
quantized_model.save("quantized_student_model.keras")


In [40]:
pip install tensorflow-model-optimization

Note: you may need to restart the kernel to use updated packages.


In [43]:
# Create converter
converter = tf.lite.TFLiteConverter.from_keras_model(model)

# Allow unsupported ops (like TensorList ops used by LSTM)
converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,
    tf.lite.OpsSet.SELECT_TF_OPS,
]

# Prevent lowering of TensorList ops
converter._experimental_lower_tensor_list_ops = False

# Optimize (optional quantization)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

# Convert
tflite_model = converter.convert()

# Save
with open("student_model_lstm_compatible.tflite", "wb") as f:
    f.write(tflite_model)

print("Saved TFLite model with Select TF Ops support.")


INFO:tensorflow:Assets written to: /var/folders/gh/1nq7ydsd0ybb433gq3p_mk5c0000gn/T/tmpgw8a17rt/assets


INFO:tensorflow:Assets written to: /var/folders/gh/1nq7ydsd0ybb433gq3p_mk5c0000gn/T/tmpgw8a17rt/assets


Saved artifact at '/var/folders/gh/1nq7ydsd0ybb433gq3p_mk5c0000gn/T/tmpgw8a17rt'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): List[TensorSpec(shape=(None, 130, 3), dtype=tf.float32, name='stroke_input'), TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='image_input')]
Output Type:
  TensorSpec(shape=(None, 50), dtype=tf.float32, name=None)
Captures:
  14250378064: TensorSpec(shape=(), dtype=tf.resource, name=None)
  14250378640: TensorSpec(shape=(), dtype=tf.resource, name=None)
  14250379792: TensorSpec(shape=(), dtype=tf.resource, name=None)
  14250379024: TensorSpec(shape=(), dtype=tf.resource, name=None)
  14250378832: TensorSpec(shape=(), dtype=tf.resource, name=None)
  14250377872: TensorSpec(shape=(), dtype=tf.resource, name=None)
  14250382096: TensorSpec(shape=(), dtype=tf.resource, name=None)
  14250381520: TensorSpec(shape=(), dtype=tf.resource, name=None)
  14250384016: TensorSpec(shape=(), dtype=tf.resource, name=None)
 

W0000 00:00:1749715053.365126  124644 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1749715053.365135  124644 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
2025-06-12 13:27:33.573073: W tensorflow/compiler/mlir/lite/flatbuffer_export.cc:3993] TFLite interpreter needs to link Flex delegate in order to run the model since it contains the following Select TFop(s):
Flex ops: FlexTensorListReserve, FlexTensorListSetItem, FlexTensorListStack
Details:
	tf.TensorListReserve(tensor<2xi32>, tensor<i32>) -> (tensor<!tf_type.variant<tensor<?x128xf32>>>) : {device = ""}
	tf.TensorListReserve(tensor<2xi32>, tensor<i32>) -> (tensor<!tf_type.variant<tensor<?x64xf32>>>) : {device = ""}
	tf.TensorListSetItem(tensor<!tf_type.variant<tensor<?x128xf32>>>, tensor<i32>, tensor<?x128xf32>) -> (tensor<!tf_type.variant<tensor<?x128xf32>>>) : {device = "", resize_if_index_out_of_bounds = false}
	tf.TensorListSetItem(tensor<!tf_type.variant<tensor<?x64xf32>>>, tensor<i3

In [44]:
import os

file_path = "student_model_lstm_compatible.tflite"
file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
print(f"Model size: {file_size_mb:.2f} MB")

Model size: 0.64 MB


In [45]:
import tensorflow as tf
import numpy as np

interpreter = tf.lite.Interpreter(model_path="student_model_lstm_compatible.tflite")
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

print("Input details:", input_details)
print("Output details:", output_details)


Input details: [{'name': 'serving_default_image_input:0', 'index': 0, 'shape': array([ 1, 28, 28,  1], dtype=int32), 'shape_signature': array([-1, 28, 28,  1], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}, {'name': 'serving_default_stroke_input:0', 'index': 1, 'shape': array([  1, 130,   3], dtype=int32), 'shape_signature': array([ -1, 130,   3], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_parameters': {'scales': array([], dtype=float32), 'zero_points': array([], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
Output details: [{'name': 'StatefulPartitionedCall_1:0', 'index': 92, 'shape': array([ 1, 50], dtype=int32), 'shape_signature': array([-1, 50], dtype=int32), 'dtype': <class 'numpy.float32'>, 'quantization': (0.0, 0), 'quantization_p

    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
INFO: Created TensorFlow Lite delegate for select TF ops.
INFO: TfLiteFlexDelegate delegate: 6 nodes delegated out of 42 nodes with 3 partitions.

INFO: Created TensorFlow Lite XNNPACK delegate for CPU.


In [None]:
# trian test split
X_str_train, X_str_val = X_str[:split], X_str[split:]
X_img_train, X_img_val = X_img[:split], X_img[split:]
y_train, y_val = y[:split], y[split:]

In [48]:
print("Input details:")
for i, d in enumerate(input_details):
    print(f"Input {i}: name = {d['name']}, shape = {d['shape']}, dtype = {d['dtype']}")


Input details:
Input 0: name = serving_default_image_input:0, shape = [ 1 28 28  1], dtype = <class 'numpy.float32'>
Input 1: name = serving_default_stroke_input:0, shape = [  1 130   3], dtype = <class 'numpy.float32'>


In [49]:
correct = 0
total = len(X_str_val)

for i in range(total):
    input_stroke = X_str_val[i:i+1].astype(np.float32)  # shape [1, 130, 3]
    input_image = X_img_val[i:i+1].astype(np.float32)   # shape [1, 28, 28, 1]

    # Set inputs (order might differ — check your model)
    interpreter.set_tensor(input_details[1]['index'], input_stroke)
    interpreter.set_tensor(input_details[0]['index'], input_image)

    interpreter.invoke()

    output = interpreter.get_tensor(output_details[0]['index'])
    pred_label = np.argmax(output)
    true_label = np.argmax(y_val[i])

    if pred_label == true_label:
        correct += 1

accuracy = correct / total
print(f"TFLite model accuracy: {accuracy * 100:.2f}%")


TFLite model accuracy: 94.02%
