<a href="https://colab.research.google.com/github/conorgibbons147/cyclegan-map/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!rm -rf cyclegan-map # use if changes are made to the repo and you need to reclone

In [2]:
! git clone https://github.com/conorgibbons147/cyclegan-map.git

Cloning into 'cyclegan-map'...
remote: Enumerating objects: 2687, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (15/15), done.[K
remote: Total 2687 (delta 0), reused 13 (delta 0), pack-reused 2672 (from 2)[K
Receiving objects: 100% (2687/2687), 166.74 MiB | 13.42 MiB/s, done.
Resolving deltas: 100% (38/38), done.
Updating files: 100% (2624/2624), done.


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import itertools
import sys
import torchvision.transforms as transforms

In [4]:
sys.path.append('/content/cyclegan-map/models')
from generator import Generator
from discriminator import Discriminator

In [5]:
sys.path.append('/content/cyclegan-map')
from dataset import ImageDataset, HZDataset
from utils import weight_init, ReplayBuffer, sample_images

In [24]:
# setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 10
batch_size = 1
image_size = 256
save_interval = 20  # how often to save images
dataset_path = "/content/cyclegan-map/data"

In [25]:
# create networks
G_AB = Generator().to(device) # modern to vintage
G_BA = Generator().to(device) # vintage to modern
D_A = Discriminator().to(device) # check if fake vintage is real
D_B = Discriminator().to(device) # check if fake modern is real

# set weights for the networks
G_AB.apply(weight_init)
G_BA.apply(weight_init)
D_A.apply(weight_init)
D_B.apply(weight_init)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [26]:
# image loading setup
transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), # how each image will be transformed to make them standard
    transforms.ToTensor(),  # Converts image from [0,255] to [0.0,1.0]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Scales [0,1] to [-1,1]
])

train_dataset = ImageDataset(dataset_path, transform=transform, mode='train')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = ImageDataset(dataset_path, transform=transform, mode='val')
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)

In [27]:
# loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

lambda_cycle = 10.0               # lambdas are used to scale/weight the cycle and identity loss in our overall loss
lambda_identity = 5.0

In [28]:
# optimizers
optimizer_G = optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [29]:
# replay buffers - ensures that the discriminator is fed older saved fake images instead of just new ones, improves model performance
fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

In [30]:
# training loop
for epoch in range(epochs):
    for i, batch in enumerate(train_loader): # form: 0 img1  1 img2  2 img3 (this is what enumerate does, saves index as i)
        real_A = batch['A'].to(device)
        real_B = batch['B'].to(device)

        # ----- Generators -----
        optimizer_G.zero_grad() # clears gradients before backpropogating on the new batch

        # identity loss
        same_B = G_AB(real_B)
        loss_identity_B = criterion_identity(same_B, real_B) * lambda_identity # testing how different a vintage map becomes when fed into the modern->vintage generator,
                                                                               # should be the same in theory
        same_A = G_BA(real_A)
        loss_identity_A = criterion_identity(same_A, real_A) * lambda_identity # same for modern map when fed into vintage->modern generator

        # GAN loss - testing loss of the fake images compared to a completely real image
        fake_B = G_AB(real_A)
        pred_fake_B = D_B(fake_B)
        loss_GAN_AB = criterion_GAN(pred_fake_B, torch.ones_like(pred_fake_B)) # torch.ones_like() creates a tensor of same shape as pred_fake_B full of ones, acts as a
                                                                               # completely real image since generator outputs 1s when an image is deemed real, takes MSE loss
        fake_A = G_BA(real_B)
        pred_fake_A = D_A(fake_A)
        loss_GAN_BA = criterion_GAN(pred_fake_A, torch.ones_like(pred_fake_A))

        # cycle loss - loss when fake image is converted back to it's previous map type (ex. comparing original modern to modern->vinatage->modern duplicate)
        recov_A = G_BA(fake_B)
        loss_cycle_A = criterion_cycle(recov_A, real_A) * lambda_cycle

        recov_B = G_AB(fake_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B) * lambda_cycle

        # calculating total loss across the entire process
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_AB + loss_GAN_BA + loss_cycle_A + loss_cycle_B
        loss_G.backward()
        optimizer_G.step()

        # ----- Discriminators -----
        # A - testing how the real modern map compares to fake one
        optimizer_D_A.zero_grad()

        loss_real_A = criterion_GAN(D_A(real_A), torch.ones_like(D_A(real_A)))
        fake_A_buffered = fake_A_buffer.push_and_pop(fake_A)
        loss_fake_A = criterion_GAN(D_A(fake_A_buffered.detach()), torch.zeros_like(D_A(fake_A_buffered)))
        loss_D_A = (loss_real_A + loss_fake_A) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        # B - testing how the real vintage map compares to fake one
        optimizer_D_B.zero_grad()

        loss_real_B = criterion_GAN(D_B(real_B), torch.ones_like(D_B(real_B)))
        fake_B_buffered = fake_B_buffer.push_and_pop(fake_B)
        loss_fake_B = criterion_GAN(D_B(fake_B_buffered.detach()), torch.zeros_like(D_B(fake_B_buffered)))
        loss_D_B = (loss_real_B + loss_fake_B) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()

        # print statement
        print(f"[Epoch {epoch+1}/{epochs}] [Batch {i+1}/{len(train_loader)}] "
              f"[D_A loss: {loss_D_A.item():.4f}] [D_B loss: {loss_D_B.item():.4f}] [G loss: {loss_G.item():.4f}]")

        # saving images
        batches_done = epoch * len(train_loader) + i
        if batches_done % save_interval == 0:
            sample_images(batches_done, G_AB, G_BA, val_loader, device)

[Epoch 1/10] [Batch 1/107] [D_A loss: 1.6308] [D_B loss: 1.9136] [G loss: 28.7131]
[Epoch 1/10] [Batch 2/107] [D_A loss: 1.1466] [D_B loss: 1.5448] [G loss: 25.5577]
[Epoch 1/10] [Batch 3/107] [D_A loss: 1.4857] [D_B loss: 1.5781] [G loss: 21.7955]
[Epoch 1/10] [Batch 4/107] [D_A loss: 1.0988] [D_B loss: 1.5901] [G loss: 18.6709]
[Epoch 1/10] [Batch 5/107] [D_A loss: 1.6940] [D_B loss: 5.3862] [G loss: 17.7035]
[Epoch 1/10] [Batch 6/107] [D_A loss: 2.2379] [D_B loss: 1.6819] [G loss: 15.7480]
[Epoch 1/10] [Batch 7/107] [D_A loss: 2.0347] [D_B loss: 1.6523] [G loss: 13.4765]
[Epoch 1/10] [Batch 8/107] [D_A loss: 2.0024] [D_B loss: 1.8534] [G loss: 12.9674]
[Epoch 1/10] [Batch 9/107] [D_A loss: 1.7018] [D_B loss: 1.3538] [G loss: 10.8177]
[Epoch 1/10] [Batch 10/107] [D_A loss: 1.2374] [D_B loss: 1.3618] [G loss: 10.0569]
[Epoch 1/10] [Batch 11/107] [D_A loss: 1.5871] [D_B loss: 1.4194] [G loss: 10.1438]
[Epoch 1/10] [Batch 12/107] [D_A loss: 1.1720] [D_B loss: 1.3312] [G loss: 8.3800]
[E