### Load Data

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch_geometric_temporal import ASTGCN
from torch_geometric.utils import dense_to_sparse
from torch.amp import GradScaler, autocast
from tqdm.auto import tqdm

# your existing utils
from utils.concatenate_data import run_pipeline
from utils.build_station_weight import build_station_weights
from model.model_core_architecture import ASTGCN_V2
# 1. Load & preprocess
all_data_df, long_df = run_pipeline(
    root_xlsx_dir="Load-data",
    cleaned_csv_dir="cleaned_data",
    preprocessed_csv_dir="preprocessed_data",
    final_wide_csv="all_data_df.csv",
    final_long_csv="all_data_timeseries.csv",
)
long_df.loc[long_df['Electricity(kW)'] < 0, 'Electricity(kW)'] = 0
station_weights_df = build_station_weights(long_df)
print("Station weights:\n", station_weights_df)

# 2. Three-way split per station (70% train / 10% eval / 20% test)
def split_threeway(df, train_frac=0.7, eval_frac=0.1):
    train_list, eval_list, test_list = [], [], []
    for station, sdf in df.groupby('station_name'):
        sdf = sdf.sort_values('Date')
        n = len(sdf)
        n_train = int(n * train_frac)
        n_eval  = int(n * (train_frac + eval_frac)) - n_train
        train_list.append(sdf.iloc[:n_train])
        eval_list.append (sdf.iloc[n_train:n_train + n_eval])
        test_list.append (sdf.iloc[n_train + n_eval:])
    train_df = pd.concat(train_list).reset_index(drop=True)
    eval_df  = pd.concat(eval_list).reset_index(drop=True)
    test_df  = pd.concat(test_list).reset_index(drop=True)
    return train_df, eval_df, test_df

train_df, eval_df, test_df = split_threeway(long_df, train_frac=0.7, eval_frac=0.1)
print(f"Rows → train: {len(train_df)}, eval: {len(eval_df)}, test: {len(test_df)}")

# 3. Graph structure (fully connected)
station_names = sorted(long_df['station_name'].unique())
num_nodes     = len(station_names)
edge_index    = torch.tensor(
    [[i, j] for i in range(num_nodes) for j in range(num_nodes) if i != j],
    dtype=torch.long
).t().contiguous()

# 4. Build sliding‐window tensors
def pivot_to_tensor(df, seq_len, station_names):
    pv = df.pivot(index='Date', columns='station_name', values='Electricity(kW)')
    pv = pv[station_names].fillna(0.0)
    arr = []
    for i in range(len(pv) - seq_len + 1):
        w = pv.iloc[i : i + seq_len].values  # (seq_len, N)
        arr.append(w.T)                     # (N, seq_len)
    return torch.tensor(np.stack(arr, axis=0), dtype=torch.float)

len_input = 96
pred_len  = 96

arr_tr = pivot_to_tensor(train_df, len_input + pred_len, station_names)
X_tr, Y_tr = arr_tr[:, :, :len_input], arr_tr[:, :, len_input:]
arr_ev = pivot_to_tensor(eval_df,  len_input + pred_len, station_names)
X_ev, Y_ev = arr_ev[:, :, :len_input], arr_ev[:, :, len_input:]
arr_te = pivot_to_tensor(test_df, len_input + pred_len, station_names)
X_te, Y_te = arr_te[:, :, :len_input], arr_te[:, :, len_input:]

# 5. Dataset & DataLoaders
class TemporalDataset(Dataset):
    def __init__(self, X, Y):
        self.X, self.Y = X, Y
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, i):
        return self.X[i], self.Y[i]

batch_size = 512
train_loader = DataLoader(TemporalDataset(X_tr, Y_tr), batch_size=batch_size, shuffle=True)
eval_loader  = DataLoader(TemporalDataset(X_ev, Y_ev), batch_size=batch_size, shuffle=False)
test_loader  = DataLoader(TemporalDataset(X_te, Y_te), batch_size=batch_size, shuffle=False)

### Visualize Graph

In [None]:
# main.py
import networkx as nx
import torch
import matplotlib.pyplot as plt
from matplotlib import font_manager
from utils.visualizer import create_graph_map, plot_static_graph

# Assuming your data variables are already defined:
# locations, edge_index, edge_weight, station_weights_df, train_df

# Build station list and ratios
station_names = list(locations.keys())
ei = edge_index.cpu().numpy()
ew = edge_weight.cpu().numpy()
station_weights = ew.tolist()

# Build NetworkX graph
G = nx.Graph()
for i, name in enumerate(station_names):
    G.add_node(i)
for (u, v), w in zip(ei.T, ew):
    G.add_edge(u, v, weight=float(w))

# Thai font for Matplotlib
font_path = r"c:\Users\patar\Documents\superai-intern\superaiss5-intern-vpp\Prompt_Font\Prompt-Regular.ttf"
thai_font = font_manager.FontProperties(fname=font_path)

# Recreate pos dict
pos = {i: locations[name][::-1] for i, name in enumerate(station_names)}

# 1) Save Folium map
m = create_graph_map(locations, ew.tolist(), G, enable_satellite=True)
m.save("graph_map.html")
print("Saved graph_map.html with Folium visualization.")

# 2) Show static Matplotlib graph
plot_static_graph(G, pos, station_names, ew.tolist(), locations, thai_font)