In [1]:
import itk
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

In [2]:
from dataset import ITKDataset
# ================================
# 3. 데이터 로드 및 분할
# ================================
image_dir = "../dataset/images"
mask_dir = "../dataset/masks"

# 전체 데이터셋
dataset = ITKDataset(image_dir, mask_dir, target_size=(256, 256))

# DataLoader
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)


In [3]:
from unet import UNet
# ================================
# 4. 모델 및 학습 설정
# ================================
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
#optimizer = optim.RMSprop(model.parameters(), lr=1e-4)

# 모델 저장 경로
model_save_path = "unet_model.pth"


In [4]:
# ================================
# 5. 훈련 루프 및 모델 저장
# ================================
min = 999
num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for images, masks in train_loader:
        images, masks = images.to(device), masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.5f}")
    if min > epoch_loss:
        torch.save(model.state_dict(), model_save_path)
        print(f"Model {epoch + 1} saved to {model_save_path}")
        min = epoch_loss

# 학습된 모델 저장
#torch.save(model.state_dict(), model_save_path)
#print(f"Model saved to {model_save_path}")


Epoch 1/50, Loss: 12.00486
Model 1 saved to unet_model.pth
Epoch 2/50, Loss: 8.19349
Model 2 saved to unet_model.pth
Epoch 3/50, Loss: 7.26618
Model 3 saved to unet_model.pth
Epoch 4/50, Loss: 6.58707
Model 4 saved to unet_model.pth
Epoch 5/50, Loss: 6.06048
Model 5 saved to unet_model.pth
Epoch 6/50, Loss: 5.64650
Model 6 saved to unet_model.pth
Epoch 7/50, Loss: 5.24117
Model 7 saved to unet_model.pth
Epoch 8/50, Loss: 4.95324
Model 8 saved to unet_model.pth
Epoch 9/50, Loss: 4.61860
Model 9 saved to unet_model.pth
Epoch 10/50, Loss: 4.36967
Model 10 saved to unet_model.pth
Epoch 11/50, Loss: 4.13094
Model 11 saved to unet_model.pth
Epoch 12/50, Loss: 3.92842
Model 12 saved to unet_model.pth
Epoch 13/50, Loss: 3.73420
Model 13 saved to unet_model.pth
Epoch 14/50, Loss: 3.46982
Model 14 saved to unet_model.pth
Epoch 15/50, Loss: 3.31674
Model 15 saved to unet_model.pth
Epoch 16/50, Loss: 3.39757
Epoch 17/50, Loss: 3.30546
Model 17 saved to unet_model.pth
Epoch 18/50, Loss: 3.03613
Mod