In [None]:
import sys
sys.path.append('.')  # Ensure OptimizedDataGenerator4.py is discoverable

import OptimizedDataGenerator4 as ODG
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from pathlib import Path
import matplotlib.pyplot as plt

# ─── GPU memory growth ─────────────────────────────────────────────────────────
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

# ─── Model definition ─────────────────────────────────────────────────────────
def CreateXYProfileModel():
    x_profile = tf.keras.Input(shape=(21, 1), name="x_profile")
    y_profile = tf.keras.Input(shape=(13, 1), name="y_profile")
    x_flat   = tf.keras.layers.Flatten(name="flatten_x")(x_profile)
    y_flat   = tf.keras.layers.Flatten(name="flatten_y")(y_profile)
    concat   = tf.keras.layers.Concatenate(name="concat_xy")([x_flat, y_flat])
    hidden1  = tf.keras.layers.Dense(64, activation="relu", name="hidden_128")(concat)
    hidden2  = tf.keras.layers.Dense(16, activation="relu", name="hidden_32")(hidden1)
    output   = tf.keras.layers.Dense(1, activation="sigmoid", name="output")(hidden2)
    return tf.keras.Model(inputs=[x_profile, y_profile], outputs=output)

# ─── Data generators factory ───────────────────────────────────────────────────
BASE_DIR       = Path("./filtering_records2000")
TRAIN_DIR      = BASE_DIR / "tfrecords_train"
VALIDATION_DIR = BASE_DIR / "tfrecords_validation"

def make_gens():
    train_gen = ODG.OptimizedDataGenerator(
        load_records=True,
        tf_records_dir=str(TRAIN_DIR),
        x_feature_description=["x_profile", "y_profile"],
    )
    val_gen = ODG.OptimizedDataGenerator(
        load_records=True,
        tf_records_dir=str(VALIDATION_DIR),
        x_feature_description=["x_profile", "y_profile"],
    )
    return train_gen, val_gen

# ─── Single‐run trainer (returns full history) ────────────────────────────────
def train_and_evaluate(config):
    train_gen, val_gen = make_gens()
    steps = len(train_gen)
    epochs = 120

    model = CreateXYProfileModel()
    callbacks = [EarlyStopping(monitor="val_loss", patience=20, restore_best_weights=True)]

    kind = config["type"]
    if kind == "constant":
        optimizer = tf.keras.optimizers.Adam(learning_rate=config["lr"])
    elif kind == "cosine_decay":
        sched = tf.keras.optimizers.schedules.CosineDecay(
            initial_learning_rate=config["initial_lr"],
            decay_steps=steps * epochs,
            alpha=config.get("alpha", 0.0)
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=sched)
    elif kind == "cosine_restarts":
        sched = tf.keras.optimizers.schedules.CosineDecayRestarts(
            initial_learning_rate=config["initial_lr"],
            first_decay_steps=(steps * epochs) // config.get("restarts_divisor", 3),
            alpha=config.get("alpha", 0.0)
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=sched)
    elif kind == "exponential_decay":
        sched = tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=config["initial_lr"],
            decay_steps=(steps * epochs) // config.get("decay_divisor", 10),
            decay_rate=config.get("decay_rate", 0.96),
            staircase=config.get("staircase", True)
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=sched)
    elif kind == "polynomial_decay":
        sched = tf.keras.optimizers.schedules.PolynomialDecay(
            initial_learning_rate=config["initial_lr"],
            decay_steps=steps * epochs,
            end_learning_rate=config.get("end_lr", 1e-5),
            power=config.get("power", 1.0)
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=sched)
    elif kind == "inverse_time_decay":
        sched = tf.keras.optimizers.schedules.InverseTimeDecay(
            initial_learning_rate=config["initial_lr"],
            decay_steps=(steps * epochs) // config.get("decay_divisor", 10),
            decay_rate=config.get("decay_rate", 1.0),
            staircase=config.get("staircase", True)
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=sched)
    elif kind == "piecewise":
        sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            boundaries=config["boundaries"],
            values=config["values"]
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=sched)
    elif kind == "reduce_on_plateau":
        optimizer = tf.keras.optimizers.Adam(learning_rate=config["lr"])
        callbacks.append(ReduceLROnPlateau(
            monitor="val_loss",
            factor=config.get("factor", 0.5),
            patience=config.get("patience", 10),
            verbose=1
        ))
    else:
        raise ValueError(f"Unknown scheduler type {kind!r}")

    model.compile(optimizer=optimizer, loss="binary_crossentropy", metrics=["accuracy"])
    hist = model.fit(
        train_gen,
        validation_data=val_gen,
        epochs=epochs,
        steps_per_epoch=steps,
        callbacks=callbacks,
        shuffle=False,
        verbose=1
    ).history

    return hist

# ─── Configurations including all schedulers ──────────────────────────────────
configs = [
    {"name":"const_1e-3",     "type":"constant",         "lr":1e-3},
    {"name":"const_1e-4",     "type":"constant",         "lr":1e-4},
    {"name":"const_1e-2",     "type":"constant",         "lr":1e-2},
    {"name":"cosine_decay",   "type":"cosine_decay",     "initial_lr":1e-3, "alpha":0.0},
    {"name":"cosine_restarts","type":"cosine_restarts",  "initial_lr":1e-3, "restarts_divisor":3, "alpha":0.0},
    {"name":"exp_decay",      "type":"exponential_decay","initial_lr":1e-3, "decay_rate":0.96, "decay_divisor":10, "staircase":True},
    {"name":"poly_decay",     "type":"polynomial_decay", "initial_lr":1e-3, "end_lr":1e-5, "power":2.0},
    {"name":"inv_time_decay", "type":"inverse_time_decay","initial_lr":1e-3, "decay_rate":1.0, "decay_divisor":10, "staircase":True},
    {"name":"piecewise",      "type":"piecewise",        "boundaries":[3000,6000], "values":[1e-3,1e-4,1e-5]},
    {"name":"reduce_plateau", "type":"reduce_on_plateau","lr":1e-3, "factor":0.5, "patience":10},
]

# ─── Run sequentially and display progress ────────────────────────────────────
histories = {}
for cfg in configs:
    name = cfg["name"]
    print(f"\n=== Running scheduler: {name} ===")
    sys.stdout.flush()
    histories[name] = train_and_evaluate(cfg)
    print(f"Completed: {name}")
    sys.stdout.flush()

# ─── Plot all training & validation accuracy curves ───────────────────────────
plt.figure(figsize=(10, 6))
for name, h in histories.items():
    plt.plot(h['accuracy'], label=f'{name} train')
    plt.plot(h['val_accuracy'], '--', label=f'{name} val')
plt.title('Training & Validation Accuracy by LR Scheduler')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right', fontsize='small')
plt.grid(True)
plt.show()



=== Running scheduler: const_1e-3 ===




Epoch 1/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 110ms/step - accuracy: 0.4312 - loss: 1.6461 - val_accuracy: 0.5964 - val_loss: 1.0623
Epoch 2/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 71ms/step - accuracy: 0.3660 - loss: 1.2905 - val_accuracy: 0.6413 - val_loss: 0.7171
Epoch 3/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 70ms/step - accuracy: 0.5312 - loss: 0.8708 - val_accuracy: 0.6569 - val_loss: 0.6978
Epoch 4/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 71ms/step - accuracy: 0.5200 - loss: 0.8621 - val_accuracy: 0.6629 - val_loss: 0.6799
Epoch 5/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 70ms/step - accuracy: 0.5263 - loss: 0.8339 - val_accuracy: 0.6669 - val_loss: 0.6645
Epoch 6/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 71ms/step - accuracy: 0.5368 - loss: 0.8081 - val_accuracy: 0.6728 - val_loss: 0.6527
Epoch 7/120
[1m30/30[0m 



Epoch 1/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 113ms/step - accuracy: 0.5655 - loss: 0.9844 - val_accuracy: 0.6371 - val_loss: 0.8069
Epoch 2/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 70ms/step - accuracy: 0.5895 - loss: 0.8685 - val_accuracy: 0.6500 - val_loss: 0.7490
Epoch 3/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 72ms/step - accuracy: 0.6230 - loss: 0.7884 - val_accuracy: 0.6498 - val_loss: 0.7117
Epoch 4/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 72ms/step - accuracy: 0.6320 - loss: 0.7376 - val_accuracy: 0.6514 - val_loss: 0.6862
Epoch 5/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 71ms/step - accuracy: 0.6368 - loss: 0.7042 - val_accuracy: 0.6540 - val_loss: 0.6676
Epoch 6/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 114ms/step - accuracy: 0.6391 - loss: 0.6818 - val_accuracy: 0.6549 - val_loss: 0.6541
Epoch 7/120
[1m30/30[0m 



Epoch 1/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 114ms/step - accuracy: 0.6616 - loss: 3.0542 - val_accuracy: 0.5550 - val_loss: 0.6891
Epoch 2/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 72ms/step - accuracy: 0.4161 - loss: 0.7064 - val_accuracy: 0.5550 - val_loss: 0.6895
Epoch 3/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 73ms/step - accuracy: 0.4161 - loss: 0.7045 - val_accuracy: 0.5550 - val_loss: 0.6893
Epoch 4/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 73ms/step - accuracy: 0.4161 - loss: 0.7050 - val_accuracy: 0.5550 - val_loss: 0.6892
Epoch 5/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 73ms/step - accuracy: 0.4161 - loss: 0.7056 - val_accuracy: 0.5550 - val_loss: 0.6891
Epoch 6/120
[1m30/30[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 72ms/step - accuracy: 0.4161 - loss: 0.7061 - val_accuracy: 0.5550 - val_loss: 0.6890
Epoch 7/120
[1m30/30[0m 