<a href="https://colab.research.google.com/github/hicbcb/AI_course/blob/main/week10(%E6%95%99%E5%AD%B8%E7%94%A8)_%E9%BB%91%E7%99%BD%E7%85%A7%E8%BD%89%E5%BD%A9%E8%89%B2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
import os
import numpy as np

from PIL import Image

In [None]:
# 刪除資料夾
!rm -rf test_results/
!rm -rf colorized_images/
!rm -rf images/

In [None]:
import os

# 創建保存生成圖片的目錄
os.makedirs("colorized_images", exist_ok=True)
os.makedirs("test_results", exist_ok=True)

In [None]:
os.makedirs("images", exist_ok=True)

In [None]:
# 設置超參數
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 8
learning_rate = 2e-4
num_epochs = 20

In [None]:
# 定義生成器
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        # 編碼器層
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # 可以增加更多層
        )
        # 解碼器層
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x



In [None]:
# 定義判別器
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(4, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # 可以增加更多層
            nn.Conv2d(128, 1, 4, padding=1),
            nn.Sigmoid(),
        )

    def forward(self, img_A, img_B):
        img_input = torch.cat((img_A, img_B), 1)
        return self.model(img_input)

In [None]:
# 載入資料集
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

train_dataset = OxfordIIITPet(root='./data', split='trainval', download=True, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = OxfordIIITPet(root='./data', split='test', download=True, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)


In [None]:
# 初始化模型
generator = Generator().to(device)
discriminator = Discriminator().to(device)

In [None]:
# 定義優化器和損失函數
g_optimizer = optim.Adam(generator.parameters(), lr=learning_rate)
d_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate)
criterion_GAN = nn.BCELoss()
criterion_pixelwise = nn.L1Loss()

### 只生成結果

In [None]:
# 訓練模型
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(train_dataloader):
        gray_imgs = imgs[:, :1, :, :].to(device)  # 灰階圖像
        color_imgs = imgs.to(device)              # 彩色圖像

        # 訓練生成器
        fake_color = generator(gray_imgs)
        g_optimizer.zero_grad()

        pred_fake = discriminator(gray_imgs, fake_color)
        loss_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        loss_pixel = criterion_pixelwise(fake_color, color_imgs)

        g_loss = loss_GAN + 100 * loss_pixel
        g_loss.backward()
        g_optimizer.step()

        # 訓練判別器
        d_optimizer.zero_grad()

        pred_real = discriminator(gray_imgs, color_imgs)
        loss_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = discriminator(gray_imgs, fake_color.detach())
        loss_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        d_loss = (loss_real + loss_fake) / 2
        d_loss.backward()
        d_optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] Generator Loss: {g_loss.item():.4f}, Discriminator Loss: {d_loss.item():.4f}")

    # 保存生成的圖片
    save_image(fake_color, f"colorized_images/epoch_{epoch+1}.png", normalize=True)

Epoch [1/100] Generator Loss: 18.0594, Discriminator Loss: 0.2018
Epoch [2/100] Generator Loss: 19.0097, Discriminator Loss: 0.2038
Epoch [3/100] Generator Loss: 17.7795, Discriminator Loss: 0.1470
Epoch [4/100] Generator Loss: 18.6045, Discriminator Loss: 0.2541
Epoch [5/100] Generator Loss: 17.0424, Discriminator Loss: 0.2798
Epoch [6/100] Generator Loss: 17.1462, Discriminator Loss: 0.2292
Epoch [7/100] Generator Loss: 15.4228, Discriminator Loss: 0.1239
Epoch [8/100] Generator Loss: 17.7591, Discriminator Loss: 0.1299
Epoch [9/100] Generator Loss: 29.9288, Discriminator Loss: 0.1334
Epoch [10/100] Generator Loss: 19.4556, Discriminator Loss: 0.0547
Epoch [11/100] Generator Loss: 23.6850, Discriminator Loss: 0.0474
Epoch [12/100] Generator Loss: 26.4254, Discriminator Loss: 0.0477
Epoch [13/100] Generator Loss: 19.3647, Discriminator Loss: 0.0576
Epoch [14/100] Generator Loss: 21.8561, Discriminator Loss: 0.0243
Epoch [15/100] Generator Loss: 23.4359, Discriminator Loss: 0.0512
Epoc

KeyboardInterrupt: 

In [None]:
# 測試模型
generator.eval()
for i, (img, _) in enumerate(test_dataloader):
    gray_img = img[:, :1, :, :].to(device)
    with torch.no_grad():
        fake_color = generator(gray_img)
    save_image(fake_color, f"test_results/test_{i+1}.png", normalize=True)

### 生成圖片與真實圖片對比版

In [None]:
# 訓練模型(版本1)
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(train_dataloader):
        gray_imgs = imgs[:, :1, :, :].to(device)  # 灰階圖像
        color_imgs = imgs.to(device)         # 彩色圖像

        # 訓練生成器
        fake_color = generator(gray_imgs)
        g_optimizer.zero_grad()

        pred_fake = discriminator(gray_imgs, fake_color)
        loss_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        loss_pixel = criterion_pixelwise(fake_color, color_imgs)

        g_loss = loss_GAN + 100 * loss_pixel
        g_loss.backward()
        g_optimizer.step()

        # 訓練判別器
        d_optimizer.zero_grad()

        pred_real = discriminator(gray_imgs, color_imgs)
        loss_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = discriminator(gray_imgs, fake_color.detach())
        loss_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        d_loss = (loss_real + loss_fake) / 2
        d_loss.backward()
        d_optimizer.step()

        # 每隔300個batch顯示並保存生成結果和實際圖片的對比
        if (i+1) % 100 == 0:
            gray_imgs_repeated = gray_imgs.repeat(1, 3, 1, 1)
            comparison = torch.cat([gray_imgs_repeated, fake_color, color_imgs], dim=3)

            # 縮放圖像的數值範圍到[0, 1]
            comparison = (comparison - comparison.min()) / (comparison.max() - comparison.min())

            # 將對比圖像轉換為PIL圖像並顯示
            comparison_grid = make_grid(comparison, nrow=4)
            comparison_img = transforms.ToPILImage()(comparison_grid.cpu())
            display(comparison_img)

             # 在Jupyter Notebook中印出生成結果和實際圖片的對比
            fig, ax = plt.subplots(1, 3, figsize=(15, 5))

            # 縮放灰階圖像的數值範圍到[0, 1]
            gray_img = gray_imgs_repeated[0].permute(1, 2, 0).cpu().detach().numpy()
            gray_img = (gray_img - gray_img.min()) / (gray_img.max() - gray_img.min())
            ax[0].imshow(gray_img)
            ax[0].set_title("Input Grayscale")
            ax[0].axis('off')

            # 縮放生成的彩色圖像的數值範圍到[0, 1]
            fake_color_img = fake_color[0].permute(1, 2, 0).cpu().detach().numpy()
            fake_color_img = (fake_color_img - fake_color_img.min()) / (fake_color_img.max() - fake_color_img.min())
            ax[1].imshow(fake_color_img)
            ax[1].set_title("Generated")
            ax[1].axis('off')

            # 縮放真實的彩色圖像的數值範圍到[0, 1]
            real_color_img = color_imgs[0].permute(1, 2, 0).cpu().detach().numpy()
            real_color_img = (real_color_img - real_color_img.min()) / (real_color_img.max() - real_color_img.min())
            ax[2].imshow(real_color_img)
            ax[2].set_title("Real")
            ax[2].axis('off')

            plt.suptitle(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i+1}/{len(train_dataloader)}]")
            plt.tight_layout()

            # 將圖像轉換為numpy陣列,並將數值範圍縮放到[0, 1]
            fig.canvas.draw()
            img_array = np.array(fig.canvas.renderer._renderer)
            img_array = (img_array - img_array.min()) / (img_array.max() - img_array.min())

            # 保存對比圖
            plt.imsave(f"colorized_images/comparison_epoch_{epoch+1}_batch_{i+1}.png", img_array)

            plt.show()
            plt.close()

    print(f"Epoch [{epoch+1}/{num_epochs}] Generator Loss: {g_loss.item():.4f}, Discriminator Loss: {d_loss.item():.4f}")



In [None]:
# 訓練模型(版本2)
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(train_dataloader):
        gray_imgs = imgs[:, :1, :, :].to(device)  # 灰階圖像
        color_imgs = imgs.to(device)         # 彩色圖像

        # 訓練生成器
        fake_color = generator(gray_imgs)
        g_optimizer.zero_grad()

        pred_fake = discriminator(gray_imgs, fake_color)
        loss_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        loss_pixel = criterion_pixelwise(fake_color, color_imgs)

        g_loss = loss_GAN + 100 * loss_pixel
        g_loss.backward()
        g_optimizer.step()

        # 訓練判別器
        d_optimizer.zero_grad()

        pred_real = discriminator(gray_imgs, color_imgs)
        loss_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = discriminator(gray_imgs, fake_color.detach())
        loss_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        d_loss = (loss_real + loss_fake) / 2
        d_loss.backward()
        d_optimizer.step()

        # 每隔300個batch顯示並保存生成結果和實際圖片的對比
        if (i+1) % 300 == 0:
            gray_imgs_repeated = gray_imgs.repeat(1, 3, 1, 1)
            comparison = torch.cat([gray_imgs_repeated, fake_color, color_imgs], dim=3)

            # 縮放圖像的數值範圍到[0, 1]
            comparison = (comparison - comparison.min()) / (comparison.max() - comparison.min())

            # 將對比圖像轉換為PIL圖像並顯示
            comparison_grid = make_grid(comparison, nrow=4)
            comparison_img = transforms.ToPILImage()(comparison_grid.cpu())
            display(comparison_img)

            # 在Jupyter Notebook中印出生成結果和實際圖片的對比
            fig, ax = plt.subplots(1, 3, figsize=(15, 5))

            # 縮放灰階圖像的數值範圍到[0, 1]
            gray_img = gray_imgs_repeated[0].permute(1, 2, 0).cpu().detach().numpy()
            gray_img = (gray_img - gray_img.min()) / (gray_img.max() - gray_img.min())
            ax[0].imshow(gray_img)
            ax[0].set_title("Input Grayscale")
            ax[0].axis('off')

            # 縮放生成的彩色圖像的數值範圍到[0, 1]
            fake_color_img = fake_color[0].permute(1, 2, 0).cpu().detach().numpy()
            fake_color_img = (fake_color_img - fake_color_img.min()) / (fake_color_img.max() - fake_color_img.min())
            ax[1].imshow(fake_color_img)
            ax[1].set_title("Generated")
            ax[1].axis('off')

            # 縮放真實的彩色圖像的數值範圍到[0, 1]
            real_color_img = color_imgs[0].permute(1, 2, 0).cpu().detach().numpy()
            real_color_img = (real_color_img - real_color_img.min()) / (real_color_img.max() - real_color_img.min())
            ax[2].imshow(real_color_img)
            ax[2].set_title("Real")
            ax[2].axis('off')

            plt.suptitle(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i+1}/{len(train_dataloader)}]")
            plt.tight_layout()

            # 將圖像轉換為numpy陣列,並將數值範圍縮放到[0, 255]
            fig.canvas.draw()
            img_array = np.array(fig.canvas.renderer._renderer)
            img_array = (img_array * 255).astype(np.uint8)

            # 使用PIL保存圖像
            pil_img = Image.fromarray(img_array)
            pil_img.save(f"colorized_images/comparison_epoch_{epoch+1}_batch_{i+1}.png")

            plt.show()
            plt.close()

    print(f"Epoch [{epoch+1}/{num_epochs}] Generator Loss: {g_loss.item():.4f}, Discriminator Loss: {d_loss.item():.4f}")



In [None]:
# 測試模型
generator.eval()
for i, (img, _) in enumerate(test_dataloader):
    gray_img = img[:, :1, :, :].to(device)
    with torch.no_grad():
        fake_color = generator(gray_img)
    save_image(fake_color, f"test_results/test_{i+1}.png", normalize=True)

### 黑白對比版

In [None]:
# 訓練模型
for epoch in range(num_epochs):
    for i, (imgs, _) in enumerate(train_dataloader):
        gray_imgs = imgs[:, :1, :, :].to(device)  # 灰階圖像
        color_imgs = imgs.to(device)              # 彩色圖像

        # 訓練生成器
        fake_color = generator(gray_imgs)
        g_optimizer.zero_grad()

        pred_fake = discriminator(gray_imgs, fake_color)
        loss_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake))
        loss_pixel = criterion_pixelwise(fake_color, color_imgs)

        g_loss = loss_GAN + 100 * loss_pixel
        g_loss.backward()
        g_optimizer.step()

        # 訓練判別器
        d_optimizer.zero_grad()

        pred_real = discriminator(gray_imgs, color_imgs)
        loss_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

        pred_fake = discriminator(gray_imgs, fake_color.detach())
        loss_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

        d_loss = (loss_real + loss_fake) / 2
        d_loss.backward()
        d_optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}] Generator Loss: {g_loss.item():.4f}, Discriminator Loss: {d_loss.item():.4f}")

    # 保存生成的圖片
    # 保存黑白圖和上色後的對比圖
    comparison = torch.cat([gray_imgs, fake_color, color_imgs], dim=3)
    save_image(comparison, f"colorized_images/comparison_epoch_{epoch+1}.png", nrow=4, normalize=True)

In [None]:
# 測試模型
generator.eval()
for i, (img, _) in enumerate(test_dataloader):
    gray_img = img[:, :1, :, :].to(device)
    with torch.no_grad():
        fake_color = generator(gray_img)
    comparison = torch.cat([gray_img, fake_color, img], dim=3)
    save_image(comparison, f"test_results/comparison_test_{i+1}.png", normalize=True)