参考：
1. https://blog.csdn.net/qq_57886603/article/details/122051538
2. https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/util/get_data.py

## 下载facades数据集
链接：http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz

解压以后分好了训练集，测试集，验证集

In [1]:
import os
import requests
import shutil
from tqdm import tqdm  

dataset_list=['cityscapes','edges2handbags','edges2shoes','facades','maps','night2day']

def download_dataset(dataset_name='facades'):
    extract_dir=f'./../data/pix2pix_datasets'
    if not os.path.exists(extract_dir):
        os.makedirs(extract_dir)

    dataset_url=f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{dataset_name}.tar.gz'
    with requests.get(dataset_url, stream=True) as response:
        response.raise_for_status()
        total_size = int(response.headers.get('content-length', 0))
        block_size = 1024  # 1 KB
        progress_bar = tqdm(total=total_size, unit='B', unit_scale=True)

        tar_file=os.path.join(extract_dir,f'{dataset_name}.tar.gz')
        with open(tar_file, 'wb') as file:
            for data in response.iter_content(block_size):
                progress_bar.update(len(data))
                file.write(data)
        progress_bar.close()
    
    # 解压文件
    shutil.unpack_archive(tar_file, extract_dir, 'gztar')
    # 删除压缩文件
    os.remove(tar_file)

download_dataset()

31.7MB [00:04, 7.66MB/s]                            


## 准备训练数据

In [2]:
from torch.utils.data.dataset import Dataset
import torch
from torchvision import transforms
from PIL import Image
import glob
import os

def get_data_path(data_type):
    dir_root=f'./pix2pix_datasets/facades'
    img_path_list = glob.glob(os.path.join(dir_root,f'{data_type}/*.jpg'))
    return img_path_list
    

class CustomDataset(Dataset):
    # 构造函数
    def __init__(self, img_path_list,img_size):
        self.img_path_list = img_path_list
        self.transforms=transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((img_size,img_size)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    # 返回数据集大小
    def __len__(self):
        return len(self.img_path_list)

    # 返回索引的数据与标签
    def __getitem__(self, index):
        AB = Image.open(self.img_path_list[index]).convert('RGB')
        # split AB image into A and B
        w, h = AB.size
        w2 = int(w / 2)
        origin_img = AB.crop((0, 0, w2, h))
        target_img = AB.crop((w2, 0, w, h))

        origin_img = self.transforms(origin_img)
        target_img = self.transforms(target_img)

        return (origin_img,target_img)
    
train_dataset=CustomDataset(get_data_path('train'),256)
val_dataset=CustomDataset(get_data_path('val'),256)
test_dataset=CustomDataset(get_data_path('test'),256)

batch_size=64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


# 构建生成器generator

In [4]:
import torch.nn as nn

class InnerMost(nn.Module):
    def __init__():
        super().__init__()
        downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
                             stride=2, padding=1, bias=use_bias)
        downrelu = nn.LeakyReLU(0.2, True)  
        downnorm = norm_layer(inzxedner_nc)
        uprelu = nn.ReLU(True)
        upconv = nn.ConvTranspose2d(
            inner_nc,
            outer_nc,
            kernel_size=4, 
            stride=2,padding=1, 
            bias=use_bias
            )

        upnorm = norm_layer(outer_nc)
        
        down = [downrelu, downconv]
        up = [uprelu, upconv, upnorm]
        model = down + up
    