# 02_model_training.ipynb

This notebook builds and trains a CNN for CIFAR-10 following the project requirements (3+ conv layers, batch norm, dropout, pooling, global avg pooling, softmax output). It also includes advanced options: residual blocks, Mixup/CutMix, transfer learning hooks, learning rate scheduling, callbacks, and evaluation (classification report + confusion matrix).

In [None]:
# Imports and helper functions
import numpy as np
import matplotlib.pyplot as plt
import json, os
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
from tensorflow.keras import layers, models, regularizers, optimizers, callbacks, applications
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.datasets import cifar10
from sklearn.model_selection import train_test_split

print('TensorFlow available')

In [None]:
# Load preprocessed data (expects previous notebook/script to have prepared split datasets)
# If you don't have them saved, this will create the 70/15/15 split and normalize.
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_all = np.concatenate([X_train, X_test])
y_all = np.concatenate([y_train, y_test])
X_train, X_temp, y_train, y_temp = train_test_split(X_all, y_all, test_size=0.30, random_state=42, stratify=y_all)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.50, random_state=42, stratify=y_temp)

# Normalize
X_train = X_train.astype('float32')/255.0
X_val = X_val.astype('float32')/255.0
X_test = X_test.astype('float32')/255.0

# One-hot
num_classes = 10
y_train_cat = to_categorical(y_train, num_classes)
y_val_cat = to_categorical(y_val, num_classes)
y_test_cat = to_categorical(y_test, num_classes)

print('Shapes:', X_train.shape, X_val.shape, X_test.shape)

In [None]:
# Data augmentation generators
train_datagen = ImageDataGenerator(
    rotation_range=15,
    horizontal_flip=True,
    zoom_range=0.1,
    brightness_range=[0.8,1.2]
)
train_datagen.fit(X_train)

train_gen = train_datagen.flow(X_train, y_train_cat, batch_size=64)
val_gen = ImageDataGenerator().flow(X_val, y_val_cat, batch_size=64)


## Model building helpers
Includes a simple residual block and a base CNN builder with 3+ conv layers, batch norm, dropout, pooling, and global average pooling.

In [None]:
def residual_block(x, filters, kernel_size=3, stride=1):
    shortcut = x
    x = layers.Conv2D(filters, kernel_size, padding='same', strides=stride, kernel_regularizer=regularizers.l2(1e-4))(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(filters, kernel_size, padding='same', strides=1, kernel_regularizer=regularizers.l2(1e-4))(x)
    x = layers.BatchNormalization()(x)
    # adjust shortcut if needed
    if shortcut.shape[-1] != filters or stride != 1:
        shortcut = layers.Conv2D(filters, 1, strides=stride, padding='same')(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)
    x = layers.add([x, shortcut])
    x = layers.ReLU()(x)
    return x


def build_cnn(input_shape=(32,32,3), num_classes=10, use_residual=False, dropout_rate=0.3):
    inp = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, 3, padding='same', kernel_regularizer=regularizers.l2(1e-4))(inp)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(32, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Dropout(dropout_rate)(x)

    x = layers.Conv2D(64, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Conv2D(64, 3, padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Dropout(dropout_rate)(x)

    if use_residual:
        x = residual_block(x, 128, stride=1)
        x = residual_block(x, 128, stride=1)
    else:
        x = layers.Conv2D(128, 3, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)

    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(256)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    x = layers.Dropout(dropout_rate)(x)
    out = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs=inp, outputs=out)
    return model

# Build model
model = build_cnn(use_residual=True)
model.summary()

## Advanced: Transfer Learning (optional)
You can uncomment and use one of these backbones (VGG16, ResNet50, MobileNetV2, EfficientNetB0) for transfer learning. Make sure to set `include_top=False` and `weights='imagenet'` and then add a GlobalAveragePooling + Dense head.

In [None]:
# Example transfer learning backbone (commented by default)
# backbone = applications.ResNet50(include_top=False, weights='imagenet', input_shape=(32,32,3))
# x = backbone.output
# x = layers.GlobalAveragePooling2D()(x)
# x = layers.Dense(256, activation='relu')(x)
# out = layers.Dense(num_classes, activation='softmax')(x)
# tl_model = models.Model(inputs=backbone.input, outputs=out)
# for layer in backbone.layers:
#     layer.trainable = False
# tl_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# tl_model.summary()


## Mixup and CutMix implementations (utilities)

In [None]:
# Mixup utility
def mixup(batch_x, batch_y, alpha=0.2):
    lam = np.random.beta(alpha, alpha)
    batch_size = batch_x.shape[0]
    index = np.random.permutation(batch_size)
    mixed_x = lam * batch_x + (1 - lam) * batch_x[index]
    mixed_y = lam * batch_y + (1 - lam) * batch_y[index]
    return mixed_x, mixed_y

# CutMix utility
def cutmix(batch_x, batch_y, alpha=1.0):
    batch_size, H, W, _ = batch_x.shape
    index = np.random.permutation(batch_size)
    lam = np.random.beta(alpha, alpha)
    rx = np.random.randint(W)
    ry = np.random.randint(H)
    rw = int(W * np.sqrt(1 - lam))
    rh = int(H * np.sqrt(1 - lam))
    x1 = np.clip(rx - rw // 2, 0, W)
    y1 = np.clip(ry - rh // 2, 0, H)
    x2 = np.clip(rx + rw // 2, 0, W)
    y2 = np.clip(ry + rh // 2, 0, H)
    new_x = batch_x.copy()
    new_x[:, y1:y2, x1:x2, :] = batch_x[index, y1:y2, x1:x2, :]
    lam_adjusted = 1 - ((x2 - x1) * (y2 - y1) / (W * H))
    new_y = batch_y * lam_adjusted + batch_y[index] * (1 - lam_adjusted)
    return new_x, new_y


## Training setup: optimizer, callbacks, compile

In [None]:
# Compile model
opt = optimizers.Adam(learning_rate=1e-3)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

# Callbacks
checkpoint_cb = callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_accuracy', mode='max')
earlystop_cb = callbacks.EarlyStopping(monitor='val_loss', patience=12, restore_best_weights=True)
reduce_lr = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5)

# Learning rate scheduler example (optional)
def scheduler(epoch, lr):
    if epoch>0 and epoch%30==0:
        return lr*0.5
    return lr
lr_cb = callbacks.LearningRateScheduler(scheduler)

cb_list = [checkpoint_cb, earlystop_cb, reduce_lr, lr_cb]


## Train the model
Choose `use_mixup` or `use_cutmix` to apply those augmentations on-the-fly. Training is performed with generators.

In [None]:
# Training parameters
epochs = 100
steps_per_epoch = len(train_gen)
use_mixup = False
use_cutmix = False

# Custom training loop with mixup/cutmix support
from tqdm import tqdm

def train_with_augmentation(model, train_gen, val_gen, epochs, cb_list, use_mixup=False, use_cutmix=False):
    history = { 'loss': [], 'accuracy': [], 'val_loss': [], 'val_accuracy': [] }
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        # train
        batch_metrics = []
        prog = tqdm(range(steps_per_epoch))
        for i in prog:
            X_batch, y_batch = next(train_gen)
            if use_mixup:
                X_batch, y_batch = mixup(X_batch, y_batch)
            if use_cutmix:
                X_batch, y_batch = cutmix(X_batch, y_batch)
            res = model.train_on_batch(X_batch, y_batch)
            prog.set_postfix({'loss':res[0], 'acc':res[1]})
        # validate
        val_res = model.evaluate(val_gen, verbose=0)
        print('val loss, val acc ->', val_res)
        # callbacks on_epoch_end simulation for ReduceLROnPlateau etc.
        # NOTE: to fully use Keras callbacks, use model.fit with generators if not using mixup/cutmix.
        history['val_loss'].append(val_res[0])
        history['val_accuracy'].append(val_res[1])
    return history

# If you aren't using mixup/cutmix, you can simply call model.fit:
if not (use_mixup or use_cutmix):
    history = model.fit(train_gen, validation_data=val_gen, epochs=50, callbacks=cb_list)
else:
    history = train_with_augmentation(model, train_gen, val_gen, epochs=50, cb_list=cb_list, use_mixup=use_mixup, use_cutmix=use_cutmix)

# Save history to JSON
hist = history.history if hasattr(history, 'history') else history
with open('training_history.json', 'w') as f:
    json.dump(hist, f)

print('Training complete or skipped (run cells to train).')


## Evaluation on test set and visualizations

In [None]:
# Load best model if exists, else use current model
from tensorflow.keras.models import load_model
model_path = 'best_model.h5'
if os.path.exists(model_path):
    print('Loading best_model.h5')
    model = load_model(model_path)

# Evaluate
test_loss, test_acc = model.evaluate(ImageDataGenerator().flow(X_test, y_test_cat, batch_size=64), verbose=1)
print('Test loss:', test_loss, 'Test acc:', test_acc)

# Predictions
y_pred = model.predict(ImageDataGenerator().flow(X_test, batch_size=64, shuffle=False))
y_pred_labels = np.argmax(y_pred, axis=1)
y_true = y_test.reshape(-1)

# Classification report
print(classification_report(y_true, y_pred_labels))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred_labels)
plt.figure(figsize=(10,8))
sns.heatmap(cm, annot=True, fmt='d')
plt.title('Confusion Matrix')
plt.show()

# Plot training curves if history exists
if os.path.exists('training_history.json'):
    with open('training_history.json','r') as f:
        hist = json.load(f)
    if 'loss' in hist:
        plt.figure(); plt.plot(hist.get('loss', []), label='train_loss'); plt.plot(hist.get('val_loss', []), label='val_loss'); plt.legend(); plt.title('Loss'); plt.show()
    if 'accuracy' in hist:
        plt.figure(); plt.plot(hist.get('accuracy', []), label='train_acc'); plt.plot(hist.get('val_accuracy', []), label='val_acc'); plt.legend(); plt.title('Accuracy'); plt.show()
