# 02 – Train TinyML Model (CWRU 12k DE)

This notebook trains a **small 1D-CNN** for 4-class bearing condition classification (Normal, InnerRace, Ball, OuterRace) on the segments you created with the preprocessing notebook.

It then performs **INT8 post-training quantization** and exports a `.tflite` file and a **C header** for TensorFlow Lite for Microcontrollers.

### Prerequisites
- Run the preprocessing notebook first. You should have:
  - `data/DE_12k/npz/train.npz`, `val.npz`, `test.npz`
- Install TensorFlow locally if not already available (e.g., `pip install tensorflow`).


In [None]:
# Imports
import numpy as np
import pandas as pd
import os
from pathlib import Path
import matplotlib.pyplot as plt

try:
    import tensorflow as tf
except Exception as e:
    raise SystemExit("TensorFlow is required for this notebook. Please install it with `pip install tensorflow`. Error: " + str(e))

from sklearn.metrics import classification_report, confusion_matrix

DATA_ROOT = Path("data/DE_12k/npz")
OUT_ROOT = Path("models")
OUT_ROOT.mkdir(parents=True, exist_ok=True)
WIN = 2048  # window size used in preprocessing
N_CLASSES = 4
LABELS = ['Normal','InnerRace','Ball','OuterRace']


In [None]:
# Load splits
def load_split(name):
    p = DATA_ROOT / f"{name}.npz"
    d = np.load(p)
    return d['X'], d['y']

X_train, y_train = load_split('train')
X_val,   y_val   = load_split('val')
X_test,  y_test  = load_split('test')

print('Shapes:')
print('  X_train', X_train.shape, 'y_train', y_train.shape)
print('  X_val  ', X_val.shape,   'y_val  ', y_val.shape)
print('  X_test ', X_test.shape,  'y_test ', y_test.shape)


In [None]:
# Standardize and reshape for 1D-CNN (N, T, 1)
def standardize(X):
    mu = X.mean(axis=1, keepdims=True)
    sd = X.std(axis=1, keepdims=True) + 1e-8
    return (X - mu) / sd

X_train_s = standardize(X_train).astype('float32')
X_val_s   = standardize(X_val).astype('float32')
X_test_s  = standardize(X_test).astype('float32')

X_train_s = X_train_s[..., None]
X_val_s   = X_val_s[..., None]
X_test_s  = X_test_s[..., None]

print('Reshaped:')
print('  X_train_s', X_train_s.shape)
print('  X_val_s  ', X_val_s.shape)
print('  X_test_s ', X_test_s.shape)


In [None]:
# Build a tiny 1D-CNN model
from tensorflow.keras import layers, models

def build_model(input_len=WIN, n_classes=N_CLASSES):
    inp = layers.Input(shape=(input_len, 1))
    x = layers.Conv1D(8, kernel_size=9, strides=2, activation='relu')(inp)
    x = layers.MaxPool1D(pool_size=2)(x)
    x = layers.Conv1D(16, kernel_size=5, activation='relu')(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(8, activation='relu')(x)
    out = layers.Dense(n_classes, activation='softmax')(x)
    model = models.Model(inputs=inp, outputs=out)
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

model = build_model()
model.summary()


In [None]:
# Train
EPOCHS = 12
BATCH  = 64
callbacks = [
    tf.keras.callbacks.ReduceLROnPlateau(patience=2, factor=0.5, verbose=1),
    tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True)
]
hist = model.fit(
    X_train_s, y_train,
    validation_data=(X_val_s, y_val),
    epochs=EPOCHS,
    batch_size=BATCH,
    verbose=1,
    callbacks=callbacks
)


In [None]:
# Evaluate
test_loss, test_acc = model.evaluate(X_test_s, y_test, verbose=0)
print(f"Test accuracy: {test_acc:.4f}")

y_pred = model.predict(X_test_s, verbose=0).argmax(axis=1)
print(classification_report(y_test, y_pred, target_names=LABELS))
cm = confusion_matrix(y_test, y_pred)

plt.figure()
plt.imshow(cm, interpolation='nearest')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.colorbar()
plt.show()


## Export to INT8 TFLite
We perform **post-training quantization** using a small representative dataset from the training set to calibrate int8 ranges.

In [None]:
# Representative dataset generator for INT8 calibration
def rep_data_gen():
    for i in range(min(500, len(X_train_s))):
        yield [X_train_s[i:i+1]]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = rep_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
tflite_int8 = converter.convert()

OUT_ROOT = Path("models")
OUT_ROOT.mkdir(parents=True, exist_ok=True)
tflite_path = OUT_ROOT / "tinyml_cnn_int8.tflite"
with open(tflite_path, 'wb') as f:
    f.write(tflite_int8)
print("Saved:", tflite_path, "(bytes)", len(tflite_int8))


In [None]:
# Create a C header for Arduino / TFLite Micro
def tflite_to_c_array(bytes_data, var_name='g_tinyml_model'):
    hex_bytes = ', '.join(f'0x{b:02x}' for b in bytes_data)
    lines = []
    lines.append('#pragma once')
    lines.append('#include <cstdint>')
    lines.append(f'const unsigned int {var_name}_len = {len(bytes_data)};')
    lines.append(f'const unsigned char {var_name}[] = {{ {hex_bytes} }};')
    return '\n'.join(lines)

h_path = OUT_ROOT / "tinyml_model_data.h"
with open(tflite_path, 'rb') as f:
    b = f.read()
h_text = tflite_to_c_array(b, 'g_tinyml_model')
with open(h_path, 'w') as f:
    f.write(h_text)
print("Saved:", h_path)


### Arduino / MCU notes
- Copy `tinyml_model_data.h` into your Arduino sketch folder.
- Use **TensorFlow Lite for Microcontrollers** and the `g_tinyml_model` array in the interpreter setup.
- Start with an arena size around **60–100 KB** and tune on device.

Minimal sketch structure:
```cpp
#include "tinyml_model_data.h"
#include "tensorflow/lite/micro/all_ops_resolver.h"
#include "tensorflow/lite/micro/micro_interpreter.h"
#include "tensorflow/lite/schema/schema_generated.h"

constexpr int kArenaSize = 100 * 1024; // tune as needed
uint8_t tensor_arena[kArenaSize];

const tflite::Model* model = tflite::GetModel(g_tinyml_model);
tflite::AllOpsResolver resolver;
tflite::MicroInterpreter interpreter(model, resolver, tensor_arena, kArenaSize);
interpreter.AllocateTensors();
TfLiteTensor* input = interpreter.input(0);
// Fill input->data.int8[...] with your 2048-length window (quantized)
interpreter.Invoke();
TfLiteTensor* output = interpreter.output(0);
// Read output->data.int8[0..3]
```
