# Drawtcha: Unified CNN and RNN Model Training

This notebook combines the training of the two AI models for the Drawtcha project into a single file. It will:
1.  Train the **CNN Model** on image (bitmap) data.
2.  Save the model as `cnn_model_tf.h5`.
3.  Clear the system memory.
4.  Train the **RNN Model** on sequential stroke data.
5.  Save the model as `rnn_model_tf.h5`.

## 0. Initial Setup
Imports all necessary libraries for both models and defines the categories.

In [None]:
import tensorflow as tf
import numpy as np
import os
import urllib.request
import json
import gc
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.model_selection import train_test_split

CATEGORIES = ["cat", "bicycle", "tree", "fish", "star"]

---

## Part 1: CNN Model Training (Image Analysis)

### 1.1. CNN Data Preparation (Bitmap)

In [None]:
BASE_URL_BITMAP = "https://storage.googleapis.com/quickdraw_dataset/full/numpy_bitmap/"
RAW_DATA_DIR_CNN = "data/raw_bitmap"

if not os.path.exists(RAW_DATA_DIR_CNN):
    os.makedirs(RAW_DATA_DIR_CNN)

print("--- Downloading data for the CNN ---")
for category in CATEGORIES:
    url = f"{BASE_URL_BITMAP}{category}.npy"
    filepath = os.path.join(RAW_DATA_DIR_CNN, f"{category}.npy")
    if not os.path.exists(filepath):
        print(f"Downloading {category}.npy...")
        urllib.request.urlretrieve(url, filepath)
    else:
        print(f"{category}.npy already exists.")

all_images = []
all_labels_cnn = []
for i, category in enumerate(CATEGORIES):
    filepath = os.path.join(RAW_DATA_DIR_CNN, f"{category}.npy")
    data = np.load(filepath)
    images = data.reshape(-1, 28, 28, 1).astype('float32') / 255.0
    all_images.append(images)
    labels = np.full(images.shape[0], i)
    all_labels_cnn.append(labels)

final_images = np.concatenate(all_images, axis=0)
final_labels_cnn = np.concatenate(all_labels_cnn, axis=0)
print(f"\nImage processing complete. Shape: {final_images.shape}")

### 1.2. CNN Model Definition and Training

In [None]:
num_classes = len(CATEGORIES)

cnn_model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

cnn_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
cnn_model.summary()

x_train_cnn, x_val_cnn, y_train_cnn, y_val_cnn = train_test_split(
    final_images, final_labels_cnn, test_size=0.2, random_state=42, stratify=final_labels_cnn
)

BATCH_SIZE_CNN = 128
train_dataset_cnn = tf.data.Dataset.from_tensor_slices((x_train_cnn, y_train_cnn)).shuffle(len(x_train_cnn)).batch(BATCH_SIZE_CNN).prefetch(tf.data.AUTOTUNE)
val_dataset_cnn = tf.data.Dataset.from_tensor_slices((x_val_cnn, y_val_cnn)).batch(BATCH_SIZE_CNN).prefetch(tf.data.AUTOTUNE)

print("\n--- Starting CNN training ---")
early_stopping = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
cnn_history = cnn_model.fit(train_dataset_cnn, epochs=10, validation_data=val_dataset_cnn, callbacks=[early_stopping])

### 1.3. Save CNN Model

In [None]:
cnn_model.save('cnn_model_tf.h5')
print("\nCNN model successfully saved as cnn_model_tf.h5")

---

## Memory Cleanup
Before training the next model, we clear the memory to prevent resource exhaustion.

In [None]:
del final_images, final_labels_cnn, x_train_cnn, x_val_cnn, y_train_cnn, y_val_cnn, cnn_model, train_dataset_cnn, val_dataset_cnn
tf.keras.backend.clear_session()
gc.collect()
print("Memory cleared. Ready for the RNN model.")

---

## Part 2: RNN Model Training (Stroke Analysis)

### 2.1. RNN Data Preparation (Strokes)

**Memory Optimization:** Limit the number of drawings loaded per category to avoid exceeding RAM limits.

In [None]:
BASE_URL_STROKE = "https://storage.googleapis.com/quickdraw_dataset/full/simplified/"
RAW_DATA_DIR_RNN = "data/raw_strokes"
MAX_SEQ_LENGTH = 200
MAX_DRAWINGS_PER_CATEGORY = 50000  # OPTIMIZATION: Limit drawings to prevent RAM overflow

if not os.path.exists(RAW_DATA_DIR_RNN):
    os.makedirs(RAW_DATA_DIR_RNN)

print("--- Downloading data for the RNN ---")
for category in CATEGORIES:
    url = f"{BASE_URL_STROKE}{category}.ndjson"
    filepath = os.path.join(RAW_DATA_DIR_RNN, f"{category}.ndjson")
    if not os.path.exists(filepath):
        print(f"Downloading {category}.ndjson...")
        urllib.request.urlretrieve(url, filepath)
    else:
        print(f"{category}.ndjson already exists.")

def strokes_to_sequence(strokes):
    """Converts raw stroke data into a sequence of [x, y, pen_state] points."""
    sequence = []
    for stroke in strokes:
        for i in range(len(stroke[0])):
            pen_state = 1 if i == 0 else 0
            sequence.append([stroke[0][i], stroke[1][i], pen_state])
    return np.array(sequence[:MAX_SEQ_LENGTH], dtype=np.float32)

all_sequences = []
all_labels_rnn = []
for i, category in enumerate(CATEGORIES):
    filepath = os.path.join(RAW_DATA_DIR_RNN, f"{category}.ndjson")
    print(f"Processing {category}...")
    with open(filepath, 'r') as f:
        for count, line in enumerate(f):
            if count >= MAX_DRAWINGS_PER_CATEGORY:
                break
            
            drawing = json.loads(line)
            seq = strokes_to_sequence(drawing['drawing'])
            padded_seq = np.zeros((MAX_SEQ_LENGTH, 3), dtype=np.float32)
            padded_seq[:len(seq)] = seq
            all_sequences.append(padded_seq)
            all_labels_rnn.append(i)

final_sequences = np.stack(all_sequences, axis=0)
final_labels_rnn = np.array(all_labels_rnn, dtype=np.uint8)
print(f"\nStroke processing complete. Shape: {final_sequences.shape}")

### 2.2. RNN Model Definition and Training

**Memory Optimization:** The LSTM layers are reduced to 128 units and the batch size is set to 64 to lower memory consumption during training.

In [None]:
def create_rnn_model(input_shape, num_classes):
    model = tf.keras.Sequential([
        # OPTIMIZATION: Reduced LSTM units from 256 to 128. If you have more resources on your machine, you can return to 256.
        tf.keras.layers.LSTM(128, return_sequences=True, input_shape=input_shape),
        tf.keras.layers.LSTM(128),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(num_classes, activation='softmax')
    ])
    return model

input_shape = (MAX_SEQ_LENGTH, 3)
num_classes = len(CATEGORIES)
rnn_model = create_rnn_model(input_shape, num_classes)

rnn_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
rnn_model.summary()

x_train_rnn, x_val_rnn, y_train_rnn, y_val_rnn = train_test_split(
    final_sequences, final_labels_rnn, test_size=0.2, random_state=42, stratify=final_labels_rnn
)

# OPTIMIZATION: Use a smaller batch size for RNN training. If you have more resources on your machine, you can change it to a higher value.
BATCH_SIZE_RNN = 64
train_dataset_rnn = tf.data.Dataset.from_tensor_slices((x_train_rnn, y_train_rnn)).shuffle(len(x_train_rnn)).batch(BATCH_SIZE_RNN).prefetch(tf.data.AUTOTUNE)
val_dataset_rnn = tf.data.Dataset.from_tensor_slices((x_val_rnn, y_val_rnn)).batch(BATCH_SIZE_RNN).prefetch(tf.data.AUTOTUNE)

print("\n--- Starting RNN training ---")
early_stopping_rnn = EarlyStopping(monitor='val_loss', patience=3, restore_best_weights=True)
rnn_history = rnn_model.fit(train_dataset_rnn, epochs=15, validation_data=val_dataset_rnn, callbacks=[early_stopping_rnn])

### 2.3. Save RNN Model

In [None]:
rnn_model.save('rnn_model_tf.h5')
print("\nRNN model successfully saved as rnn_model_tf.h5")

---

## Conclusion
Training complete! The following files have been generated and are ready to be used in the FastAPI backend:
- `cnn_model_tf.h5`
- `rnn_model_tf.h5`