# Imports

In [None]:
import tensorflow as tf
import numpy as np
import os
import time
import psutil
import json

# Configuration

## Parameter

In [None]:
BATCH_SIZE = 32
EPOCHS = 6 
LEARNING_RATE = 0.0001
IMG_SIZE = 32
N_REPEATS = 1 # Number of times to repeat the training for averaging results
NUM_CLASSES = 10 # CIFAR-10 dataset
# NUM_CLASSES = 100 # CIFAR-100 dataset

## Directory Setup

In [None]:
os.makedirs('saved_models', exist_ok=True)
os.makedirs('results', exist_ok=True)
os.makedirs('histories', exist_ok=True)

# Utilities functions

## Data

In [None]:
def set_seeds(seed):
    """Set random seeds for reproducibility."""
    tf.keras.utils.set_random_seed(seed)
    tf.config.experimental.enable_op_determinism()

def get_cifar10_data():
    """Load and normalize the CIFAR-10 dataset."""
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0
    return (x_train, y_train), (x_test, y_test)

def get_cifar100_data():
    """Load and normalize the CIFAR-100 dataset."""
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data(label_mode='fine')
    x_train = x_train.astype("float32") / 255.0
    x_test = x_test.astype("float32") / 255.0
    return (x_train, y_train), (x_test, y_test)

def rgb_to_grayscale(images):
    """Convert RGB images to grayscale using standard weights."""
    return np.expand_dims(
        np.dot(images[...,:3], [0.299, 0.587, 0.114]), axis=-1
    )

def grayscale_to_rgb(images):
    """Convert single-channel grayscale images to 3-channel by copying."""
    return np.concatenate([images]*3, axis=-1)

def data_augmentation(img):
    """Apply simple augmentations to an image tensor."""
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_contrast(img, 0.85, 1.15)
    img = tf.image.random_brightness(img, 0.1)
    img = tf.image.rot90(img, k=np.random.randint(0, 4))
    img = tf.clip_by_value(img, 0.0, 1.0)
    return img

def make_dataset(x, y, is_training=True, augment=False, grayscale=False):
    """
    Create a TensorFlow dataset for training or testing.
    Optionally convert to grayscale and/or apply augmentation.
    """
    AUTOTUNE = tf.data.AUTOTUNE

    def preprocess(img, label):
        if grayscale:
            img = tf.image.rgb_to_grayscale(img)
        if augment:
            img = data_augmentation(img)
        return img, label

    ds = tf.data.Dataset.from_tensor_slices((x, y))
    if is_training:
        ds = ds.shuffle(buffer_size=1024)
    ds = ds.map(lambda img, lbl: preprocess(img, lbl), num_parallel_calls=AUTOTUNE)
    ds = ds.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return ds

def random_grayscale_augmentation(img):
    """
    Augmentation function to apply random flip, contrast, brightness, and rotation to an image.
    Used for offline data augmentation.
    """
    img = tf.convert_to_tensor(img)
    img = tf.image.random_flip_left_right(img)
    img = tf.image.random_contrast(img, 0.8, 1.2)
    img = tf.image.random_brightness(img, 0.1)
    img = tf.image.rot90(img, k=np.random.randint(0, 4))
    img = tf.clip_by_value(img, 0.0, 1.0)
    return img.numpy()

def generate_offline_augmented_set(x, y, augment_fn, seed=42):
    """
    Augment the entire training set offline (ahead of time) and return the combined set.
    """
    np.random.seed(seed)
    N = x.shape[0]
    augmented = []
    for i in range(N):
        img = x[i]
        img_aug = augment_fn(img)
        augmented.append(img_aug)
    x_aug = np.stack(augmented)
    y_aug = np.copy(y)
    # Combine original and augmented data
    x_total = np.concatenate([x, x_aug], axis=0)
    y_total = np.concatenate([y, y_aug], axis=0)
    return x_total, y_total

## Measurements

In [None]:
def get_cpu_memory():
    """Return the current process memory usage in MB."""
    process = psutil.Process(os.getpid())
    mem_mb = process.memory_info().rss / 1024**2
    return mem_mb

def measure_inference_time(model, ds, num_batches=20):
    """Measure average inference time and memory on the CPU for a few batches."""
    times = []
    mems = []
    it = iter(ds)
    for _ in range(num_batches):
        try:
            x_batch, _ = next(it)
        except StopIteration:
            break
        start = time.time()
        _ = model.predict(x_batch, verbose=0)
        t = time.time() - start
        m = get_cpu_memory()
        times.append(t)
        mems.append(m)
    return np.mean(times), np.mean(mems)

## Model

In [None]:
def build_cnn(input_shape, num_classes=10):
    """Build a simple CNN model for CIFAR-10 or CIFAR-100."""
    model = tf.keras.Sequential([
        tf.keras.layers.Input(shape=input_shape),
        tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same'),
        tf.keras.layers.Conv2D(32, 3, activation='relu', padding='same'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
        tf.keras.layers.Conv2D(64, 3, activation='relu', padding='same'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(num_classes, activation='softmax')
    ])
    return model

# Experiment

In [None]:
def run_experiment(
    name, x_train, y_train, x_test, y_test, 
    grayscale=False, augment=False, repeat=0, save_hist=True, seed=None, 
    offline_augmented=False):
    """
    Train and evaluate a model for a specific experiment configuration.
    Saves results and learning curves to disk.
    """
    print(f"\n=== Experiment: {name} | Grayscale: {grayscale} | Augment: {augment} | Repeat: {repeat} ===")
    
    # Decide input shape: grayscale images have 1 channel, color have 3
    input_shape = (IMG_SIZE, IMG_SIZE, 1) if grayscale or offline_augmented else (IMG_SIZE, IMG_SIZE, 3)
    
    # Only apply grayscale conversion if images are not already grayscale
    train_grayscale = grayscale and not offline_augmented
    test_grayscale = grayscale and not offline_augmented
    
    # Make datasets
    ds_train = make_dataset(x_train, y_train, is_training=True, augment=(augment and not offline_augmented), grayscale=train_grayscale)
    ds_test = make_dataset(x_test, y_test, is_training=False, augment=False, grayscale=test_grayscale)
    
    # Build and compile model
    model = build_cnn(input_shape, num_classes=NUM_CLASSES)
    optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE)
    model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    print(model.summary())
    if seed is not None:
        set_seeds(seed)
    
    # Train model
    early_stopping = tf.keras.callbacks.EarlyStopping(
        patience=3, 
        restore_best_weights=True, 
        monitor='val_loss'
    )
    
    start_train_time = time.time()
    history = model.fit(ds_train, epochs=EPOCHS, validation_data=ds_test, verbose=2, callbacks=[early_stopping])
    train_time = time.time() - start_train_time
    
    # Measure memory and inference
    cpu_mem_train = get_cpu_memory()
    val_loss, val_acc = model.evaluate(ds_test, verbose=0)
    inf_time_cpu, inf_mem_cpu = measure_inference_time(model, ds_test)
    
    model_name = f"{name}_rep{repeat}"
    model.save(f"saved_models/{model_name}.keras")
    
    results = {
        "experiment": name,
        "repeat": repeat,
        "grayscale": grayscale,
        "augment": augment,
        "offline_augmented": offline_augmented,
        "epochs": EPOCHS,
        "batch_size": BATCH_SIZE,
        "val_acc": float(val_acc),
        "val_loss": float(val_loss),
        "train_time_sec": float(train_time),
        "cpu_mem_train_MB": float(cpu_mem_train),
        "inf_time_cpu_sec": float(inf_time_cpu),
        "inf_mem_cpu_MB": float(inf_mem_cpu),
        "model_params": model.count_params(),
    }
    with open(f"results/{model_name}.json", "w") as f:
        json.dump(results, f, indent=2)
    if save_hist:
        with open(f"histories/{model_name}_history.json", "w") as f:
            json.dump(history.history, f, indent=2)
    print("Results:", results)
    return results, history.history     

**Cell to run all experiments:**

In [None]:
(x_train, y_train), (x_test, y_test) = get_cifar100_data()
# Precompute grayscale versions for efficiency
x_test_gray = rgb_to_grayscale(x_test).astype("float32")
x_test_gray_as_rgb = grayscale_to_rgb(x_test_gray)
x_train_gray = rgb_to_grayscale(x_train).astype("float32")

experiments = [
    # Baseline: Color images
    dict(name="color_baseline", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, grayscale=False, augment=False),
    # Baseline: Grayscale images (no augmentation)
    dict(name="grayscale_baseline", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, grayscale=True, augment=False),
    # Grayscale images with augmentation
    dict(name="grayscale_aug", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, grayscale=True, augment=True),
    # Extra controls: Color model on grayscale test images
    dict(name="color_on_gray_test", x_train=x_train, y_train=y_train, x_test=x_test_gray_as_rgb, y_test=y_test, grayscale=False, augment=False),
    dict(name="grayscale_on_color_test", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, grayscale=True, augment=False),
    dict(name="grayscale_aug_on_color_test", x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test, grayscale=True, augment=True)
]

# Offline-augmented grayscale experiment (doubles training data size)
print("Generating offline-augmented grayscale training set...")
for rep in range(N_REPEATS):
    seed = 42 + rep
    x_train_gray_offline, y_train_gray_offline = generate_offline_augmented_set(
        x_train_gray, y_train, random_grayscale_augmentation, seed=seed
    )
    run_experiment(
        name="grayscale_offline_aug_100k",
        x_train=x_train_gray_offline, y_train=y_train_gray_offline,
        x_test=x_test_gray, y_test=y_test,
        grayscale=False, augment=False, repeat=rep, seed=seed,
        offline_augmented=True
    )

# Run all other experiments
for exp in experiments:
    for rep in range(N_REPEATS):
        seed = 42 + rep
        run_experiment(repeat=rep, seed=seed, **exp)