In [None]:
black_images_dir = '../Dataset/A'  # 黑色轮廓图所在的目录
color_images_dir = '../Dataset/B'  # 彩色照片所在的目录

In [None]:
# 更新数据集类

import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class FloorPlanDataset(Dataset):
    def __init__(self, black_images_dir, color_images_dir, transform=None):
        self.black_images_dir = black_images_dir
        self.color_images_dir = color_images_dir
        self.transform = transform
        self.black_images = sorted(os.listdir(black_images_dir))
        self.color_images = sorted(os.listdir(color_images_dir))
    
    def __len__(self):
        return len(self.black_images)
    
    def __getitem__(self, idx):
        black_image_path = os.path.join(self.black_images_dir, self.black_images[idx])
        color_image_path = os.path.join(self.color_images_dir, self.color_images[idx])
        
        black_image = Image.open(black_image_path).convert('RGB')
        color_image = Image.open(color_image_path).convert('RGB')
        
        if self.transform:
            black_image = self.transform(black_image)
            color_image = self.transform(color_image)
        
        return black_image, color_image


In [None]:
# 图像预处理

transform = transforms.Compose([
    transforms.Resize((256, 256)),  # 调整图像大小到256x256
    transforms.ToTensor(),          # 将图像转换为张量
])


In [None]:
# 初始化数据加载器

# 初始化数据集
dataset = FloorPlanDataset(
    black_images_dir=black_images_dir,
    color_images_dir=color_images_dir,
    transform=transform
)

# 初始化数据加载器
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)


In [None]:
# 验证数据加载器

# 获取一个批次的数据
for black_images, color_images in dataloader:
    print(black_images.shape)  # 预期输出: [batch_size, 3, 256, 256]
    print(color_images.shape)  # 预期输出: [batch_size, 3, 256, 256]
    break


In [None]:
# 生成器

import torch
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, ngf=64):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            nn.Conv2d(ngf, ngf * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            nn.Conv2d(ngf * 2, ngf * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            nn.Conv2d(ngf * 4, ngf * 8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 8, ngf * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf, output_nc, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, input):
        return self.main(input)

# 初始化生成器
generator = Generator(input_nc=3, output_nc=3)
print(generator)


In [None]:
# 判别器

class Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )
    
    def forward(self, input):
        return self.main(input)

# 初始化判别器
discriminator = Discriminator(input_nc=6)  # 3 channels for the black image and 3 for the color image
print(discriminator)


In [None]:
# 定义损失函数

import torch.nn as nn

# 判别器的损失函数
criterion_GAN = nn.BCEWithLogitsLoss()

# 生成器的L1损失，用于生成图像与目标图像之间的像素差异
criterion_L1 = nn.L1Loss()


In [None]:
# 定义优化器

import torch.optim as optim

# 为生成器和判别器分别定义优化器
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))


In [None]:
# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 将模型移动到设备上
generator.to(device)
discriminator.to(device)


In [None]:
import torch

# 打印 PyTorch 版本
print("PyTorch version:", torch.__version__)

# 检查 CUDA 是否可用
print("CUDA available:", torch.cuda.is_available())

# 打印 PyTorch 检测到的 CUDA 版本
print("CUDA version (reported by PyTorch):", torch.version.cuda)

# 如果 CUDA 可用，打印设备名称
if torch.cuda.is_available():
    print("CUDA device count:", torch.cuda.device_count())
    print("Device name:", torch.cuda.get_device_name(0))
else:
    print("No CUDA devices detected.")


In [None]:
# 进行判别器前向传播，定义 output_real 和 output_fake
output_real = discriminator(real_input)
output_fake = discriminator(fake_input)

# 定义标签
real_labels = torch.ones_like(output_real)
fake_labels = torch.zeros_like(output_fake)


In [None]:
# 判别器的损失
loss_D_real = criterion_GAN(output_real, real_labels)
loss_D_fake = criterion_GAN(output_fake, fake_labels)

In [None]:
import torchvision.utils as vutils
import os

In [None]:
# 假设你有一个目录专门用来保存生成的图片
output_dir = 'Training/generated_images'
os.makedirs(output_dir, exist_ok=True)

# 开始训练循环
for epoch in range(num_epochs):
    for i, (black_images, color_images) in enumerate(dataloader):
        black_images = black_images.to(device)
        color_images = color_images.to(device)

        # 生成器前向传播
        fake_images = generator(black_images)

        # 更新判别器网络
        real_input = torch.cat((black_images, color_images), dim=1)
        fake_input = torch.cat((black_images, fake_images.detach()), dim=1)
        
        output_real = discriminator(real_input)
        output_fake = discriminator(fake_input)

        real_labels = torch.ones_like(output_real)
        fake_labels = torch.zeros_like(output_fake)

        loss_D_real = criterion_GAN(output_real, real_labels)
        loss_D_fake = criterion_GAN(output_fake, fake_labels)
        loss_D = (loss_D_real + loss_D_fake) / 2

        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # 更新生成器网络
        output_fake = discriminator(fake_input)
        loss_G_GAN = criterion_GAN(output_fake, real_labels)
        loss_G_L1 = criterion_L1(fake_images, color_images)
        loss_G = loss_G_GAN + 100 * loss_G_L1

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # 每个批次打印一次损失
        print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "
              f"Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")

    # 在每个epoch结束时保存生成的图片
    vutils.save_image(fake_images, os.path.join(output_dir, f"epoch_{epoch+1}.png"), normalize=True)
    print(f"Saved generated images for epoch {epoch+1}")
