In [169]:
import torch 
import torch.nn as nn
import torch.nn.functional as f
from torch.utils.data import Dataset, DataLoader
import torch.utils as vutils
import torchvision.transforms as tr

import numpy as np

import matplotlib.pyplot as plt
import itertools

from datetime import datetime
import os


device = torch.device("mps")

import albumentations as A
from albumentations.pytorch import ToTensorV2

from PIL import Image

import random

from torchinfo import summary

from tqdm.notebook import tqdm

In [170]:

transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.1),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

In [171]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, strides):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, strides, 1, bias=True, padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )
        
    def forward(self, x):
        return self.conv(x)

In [172]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], kernel_size=4, stride=2, padding=1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                Block(in_channels, feature, strides=1 if feature == features[-1] else 2)
            )
            in_channels = feature
        
        layers.append(
            nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect")
        )
    
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

In [173]:
class Discriminator(nn.Module):
    def __init__(self, channels):
        super(Discriminator, self).__init__()

        self.channels = channels

        self.model = nn.Sequential(
            * self._create_layer_(self.channels, 64, 2, normalize = False),
            * self._create_layer_(64, 128, 2),
            * self._create_layer_(128, 256, 2),
            * self._create_layer_(256, 512, 1),
            nn.Conv2d(512, 1, 4, stride = 1, padding = 1)
        )

    def _create_layer_(self, size_in, size_out, stride, normalize = True):
        layers = [nn.Conv2d(size_in, size_out, 4, stride = stride, padding = 1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(size_out))
        layers.append(nn.LeakyReLU(0.2, inplace = True))
        return layers
    def forward(self, x):
        return self.model(x)

In [174]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )
    
    def forward(self, x):
        return self.conv(x)

    
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1, stride=1),
            ConvBlock(channels, channels, kernel_size=3, padding=1, use_act=False)
        )
    
    def forward(self, x):
        return x + self.block(x)
class Generator(nn.Module):
    def __init__(self, img_channels, num_residuals=9, num_features=64):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True)
        )
        
        self.down_blocks = nn.ModuleList([
            ConvBlock(num_features, num_features * 2, kernel_size=3, stride=2, padding=1),
            ConvBlock(num_features * 2, num_features * 4, kernel_size=3, stride=2, padding=1)
        ])
        
        self.residual_blocks = nn.Sequential(*[
            ResidualBlock(num_features * 4) for _ in range(num_residuals)
        ])
        
        self.up_blocks = nn.ModuleList([
            ConvBlock(num_features * 4, num_features * 2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ConvBlock(num_features * 2, num_features * 1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
        ])
        
        self.last = nn.Conv2d(num_features * 1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")
        
    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.residual_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

In [175]:
class HorseZebraDataset(Dataset):
    def __init__(self, zebra, horse, transform=None):
        self.horse = horse
        self.zebra = zebra
        self.transform = transform
        
        self.horse_images = os.listdir(horse)
        self.zebra_images = os.listdir(zebra)
        self.dataset_length = max(len(self.horse_images), len(self.zebra_images))
        self.horse_length = len(self.horse_images)
        self.zebra_length = len(self.zebra_images)
        
    def __len__(self):
        return self.dataset_length

    def __getitem__(self, index):
        zebra_img = self.zebra_images[index % self.zebra_length]
        horse_img = self.horse_images[index % self.horse_length]
        
        horse_path = os.path.join(self.horse, horse_img)
        zebra_path = os.path.join(self.zebra, zebra_img)
        
        zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
        horse_img = np.array(Image.open(horse_path).convert("RGB"))
        
        if self.transform:
            augmentations = self.transform(image=horse_img, image0=zebra_img)
            horse_img = augmentations["image"]
            zebra_img = augmentations["image0"]
        
        return horse_img, zebra_img

In [176]:
train_dir = "data/"


train_ds = HorseZebraDataset(
    zebra = "data/trainA", horse = "data/trainB",
    transform = transforms
)


train_dl = DataLoader(
    train_ds, batch_size = 4, shuffle = True, num_workers = 0, pin_memory = True
)

In [177]:
def seed_everything(seed = 45):
    os.environ["PYTHONASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.mps.manual_seed(seed)


seed_everything()

In [178]:
h_disc = Discriminator(channels=3).to(device)
z_disc = Discriminator(channels=3).to(device)
h_gen = Generator(img_channels=3, num_residuals=9).to(device)
z_gen = Generator(img_channels=3, num_residuals=9).to(device)
opt_disc = torch.optim.Adam(
    list(h_disc.parameters()) + list(z_disc.parameters()),
    lr=2e-4,
    betas=(0.5, 0.999),
)

opt_gen = torch.optim.Adam(
    list(z_gen.parameters()) + list(h_gen.parameters()),
    lr=2e-4,
    betas=(0.5, 0.999),
)

g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

num_epochs = 10

LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10



In [179]:
mse = f.mse_loss
l1 = f.l1_loss

In [180]:
from torchvision.utils import save_image


In [181]:
from tqdm import tqdm



for epoch in range(num_epochs):
    fake_horses = 0
    real_horses = 0
    save_count = 0
    loop = tqdm(train_dl, leave=True)
    
    fake_horses = 0
    real_horses = 0

    for idx, (horse, zebra) in enumerate(loop):
        horse = horse.to(device)
        zebra = zebra.to(device)

      
        fake_horse = h_gen(zebra)
        real_horse_disc = h_disc(horse)
        fake_horse_disc = h_disc(fake_horse.detach())
        real_horses += real_horse_disc.mean().item()
        fake_horses += fake_horse_disc.mean().item()
        real_horse_disc_loss = mse(real_horse_disc, torch.ones_like(real_horse_disc))
        fake_horse_disc_loss = mse(fake_horse_disc, torch.zeros_like(fake_horse_disc))
        horse_disc_loss = real_horse_disc_loss + fake_horse_disc_loss


        fake_zebra = z_gen(horse)
        real_zebra_disc = z_disc(zebra)
        fake_zebra_disc = z_disc(fake_zebra.detach())

        real_zebra_disc_loss = mse(real_zebra_disc, torch.ones_like(real_zebra_disc))
        fake_zebra_disc_loss = mse(fake_zebra_disc, torch.zeros_like(fake_zebra_disc))
        zebra_disc_loss = real_zebra_disc_loss + fake_zebra_disc_loss

        disc_loss = (zebra_disc_loss + horse_disc_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(disc_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # adversarial loss
        fake_horse_disc = h_disc(fake_horse)
        fake_zebra_disc = z_disc(fake_zebra)
        horse_gen_loss = mse(fake_horse_disc, torch.ones_like(fake_horse_disc))
        zebra_gen_loss = mse(fake_zebra_disc, torch.ones_like(fake_zebra_disc))

        # cycle loss
        cycled_zebra = z_gen(fake_horse) # fake horse is a horse generated from a zebra
        cycled_horse = h_gen(fake_zebra) # we generate a horse from a zebra image that is generated from a horse image. Should be the same
        cycled_zebra_loss = l1(zebra, cycled_zebra) # the diference between the original zebra image and the cyceled_zebra.
        cycled_horse_loss = l1(horse, cycled_horse)

        # identity loss
        zebra_identity = z_gen(zebra)
        horse_identity = h_gen(horse)
        zebra_identity_loss = l1(zebra, zebra_identity)
        horse_identity_loss = l1(horse, horse_identity)

        gen_loss = (
            horse_gen_loss + zebra_gen_loss
            + (cycled_zebra_loss * LAMBDA_CYCLE)
            + (cycled_horse_loss * LAMBDA_CYCLE)
            + (horse_identity_loss * LAMBDA_IDENTITY)
            + (zebra_identity_loss * LAMBDA_IDENTITY)
        )

        opt_gen.zero_grad()
        g_scaler.scale(gen_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
        
        if idx % 200 == 0:
            save_image(fake_horse * 0.5 + 0.5, f"horses_images/horse_{idx}_{save_count}.png")
            save_image(fake_zebra * 0.5 + 0.5, f"zebras_images/zebra_{idx}_{save_count}.png")
            save_count = save_count + 1

        loop.set_postfix(real_horse=real_horses / (idx + 1), fake_horse=fake_horses / (idx + 1))

  2%|▏         | 6/334 [00:10<09:15,  1.69s/it, fake_horse=0.366, real_horse=0.387]  


KeyboardInterrupt: 