# Drawtcha: Unified CNN and RNN Model Training

This notebook combines the training of the two AI models for the Drawtcha project into a single file. It will:
1.  Train the **CNN Model** on image (bitmap) data.
2.  Save the model as `cnn_model_tf.keras`.
3.  Clear the system memory.
4.  Train the **RNN Model** on sequential stroke data.
5.  Save the model as `rnn_model_tf.keras`.

## 0. Initial Setup
Imports all necessary libraries for both models and defines the categories.

In [1]:
import tensorflow as tf
import numpy as np
import os
import urllib.request
import json
import gc
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split

CATEGORIES = ["cat", "bicycle", "tree", "fish", "star"]

---

## Part 1: CNN Model Training (Image Analysis)

### 1.1. CNN Data Preparation (Bitmap)

In [2]:
BASE_URL_BITMAP = "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/"
RAW_DATA_DIR_CNN = "data/raw_bitmap"

if not os.path.exists(RAW_DATA_DIR_CNN):
    os.makedirs(RAW_DATA_DIR_CNN)

print("--- Downloading data for the CNN ---")
for category in CATEGORIES:
    url = f"{BASE_URL_BITMAP}{category}.npy"
    filepath = os.path.join(RAW_DATA_DIR_CNN, f"{category}.npy")
    if not os.path.exists(filepath):
        print(f"Downloading {category}.npy...")
        urllib.request.urlretrieve(url, filepath)
    else:
        print(f"{category}.npy already exists.")

all_images = []
all_labels_cnn = []
for i, category in enumerate(CATEGORIES):
    filepath = os.path.join(RAW_DATA_DIR_CNN, f"{category}.npy")
    data = np.load(filepath)
    images = data.reshape(-1, 28, 28, 1).astype('float32') / 255.0
    all_images.append(images)
    labels = np.full(images.shape[0], i)
    all_labels_cnn.append(labels)

final_images = np.concatenate(all_images, axis=0)
final_labels_cnn = np.concatenate(all_labels_cnn, axis=0)
print(f"\nImage processing complete. Shape: {final_images.shape}")

--- Downloading data for the CNN ---
Downloading cat.npy...
Downloading bicycle.npy...
Downloading tree.npy...
Downloading fish.npy...
Downloading star.npy...

Image processing complete. Shape: (666219, 28, 28, 1)


### 1.2. CNN Model Definition and Training

In [3]:
num_classes = len(CATEGORIES)

cnn_model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

cnn_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
cnn_model.summary()

x_train_cnn, x_val_cnn, y_train_cnn, y_val_cnn = train_test_split(
    final_images, final_labels_cnn, test_size=0.2, random_state=42, stratify=final_labels_cnn
)

BATCH_SIZE_CNN = 128
train_dataset_cnn = tf.data.Dataset.from_tensor_slices((x_train_cnn, y_train_cnn)).shuffle(len(x_train_cnn)).batch(BATCH_SIZE_CNN).prefetch(tf.data.AUTOTUNE)
val_dataset_cnn = tf.data.Dataset.from_tensor_slices((x_val_cnn, y_val_cnn)).batch(BATCH_SIZE_CNN).prefetch(tf.data.AUTOTUNE)

print("\n--- Starting CNN training ---")
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
cnn_history = cnn_model.fit(train_dataset_cnn, epochs=10, validation_data=val_dataset_cnn, callbacks=[early_stopping])

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)



--- Starting CNN training ---
Epoch 1/10
[1m4164/4164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m37s[0m 7ms/step - accuracy: 0.9343 - loss: 0.1972 - val_accuracy: 0.9756 - val_loss: 0.0743
Epoch 2/10
[1m4164/4164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m34s[0m 6ms/step - accuracy: 0.9772 - loss: 0.0701 - val_accuracy: 0.9805 - val_loss: 0.0593
Epoch 3/10
[1m4164/4164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m28s[0m 7ms/step - accuracy: 0.9823 - loss: 0.0543 - val_accuracy: 0.9827 - val_loss: 0.0536
Epoch 4/10
[1m4164/4164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 6ms/step - accuracy: 0.9848 - loss: 0.0455 - val_accuracy: 0.9830 - val_loss: 0.0533
Epoch 5/10
[1m4164/4164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 6ms/step - accuracy: 0.9865 - loss: 0.0396 - val_accuracy: 0.9831 - val_loss: 0.0553
Epoch 6/10
[1m4164/4164[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 6ms/step - accuracy: 0.9880 - loss: 0.0348 - val_accuracy: 0.983

### 1.3. Save CNN Model

In [4]:
cnn_model.save('cnn_model_tf.keras')
print("\nCNN model successfully saved as cnn_model_tf.keras")


CNN model successfully saved as cnn_model_tf.keras


---

## Memory Cleanup
Before training the next model, we clear the memory to prevent resource exhaustion.

In [5]:
del final_images, final_labels_cnn, x_train_cnn, x_val_cnn, y_train_cnn, y_val_cnn, cnn_model, train_dataset_cnn, val_dataset_cnn
tf.keras.backend.clear_session()
gc.collect()
print("Memory cleared. Ready for the RNN model.")

Memory cleared. Ready for the RNN model.


---

## Part 2: RNN Model Training (Stroke Analysis)

### 2.1. RNN Data Preparation (Strokes)

In [6]:
BASE_URL_STROKE = "https://storage.googleapis.com/quickdraw_dataset/full/simplified/"
RAW_DATA_DIR_RNN = "data/raw_strokes"
MAX_SEQ_LENGTH = 200
MAX_DRAWINGS_PER_CATEGORY = 100000  # OPTIMIZATION: Limit drawings to prevent RAM overflow

if not os.path.exists(RAW_DATA_DIR_RNN):
    os.makedirs(RAW_DATA_DIR_RNN)

print("--- Downloading data for the RNN ---")
for category in CATEGORIES:
    url = f"{BASE_URL_STROKE}{category}.ndjson"
    filepath = os.path.join(RAW_DATA_DIR_RNN, f"{category}.ndjson")
    if not os.path.exists(filepath):
        print(f"Downloading {category}.ndjson...")
        urllib.request.urlretrieve(url, filepath)
    else:
        print(f"{category}.ndjson already exists.")

def strokes_to_sequence(strokes):
    """Converts raw stroke data into a sequence of [x, y, pen_state] points."""
    sequence = []
    for stroke in strokes:
        for i in range(len(stroke[0])):
            pen_state = 1 if i == 0 else 0
            sequence.append([stroke[0][i], stroke[1][i], pen_state])
    return np.array(sequence[:MAX_SEQ_LENGTH], dtype=np.float32)

all_sequences = []
all_labels_rnn = []
for i, category in enumerate(CATEGORIES):
    filepath = os.path.join(RAW_DATA_DIR_RNN, f"{category}.ndjson")
    print(f"Processing {category}...")
    with open(filepath, 'r') as f:
        for count, line in enumerate(f):
            if count >= MAX_DRAWINGS_PER_CATEGORY:
                break

            drawing = json.loads(line)
            seq = strokes_to_sequence(drawing['drawing'])
            padded_seq = np.zeros((MAX_SEQ_LENGTH, 3), dtype=np.float32)
            padded_seq[:len(seq)] = seq
            all_sequences.append(padded_seq)
            all_labels_rnn.append(i)

final_sequences = np.stack(all_sequences, axis=0)
final_labels_rnn = np.array(all_labels_rnn, dtype=np.uint8)
print(f"\nStroke processing complete. Shape: {final_sequences.shape}")

--- Downloading data for the RNN ---
Downloading cat.ndjson...
Downloading bicycle.ndjson...
Downloading tree.ndjson...
Downloading fish.ndjson...
Downloading star.ndjson...
Processing cat...
Processing bicycle...
Processing tree...
Processing fish...
Processing star...

Stroke processing complete. Shape: (500000, 200, 3)


### 2.2. RNN Model Definition and Training

In [7]:
def create_rnn_model(input_shape, num_classes):
    model = tf.keras.Sequential([
        tf.keras.layers.LSTM(256, return_sequences=True, input_shape=input_shape),
        tf.keras.layers.LSTM(256),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(num_classes, activation='softmax')
    ])
    return model

input_shape = (MAX_SEQ_LENGTH, 3)
num_classes = len(CATEGORIES)
rnn_model = create_rnn_model(input_shape, num_classes)

rnn_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
rnn_model.summary()

x_train_rnn, x_val_rnn, y_train_rnn, y_val_rnn = train_test_split(
    final_sequences, final_labels_rnn, test_size=0.2, random_state=42, stratify=final_labels_rnn
)

BATCH_SIZE_RNN = 128
train_dataset_rnn = tf.data.Dataset.from_tensor_slices((x_train_rnn, y_train_rnn)).shuffle(len(x_train_rnn)).batch(BATCH_SIZE_RNN).prefetch(tf.data.AUTOTUNE)
val_dataset_rnn = tf.data.Dataset.from_tensor_slices((x_val_rnn, y_val_rnn)).batch(BATCH_SIZE_RNN).prefetch(tf.data.AUTOTUNE)

print("\n--- Starting RNN training ---")
early_stopping_rnn = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
rnn_history = rnn_model.fit(train_dataset_rnn, epochs=15, validation_data=val_dataset_rnn, callbacks=[early_stopping_rnn])

  super().__init__(**kwargs)



--- Starting RNN training ---
Epoch 1/15
[1m3125/3125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m201s[0m 63ms/step - accuracy: 0.5935 - loss: 0.9218 - val_accuracy: 0.9318 - val_loss: 0.2195
Epoch 2/15
[1m3125/3125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m201s[0m 64ms/step - accuracy: 0.9307 - loss: 0.2221 - val_accuracy: 0.9488 - val_loss: 0.1694
Epoch 3/15
[1m3125/3125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m201s[0m 64ms/step - accuracy: 0.9488 - loss: 0.1643 - val_accuracy: 0.9548 - val_loss: 0.1448
Epoch 4/15
[1m3125/3125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m202s[0m 64ms/step - accuracy: 0.9569 - loss: 0.1365 - val_accuracy: 0.9625 - val_loss: 0.1200
Epoch 5/15
[1m3125/3125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m203s[0m 64ms/step - accuracy: 0.9620 - loss: 0.1213 - val_accuracy: 0.9670 - val_loss: 0.1028
Epoch 6/15
[1m3125/3125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m202s[0m 64ms/step - accuracy: 0.9661 - loss: 0.1068 - val_acc

### 2.3. Save RNN Model

In [8]:
rnn_model.save('rnn_model_tf.keras')
print("\nRNN model successfully saved as rnn_model_tf.keras")


RNN model successfully saved as rnn_model_tf.keras


---

## Conclusion
Training complete! The following files have been generated and are ready to be used in the FastAPI backend:
- `cnn_model_tf.keras`
- `rnn_model_tf.keras`