In [14]:
import torch
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import numpy as np
import os
import cv2
from datasets import load_dataset
import torch.nn as nn  

# ===== 超參數（切割比例） =====
super_valid_and_tes_vs_train_ratio = 0.2  # 20% for valid+test
super_test_vs_alid_ratio = 0.5            # valid:test = 1:1

# 是否使用 YUV 色彩空間
use_yuv = True

# 載入生成器
class UNetGenerator_type8(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNetGenerator_type8, self).__init__()
        self.down1 = nn.Sequential(nn.Conv2d(in_channels, 64, 4, 2, 1), nn.LeakyReLU(0.2))
        self.down2 = nn.Sequential(nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))
        self.down3 = nn.Sequential(nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2))
        self.down4 = nn.Sequential(nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2))
        self.down5 = nn.Sequential(nn.Conv2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2))
        self.down6 = nn.Sequential(nn.Conv2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2))

        self.up1 = nn.Sequential(nn.ConvTranspose2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.Dropout(0.5), nn.ReLU())
        self.up2 = nn.Sequential(nn.ConvTranspose2d(1024, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.Dropout(0.5), nn.ReLU())
        self.up3 = nn.Sequential(nn.ConvTranspose2d(1024, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.Dropout(0.5), nn.ReLU())
        self.up4 = nn.Sequential(nn.ConvTranspose2d(512, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU())
        self.up5 = nn.Sequential(nn.ConvTranspose2d(256, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU())
        self.up6 = nn.Sequential(nn.ConvTranspose2d(128, 32, 4, 2, 1), nn.BatchNorm2d(32), nn.ReLU())
        self.up7 = nn.Sequential(nn.Conv2d(32, out_channels, 3, 1, 1))
        self.tanh = nn.Tanh()

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        u1 = self.up1(d6)
        u2 = self.up2(torch.cat([u1, d5], dim=1))
        u3 = self.up3(torch.cat([u2, d4], dim=1))
        u4 = self.up4(torch.cat([u3, d3], dim=1))
        u5 = self.up5(torch.cat([u4, d2], dim=1))
        u6 = self.up6(torch.cat([u5, d1], dim=1))
        u7 = self.up7(u6)
        return self.tanh(u7 + x)

# 👇 Attention 模組
class Self_Attn(nn.Module):
    def __init__(self, in_dim, activation):
        super(Self_Attn, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key_conv   = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)
        self.activation = activation

    def forward(self, x):
        B, C, W, H = x.size()
        proj_query = self.query_conv(x).view(B, -1, W*H).permute(0, 2, 1)  # B x N x C
        proj_key = self.key_conv(x).view(B, -1, W*H)                       # B x C x N
        energy = torch.bmm(proj_query, proj_key)                          # B x N x N
        attention = self.softmax(energy)
        proj_value = self.value_conv(x).view(B, -1, W*H)                  # B x C x N
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))           # B x C x N
        out = out.view(B, C, W, H)
        out = self.gamma * out + x
        return out, attention

# 👇 type9 Generator
class UNetGenerator_type9(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNetGenerator_type9, self).__init__()
        self.down1 = nn.Sequential(nn.Conv2d(in_channels, 64, 4, 2, 1), nn.LeakyReLU(0.2))
        self.down2 = nn.Sequential(nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2))
        self.down3 = nn.Sequential(nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2))
        self.attention1 = Self_Attn(256, activation=nn.LeakyReLU(0.2))
        self.down4 = nn.Sequential(nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2))
        self.down5 = nn.Sequential(nn.Conv2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2))
        self.down6 = nn.Sequential(nn.Conv2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2))
        self.attention2 = Self_Attn(512, activation=nn.LeakyReLU(0.2))

        self.up1 = nn.Sequential(nn.ConvTranspose2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.Dropout(0.5), nn.ReLU())
        self.up2 = nn.Sequential(nn.ConvTranspose2d(1024, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.Dropout(0.5), nn.ReLU())
        self.up3 = nn.Sequential(nn.ConvTranspose2d(1024, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.Dropout(0.5), nn.ReLU())
        self.up4 = nn.Sequential(nn.ConvTranspose2d(512, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU())
        self.up5 = nn.Sequential(nn.ConvTranspose2d(256, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU())
        self.up6 = nn.Sequential(nn.ConvTranspose2d(128, 32, 4, 2, 1), nn.BatchNorm2d(32), nn.ReLU())
        self.up7 = nn.Sequential(nn.Conv2d(32, out_channels, 3, 1, 1))
        self.tanh = nn.Tanh()

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d3, _ = self.attention1(d3)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d6, _ = self.attention2(d6)
        u1 = self.up1(d6)
        u2 = self.up2(torch.cat([u1, d5], dim=1))
        u3 = self.up3(torch.cat([u2, d4], dim=1))
        u4 = self.up4(torch.cat([u3, d3], dim=1))
        u5 = self.up5(torch.cat([u4, d2], dim=1))
        u6 = self.up6(torch.cat([u5, d1], dim=1))
        u7 = self.up7(u6)
        return self.tanh(u7 + x)
# 初始化模型與載入權重
model = 'type93_netG_epoch_0020'
#model = 'type93_netG_epoch_0050'
netG = UNetGenerator_type9().cuda()
netG.load_state_dict(torch.load(f'./output_pix2pix/{model}.pth'))
netG.eval()

# YUV → RGB 還原函數
def yuv_tensor_to_rgb_tensor(yuv_tensor):
    yuv = yuv_tensor.permute(1, 2, 0).cpu().numpy()
    yuv = (yuv * 0.5 + 0.5) * 255.0
    yuv = yuv.astype('uint8')
    bgr = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR)
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    rgb = rgb.astype('float32') / 255.0
    return torch.from_numpy(rgb).permute(2, 0, 1).clamp(0, 1)

# 圖片轉換流程
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 建立輸出資料夾
output_dir = f'./{model}_gen'
os.makedirs(output_dir, exist_ok=True)

# ========================
# 📦 載入與切分資料集
# ========================
dataset = load_dataset("MichaelP84/manga-colorization-dataset", split="train", cache_dir="/home/dcmc/Data/xinjia/GAI")
train_validtest = dataset.train_test_split(test_size=super_valid_and_tes_vs_train_ratio, seed=42)
valid_test = train_validtest["test"].train_test_split(test_size=super_test_vs_alid_ratio, seed=42)
test_dataset = valid_test["test"]

# ========================
# 🖌 對 test dataset 上色
# ========================
for i, sample in enumerate(test_dataset):
    bw_img = Image.fromarray(np.array(sample['bw_image'])).convert("RGB")
    if use_yuv:
        bw_yuv = cv2.cvtColor(np.array(bw_img), cv2.COLOR_RGB2YUV)
        bw_img = Image.fromarray(bw_yuv)

    input_tensor = transform(bw_img).unsqueeze(0).cuda()

    with torch.no_grad():
        fake_color = netG(input_tensor)

    if use_yuv:
        output_tensor = yuv_tensor_to_rgb_tensor(fake_color[0])
    else:
        output_tensor = (fake_color[0] * 0.5 + 0.5).clamp(0, 1)

    save_path = os.path.join(output_dir, f"test_{i:04d}.png")
    save_image(output_tensor, save_path)
    print(f"[{i}] ✅ 已上色並儲存：{save_path}")
    if i >= 100:  # 生成100張
        break


Resolving data files:   0%|          | 0/43 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/43 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/36 [00:00<?, ?it/s]

[0] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0000.png
[1] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0001.png
[2] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0002.png
[3] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0003.png
[4] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0004.png
[5] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0005.png
[6] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0006.png
[7] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0007.png
[8] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0008.png
[9] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0009.png
[10] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0010.png
[11] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0011.png
[12] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0012.png
[13] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0013.png
[14] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0014.png
[15] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0015.png
[16] ✅ 已上色並儲存：./type93_netG_epoch_0020_gen\test_0016.png
[17] ✅ 已上色並儲存：./type93_netG_epoch_0020_ge