<a href="https://colab.research.google.com/github/cheolhakja/fine-dust-prediction/blob/main/gnn/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 1. 충돌 방지를 위해 모두 제거
!pip uninstall -y torch torchvision torchaudio torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric torch-geometric-temporal

# 2. PyTorch 2.6.0 + CUDA 12.4 재설치 (Colab 기본 버전과 일치)
!pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124

# 3. PyG 확장 모듈 설치 (PyTorch 2.6.0+cu124 전용 빌드)
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric \
    -f https://data.pyg.org/whl/torch-2.6.0+cu124.html

# 4. Temporal 모듈
!pip install torch-geometric-temporal



Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
[0mLooking in indexes: https://download.pytorch.org/whl/cu124
Collecting torch==2.6.0+cu124
  Downloading https://download.pytorch.org/whl/cu124/torch-2.6.0%2Bcu124-cp311-cp311-linux_x86_64.whl.metadata (28 kB)
Collecting torchvision==0.21.0+cu124
  Downloading https://download.pytorch.org/whl/cu124/torchvision-0.21.0%2Bcu124-cp311-cp311-linux_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio==2.6.0+cu124
  Downloading https://download.pytorch.org/whl/cu124/torchaudio-2.6.0%2Bcu124-cp311-cp311-linux_x86_64.whl.metadata (6.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch=

In [None]:


import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import dense_to_sparse
from torch_geometric_temporal.nn.recurrent import A3TGCN
from torch_geometric_temporal.signal import StaticGraphTemporalSignalBatch
import random
from sklearn.metrics import mean_squared_error, mean_absolute_error


from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
"""데이터 로"""

train_tensor = torch.load('/content/drive/MyDrive/gnn/train_tensor.pt')
val_tensor = torch.load('/content/drive/MyDrive/gnn/val_tensor.pt')
test_tensor = torch.load('/content/drive/MyDrive/gnn/test_tensor.pt')

train_target = torch.load('/content/drive/MyDrive/gnn/train_target.pt')
val_target = torch.load('/content/drive/MyDrive/gnn/val_target.pt')
test_target = torch.load('/content/drive/MyDrive/gnn/test_target.pt')

# 인접 행렬
adj_matrix = np.loadtxt('/content/drive/MyDrive/gnn/adj_matrix.txt', delimiter=',')


In [None]:
"""2. Edge Index, Edge Weight 만들기
"""
adj_tensor = torch.tensor(adj_matrix, dtype=torch.float32)
edge_index, edge_weight = dense_to_sparse(adj_tensor)

In [None]:
"""3. 데이터셋(Loader) 클래스 정의
"""
class AirQualityDatasetLoader(object):
    def __init__(self, x_tensor, y_tensor, edge_index, edge_weight, batch_size=64, n_timesteps_in=12, n_timesteps_out=12):
        self.x_tensor = x_tensor
        self.y_tensor = y_tensor
        self.edge_index = edge_index
        self.edge_weight = edge_weight
        self.batch_size = batch_size
        self.n_timesteps_in = n_timesteps_in
        self.n_timesteps_out = n_timesteps_out

    def get_dataset(self):
        features, targets = [], []
        T = len(self.x_tensor)
        for i in range(T - (self.n_timesteps_in + self.n_timesteps_out)):
            # (시작시점, 윈도우) → (노드, 피처, 타임)
            x = self.x_tensor[i:i+self.n_timesteps_in].permute(1, 2, 0)  # (node, feature, time)
            y = self.y_tensor[i+self.n_timesteps_in:i+self.n_timesteps_in+self.n_timesteps_out].permute(1, 0)  # (node, time)
            features.append(x.numpy())
            targets.append(y.numpy())
        return StaticGraphTemporalSignalBatch(
            self.edge_index, self.edge_weight, features, targets, self.batch_size
        )

In [None]:
"""4. 데이터셋 인스턴스 만들기
"""
# 인풋 시계열, 타깃 시계열, 인접행렬
n_timesteps_in = 12  # 입력 타임스텝
n_timesteps_out = 12 # 예측 타임스텝

train_loader = AirQualityDatasetLoader(
    train_tensor, train_target, edge_index, edge_weight, batch_size=64,
    n_timesteps_in=n_timesteps_in, n_timesteps_out=n_timesteps_out
)
val_loader = AirQualityDatasetLoader(
    val_tensor, val_target, edge_index, edge_weight, batch_size=64,
    n_timesteps_in=n_timesteps_in, n_timesteps_out=n_timesteps_out
)
test_loader = AirQualityDatasetLoader(
    test_tensor, test_target, edge_index, edge_weight, batch_size=64,
    n_timesteps_in=n_timesteps_in, n_timesteps_out=n_timesteps_out
)

train_dataset = train_loader.get_dataset()
val_dataset = val_loader.get_dataset()
test_dataset = test_loader.get_dataset()

In [None]:
"""5. 모델 클래스 정의 (A3TGCN)
"""
class TemporalGNN(nn.Module):
    def __init__(self, node_features, periods):
        super(TemporalGNN, self).__init__()
        self.tgnn = A3TGCN(in_channels=node_features, out_channels=64, periods=periods)
        self.linear = nn.Linear(64, periods)

    def forward(self, x, edge_index, edge_weight):
        h = self.tgnn(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.linear(h)
        return h

In [None]:
print(torch.isnan(train_tensor).sum(), torch.isinf(train_tensor).sum())
print(torch.isnan(train_target).sum(), torch.isinf(train_target).sum())
print(train_tensor.max(), train_tensor.min())
print(train_target.max(), train_target.min())


tensor(0) tensor(0)
tensor(0) tensor(0)
tensor(22.2578) tensor(-3.4890)
tensor(6.8448) tensor(-1.8847)


In [None]:
"""6. 학습(Train) 루프
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TemporalGNN(node_features=train_tensor.shape[2], periods=n_timesteps_out).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

print("Start Training...")
for epoch in range(45):
    model.train()
    loss_sum = 0
    steps = 0
    for snapshot in train_dataset:
        snapshot = snapshot.to(device)
        y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
        loss = F.mse_loss(y_hat, snapshot.y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
        steps += 1
    print(f"Epoch {epoch+1}: Train MSE {loss_sum/steps:.4f}")

Start Training...
Epoch 1: Train MSE 0.7358
Epoch 2: Train MSE 0.7660
Epoch 3: Train MSE 0.7571
Epoch 4: Train MSE 0.8657
Epoch 5: Train MSE 0.7648
Epoch 6: Train MSE 0.7925
Epoch 7: Train MSE 0.7085
Epoch 8: Train MSE 0.8362
Epoch 9: Train MSE 0.9442
Epoch 10: Train MSE 0.8293
Epoch 11: Train MSE 0.7818
Epoch 12: Train MSE 0.7996
Epoch 13: Train MSE 0.9257
Epoch 14: Train MSE 0.8476
Epoch 15: Train MSE 0.7349
Epoch 16: Train MSE 0.6970
Epoch 17: Train MSE 0.6274
Epoch 18: Train MSE 0.6208
Epoch 19: Train MSE 0.6669
Epoch 20: Train MSE 0.6503
Epoch 21: Train MSE 0.6556
Epoch 22: Train MSE 0.6049
Epoch 23: Train MSE 0.6276
Epoch 24: Train MSE 0.6054
Epoch 25: Train MSE 0.5650
Epoch 26: Train MSE 0.5539
Epoch 27: Train MSE 0.5274
Epoch 28: Train MSE 0.5673
Epoch 29: Train MSE 0.5391
Epoch 30: Train MSE 0.5039
Epoch 31: Train MSE 0.5109
Epoch 32: Train MSE 0.5004
Epoch 33: Train MSE 0.4757
Epoch 34: Train MSE 0.4780
Epoch 35: Train MSE 0.4920
Epoch 36: Train MSE 0.4883
Epoch 37: Train MSE

In [None]:
"""save result"""

torch.save(model.state_dict(), "/content/drive/MyDrive/gnn/a3t_gcn_model_interpolate.pth")


In [None]:
"""모델을 검증하자.테스트 데이터셋으로"""
def evaluate(model, dataset, device):
    model.eval()
    loss_sum = 0
    steps = 0
    with torch.no_grad():
        for snapshot in dataset:
            snapshot = snapshot.to(device)
            y_hat = model(snapshot.x, snapshot.edge_index, snapshot.edge_attr)
            loss = F.mse_loss(y_hat, snapshot.y)
            loss_sum += loss.item()
            steps += 1
    avg_mse = loss_sum / steps
    return avg_mse

# 2. 학습 루프 끝난 뒤, 검증/테스트 MSE 출력
val_mse = evaluate(model, val_dataset, device)
print(f"Validation MSE: {val_mse:.4f}")

test_mse = evaluate(model, test_dataset, device)
print(f"Test MSE: {test_mse:.4f}")


Validation MSE: 2.9736
Test MSE: 2.9965
