In [1]:
import glob
import random
import os
import tqdm
import numpy as np
import torch
from torch import nn, Tensor
from torchvision.utils import save_image, make_grid

device = "cuda" if torch.cuda.is_available() else "cpu"
dataset_name = "monet2photo"

In [2]:
device

'cuda'

In [3]:
os.makedirs("images", exist_ok=True)

# 1. Dataset

In [1]:
# !bash download_cyclegan_dataset.sh monet2photo

# !mkdir -p "./data/monet2photo/train" "./data/monet2photo/test"

# !mv "./data/monet2photo/trainA" "./data/monet2photo/train/A"
# !mv "./data/monet2photo/trainB" "./data/monet2photo/train/B"
# !mv "./data/monet2photo/testA" "./data/monet2photo/test/A"
# !mv "./data/monet2photo/testB" "./data/monet2photo/test/B"

for details.

--2024-03-03 10:22:49--  https://efrosgans.eecs.berkeley.edu/cyclegan/datasets/monet2photo.zip
Resolving efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)... 128.32.244.190
Connecting to efrosgans.eecs.berkeley.edu (efrosgans.eecs.berkeley.edu)|128.32.244.190|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 305231073 (291M) [application/zip]
Saving to: ‘./monet2photo.zip’


2024-03-03 10:23:50 (5,00 MB/s) - ‘./monet2photo.zip’ saved [305231073/305231073]

Archive:  ./monet2photo.zip
   creating: ./data/monet2photo/
   creating: ./data/monet2photo/trainA/
  inflating: ./data/monet2photo/trainA/01159.jpg  
  inflating: ./data/monet2photo/trainA/01048.jpg  
  inflating: ./data/monet2photo/trainA/01144.jpg  
  inflating: ./data/monet2photo/trainA/00799.jpg  
  inflating: ./data/monet2photo/trainA/00897.jpg  
  inflating: ./data/monet2photo/trainA/00998.jpg  
  inflating: ./data/monet2photo/trainA/00883.jpg  
  inflating: ./data/monet2photo/tra

In [5]:
import torchvision 
from torch.utils.data import Dataset
from PIL import Image

In [6]:
class ImageDataset(Dataset):
    def __init__(self, root, transforms=None, unaligned=False, split="train"):
        self.transform = torchvision.transforms.Compose(transforms)
        self.unaligned = unaligned

        self.files_A = sorted(glob.glob(os.path.join(root, f"{split}/A") + "/*.*"))
        self.files_B = sorted(glob.glob(os.path.join(root, f"{split}/B") + "/*.*"))

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])

        if self.unaligned:
            image_B = Image.open(self.files_B[random.randint(0, len(self.files_B) - 1)])
        else:
            image_B = Image.open(self.files_B[index % len(self.files_B)])

        # grayscale to rgb
        if image_A.mode != "RGB":
            image_A = self.to_rgb(image_A)
        if image_B.mode != "RGB":
            image_B = self.to_rgb(image_B)

        item_A = self.transform(image_A)
        item_B = self.transform(image_B)
        return {"A": item_A, "B": item_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))


    def _to_rgb(self, image):
        rgb_image = Image.new("RGB", image.size)
        rgb_image.paste(image)
        return rgb_image

In [7]:
img_height = 256
img_width = 256

transforms = [
    torchvision.transforms.Resize(int(img_height * 1.12), Image.BICUBIC),
    torchvision.transforms.RandomCrop((img_height, img_width)),
    # torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]


train_set = ImageDataset(f"data/{dataset_name}", transforms=transforms, unaligned=True, split="train")
test_set = ImageDataset(f"data/{dataset_name}", transforms=transforms, unaligned=True, split="test")

In [8]:
from torch.utils.data import DataLoader

BATCH_SIZE = 8
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# 2. Model

In [9]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_shape, num_residual_blocks):
        super().__init__()

        channels = input_shape[0]
        out_features = 64

        model = [
            nn.ReflectionPad2d(channels),
            nn.Conv2d(channels, out_features, 7),
            nn.InstanceNorm2d(out_features),
            nn.ReLU(inplace=True),
        ]

        in_features = out_features

        # Downsampling
        for _ in range(2):
            out_features *= 2
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
        
        # Residual blocks
        for _ in range(num_residual_blocks):
            model += [ResidualBlock(out_features)]
        
        # Upsampling
        for _ in range(2):
            out_features //= 2
            model += [
                nn.Upsample(scale_factor=2),
                nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
        
        model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
        self.model = nn.Sequential(*model)
    
    def forward(self, x):
        return self.model(x)



class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super().__init__()

        channels, height, width = input_shape

        self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)

        def discriminator_block(in_features, out_features, normalize=True):
            layers = [nn.Conv2d(in_features, out_features, 4, stride=2, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_features))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1),
        )

    def forward(self, img):
        return self.model(img)

In [10]:
input_shape = (3, img_height, img_width)
n_residual_blocks = 3

G_AB = Generator(input_shape, n_residual_blocks).to(device)
G_BA = Generator(input_shape, n_residual_blocks).to(device)
D_A = Discriminator(input_shape).to(device)
D_B = Discriminator(input_shape).to(device)

# 3. Training

In [11]:
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0, "Empty buffer or trying to create a black hole. Be careful."
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0, 1) > 0.5:
                    i = random.randint(0, self.max_size - 1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return torch.cat(to_return)


def sample_images(batches_done):
    """Saves a generated sample from the validation set"""
    G_AB.eval()
    G_BA.eval()
    imgs = next(iter(val_loader))
    real_A = imgs["B"][:5, ...].to(device)
    fake_B = G_AB(real_A)
    real_B = imgs["A"][:5, ...].to(device)
    fake_A = G_BA(real_B)

    real_A = make_grid(real_A, nrow=5, normalize=True)
    real_B = make_grid(real_B, nrow=5, normalize=True)
    fake_A = make_grid(fake_A, nrow=5, normalize=True)
    fake_B = make_grid(fake_B, nrow=5, normalize=True)
    image_grid = torch.cat((real_A, fake_B, real_B, fake_A), 1)
    save_image(image_grid, f"images/{batches_done}.png", normalize=False)

In [12]:
import itertools

LR = 0.0002

optimizer_G = torch.optim.Adam(
    itertools.chain(G_AB.parameters(), G_BA.parameters()), 
    lr=LR, betas=(0.5, 0.999)
)

optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=LR, betas=(0.5, 0.999))
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=LR, betas=(0.5, 0.999))

In [13]:
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

In [14]:
EPOCHS = 50
save_interval = 10
lambda_cyc = 10
lambda_id = 5

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

hist = {
        "train_G_loss": [],
        "train_D_loss": [],
}
for epoch in range(EPOCHS):
    running_G_loss = 0.0
    running_D_loss = 0.0

    for batch in tqdm.tqdm(train_loader, total=len(train_loader)):

        real_imgs_A = batch["B"].to(device)
        real_imgs_B = batch["A"].to(device)

        valid = Tensor(np.ones((real_imgs_A.size(0), *D_A.output_shape))).to(device)
        fake = Tensor(np.zeros((real_imgs_A.size(0), *D_A.output_shape))).to(device)

        # --- Train Generator ---
        G_AB.train()
        G_BA.train()
        optimizer_G.zero_grad()

        # Identity loss
        loss_id_A = criterion_identity(G_BA(real_imgs_A), real_imgs_A)
        loss_id_B = criterion_identity(G_AB(real_imgs_B), real_imgs_B)
        loss_identity = (loss_id_A + loss_id_B) / 2
        
        # GAN loss
        fake_imgs_B = G_AB(real_imgs_A)
        loss_GAN_AB = criterion_GAN(D_B(fake_imgs_B), valid)
        fake_imgs_A = G_BA(real_imgs_B)
        loss_GAN_BA = criterion_GAN(D_A(fake_imgs_A), valid)
        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recovered_imgs_A = G_BA(fake_imgs_B)
        loss_cycle_A = criterion_cycle(recovered_imgs_A, real_imgs_A)
        recovered_imgs_B = G_AB(fake_imgs_A)
        loss_cycle_B = criterion_cycle(recovered_imgs_B, real_imgs_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + loss_cycle*lambda_cyc + loss_identity*lambda_id
        running_G_loss += loss_G.item()
        loss_G.backward()
        optimizer_G.step()


        # --- Train Discriminator A ---
        optimizer_D_A.zero_grad()

        loss_real = criterion_GAN(D_A(real_imgs_A), valid)
        
        fake_A_ = fake_A_buffer.push_and_pop(fake_imgs_A)
        loss_fake = criterion_GAN(D_A(fake_A_), fake)
        
        loss_D_A = (loss_real + loss_fake) / 2
        loss_D_A.backward()
        optimizer_D_A.step()

        # --- Train Discriminator B ---
        optimizer_D_B.zero_grad()
        
        loss_real = criterion_GAN(D_B(real_imgs_B), valid)

        fake_B_ = fake_B_buffer.push_and_pop(fake_imgs_B)
        loss_fake = criterion_GAN(D_B(fake_B_), fake)

        loss_D_B = (loss_real + loss_fake) / 2
        loss_D_B.backward()
        optimizer_D_B.step()
        
        # Total loss
        loss_D = (loss_D_A + loss_D_B) / 2
        running_D_loss += loss_D.item()

    epoch_G_loss = running_G_loss / len(train_loader)
    epoch_D_loss = running_D_loss / len(train_loader)

    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train G Loss: {epoch_G_loss:.4f}, Train D Loss: {epoch_D_loss:.4f}")

    hist["train_G_loss"].append(epoch_G_loss)
    hist["train_D_loss"].append(epoch_D_loss)

    if epoch % save_interval == 0:
        sample_images(epoch)


100%|██████████| 786/786 [09:06<00:00,  1.44it/s]

Epoch [1/50], Train G Loss: 3.9555, Train D Loss: 0.1835



100%|██████████| 786/786 [08:44<00:00,  1.50it/s]


Epoch [2/50], Train G Loss: 3.3419, Train D Loss: 0.1617


100%|██████████| 786/786 [08:59<00:00,  1.46it/s]


Epoch [3/50], Train G Loss: 3.1046, Train D Loss: 0.1725


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [4/50], Train G Loss: 2.9587, Train D Loss: 0.1761


100%|██████████| 786/786 [09:03<00:00,  1.45it/s]


Epoch [5/50], Train G Loss: 2.8967, Train D Loss: 0.1605


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [6/50], Train G Loss: 2.8164, Train D Loss: 0.1584


100%|██████████| 786/786 [08:43<00:00,  1.50it/s]


Epoch [7/50], Train G Loss: 2.6104, Train D Loss: 0.2217


100%|██████████| 786/786 [09:04<00:00,  1.44it/s]


Epoch [8/50], Train G Loss: 2.6902, Train D Loss: 0.1575


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [9/50], Train G Loss: 2.6428, Train D Loss: 0.1615


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [10/50], Train G Loss: 2.4746, Train D Loss: 0.2142


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]

Epoch [11/50], Train G Loss: 2.4874, Train D Loss: 0.1669



100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [12/50], Train G Loss: 2.4891, Train D Loss: 0.1688


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [13/50], Train G Loss: 2.5007, Train D Loss: 0.1437


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [14/50], Train G Loss: 2.4246, Train D Loss: 0.2106


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [15/50], Train G Loss: 2.4810, Train D Loss: 0.1235


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [16/50], Train G Loss: 2.5664, Train D Loss: 0.1001


100%|██████████| 786/786 [09:01<00:00,  1.45it/s]


Epoch [17/50], Train G Loss: 2.7438, Train D Loss: 0.2919


100%|██████████| 786/786 [09:12<00:00,  1.42it/s]


Epoch [18/50], Train G Loss: 2.2985, Train D Loss: 0.1579


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [19/50], Train G Loss: 2.3552, Train D Loss: 0.1326


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [20/50], Train G Loss: 2.4146, Train D Loss: 0.1132


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]

Epoch [21/50], Train G Loss: 2.7186, Train D Loss: 0.1772



100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [22/50], Train G Loss: 2.4311, Train D Loss: 0.1609


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [23/50], Train G Loss: 2.2624, Train D Loss: 0.2087


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [24/50], Train G Loss: 2.2002, Train D Loss: 0.1638


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [25/50], Train G Loss: 2.2276, Train D Loss: 0.1674


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [26/50], Train G Loss: 2.2860, Train D Loss: 0.1396


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [27/50], Train G Loss: 2.2700, Train D Loss: 0.1441


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [28/50], Train G Loss: 2.2473, Train D Loss: 0.1386


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [29/50], Train G Loss: 2.2445, Train D Loss: 0.1335


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [30/50], Train G Loss: 2.2673, Train D Loss: 0.1228


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]

Epoch [31/50], Train G Loss: 2.6849, Train D Loss: 0.2226



100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [32/50], Train G Loss: 2.2374, Train D Loss: 0.1984


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [33/50], Train G Loss: 1.9957, Train D Loss: 0.1894


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [34/50], Train G Loss: 2.0515, Train D Loss: 0.1661


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [35/50], Train G Loss: 2.1081, Train D Loss: 0.1579


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [36/50], Train G Loss: 2.1753, Train D Loss: 0.1199


100%|██████████| 786/786 [08:48<00:00,  1.49it/s]


Epoch [37/50], Train G Loss: 2.2081, Train D Loss: 0.1098


100%|██████████| 786/786 [09:23<00:00,  1.40it/s]


Epoch [38/50], Train G Loss: 2.2252, Train D Loss: 0.1037


100%|██████████| 786/786 [08:43<00:00,  1.50it/s]


Epoch [39/50], Train G Loss: 2.2443, Train D Loss: 0.0959


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [40/50], Train G Loss: 2.2628, Train D Loss: 0.0903


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]

Epoch [41/50], Train G Loss: 2.2633, Train D Loss: 0.0861



100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [42/50], Train G Loss: 2.6482, Train D Loss: 0.2345


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [43/50], Train G Loss: 2.2444, Train D Loss: 0.2210


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [44/50], Train G Loss: 2.1789, Train D Loss: 0.1477


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [45/50], Train G Loss: 2.0243, Train D Loss: 0.1480


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [46/50], Train G Loss: 2.0692, Train D Loss: 0.1242


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [47/50], Train G Loss: 2.1126, Train D Loss: 0.1129


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [48/50], Train G Loss: 2.1241, Train D Loss: 0.1100


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]


Epoch [49/50], Train G Loss: 2.1386, Train D Loss: 0.1071


100%|██████████| 786/786 [08:38<00:00,  1.52it/s]

Epoch [50/50], Train G Loss: 2.3898, Train D Loss: 0.2482





In [17]:
sample_images(50)

In [15]:
os.makedirs("saved_models", exist_ok=True)
torch.save(G_AB.state_dict(), "saved_models/G_AB_%d.pth" % ( epoch))
torch.save(G_BA.state_dict(), "saved_models/G_BA_%d.pth" % ( epoch))
torch.save(D_A.state_dict(), "saved_models/D_A_%d.pth" % (epoch))
torch.save(D_B.state_dict(), "saved_models/D_B_%d.pth" % ( epoch))