In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

os.chdir(r"D:\comm_py\v4")
import UnetModel
import UnetDataset

In [5]:
# Path
input_dir1_train = r'D:\comm_py\data\ATN_v4_train'
input_dir2_train = r'D:\comm_py\data\BDM_v4_train'
output_dir_train = r'D:\comm_py\data\PGM_v4_train'

input_dir1_val = r'D:\comm_py\data\ATN_v4_val'
input_dir2_val = r'D:\comm_py\data\BDM_v4_val'
output_dir_val = r'D:\comm_py\data\PGM_v4_val'

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

num_epochs = 30
learning_rate = 1e-3


cuda


In [6]:
train_dataset = UnetDataset.IsoDataset(input_dir1_train, input_dir2_train, output_dir_train)
val_dataset = UnetDataset.IsoDataset(input_dir1_val, input_dir2_val, output_dir_val)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)

In [7]:
# Define the model, loss function, optimizer, and scheduler
model = UnetModel.UNetIso().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
#use scheduler to reduce learning rate when loss is not decreasing
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

In [8]:
# Training loop
import time

Tloss = []
Vloss = []
LRhist = []
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    Tloss.append(epoch_loss)

    # Validation step
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_inputs, val_targets in val_loader:
            val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
            val_outputs = model(val_inputs)
            val_loss += criterion(val_outputs, val_targets).item() * val_inputs.size(0)

    val_loss /= len(val_loader.dataset)
    Vloss.append(val_loss)

    LRhist.append(optimizer.param_groups[0]['lr'])
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"learning rate: ",optimizer.param_groups[0]['lr'])
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    

    # Adjust learning rate based on validation loss
    scheduler.step(val_loss)

print("Training complete.")

# Save weight
save_path = r"D:\comm_py\v4\UnetModel_0030.pth"

torch.save(model.state_dict(), save_path)

# Save loss
np.save(r"D:\comm_py\v4\Tloss.npy", Tloss)
np.save(r"D:\comm_py\v4\Vloss.npy", Vloss)
np.save(r"D:\comm_py\v4\LRhist.npy", LRhist)


KeyboardInterrupt: 