# DSTGCN Model for Accident Prediction

Đây là notebook gộp toàn bộ mã nguồn từ các file Python để huấn luyện mô hình DSTGCN. Notebook được chia thành các phần chính:
1. **Cài đặt & Imports**: Cài đặt các thư viện cần thiết và import chúng.
2. **Cấu hình (Configuration)**: Tạo và tải file cấu hình.
3. **Hàm tiện ích (Utilities)**: Bao gồm các hàm tính loss, đánh giá, chuyển đổi tọa độ, và các hàm hỗ trợ khác.
4. **Kiến trúc Model (Model Architecture)**: Định nghĩa các lớp mạng neural, bao gồm các layer và mô hình DSTGCN chính.
5. **Tải dữ liệu (Data Loading)**: Định nghĩa Dataset và DataLoader để nạp và xử lý dữ liệu.
6. **Hàm Huấn luyện (Training Function)**: Chứa logic cho vòng lặp huấn luyện, validation và test.
7. **Thực thi chính (Main Execution)**: Chạy toàn bộ quy trình huấn luyện và lưu kết quả.

### 1. Cài đặt & Imports

In [None]:
!pip install torchdata==0.7.1 --quiet
!pip install dgl -f https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html --quiet

In [None]:
!pip install tensorboardX pandas numpy networkx tqdm scikit-learn scipy --quiet

In [None]:
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

In [None]:
import torch
import dgl
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from tensorboardX import SummaryWriter
import pandas as pd
import numpy as np
import networkx as nx
from tqdm import tqdm
import json
import os
import shutil
import copy
import time
import datetime
import math
import sys
import warnings
from typing import List, Dict, Tuple
from dgl import init as g_init
from dgl.nn.pytorch import GraphConv
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
from sklearn import metrics

# Bỏ qua các cảnh báo không cần thiết
warnings.filterwarnings('ignore')

### 2. Cấu hình (Configuration)

In [None]:
config_content = '''{
  "model_name": "DSTGCN",
  "train_repeat_times": 1,
  "epochs": 100,
  "cuda": 0,
  "temporal_features_number": 1,
  "weather_features_number": 38,
  "external_temporal_features_number": 5,
  "poi_features_number": 22,
  "batch_size": 64,
  "learning_rate": 1e-3,
  "loss_function": "bce_loss",
  "optim": "Adam",
  "weight_decay": 1e-3,
  "K_hop": 10,
  "spatial_features_mean": [
    6.40254292e+00,
    4.81553347e+00,
    2.64811501e+01,
    5.54260681e+00,
    7.84915710e+00,
    9.01424670e+00,
    1.05751666e+00,
    2.83014290e+01,
    1.77408007e+01,
    3.24747843e+01,
    1.05794121e+00,
    3.47343126e+00,
    6.26610390e+00,
    3.39390529e+00,
    1.28002351e+01,
    5.71796174e-02,
    2.36710361e-02,
    5.17406913e-02,
    8.18431583e-01,
    4.57574675e-01,
    1.61370302e+02,
    2.00511484e+00
  ],
  "spatial_features_std": [
    8.25421040e+00,
    6.98772605e+00,
    3.12181863e+01,
    7.78804699e+00,
    1.12482891e+01,
    8.86203753e+00,
    4.88544842e+00,
    4.65677369e+01,
    1.88312515e+01,
    6.70090744e+01,
    2.95457633e+00,
    4.29173854e+00,
    9.78727429e+00,
    3.76843985e+00,
    1.19405955e+01,
    2.71230076e-01,
    1.88342431e-01,
    2.82777835e-01,
    2.18729099e+00,
    1.55007715e+00,
    2.04879612e+02,
    7.87830458e-02
  ],
  "temporal_features_mean": [
    18.54884516
  ],
  "temporal_features_std": [
    19.87195773
  ],
  "external_features_mean": [
    7.02588369e+01,
    5.28922463e+01,
    5.95480331e+01,
    2.98048133e+01,
    6.23360889e+00,
    7.17909065e+01,
    1.31128848e-02,
    8.48346636e-01,
    4.86031927e-02,
    1.65336374e-02,
    1.71037628e-03,
    1.71037628e-02,
    1.51083238e-02,
    7.41163056e-03,
    2.28050171e-03,
    1.58209806e-02,
    1.25427594e-02,
    0.00000000e+00,
    1.42531357e-04,
    1.28278221e-03,
    2.10946408e-02,
    5.23090080e-02,
    2.53705815e-02,
    4.53249715e-02,
    1.62058153e-01,
    1.96693273e-02,
    4.53249715e-02,
    6.01482326e-02,
    6.14310148e-02,
    5.60148233e-02,
    6.07183580e-02,
    5.91505131e-02,
    1.58209806e-02,
    2.28050171e-03,
    2.29618016e-01,
    2.87913341e-02,
    5.11687571e-02,
    3.70581528e-03,
    8.90992018e+00,
    1.67015393e+01,
    2.72548461e+00,
    1.22075257e+01,
    1.92702395e-01
  ],
  "external_features_std": [
    1.45939773e+01,
    1.86117750e+01,
    2.39727858e+01,
    2.15191858e-01,
    5.02840584e+00,
    1.71451636e+01,
    1.05974340e-01,
    3.50647069e-01,
    2.05374727e-01,
    1.19435494e-01,
    2.91935396e-02,
    1.22304197e-01,
    1.10326636e-01,
    6.60570950e-02,
    3.95302952e-02,
    1.18032057e-01,
    9.83734863e-02,
    1.00000000e+00,
    8.44069685e-03,
    2.52931923e-02,
    1.17508093e-01,
    1.82744266e-01,
    1.15431393e-01,
    1.69085448e-01,
    3.23940225e-01,
    9.79301412e-02,
    1.64816793e-01,
    1.96033008e-01,
    2.05755011e-01,
    1.93423139e-01,
    2.05099602e-01,
    1.97963022e-01,
    1.01824764e-01,
    3.76843851e-02,
    3.46547020e-01,
    1.50627796e-01,
    1.96944453e-01,
    4.76105938e-02,
    8.32550240e-01,
    8.70795832e+00,
    1.86828552e+00,
    5.96074486e+00,
    3.94421325e-01
  ]
}'''  # (cut for brevity — insert your full JSON here)

with open("config.json", "w") as f:
    f.write(config_content)

In [None]:
# Tải cấu hình từ file
with open("config.json") as file:
    config = json.load(file)

def get_attribute(name, defaultValue=None):
    try:
        return config[name]
    except KeyError:
        return defaultValue

config['device'] = f'cuda:{get_attribute("cuda")}' if torch.cuda.is_available() and get_attribute("cuda") >= 0 else 'cpu'
print(f"Using device: {config['device']}")

### 3. Hàm tiện ích (Utilities)

In [None]:
# Module: coordTransform_utils.py & coord_converter.py

x_pi = 3.14159265358979324 * 3000.0 / 180.0
pi = 3.1415926535897932384626  # π
a = 6378245.0  # 長半軸
ee = 0.00669342162296594323  # 偏心率平方

def gcj02_to_bd09(lng, lat):
    z = math.sqrt(lng * lng + lat * lat) + 0.00002 * math.sin(lat * x_pi)
    theta = math.atan2(lat, lng) + 0.000003 * math.cos(lng * x_pi)
    bd_lng = z * math.cos(theta) + 0.0065
    bd_lat = z * math.sin(theta) + 0.006
    return [bd_lng, bd_lat]

def bd09_to_gcj02(bd_lon, bd_lat):
    x = bd_lon - 0.0065
    y = bd_lat - 0.006
    z = math.sqrt(x * x + y * y) - 0.00002 * math.sin(y * x_pi)
    theta = math.atan2(y, x) - 0.000003 * math.cos(x * x_pi)
    gg_lng = z * math.cos(theta)
    gg_lat = z * math.sin(theta)
    return [gg_lng, gg_lat]

def wgs84_to_gcj02(lng, lat):
    if out_of_china(lng, lat):
        return [lng, lat]
    dlat = _transformlat(lng - 105.0, lat - 35.0)
    dlng = _transformlng(lng - 105.0, lat - 35.0)
    radlat = lat / 180.0 * pi
    magic = math.sin(radlat)
    magic = 1 - ee * magic * magic
    sqrtmagic = math.sqrt(magic)
    dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi)
    dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi)
    mglat = lat + dlat
    mglng = lng + dlng
    return [mglng, mglat]

def gcj02_to_wgs84(lng, lat):
    if out_of_china(lng, lat):
        return [lng, lat]
    dlat = _transformlat(lng - 105.0, lat - 35.0)
    dlng = _transformlng(lng - 105.0, lat - 35.0)
    radlat = lat / 180.0 * pi
    magic = math.sin(radlat)
    magic = 1 - ee * magic * magic
    sqrtmagic = math.sqrt(magic)
    dlat = (dlat * 180.0) / ((a * (1 - ee)) / (magic * sqrtmagic) * pi)
    dlng = (dlng * 180.0) / (a / sqrtmagic * math.cos(radlat) * pi)
    mglat = lat + dlat
    mglng = lng + dlng
    return [lng * 2 - mglng, lat * 2 - mglat]

def bd09_to_wgs84(bd_lon, bd_lat):
    lon, lat = bd09_to_gcj02(bd_lon, bd_lat)
    return gcj02_to_wgs84(lon, lat)

def wgs84_to_bd09(lon, lat):
    lon, lat = wgs84_to_gcj02(lon, lat)
    return gcj02_to_bd09(lon, lat)

def _transformlat(lng, lat):
    ret = -100.0 + 2.0 * lng + 3.0 * lat + 0.2 * lat * lat + \
          0.1 * lng * lat + 0.2 * math.sqrt(math.fabs(lng))
    ret += (20.0 * math.sin(6.0 * lng * pi) + 20.0 *
            math.sin(2.0 * lng * pi)) * 2.0 / 3.0
    ret += (20.0 * math.sin(lat * pi) + 40.0 *
            math.sin(lat / 3.0 * pi)) * 2.0 / 3.0
    ret += (160.0 * math.sin(lat / 12.0 * pi) + 320 *
            math.sin(lat * pi / 30.0)) * 2.0 / 3.0
    return ret

def _transformlng(lng, lat):
    ret = 300.0 + lng + 2.0 * lat + 0.1 * lng * lng + \
          0.1 * lng * lat + 0.1 * math.sqrt(math.fabs(lng))
    ret += (20.0 * math.sin(6.0 * lng * pi) + 20.0 *
            math.sin(2.0 * lng * pi)) * 2.0 / 3.0
    ret += (20.0 * math.sin(lng * pi) + 40.0 *
            math.sin(lng / 3.0 * pi)) * 2.0 / 3.0
    ret += (150.0 * math.sin(lng / 12.0 * pi) + 300.0 *
            math.sin(lng / 30.0 * pi)) * 2.0 / 3.0
    return ret

def out_of_china(lng, lat):
    return not (lng > 73.66 and lng < 135.05 and lat > 3.86 and lat < 53.55)

def convert_by_type(lng, lat, type):
    if type == 'g2b':
        return gcj02_to_bd09(lng, lat)
    elif type == 'b2g':
        return bd09_to_gcj02(lng, lat)
    elif type == 'w2g':
        return wgs84_to_gcj02(lng, lat)
    elif type == 'g2w':
        return gcj02_to_wgs84(lng, lat)
    elif type == 'b2w':
        return bd09_to_wgs84(lng, lat)
    elif type == 'w2b':
        return wgs84_to_bd09(lng, lat)
    else:
        print('Usage: type must be in one of g2b, b2g, w2g, g2w, b2w, w2b')
        sys.exit()

In [None]:
# Module: loss.py
class MSELoss(nn.Module):
    def __init__(self):
        super(MSELoss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction='sum')

    def forward(self, truth, predict):
        loss = self.mse_loss(predict, truth)
        return loss

class BCELoss(nn.Module):
    def __init__(self):
        super(BCELoss, self).__init__()
        self.bce_loss = nn.BCELoss()

    def forward(self, truth, predict):
        loss = self.bce_loss(predict, truth)
        return loss

# Module: metric.py
def evaluate(y_predictions: np.ndarray, y_targets: np.ndarray, threshold: float = 0.5):
    assert y_predictions.shape == y_targets.shape, \
        f'Predictions of shape {y_predictions.shape} while targets of shape {y_predictions.shape}.'
    rmse = mean_squared_error(y_targets, y_predictions) ** 0.5
    pcc, _ = pearsonr(y_predictions, y_targets)

    y_predictions_class = y_predictions >= threshold
    y_targets_class = y_targets == 1

    tp = ((y_predictions_class == 1) & (y_targets_class == 1)).sum()
    fp = ((y_predictions_class == 1) & (y_targets_class == 0)).sum()
    fn = ((y_predictions_class == 0) & (y_targets_class == 1)).sum()

    if tp + fp != 0:
        precision = tp / (tp + fp)
    else:
        precision = 0.0
    if tp + fn != 0:
        recall = tp / (tp + fn)
    else:
        recall = 0.0
    if precision + recall != 0:
        f1_score = 2 * (precision * recall) / (precision + recall)
    else:
        f1_score = 0.0

    # Use original probabilities for AUC
    auc = metrics.roc_auc_score(y_targets_class, y_predictions)

    # Create a dictionary of metrics
    results = {
        'RMSE': rmse,
        'PCC': pcc,
        'PRECISION': precision,
        'RECALL': recall,
        'F1-SCORE': f1_score,
        'AUC': auc
    }

    return results

# Module: util.py
def convert_to_gpu(data):
    return data.to(get_attribute('device'))

def convert_train_truth_to_gpu(train_data, truth_data):
    train_data = [convert_to_gpu(data) for data in train_data]
    truth_data = convert_to_gpu(truth_data)
    return train_data, truth_data

def load_model(model, modelFilePath):
    model.load_state_dict(torch.load(modelFilePath))
    return model

def save_model(path: str, **save_dict):
    os.makedirs(os.path.split(path)[0], exist_ok=True)
    torch.save(save_dict, path)

### 4. Kiến trúc Model (Model Architecture)

In [None]:
# Module: fully_connected.py
class fully_connected_layer(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(fully_connected_layer, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size
        fcList = []
        reluList = []
        
        current_size = input_size
        for index in range(len(self.hidden_size)):
            fc = nn.Linear(current_size, self.hidden_size[index])
            setattr(self, f'fc{index}', fc)
            fcList.append(fc)
            relu = nn.ReLU()
            setattr(self, f'relu{index}', relu)
            reluList.append(relu)
            current_size = self.hidden_size[index]
            
        self.last_fc = nn.Linear(self.hidden_size[-1], self.output_size)

        self.fcList = nn.ModuleList(fcList)
        self.reluList = nn.ModuleList(reluList)

    def forward(self, input_tensor):
        for idx in range(len(self.fcList)):
            out = self.fcList[idx](input_tensor)
            out = self.reluList[idx](out)
            input_tensor = out
        output_tensor = self.last_fc(input_tensor)
        return output_tensor

In [None]:
# Module: spatial_layer.py
class GCN(nn.Module):
    def __init__(self, in_features: int, hidden_sizes: List[int], out_features: int):
        super(GCN, self).__init__()
        gcns, relus, bns = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        for idx, hidden_size in enumerate(hidden_sizes):
            if idx == 0:
                gcns.append(GraphConv(in_features, hidden_size))
                relus.append(nn.ReLU())
                bns.append(nn.BatchNorm1d(hidden_size))
            else:
                gcns.append(GraphConv(hidden_sizes[idx - 1], hidden_size))
                relus.append(nn.ReLU())
                bns.append(nn.BatchNorm1d(hidden_size))
        relus.append(nn.ReLU())
        bns.append(nn.BatchNorm1d(out_features))
        gcns.append(GraphConv(hidden_sizes[-1], out_features))
        self.gcns, self.relus, self.bns = gcns, relus, bns

    def forward(self, g: dgl.DGLGraph, node_features: torch.Tensor):
        g = g.local_var() # Make a local copy of the graph
        h = node_features
        for i in range(len(self.gcns)):
            h = self.gcns[i](g, h)
            if len(h.shape) > 2:
                h = self.bns[i](h.transpose(1, -1)).transpose(1, -1)
            else:
                h = self.bns[i](h)
            h = self.relus[i](h)
        return h

class StackedSBlocks(nn.ModuleList):
    def __init__(self, *args, **kwargs):
        super(StackedSBlocks, self).__init__(*args, **kwargs)

    def forward(self, *input):
        g, h = input
        for module in self[:-1]:
            h = h + module(g, h)
        h = self[-1](g, h)
        return h

In [None]:
# Module: spatial_temporal_layer.py
class STBlock(nn.Module):
    def __init__(self, f_in: int, f_out: int):
        super(STBlock, self).__init__()
        self.spatial_embedding = GCN(f_in, [(f_in * (4 - i) + f_out * i) // 4 for i in (1, 4)], f_out)
        self.temporal_embedding = nn.Conv1d(f_out, f_out, 3, padding=1)

    def forward(self, g: dgl.DGLGraph, temporal_features: torch.Tensor):
        # Transpose to (node_num, t_in, f_in) for GCN
        spatial_input = temporal_features.transpose(-2, -1)
        spatial_output = self.spatial_embedding(g, spatial_input)
        # Transpose back to (node_num, f_out, t_in) for Conv1d
        temporal_input = spatial_output.transpose(-2, -1)
        temporal_output = self.temporal_embedding(temporal_input)
        return temporal_output

class StackedSTBlocks(nn.ModuleList):
    def __init__(self, *args, **kwargs):
        super(StackedSTBlocks, self).__init__(*args, **kwargs)

    def forward(self, *input):
        g, h = input
        for module in self:
            processed_h = module(g, h)
            # Concatenate along the feature dimension (dim=1)
            h = torch.cat((h, processed_h), dim=1)
        return h

In [None]:
# Module: DSTGCN.py
class DSTGCN(nn.Module):
    def __init__(self, f_1: int, f_2: int, f_3: int):
        super(DSTGCN, self).__init__()

        self.spatial_embedding = fully_connected_layer(f_1, [20], 15)
        self.spatial_gcn = StackedSBlocks([GCN(15, [15, 15, 15], 15),
                                           GCN(15, [15, 15, 15], 15),
                                           GCN(15, [14, 13, 12, 11], 10)])
        # Correct final dimension after concatenation in StackedSTBlocks
        # Initial: f_2 -> After STBlock(f_2, 4): f_2 + 4 -> After STBlock(f_2+4, 5): f_2+4+5 -> After STBlock(f_2+4+5, 10): f_2+4+5+10
        # Given f_2=1: 1 -> 1+4=5 -> 5+5=10 -> 10+10=20
        self.temporal_embedding = StackedSTBlocks([STBlock(f_2, 4), STBlock(f_2 + 4, 5), STBlock(f_2 + 4 + 5, 10)])

        # The input channel to AvgPool1d is the final feature dimension from temporal_embedding
        final_temporal_dim = f_2 + 4 + 5 + 10
        self.temporal_agg = nn.Sequential(
            nn.Linear(final_temporal_dim, 20), # Aggregate features
            nn.ReLU()
        )
        
        self.external_embedding = fully_connected_layer(f_3, [(f_3 * (4 - i) + 10 * i) // 4 for i in (1, 4)], 10)

        # Input to final layer: 10 (spatial) + 20 (temporal) + 10 (external)
        self.output_layer = nn.Sequential(nn.Linear(10 + 20 + 10, 1),
                                          nn.Sigmoid())

    def forward(self,
                bg: dgl.DGLGraph,
                spatial_features: torch.Tensor,
                temporal_features: torch.Tensor,
                external_features: torch.Tensor):

        # Spatial stream
        s_emb = self.spatial_embedding(spatial_features)
        s_out = self.spatial_gcn(bg, s_emb)

        # Temporal stream
        temporal_embeddings = self.temporal_embedding(bg, temporal_features)
        # temporal_embeddings shape: [node_num, final_temporal_dim, T]
        
        # Aggregate across time dimension (T=24)
        t_agg_time = torch.mean(temporal_embeddings, dim=2) # Shape: [node_num, final_temporal_dim]
        t_out = self.temporal_agg(t_agg_time) # Shape: [node_num, 20]

        # External stream
        e_out = self.external_embedding(external_features)

        # Aggregate node features for each graph in the batch
        # We use the features of the first node (the target node) for prediction
        nums_nodes = bg.batch_num_nodes().tolist()
        s_features_list, t_features_list = [], []
        
        node_idx_offset = 0
        for num_nodes in nums_nodes:
            # The first node of each subgraph is the target node
            s_features_list.append(s_out[node_idx_offset])
            t_features_list.append(t_out[node_idx_offset])
            node_idx_offset += num_nodes

        s_features = torch.stack(s_features_list) # Shape: [batch_size, 10]
        t_features = torch.stack(t_features_list) # Shape: [batch_size, 20]
        
        # Concatenate features from all streams
        output_features = torch.cat((s_features, t_features, e_out), dim=-1)

        return self.output_layer(output_features)

### 5. Tải dữ liệu (Data Loading)

**Lưu ý quan trọng:** Trước khi chạy cell bên dưới, bạn cần tải dữ liệu của mình lên Colab. Hãy tạo một thư mục tên là `data` trong cây thư mục của Colab (ngang hàng với `sample_data`) và tải các file sau vào đó:
- `beijing_roadnet.gpickle`
- `edges_data.h5`
- `accident.h5`
- `weather.h5`
- `all_grids_speed.h5`

Nếu bạn dùng Google Drive, hãy mount Drive và thay đổi đường dẫn cho phù hợp.

In [None]:
# Module: data_container.py

longitudeMin = 116.09608
longitudeMax = 116.71040
latitudeMin = 39.69086
latitudeMax = 40.17647

longitudeMin, latitudeMin = convert_by_type(lng=longitudeMin, lat=latitudeMin, type="g2w")
longitudeMax, latitudeMax = convert_by_type(lng=longitudeMax, lat=latitudeMax, type="g2w")

divideBound = 5

widthSingle = 0.01 / math.cos(latitudeMin / 180 * math.pi) / divideBound
width = math.floor((longitudeMax - longitudeMin) / widthSingle)
heightSingle = 0.01 / divideBound
height = math.floor((latitudeMax - latitudeMin) / heightSingle)

def collate_fn(batch):
    ret = list()
    for idx, item in enumerate(zip(*batch)):
        if isinstance(item[0], torch.Tensor):
            if idx < 3:  # spatial and temporal features
                ret.append(torch.cat(item))
            else:  # overall features and y
                ret.append(torch.stack(item))
        elif isinstance(item[0], dgl.DGLGraph):
            ret.append(dgl.batch(item))
        else:
            raise ValueError(f'batch must contain tensors or graphs; found {type(item[0])}')
    return tuple(ret)

def fill_speed(speed_data):
    date_range = pd.date_range(start="2018-08-01", end="2018-11-01", freq="1H")[:-1]
    speed_data = speed_data.resample(rule="1H").mean()
    assert date_range[0] in speed_data.index and date_range[-1] in speed_data.index
    
    one_week, two_week = datetime.timedelta(days=7), datetime.timedelta(days=14)
    
    # Tối ưu hóa - tránh việc kiểm tra từng ngày
    missing_dates = []
    for date in tqdm(date_range, 'Finding missing dates'):
        if any(speed_data.loc[date].isna()):
            missing_dates.append(date)
    
    print(f"Found {len(missing_dates)} dates with missing data")
    
    # Xử lý chỉ những ngày missing
    for date in tqdm(missing_dates, 'Fill speed'):
        for idx in [date - one_week, date + one_week, date - two_week, date + two_week]:
            if idx in speed_data.index and all(speed_data.loc[idx].notna()):
                speed_data.loc[date] = speed_data.loc[idx]
                break
        else:
            print(f"Warning: Cannot find replacement for {date}")
            # Thay vì raise error, có thể fill với giá trị trung bình
            speed_data.loc[date] = speed_data.mean()
    
    return speed_data

class AccidentDataset(Dataset):
    def __init__(self,
                 k_order: int,
                 network: nx.Graph,
                 node_attr: pd.DataFrame,
                 accident: pd.DataFrame,
                 weather: pd.DataFrame,
                 speed: pd.DataFrame,
                 sf_scaler: Tuple[np.ndarray, np.ndarray] = None,
                 tf_scaler: Tuple[np.ndarray, np.ndarray] = None,
                 ef_scaler: Tuple[np.ndarray, np.ndarray] = None):
        self.k_order = k_order
        self.network = network
        self.nodes = node_attr
        self.accident = accident
        self.weather = weather
        self.speed = speed
        self.sf_scaler = sf_scaler
        self.tf_scaler = tf_scaler
        self.ef_scaler = ef_scaler

    def __getitem__(self, sample_id: int):
        _, _, accident_time, node_id, target = self.accident.iloc[sample_id]
        neighbors = nx.single_source_shortest_path_length(self.network, node_id, cutoff=self.k_order)
        neighbors.pop(node_id, None)
        neighbors = [node_id] + sorted(neighbors.keys())
        
        sub_graph_nx = nx.subgraph(self.network, neighbors)
        # Relabel nodes to be contiguous integers from 0
        relabel_map = {old_label: new_label for new_label, old_label in enumerate(neighbors)}
        sub_graph_nx = nx.relabel_nodes(sub_graph_nx, relabel_map)
        sub_graph_nx.add_edges_from([(v, v) for v in sub_graph_nx.nodes])
        g = dgl.from_networkx(sub_graph_nx)

        date_range = pd.date_range(end=accident_time.strftime("%Y%m%d %H"), freq="1H", periods=24)
        selected_time = self.speed.loc[date_range]
        selected_nodes = self.nodes.loc[neighbors]
        spatial_features = selected_nodes['spatial_features'].tolist()

        x_ids = np.floor((selected_nodes['XCoord'].values - longitudeMin) / widthSingle).astype(np.int64)
        y_ids = np.floor((selected_nodes['YCoord'].values - latitudeMin) / heightSingle).astype(np.int64)
        
        temporal_features = selected_time[map(lambda ids: f'{ids[0]},{ids[1]}', zip(y_ids, x_ids))].values.transpose()

        weather_data = self.weather.loc[date_range[-1]].tolist()
        external_features = weather_data + [accident_time.month, accident_time.day, accident_time.dayofweek,
                                       accident_time.hour, int(accident_time.dayofweek >= 5)]

        if self.sf_scaler is not None:
            mean, std = self.sf_scaler
            spatial_features = (np.array(spatial_features) - mean) / std
        if self.tf_scaler is not None:
            mean, std = self.tf_scaler
            temporal_features = (np.array(temporal_features) - mean) / std
        if self.ef_scaler is not None:
            mean, std = self.ef_scaler
            external_features = (np.array(external_features) - mean) / std

        spatial_features = torch.tensor(spatial_features).float()
        temporal_features = torch.tensor(temporal_features).unsqueeze(1).float()
        external_features = torch.tensor(external_features).float()
        target = torch.tensor(target).float()

        return g, spatial_features, temporal_features, external_features, target

    def __len__(self):
        return len(self.accident)

def get_data_loaders(k_order, batch_size):
    # Thay đổi đường dẫn tương đối
    network_path = r'/kaggle/input/dstgcn-dataset/beijing_roadnet.gpickle'
    node_attr_path = r'/kaggle/input/dstgcn-dataset/edges_data.h5'
    accident_path = r'/kaggle/input/dstgcn-dataset/accident.h5'
    weather_path = "/kaggle/input/dstgcn-dataset/weather.h5"
    speed_path = "/kaggle/input/dstgcn-dataset/all_grids_speed.h5"

    sf_mean, sf_std = np.array(get_attribute('spatial_features_mean')), np.array(get_attribute('spatial_features_std'))
    tf_mean, tf_std = np.array(get_attribute('temporal_features_mean')), np.array(get_attribute('temporal_features_std'))
    ef_mean, ef_std = np.array(get_attribute('external_features_mean')), np.array(get_attribute('external_features_std'))

    import pickle
    with open(network_path, 'rb') as f:
        network = pickle.load(f)
    nodes = pd.read_hdf(node_attr_path)
    weather = pd.read_hdf(weather_path)
    speed = fill_speed(pd.read_hdf(speed_path))

    dls = dict()
    for key in ['train', 'validate', 'test']:
        accident = pd.read_hdf(accident_path, key=key)
        dataset = AccidentDataset(k_order, network, nodes, accident, weather, speed,
                                  sf_scaler=(sf_mean, sf_std),
                                  tf_scaler=(tf_mean, tf_std),
                                  ef_scaler=(ef_mean, ef_std))
        # Giảm num_workers cho Colab để tránh lỗi
        dls[key] = DataLoader(dataset=dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              drop_last=False,
                              collate_fn=collate_fn,
                              num_workers=2) 
    return dls

### 6. Hàm Huấn luyện (Training Function)

In [None]:
# Module: train_model.py
def train_model(model: nn.Module,
                data_loaders: Dict[str, DataLoader],
                loss_func: callable,
                optimizer,
                model_folder: str,
                tensorboard_folder: str):
    phases = ['train', 'validate', 'test']
    writer = SummaryWriter(tensorboard_folder)
    num_epochs = get_attribute('epochs')

    since = time.time()

    model = convert_to_gpu(model)
    loss_func = convert_to_gpu(loss_func)

    save_dict, best_f1_score = {'model_state_dict': copy.deepcopy(model.state_dict()), 'epoch': 0}, 0

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=.5, patience=2, threshold=1e-3, min_lr=1e-6)
    test_metric = None

    try:
        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)
            running_loss, running_metrics = {phase: 0.0 for phase in phases}, {phase: dict() for phase in phases}
            save_validate_this_epoch = False

            for phase in phases:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                steps, predictions, targets = 0, list(), list()
                tqdm_loader = tqdm(data_loaders[phase], ncols=120)
                for g, spatial_features, temporal_features, external_features, truth_data in tqdm_loader:
                    features, truth_data = convert_train_truth_to_gpu(
                        [spatial_features, temporal_features, external_features], truth_data)
                    g = convert_to_gpu(g)

                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(g, *features)
                        outputs = torch.squeeze(outputs)

                        loss = loss_func(truth_data, outputs)

                        if phase == 'train':
                            optimizer.zero_grad()
                            loss.backward()
                            optimizer.step()

                    targets.append(truth_data.cpu().numpy())
                    with torch.no_grad():
                        predictions.append(outputs.cpu().numpy())

                    running_loss[phase] += loss.item() * truth_data.size(0)
                    steps += truth_data.size(0)

                    tqdm_loader.set_description(
                        f'{phase:8} epoch: {epoch:3}, {phase:8} loss: {running_loss[phase] / steps:3.6}')

                    torch.cuda.empty_cache()

                print(f'{phase} metric ...')
                scores = evaluate(np.concatenate(predictions), np.concatenate(targets))
                running_metrics[phase] = scores
                print(scores)

                if phase == 'validate' and scores['F1-SCORE'] > best_f1_score:
                    best_f1_score = scores['F1-SCORE']
                    save_validate_this_epoch = True
                    save_dict.update(model_state_dict=copy.deepcopy(model.state_dict()),
                                     epoch=epoch,
                                     optimizer_state_dict=copy.deepcopy(optimizer.state_dict()))
                    print(f"save model as {model_folder}/model_{epoch}.pkl")
                    save_model(f"{model_folder}/model_{epoch}.pkl", **save_dict)

            scheduler.step(running_loss['train'])

            if save_validate_this_epoch:
                test_metric = running_metrics["test"].copy()

            for metric in running_metrics['train'].keys():
                writer.add_scalars(metric, {
                    f'{phase} {metric}': running_metrics[phase][metric] for phase in phases},
                                   global_step=epoch)
            writer.add_scalars('Loss', {
                f'{phase} loss': running_loss[phase] / len(data_loaders[phase].dataset) for phase in phases},
                               global_step=epoch)
    finally:
        time_elapsed = time.time() - since
        print(f"cost {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s")
        save_model(f"{model_folder}/best_model.pkl", **save_dict)

    return test_metric

In [None]:
# Module: train_main.py (phần thực thi chính)

def create_model() -> nn.Module:
    return DSTGCN(get_attribute('poi_features_number'), get_attribute('temporal_features_number'),
                  get_attribute('weather_features_number') + get_attribute('external_temporal_features_number'))


In [None]:
import torch
import numpy as np
import os
from sklearn.metrics import mean_absolute_error, mean_squared_error
from scipy.stats import pearsonr
from tqdm import tqdm

# --- 1. Thiết lập lại thiết bị & Load cấu hình ---
# Đảm bảo hàm get_attribute và biến config đã tồn tại (chạy cell Config trước)
device = torch.device(get_attribute('device'))
print(f"Using device: {device}")

# --- 2. Hàm đánh giá chi tiết (MAE, RMSE, PCC) ---
def evaluate_metrics(model, data_loader, device):
    model.eval()  # Chế độ đánh giá
    predictions = []
    targets = []
    
    print("Đang chạy đánh giá trên tập Test...")
    with torch.no_grad():
        for g, spatial_features, temporal_features, external_features, truth_data in tqdm(data_loader, ncols=100):
            # Chuyển dữ liệu sang GPU/CPU
            # Sử dụng lại các hàm tiện ích trong notebook gốc nếu có, hoặc dùng .to(device) trực tiếp
            spatial_features = spatial_features.to(device)
            temporal_features = temporal_features.to(device)
            external_features = external_features.to(device)
            truth_data = truth_data.to(device)
            g = g.to(device)

            # Dự đoán
            outputs = model(g, spatial_features, temporal_features, external_features)
            outputs = torch.squeeze(outputs)

            # Lưu kết quả (chuyển về CPU để tính toán sklearn)
            predictions.extend(outputs.cpu().numpy())
            targets.extend(truth_data.cpu().numpy())

    # Chuyển list sang numpy array
    y_pred = np.array(predictions)
    y_true = np.array(targets)

    # Tính toán các chỉ số
    mae = mean_absolute_error(y_true, y_pred)
    rmse = np.sqrt(mean_squared_error(y_true, y_pred))
    pcc, _ = pearsonr(y_pred.flatten(), y_true.flatten())

    return mae, rmse, pcc

# --- 3. Khởi tạo Model & Load Weights ---
print("Khởi tạo model...")
model = create_model() # Hàm này phải được định nghĩa trong notebook
model = model.to(device)

# Đường dẫn file model đã lưu (từ code train cũ)
save_path = f"/kaggle/input/dstgcn/saves/DSTGCN/best_model.pkl"

if os.path.exists(save_path):
    print(f"Đang tải trọng số từ: {save_path}")
    checkpoint = torch.load(save_path, map_location=device)
    
    # Kiểm tra cấu trúc file save (dict hay state_dict trực tiếp)
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    else:
        model.load_state_dict(checkpoint)
    print("Load model thành công!")
else:
    print(f"CẢNH BÁO: Không tìm thấy file {save_path}. Hãy đảm bảo bạn đã train xong và file được lưu.")
    # Nếu không có file save, code sẽ chạy với trọng số ngẫu nhiên (không đúng)

# --- 4. Load Dữ liệu Test ---
print("Đang load dữ liệu test...")
# Gọi lại hàm lấy dataloader
loaders = get_data_loaders(get_attribute('K_hop'), get_attribute('batch_size'))
test_loader = loaders['test']

# --- 5. Thực thi tính toán ---
if os.path.exists(save_path):
    mae, rmse, pcc = evaluate_metrics(model, test_loader, device)
    
    print("\n" + "="*30)
    print("KẾT QUẢ ĐÁNH GIÁ (TEST SET)")
    print("="*30)
    print(f"MAE  : {mae:.6f}")
    print(f"RMSE : {rmse:.6f}")
    print(f"PCC  : {pcc:.6f}")
    print("="*30)

### 7. Thực thi chính (Main Execution)