In [None]:
from src.temporal_autoencoder import pretrain_en_de_with_regularizers

# Example usage: pretrain embeddings on temporal tensor
# (Assumes temporal_node_tensor and masks are already defined above)
model, logs, best_epoch = pretrain_en_de_with_regularizers(
    train_tensor=temporal_node_tensor,
    val_tensor=None,
    in_dim=masks["mask_embed"].sum().item(),
    embed_dim=16,
    conv_hidden=128,
    window=6,
    use_attention=True,
    batch_size=8,
    lr=1e-3,
    epochs=20,
    device=device,
    early_stopping_patience=5,
    mask_embed=masks["mask_embed"],
    mask_cloud=masks["mask_cloud"],
    save_path=None,
    verbose=True,
)


In [None]:
from src.cv_splits import make_expanding_folds, extract_fold_data

T = temporal_node_tensor.shape[0]
first_train_end = int(T * 0.33)
val_window = int(T * 0.11)
n_folds = 3

folds = make_expanding_folds(
    T=T,
    train_start=0,
    first_train_end=first_train_end,
    val_window=val_window,
    n_folds=n_folds,
)

print("FOLD SPLITS:")
for i, f in enumerate(folds):
    print(f"Fold {i+1}: Train {f['train_slice']}  Val {f['val_slice']}")

train_data, val_data = extract_fold_data(temporal_node_tensor, folds, fold_idx=0)


In [None]:
from src.hparam_search_encoder import HParamConfig, run_hparam_search, select_best

cfg = HParamConfig(
    embed_dims=[8, 16, 32],
    conv_hiddens=[32, 64, 128, 256],
    seeds=[123],
    folds=[0, 1, 2],
    window=12,
    use_attention=True,
    batch_size=8,
    epochs=20,
    lr=1e-3,
    early_stopping_patience=5,
)

results = run_hparam_search(
    temporal_node_tensor=temporal_node_tensor,
    folds=folds,
    mask_embed=masks["mask_embed"].bool().to(device),
    mask_cloud=masks["mask_cloud"].bool().to(device),
    base_dir="/content/gdrive/MyDrive/hparam_search_encoder",
    cfg=cfg,
    device=device,
    verbose=True,
)

best_cfg = select_best(results)
print(best_cfg)


In [None]:
from src.hparam_results import collect_hparam_records, summarize_hparam_results

base_dir = "/content/gdrive/MyDrive/hparam_search_encoder"
df = collect_hparam_records(base_dir)
summary = summarize_hparam_results(df)

print(summary)
