In [1]:
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
import torch.utils.data as data

In [4]:
pyg_data = torch.load('../dataset/graph/pyg_data.pt', weights_only=False)

In [5]:
class GraphDataset(data.Dataset):
    def __init__(self, pyg_data, window_size=28, stride=28):
        self.pyg_data = pyg_data
        self.time_length = pyg_data.x.shape[1]
        self.window_size = window_size
        self.num_windows = max(0, (self.time_length - 2*window_size) // stride + 1)
        self.stride = stride
    
    def __len__(self):
        return self.num_windows
    
    def __getitem__(self, idx):
        start_idx = idx * self.stride
        input_window = self.pyg_data.x[:, start_idx:start_idx+self.window_size].clone()
        output_window = self.pyg_data.x[:, start_idx+self.window_size:start_idx+2*self.window_size].clone()
        
        data = Data(
            x=input_window,
            y=output_window,
            edge_index=self.pyg_data.edge_index,
            edge_type=self.pyg_data.edge_type,
            num_nodes=self.pyg_data.num_nodes,
            num_edge_types=self.pyg_data.num_edge_types
        )
        
        return data

# 데이터셋 분할
def create_train_val_test_datasets(graph_data, window_size=28, stride=1, train_ratio=0.7, val_ratio=0.15):
    """
    훈련/검증/테스트 데이터셋 생성
    """
    dataset = GraphDataset(graph_data, window_size, stride)
    
    # 데이터 개수
    n_samples = len(dataset)
    
    if n_samples == 0:
        raise ValueError("데이터셋에 샘플이 없습니다. 윈도우 크기와 시계열 길이를 확인하세요.")
    
    # 각 분할의 크기 계산
    train_size = int(n_samples * train_ratio)
    val_size = int(n_samples * val_ratio)
    test_size = n_samples - train_size - val_size
    
    # 인덱스 분할
    indices = list(range(n_samples))
    
    # 시간적 의존성을 고려하여 순차적 분할
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size+val_size]
    test_indices = indices[train_size+val_size:]
    
    # 서브셋 생성
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    test_dataset = torch.utils.data.Subset(dataset, test_indices)
    
    return train_dataset, val_dataset, test_dataset

# 데이터셋 생성
train_dataset, val_dataset, test_dataset = create_train_val_test_datasets(pyg_data)

# 데이터로더 생성
batch_size = 1  # 그래프 전체가 하나의 배치
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)