In [None]:
import os, json, csv, time
from pathlib import Path
from contextlib import nullcontext

results_path = Path("./grid_results.csv")
results_path.parent.mkdir(parents=True, exist_ok=True)

# Choose strategy
use_random_search = True
max_configs = 24  # if random search
k_folds = 2       # reduce K
epochs_max = 6    # allow higher max but prune early
warmup_epochs = 1 # evaluate/prune after this many epochs
early_margin = 0.01  # prune if clearly below best

# Update param_grid epochs if you like, but we'll override per-run
param_grid = {
    "channels": [32, 64],
    "dropout": [0.0, 0.25, 0.5],
    "lr": [6e-4, 1e-3, 6e-3, 1e-2],
    "weight_decay": [0.0, 1e-4],
    "batch_size": [256],      # increase if possible
    "optimizer": ["adam"],    # simplify for speed
    "epochs": [epochs_max],
}

# Build param list
all_params = list(param_product(param_grid))
if use_random_search and len(all_params) > max_configs:
    rng = np.random.default_rng(27)
    all_params = list(rng.choice(all_params, size=max_configs, replace=False))

# Precompute folds once
folds = list(kfold_indices(len(train_full), k=k_folds, seed=27))

# Build static val loaders per fold (batch_size large, no shuffle)
val_loaders = []
for _, val_idx in folds:
    val_ds = Subset(train_full, val_idx)
    val_loaders.append(DataLoader(
        val_ds, batch_size=512, shuffle=False,
        num_workers=4, pin_memory=(device=="cuda"),
        persistent_workers=True
    ))

# Utility: CSV logger
def log_result(row_dict):
    exists = results_path.exists()
    with results_path.open("a", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=row_dict.keys())
        if not exists:
            writer.writeheader()
        writer.writerow(row_dict)

# Utility: check if a (params, fold, epochs_done) result exists
def result_key(p, fold):
    # Make a stable, hashable string
    key = json.dumps({**p, "fold": fold}, sort_keys=True)
    return key

# AMP context
use_amp = (device == "cuda") or (device == "mps")
autocast_dtype = torch.float16 if device == "cuda" else torch.bfloat16
scaler = torch.cuda.amp.GradScaler(enabled=(device == "cuda"))

# Optional: cuDNN speed
if device == "cuda":
    torch.backends.cudnn.benchmark = True

best_score_global = -1.0
best_params = None

# Cache built train loaders per (fold, batch_size)
train_loaders_cache = {}

for p_idx, params in enumerate(all_params):
    start_config = time.time()
    cv_scores = []
    pruned = False

    # Track config best across folds for pruning
    config_fold_scores = []

    for f_idx, (train_idx, val_idx) in enumerate(folds):
        # Build/reuse train loader for this batch size
        key = (f_idx, params["batch_size"])
        if key not in train_loaders_cache:
            train_ds = Subset(train_full, train_idx)
            train_loaders_cache[key] = DataLoader(
                train_ds, batch_size=params["batch_size"], shuffle=True,
                num_workers=4, pin_memory=(device=="cuda"),
                persistent_workers=True
            )
        train_loader = train_loaders_cache[key]
        val_loader = val_loaders[f_idx]

        # Fresh model
        model = SimpleCNN(channels=params["channels"], dropout=params["dropout"]).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = make_optimizer(params["optimizer"], model.parameters(),
                                   lr=params["lr"], weight_decay=params["weight_decay"])

        # Train with AMP + early pruning
        best_val_this_fold = -1.0
        for epoch in range(epochs_max):
            model.train()
            running_loss = 0.0
            total = 0

            for X, y in train_loader:
                X = X.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                optimizer.zero_grad(set_to_none=True)
                if use_amp:
                    with torch.autocast(device_type=("cuda" if device=="cuda" else "cpu"),
                                        dtype=autocast_dtype):
                        logits = model(X)
                        loss = criterion(logits, y)
                    if device == "cuda":
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        loss.backward()
                        optimizer.step()
                else:
                    logits = model(X)
                    loss = criterion(logits, y)
                    loss.backward()
                    optimizer.step()
                running_loss += loss.item() * X.size(0)
                total += y.size(0)

            # Validate
            val_loss, val_acc = evaluate(model, val_loader, criterion=None)
            best_val_this_fold = max(best_val_this_fold, val_acc)
            # Log after each epoch for resumability/diagnostics
            log_result({
                "time": time.time(),
                "config_idx": p_idx,
                "fold": f_idx,
                "epoch": epoch+1,
                "val_acc": float(val_acc),
                "train_loss_epoch_avg": running_loss / max(total, 1),
                "params": json.dumps(params, sort_keys=True)
            })

            # Prune after warmup if clearly worse than current global best
            if epoch+1 >= warmup_epochs and best_score_global > 0:
                # Allow a small margin
                if val_acc < best_score_global - early_margin:
                    # Stop training this fold early
                    break

        cv_scores.append(best_val_this_fold)

        # Optional: fold-level pruning (if average so far will never catch up)
        mean_so_far = float(np.mean(cv_scores))
        # optimistic max possible if remaining folds reached 1.0
        optimistic_max = (mean_so_far * len(cv_scores) + (k_folds - len(cv_scores)) * 1.0) / k_folds
        if best_score_global > 0 and optimistic_max < best_score_global - early_margin:
            pruned = True
            break

    if len(cv_scores) == 0:
        continue
    mean_cv = float(np.mean(cv_scores))
    if mean_cv > best_score_global:
        best_score_global = mean_cv
        best_params = params
        print(f"New best: {best_score_global:.4f} with {best_params}")

    # Summarize config
    log_result({
        "time": time.time(),
        "config_idx": p_idx,
        "fold": "ALL",
        "epoch": "final",
        "val_acc": mean_cv,
        "train_loss_epoch_avg": "",
        "params": json.dumps(params, sort_keys=True)
    })

print("Best CV accuracy:", best_score_global)
print("Best params:", best_params)
print(f"Results saved to {results_path.resolve()}")