In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import save_image

In [None]:
train_img_dir = "Training_dataset/aug_train/aug_imgs"
train_label_dir = "Training_dataset/aug_train/aug_gts"
val_img_dir = "Training_dataset/validation/imgs"
val_label_dir = "Training_dataset/validation/gts"

train_imgs = sorted([os.path.join(train_img_dir, f) for f in os.listdir(train_img_dir) if f.endswith(".jpg")])
train_labels = sorted([os.path.join(train_label_dir, f) for f in os.listdir(train_label_dir) if f.endswith(".png")])

val_imgs = sorted([os.path.join(val_img_dir, f) for f in os.listdir(val_img_dir) if f.endswith(".jpg")])
val_labels = sorted([os.path.join(val_label_dir, f) for f in os.listdir(val_label_dir) if f.endswith(".png")])

In [None]:
# 列印訓練與驗證資料集數量
train_imgs_length = len(train_imgs)
train_labels_length = len(train_labels)
val_imgs_length = len(val_imgs)
val_labels_length = len(val_labels)

print("train_imgs 長度:", train_imgs_length)
print("train_labels 長度:", train_labels_length)
print("val_imgs 長度:", val_imgs_length)
print("val_labels 長度:", val_labels_length)

#### Dataset, DataLoader, and Transforms

In [None]:
class GeneralDataset(Dataset): # 用於 training set 和 validation set 的 Dataset
    def __init__(self, images, labels, img_size, mode='train'):
      assert mode in ['train', 'val'] # mode 必須是' train' 或 'val'
      self.img_size = img_size
      self.mode = mode
      self.images = images
      self.gts = labels
      self.filter_files()

      # image 預處理操作: 調整大小、轉換為 Tensor、標準化
      self.img_transform = transforms.Compose([
          transforms.Resize((self.img_size, self.img_size)),
          transforms.ToTensor(),
          transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

      # ground truth 預處理操作: 調整大小並轉換為 Tensor
      self.gt_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor()])

    def __getitem__(self, index):
      image = Image.open(self.images[index]).convert('RGB') # 打開圖像文件，轉換為RGB
      gt = Image.open(self.gts[index]).convert('L') # 打開圖像文件，轉換為灰階影像

      image = self.img_transform(image)
      gt = self.gt_transform(gt)
      return image, gt

    def filter_files(self):
      # 確保 image 與 ground truth 數量必須匹配
      assert len(self.images) == len(self.gts)
      images, gts = [], []
      for img_path, gt_path in zip(self.images, self.gts):
          img = Image.open(img_path)
          gt = Image.open(gt_path)
          if img.size == gt.size:
              images.append(img_path)
              gts.append(gt_path)
      self.images, self.gts = images, gts

    def __len__(self):
        return len(self.images)

In [None]:
class TestlDataset(Dataset):  # Used for the testing set dataset
    def __init__(self, images, img_size):
        self.images = images
        self.img_size = img_size
        self.img_transform = transforms.Compose([
            transforms.Resize((self.img_size, self.img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])

    def __getitem__(self, index):
        image = Image.open(self.images[index]).convert('RGB') # 打開圖像文件，轉換為RGB
        image = self.img_transform(image)
        name = os.path.basename(self.images[index]).replace('.jpg', '.png')
        return image, name

    def __len__(self):
        return len(self.images)

In [None]:
# 使用 DataLoader 來包裝 training set 和 validation set
# - mode: 設置數據集的模式，設定為 "train "則會做資料擴增
tr_datastet = GeneralDataset(images=train_imgs,
                labels=train_labels,
                img_size=256, mode="train")
val_datastet = GeneralDataset(images=val_imgs,
                labels=val_labels,
                img_size=256, mode="val")

# 建立 DataLoader 來加載 training set 和 validation set
tr_loader = DataLoader(dataset = tr_datastet, batch_size=64, shuffle=True,
                  num_workers=2, pin_memory=False)
val_loader = DataLoader(dataset = val_datastet, batch_size=64, shuffle=False,
                  num_workers=2, pin_memory=False)

In [None]:
# 檢查 DataLoader 是否成功載入 image 和 ground truth
data_iter = iter(tr_loader)
images, labels = next(data_iter)

image, label = images[0], labels[0]

image = image.numpy()
label = label.numpy()
image = np.transpose(image, (1, 2, 0))
label = np.transpose(label, (1, 2, 0))

if image.min() < 0 or image.max() > 1:
    image = (image - image.min()) / (image.max() - image.min())
plt.imshow(image)
plt.show()
plt.imshow(label, cmap='gray')
plt.show()

#### Trainning stage

In [None]:
# define loss function
bce_loss_module = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

In [None]:
from models.autoencoderPatchGAN import VGG16Generator, ConditionalDiscriminator

# define model
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# autoencoder-PatchGAN
autoencoder_generator = VGG16Generator().to(device)
autoencoder_discriminator = ConditionalDiscriminator().to(device)

In [None]:
# define optimizer
# autoencoder-PatchGAN
optimizer_autoencoder_generator = optim.Adam(autoencoder_generator.parameters(), lr=0.0002, betas=(0.5, 0.99))
optimizer_autoencoder_discriminator = optim.Adam(autoencoder_discriminator.parameters(), lr=0.0002, betas=(0.5, 0.99))

##### training loop (autoencoder-PatchGAN)

In [None]:
autoencoder_models_path = 'models_autoencoder' # 模型權重儲存路徑
os.makedirs(autoencoder_models_path, exist_ok=True)
os.makedirs("autoencoder_validation_output", exist_ok=True)

epoch_num = 1000
for epoch in range(epoch_num):
    ## -------------Training stage--------------
    autoencoder_generator.train()
    autoencoder_discriminator.train()

    loss_all_g = 0
    loss_all_d = 0
    epoch_step = 0
    for images, gts in tr_loader:
        images, gts = images.to(device), gts.to(device)

        # update discriminator
        real_d = autoencoder_discriminator(images, gts)
        real_loss_d = bce_loss_module(real_d, torch.ones_like(real_d))
        fake_images = autoencoder_generator(images).detach()
        fake_d = autoencoder_discriminator(images, fake_images)
        fake_loss_d = bce_loss_module(fake_d, torch.zeros_like(fake_d))
        loss_d = 0.5 * (real_loss_d + fake_loss_d)

        optimizer_autoencoder_discriminator.zero_grad()
        loss_d.backward()
        optimizer_autoencoder_discriminator.step()
        loss_all_d += loss_d.item()

        # update generator every 5 epochs
        if epoch % 5 == 0:
            fake_images = autoencoder_generator(images)
            fake_d = autoencoder_discriminator(images, fake_images)
            fake_loss_g = bce_loss_module(fake_d, torch.ones_like(fake_d))
            l1 = l1_loss(fake_images, gts) * 100
            loss_g = fake_loss_g + l1

            optimizer_autoencoder_generator.zero_grad()
            loss_g.backward()
            optimizer_autoencoder_generator.step()
            loss_all_g += loss_g.item()

        epoch_step += 1

    if epoch % 5 == 0:
        avg_train_loss_g = loss_all_g/ epoch_step
    avg_train_loss_d = loss_all_d/ epoch_step

    # 每 5 個 epoch 儲存一次模型權重
    if epoch % 5 == 0:
        model_save_path = os.path.join(autoencoder_models_path, f'Net_epoch_{epoch}.pth')
        torch.save(autoencoder_generator.state_dict(), model_save_path)
        print(f'Model saved at epoch {epoch}')

    # 每 5 個 epoch 保存一次生成的驗證圖像
    if epoch % 5 == 0:
        with torch.no_grad():
            for j, (val_img, val_label) in enumerate(val_loader):
                val_img = val_img.to(device)
                gen_val_label = autoencoder_generator(val_img)
                save_image(gen_val_label.data, f"autoencoder_validation_output/generated_{epoch}.png")
                break

    print(f"Epoch [{epoch+1}/{epoch_num}], Generator Train Loss: {avg_train_loss_g:.4f}, Discriminator Train Loss: {avg_train_loss_d:.4f}")