# 02 â€” Train TF2 models (ResNet + VGG) and save `.keras` weights

This notebook trains ResNet/VGG models on CIFAR-10 using TF2.11 and saves weights to `data/models/`. 

In [None]:
import os
import sys
import yaml
import numpy as np
import tensorflow as tf

from pathlib import Path
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
sys.path.append("./source")

def load_yaml(path):
    with open(path, "r") as f:
        return yaml.safe_load(f)

PATHS : yaml = load_yaml("./configs/paths.yaml")
EXP : yaml  = load_yaml("./configs/exp.yaml")

data_root    = PATHS["data_root"]
cifar10_root = PATHS["cifar10_root"]
tf_model_dir = PATHS["tf_model_dir"]

seed = int(EXP["seed"])
train_cfg = EXP.get("training", {})


np.random.seed(seed)
tf.random.set_seed(seed)

## Load CIFAR-10

In [None]:

(x_all, y_all), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Flatten labels
y_all  = y_all.reshape(-1).astype(np.int64)
y_test = y_test.reshape(-1).astype(np.int64)

# Train/val split from training portion
x_train, x_val, y_train, y_val = train_test_split(
    x_all, y_all,
    test_size=0.2,
    random_state=seed,
    stratify=y_all
)

# scale to [0,1]
x_train = x_train.astype(np.float32) / 255.0
x_val   = x_val.astype(np.float32) / 255.0
x_test  = x_test.astype(np.float32) / 255.0

# one-hot
y_train_categorical = to_categorical(y_train, 10)
y_val_categorical   = to_categorical(y_val, 10)
y_test_categorical  = to_categorical(y_test, 10)


In [None]:
x_train, x_val, y_train, y_val = train_test_split(
    x_all, y_all,
    test_size=0.2,
    random_state=seed,
    stratify=y_all
)

# scale to [0,1]
x_train = x_train.astype(np.float32) / 255.0
x_val   = x_val.astype(np.float32) / 255.0
x_test  = x_test.astype(np.float32) / 255.0

# one-hot
y_train_categorical = to_categorical(y_train, 10)
y_val_categorical   = to_categorical(y_val, 10)
y_test_categorical  = to_categorical(y_test, 10)

In [None]:
MODEL_DIR = Path(tf_model_dir)
MODEL_DIR.mkdir(parents=True, exist_ok=True)

def mpath(name: str) -> str:
    return str(MODEL_DIR / f"{name}.keras")

print("MODEL_DIR:", str(MODEL_DIR.resolve()))

## Train ResNet models (TF2)
Models are saved to `data/models/*.keras`.

In [None]:
from source.resnet import ResNet

In [None]:
resnet_specs = [
    ("resnet_v1_n3_d20",  1, 3),
    ("resnet_v1_n9_d56",  1, 9),
    ("resnet_v2_n3_d20",  2, 2),
    ("resnet_v2_n9_d56",  2, 6),
]

trained_resnets = []

for name, version, n in resnet_specs:
    path = mpath(name)


    m = ResNet(
        path,
        x_train, y_train_categorical,
        x_val,   y_val_categorical,
        subtract_pixel_mean=False,
        version=version,
        n=n,
    )
    m.train(save_best_only=True, epochs=200, loss="categorical_crossentropy")
    

## Train VGG models (TF2)
We train raw VGG variants and save weights to `data/models/*.keras`.

In [None]:
from source.vgg import VGG

In [None]:
vgg_specs = [
    ("vgg11_raw", "vgg11"),
    ("vgg13_raw", "vgg13"),
    ("vgg16_raw", "vgg16"),
    ("vgg19_raw", "vgg19"),
]

trained_vggs = []

for name, arch in vgg_specs:
    path = mpath(name)

    m = VGG(
        path,
        x_train, y_train_categorical,
        x_val,   y_val_categorical,
        model_name=arch,
        num_classes=10,
        subtract_pixel_mean=False,
    )
    m.train(save_best_only=True, epochs=200, loss="categorical_crossentropy")
    