In [1]:
# 📦 라이브러리 불러오기
import h5py
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

# ⚙️ 디바이스 설정
device = torch.device("cpu")
print(f"Using device: {device}")


Using device: cpu


In [2]:
# 📁 HDF5 데이터셋 정의
class HDF5Dataset(Dataset):
    def __init__(self, input_path, output_path):
        self.input_file = h5py.File(input_path, 'r')
        self.output_file = h5py.File(output_path, 'r')
        self.X = self.input_file['subcubes']
        self.Y = self.output_file['subcubes']
    
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x = torch.from_numpy(self.X[idx]).float().unsqueeze(0)  # (1, 122, 122, 122)
        y = torch.from_numpy(self.Y[idx]).float().unsqueeze(0)
        return x, y


In [3]:
# 🧠 간단한 U-Net (crop 추가 적용)
class UNet3D(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = nn.Sequential(nn.Conv3d(1, 32, 3, padding=1), nn.BatchNorm3d(32), nn.ReLU())
        self.enc2 = nn.Sequential(nn.Conv3d(32, 64, 3, padding=1), nn.BatchNorm3d(64), nn.ReLU())
        self.pool = nn.MaxPool3d(2)
        self.middle = nn.Sequential(nn.Conv3d(64, 64, 3, padding=1), nn.ReLU())
        self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)
        self.dec2 = nn.Sequential(nn.Conv3d(128, 32, 3, padding=1), nn.ReLU())
        self.dec1 = nn.Sequential(nn.Conv3d(64, 1, 3, padding=1))

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(self.pool(x1))
        x_mid = self.middle(x2)
        x_up = self.up(x_mid)[:, :, :x2.shape[2], :x2.shape[3], :x2.shape[4]]
        x_concat2 = torch.cat([x2, x_up], dim=1)
        x3 = self.dec2(x_concat2)
        x_up2 = self.up(x3)[:, :, :x1.shape[2], :x1.shape[3], :x1.shape[4]]
        x_concat1 = torch.cat([x1, x_up2], dim=1)
        out = self.dec1(x_concat1)
        return out


In [4]:
# 🧪 데이터 로딩
input_path = "/caefs/data/IllustrisTNG/subcube/input/subcubes_stride2_50mpc_parallel.h5"
output_path = "/caefs/data/IllustrisTNG/subcube/output/subcubes_stride2_50mpc_parallel.h5"
dataset = HDF5Dataset(input_path, output_path)
loader = DataLoader(dataset, batch_size=1, shuffle=True)


In [5]:
# 🧠 모델 초기화
model = UNet3D().to(device)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
# 🔁 학습 루프 + 손실 추적
num_epochs = 10
train_losses = []

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for step, (x, y) in enumerate(tqdm(loader)):
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        pred = model(x)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        if step >= 100:  # 학습 속도 고려해 step 수 제한 (선택)
            break

    avg_loss = total_loss / (step + 1)
    train_losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {avg_loss:.6f}")


  1%|          | 100/10000 [04:16<7:03:05,  2.56s/it]


Epoch 1/10 - Avg Loss: 0.113047


  1%|          | 100/10000 [04:21<7:11:28,  2.61s/it]


Epoch 2/10 - Avg Loss: 0.021190


  0%|          | 28/10000 [01:09<6:32:16,  2.36s/it]

In [None]:
# 📉 학습 곡선 시각화
plt.figure(figsize=(6, 4))
plt.plot(range(1, num_epochs + 1), train_losses, marker='o')
plt.title("Training Loss per Epoch")
plt.xlabel("Epoch")
plt.ylabel("MSE Loss")
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# 🖼️ 예측 결과 시각화 (중앙 슬라이스만)
model.eval()
with torch.no_grad():
    x, y = dataset[0]
    x = x.unsqueeze(0).to(device)
    pred = model(x).cpu().squeeze().numpy()
    truth = y.squeeze().numpy()

    slice_idx = 61
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(truth[slice_idx], cmap='viridis')
    plt.title("True")
    plt.colorbar()
    plt.subplot(1, 2, 2)
    plt.imshow(pred[slice_idx], cmap='viridis')
    plt.title("Predicted")
    plt.colorbar()
    plt.show()
