In [None]:
# Run setup notebook from here
!git clone https://github.com/goodarzi64/GHI_Forecasting
%cd GHI_Forecasting
!git pull
%run /content/GHI_Forecasting/notebooks/00_colab_setup.ipynb
%run /content/GHI_Forecasting/notebooks/01_import_datasets.ipynb

In [None]:
import os, sys, torch
from google.colab import drive

drive.mount("/content/gdrive")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

sys.path.append("/content/GHI_Forecasting")

from src.data_pipeline import (
    load_geo_and_build_static_features,
    load_temporal_artifacts,
    build_temporal_from_prep_csv,
    save_temporal_artifacts,
    get_wind_positions,
)

geo_dir = "/content/gdrive/MyDrive/CYL_geo"
raw_prep_dir = "/content/gdrive/MyDrive/CYL_GHI/prep_files"
artifacts_dir = "/content/gdrive/MyDrive/CYL_GHI/prep_files"

geo = load_geo_and_build_static_features(geo_dir)
vars_geo, station_files = geo.vars_geo, geo.station_files

loading_flag = True
if loading_flag:
    temporal = load_temporal_artifacts(artifacts_dir)
else:
    temporal = build_temporal_from_prep_csv(
        station_files=station_files,
        raw_prep_dir=raw_prep_dir,
        keep_doy_tod=True,
    )
    save_temporal_artifacts(artifacts_dir, temporal)

temporal_node_tensor = temporal.temporal_node_tensor
temporal_target_tensor = temporal.temporal_target_tensor
df_cols = temporal.df_cols
masks = temporal.masks

wind_dir_pos, wind_sp_pos = get_wind_positions(df_cols, masks)
print("wind_dir_pos:", wind_dir_pos, "wind_sp_pos:", wind_sp_pos)


In [None]:
from src.Graph_build import build_geo_matrices, build_static_adjacency, WindAdjacency

# Compute geo matrices once (shared by static + wind adjacency)
geo_mats = build_geo_matrices(df_geo=geo.vars_geo, device=device)
dist_matrix = geo_mats["dist_matrix"]
theta_matrix = geo_mats["theta_matrix"]
# ---------------------------------------------------------------------
# Build static adjacency from precomputed distance matrix
graph = build_static_adjacency(
    dist_matrix=dist_matrix,
    k=5,
    self_loops=False,
    topk_sym=False,
)

A_raw = graph["A_raw"]
A_static1 = graph["A_topk"]
A_static2 = graph["A_row_norm"]
A_static3 = graph["A_sym_norm"]
# -------------------------------------------------------------------

wind_adj = WindAdjacency(
    D_ij=dist_matrix,
    Theta_ij=theta_matrix,
    wind_speed_pos= wind_dir_pos,
    wind_dir_pos= wind_sp_pos,
    R=150.0,
    lambda_theta=1.0
)


In [None]:
# Build datasets
from src.data_pipeline import GraphSequenceDataset
from torch.utils.data import DataLoader
from src.gconvgru import RecurrentGCN, GraphGateNodewise
from src.train_context_gated import train_joint_context_gated
from src.temporal_autoencoder import load_pretrained_embedor

lags, horizon = 6, 1
full_dataset = GraphSequenceDataset(
    temporal_node_tensor, temporal_target_tensor, masks, lags=lags, horizon=horizon
)
# Chronological split
N_total = len(full_dataset)
train_end = int(N_total * 0.33)
valid_end = int(N_total * 0.44)

train_dataset = torch.utils.data.Subset(full_dataset, range(0, train_end))
train_dataset2 = torch.utils.data.Subset(full_dataset, range(0, valid_end))
valid_dataset = torch.utils.data.Subset(full_dataset, range(train_end, valid_end))
test_dataset  = torch.utils.data.Subset(full_dataset, range(valid_end, N_total))

# DataLoaders with batch_size=10
# x_seq: [B, W, N, Fx]
# e_seq: [B, W, N, Fe]
# w_seq: [B, W, N, Fw]
# g_seq: [B, W, N, Fg]
# y_seq: [B, H, N]

train_loader = DataLoader(train_dataset, batch_size=10, shuffle=False)
train_loader2 = DataLoader(train_dataset2, batch_size=10, shuffle=False)
valid_loader = DataLoader(valid_dataset, batch_size=10, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=10, shuffle=False)


Fx = temporal_node_tensor[:, :, masks['mask_forecast']].shape[2]  # forecast features (without meteo_cat)
Fe = temporal_node_tensor[:,:,masks['mask_embed']].shape[2] # embed features
Fg = temporal_node_tensor[:,:,masks['mask_gate']].shape[2] # global features

model_copy = RecurrentGCN(node_feature_dim=Fx, filters=[64,16], horizon=horizon,alpha=0.0,mode='last').to(device)

embed_ckpt = "/content/gdrive/MyDrive/hparam_search_encoder/ed16_ch64_seed123_fold3.pt"

Embedor_copy = load_pretrained_embedor(
    ckpt_path=embed_ckpt,
    mask_cloud=torch.tensor(masks["mask_cloud"], dtype=torch.bool),
    device=device
).to(device)

gater_copy = GraphGateNodewise(in_dim=47, hidden=16, n_graphs=3).to(device)

cfg1 = {"use_static": True, "use_dynamic": True, "use_wind": True}
cfg2 = {"use_static": True, "use_dynamic": True, "use_wind": False}
cfg3 = {"use_static": True, "use_dynamic": False, "use_wind": True}
cfg4 = {"use_static": False, "use_dynamic": True, "use_wind": True}
cfg5 = {"use_static": False, "use_dynamic": False, "use_wind": True}
cfg6 = {"use_static": False, "use_dynamic": True, "use_wind": False}
cfg7 = {"use_static": True, "use_dynamic": False, "use_wind": False}


# ---train---
result = train_joint_context_gated(
    model=model_copy,
    Embedor=Embedor_copy,
    Gater=gater_copy,
    WindKernel=wind_adj,
    train_loader=train_loader,
    device=device,
    A_static=A_static1,
    test_loader=valid_loader,
    node_const=torch.tensor(vars_geo.values,dtype=torch.float32,device=device),
    epochs=1,
    lr=1e-3,
    temperature=1,
    topk_each=5,
    checkpoint_path="/content/gdrive/MyDrive/checkpoints/train_A_W_f6416_a0_hor1.pt",
    resume=False,
    ablation_config=cfg1,
)
