In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

os.chdir(r"D:\CK0731\v3")
import UnetModel
import UnetDataset

In [3]:
# Path
input_dir1_train = r'D:\CK0731\ATN_v3_train'
input_dir2_train = r'D:\CK0731\BDM_v3_train'
output_dir_train = r'D:\CK0731\PGM_v3_train'

input_dir1_val = r'D:\CK0731\ATN_v3_val'
input_dir2_val = r'D:\CK0731\BDM_v3_val'
output_dir_val = r'D:\CK0731\PGM_v3_val'

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

num_epochs = 20
learning_rate = 1e-5

cuda


In [4]:
train_dataset = UnetDataset.IsoDataset(input_dir1_train, input_dir2_train, output_dir_train)
val_dataset = UnetDataset.IsoDataset(input_dir1_val, input_dir2_val, output_dir_val)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)

In [5]:
# Define the model, loss function, optimizer, and scheduler
model = UnetModel.UNetIso()
state_dict = torch.load(r"D:\CK0731\v3\UnetModel_0020.pth")
model.load_state_dict(state_dict)
model.to(device)
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
#use scheduler to reduce learning rate when loss is not decreasing
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

In [6]:
# Training loop
import time

Tloss = []
Vloss = []
LRhist = []
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets in train_loader:
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    Tloss.append(epoch_loss)

    # Validation step
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for val_inputs, val_targets in val_loader:
            val_inputs, val_targets = val_inputs.to(device), val_targets.to(device)
            val_outputs = model(val_inputs)
            val_loss += criterion(val_outputs, val_targets).item() * val_inputs.size(0)

    val_loss /= len(val_loader.dataset)
    Vloss.append(val_loss)

    LRhist.append(optimizer.param_groups[0]['lr'])
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
    print(f"Validation Loss: {val_loss:.4f}")
    print(f"learning rate: ",optimizer.param_groups[0]['lr'])
    print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
    

    # Adjust learning rate based on validation loss
    scheduler.step(val_loss)

print("Training complete.")

# Save weight
save_path = r"D:\CK0731\v3\UnetModel_0040.pth"

torch.save(model.state_dict(), save_path)

# Save loss
np.save(r"D:\CK0731\v3\Tloss40.npy", Tloss)
np.save(r"D:\CK0731\v3\Vloss40.npy", Vloss)
np.save(r"D:\CK0731\v3\LRhist40.npy", LRhist)


Epoch 1/20, Loss: 15.5137
Validation Loss: 15.6045
learning rate:  1e-05
2024-07-31 20:30:30
Epoch 2/20, Loss: 15.3946
Validation Loss: 15.5825
learning rate:  1e-05
2024-07-31 20:35:45
Epoch 3/20, Loss: 15.3434
Validation Loss: 15.5776
learning rate:  1e-05
2024-07-31 20:41:06
Epoch 4/20, Loss: 15.3093
Validation Loss: 15.4740
learning rate:  1e-05
2024-07-31 20:46:33
Epoch 5/20, Loss: 15.2800
Validation Loss: 15.4289
learning rate:  1e-05
2024-07-31 20:51:59
Epoch 6/20, Loss: 15.2588
Validation Loss: 15.4398
learning rate:  1e-05
2024-07-31 20:57:22
Epoch 7/20, Loss: 15.2395
Validation Loss: 15.3859
learning rate:  1e-05
2024-07-31 21:02:49
Epoch 8/20, Loss: 15.2133
Validation Loss: 15.3685
learning rate:  1e-05
2024-07-31 21:08:15
Epoch 9/20, Loss: 15.1968
Validation Loss: 15.4119
learning rate:  1e-05
2024-07-31 21:13:37
Epoch 10/20, Loss: 15.1785
Validation Loss: 15.3689
learning rate:  1e-05
2024-07-31 21:18:56
Epoch 11/20, Loss: 15.1611
Validation Loss: 15.3203
learning rate:  1