In [1]:
import itk
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [2]:
from dataset import ITKDataset
# ================================
# 3. 데이터 로드 및 분할
# ================================
image_dir = "../dataset/images"
mask_dir = "../dataset/masks"

# 전체 데이터셋
dataset = ITKDataset(image_dir, mask_dir, target_size=(256, 256))

# DataLoader
train_loader = DataLoader(dataset, batch_size=4, shuffle=True)


In [3]:
from unet import UNet
# ================================
# 4. 모델 및 학습 설정
# ================================
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
#optimizer = optim.Adam(model.parameters(), lr=1e-4)
optimizer = optim.RMSprop(model.parameters(), lr=1e-4)

# 모델 저장 경로
model_save_path = "unet_model.pth"


In [None]:
# ================================
# 5. 훈련 루프 및 모델 저장
# ================================
min = 999
num_epochs = 60
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.5f}")
    if min > epoch_loss:
        torch.save(model.state_dict(), model_save_path)
        print(f"Model {epoch + 1} saved to {model_save_path}")
        min = epoch_loss

# 학습된 모델 저장
#torch.save(model.state_dict(), model_save_path)
#print(f"Model saved to {model_save_path}")
