In [2]:
# Common
import os
import numpy as np
from glob import glob
from tqdm import tqdm
from random import random

import matplotlib.pyplot as plt

In [3]:
import torch
from PIL import Image
from torchvision import transforms

torch.device("cuda" if torch.cuda.is_available() else "cpu")

device(type='cuda')

In [4]:
A_path = "data/trainA"
B_path = "data/trainB"

A_paths_join = [os.path.join(A_path, f) for f in os.listdir(A_path) if f.endswith(".jpg")]
B_paths_join = [os.path.join(B_path, f) for f in os.listdir(B_path) if f.endswith(".jpg")]


print(f"Total A Paths: {len(A_paths_join)}")
print(f"Total B Paths: {len(B_paths_join)}")

Total A Paths: 1067
Total B Paths: 1334


In [5]:
SIZE = 256

# Transform = resize + to tensor + normalize to [0,1]
transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor(),   # (H,W,C) → (C,H,W) + /255 automatically
])

horse_images = torch.zeros(len(A_paths_join), 3, SIZE, SIZE)
zebra_images = torch.zeros(len(A_paths_join), 3, SIZE, SIZE)

for i, (horse_path, zebra_path) in tqdm(enumerate(zip(A_paths_join, B_paths_join)), total=len(A_paths_join), desc="Loading"):
    
    # Horse
    horse_img = Image.open(horse_path).convert("RGB")
    horse_tensor = transform(horse_img)    # [3,256,256] float32 0–1
    
    # Zebra
    zebra_img = Image.open(zebra_path).convert("RGB")
    zebra_tensor = transform(zebra_img)

    horse_images[i] = horse_tensor
    zebra_images[i] = zebra_tensor

Loading: 100%|██████████| 1067/1067 [00:05<00:00, 200.63it/s]


In [6]:
# dataset = [horse_images, zebra_images]

In [7]:
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1, bias=False),
            nn.InstanceNorm2d(channels),
            nn.ReLU(True),
            nn.Conv2d(channels, channels, 3, padding=1, bias=False),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)


In [8]:
import torch.nn.functional as F

In [9]:
class Downsample(nn.Module):
    def __init__(self, in_channels, filters, size=3, stride=2, norm=True, activation=None):
        super().__init__()

        self.conv = nn.Conv2d(in_channels, filters, kernel_size=size, stride=stride, padding=size//2, bias=False)

        self.norm = nn.InstanceNorm2d(filters, affine=True) if norm else None

        if activation is not None:
            self.act = activation
        else:
            self.act = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.norm is not None:
            x = self.norm(x)
        x = self.act(x)
        return x


In [10]:
class Upsample(nn.Module):
    def __init__(self, in_channels, filters, size=3, stride=2):
        super().__init__()

        self.convT = nn.ConvTranspose2d(in_channels, filters, kernel_size=size, stride=stride,
                                        padding=size//2, output_padding=stride-1, bias=False)

        self.norm = nn.InstanceNorm2d(filters, affine=True)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.convT(x)
        x = self.norm(x)
        x = self.relu(x)
        return x


In [11]:
import torch.optim as optim


In [12]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, n_resnet=9):
        super().__init__()

        # Encoder
        self.down1 = Downsample(in_channels, 64, size=7, stride=1)   # 256×256
        self.down2 = Downsample(64, 128)                            # 128×128
        self.down3 = Downsample(128, 256)                           # 64×64

        # 9 ResNet blocks
        self.resblocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(n_resnet)]
        )

        # Decoder
        self.up1 = Upsample(256, 128)                               # 128×128
        self.up2 = Upsample(128, 64)                                # 256×256

        # Output
        self.out = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)

        x = self.resblocks(x)

        x = self.up1(x)
        x = self.up2(x)

        return self.out(x)

In [13]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()

        self.model = nn.Sequential(
            Downsample(in_channels, 64, size=4, stride=2, norm=False),
            Downsample(64, 128, size=4, stride=2),
            Downsample(128, 256, size=4, stride=2),
            Downsample(256, 512, size=4, stride=2),
            Downsample(512, 512, size=4, stride=2),

            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)  # PatchGAN
        )

    def forward(self, x):
        return self.model(x)


In [14]:
class CombineModel(nn.Module):
    def __init__(self, g_model1, g_model2, d_model, lr=2e-4):
        super().__init__()

        # Generator A→B
        self.G1 = g_model1
        # Generator B→A
        self.G2 = g_model2
        # Discriminator for B
        self.D  = d_model

        # Freeze G2 and D
        for p in self.G2.parameters():
            p.requires_grad = False
        for p in self.D.parameters():
            p.requires_grad = False

        # Losses (same as Keras)
        self.loss_gan = nn.MSELoss()   # mse
        self.loss_l1  = nn.L1Loss()    # mae

        self.opt = optim.Adam(self.G1.parameters(), lr=lr, betas=(0.5, 0.999))

    def forward(self, input_gen, input_id):
        # Adversarial
        gen_1_out = self.G1(input_gen)
        dis_out = self.D(gen_1_out)

        # Identity
        output_id = self.G1(input_id)

        # Cycle forward
        output_f = self.G2(gen_1_out)

        # Cycle backward
        gen_2_out = self.G2(input_id)
        output_b = self.G1(gen_2_out)

        return dis_out, output_id, output_f, output_b

    def train_step(self, input_gen, input_id):
        self.opt.zero_grad()

        dis_out, output_id, output_f, output_b = self.forward(input_gen, input_id)

        valid = torch.ones_like(dis_out)

        # losses exactly like Keras
        loss_adv = self.loss_gan(dis_out, valid) * 1
        loss_id  = self.loss_l1(output_id, input_id) * 5
        loss_fwd = self.loss_l1(output_f, input_gen) * 10
        loss_bwd = self.loss_l1(output_b, input_id) * 10

        loss = loss_adv + loss_id + loss_fwd + loss_bwd
        loss.backward()
        self.opt.step()

        return {
            "total": loss.item(),
            "adv": loss_adv.item(),
            "id": loss_id.item(),
            "fwd": loss_fwd.item(),
            "bwd": loss_bwd.item()
        }


In [15]:
def generate_real_samples(n_samples, dataset, device="cuda"):
    # Random indices
    ix = torch.randint(0, dataset.size(0), (n_samples,), device=dataset.device)

    # Select real images
    X = dataset[ix]

    # Real labels = 1
    y = torch.ones((n_samples, 1, 8, 8), device=dataset.device)

    return X, y


def generate_fake_samples(g_model, dataset):
    with torch.no_grad():   # same as model.predict()
        X = g_model(dataset)

    # Fake labels = 0
    y = torch.zeros((dataset.size(0), 1, 8, 8), device=dataset.device)

    return X, y


In [16]:
import random

def update_image_pool(pool, images, max_size=50):
    selected = []

    for i in range(images.size(0)):
        image = images[i].unsqueeze(0)   # keep batch dim [1,C,H,W]

        if len(pool) < max_size:
            pool.append(image)
            selected.append(image)

        elif random.random() < 0.5:
            selected.append(image)

        else:
            ix = random.randint(0, len(pool) - 1)
            selected.append(pool[ix])
            pool[ix] = image

    return torch.cat(selected, dim=0)

In [17]:
import matplotlib.pyplot as plt

def show_preds(g_AB, g_BA, n_images=1):
    g_AB.eval()
    g_BA.eval()

    with torch.no_grad():
        for i in range(n_images):

            idx = np.random.randint(len(horse_images))
            horse = horse_images[idx]     # [C,H,W]
            zebra = zebra_images[idx]

            # Add batch dimension
            horse_in = horse.unsqueeze(0)   # [1,C,H,W]
            zebra_in = zebra.unsqueeze(0)

            # Generate
            zebra_pred = g_AB(horse_in)[0]
            horse_pred = g_BA(zebra_in)[0]

            plt.figure(figsize=(10,8))

            plt.subplot(1,4,1)
            show_image(horse, title='Original Horse')

            plt.subplot(1,4,2)
            show_image(zebra_pred, title='Generated Zebra')

            plt.subplot(1,4,3)
            show_image(zebra, title='Original Zebra')

            plt.subplot(1,4,4)
            show_image(horse_pred, title='Generated Horse')

            plt.tight_layout()
            plt.show()


In [18]:
def train(
    d_model_A, d_model_B,
    gen_AB, gen_BA,
    opt_g, opt_dA, opt_dB,
    trainA, trainB,
    epochs=100, chunk=5,
    device="cuda"
):

    mse = torch.nn.MSELoss()
    l1  = torch.nn.L1Loss()

    poolA, poolB = [], []

    n_batch = 1
    bat_per_epoch = len(trainA)

    for epoch in tqdm(range(1, epochs+1), desc="Epochs"):

        for i in range(bat_per_epoch):

            # -----------------------
            #  Real samples
            # -----------------------
            X_realA, y_realA = generate_real_samples(n_batch, trainA)
            X_realB, y_realB = generate_real_samples(n_batch, trainB)

            X_realA = X_realA.to(device)
            X_realB = X_realB.to(device)
            y_realA = y_realA.to(device)
            y_realB = y_realB.to(device)

            # -----------------------
            #  Generate fake images
            # -----------------------
            X_fakeA, y_fakeA = generate_fake_samples(gen_BA, X_realB)
            X_fakeB, y_fakeB = generate_fake_samples(gen_AB, X_realA)

            X_fakeA = update_image_pool(poolA, X_fakeA)
            X_fakeB = update_image_pool(poolB, X_fakeB)

            # -----------------------
            #  Train Generator BA
            # -----------------------
            opt_g.zero_grad()

            fakeA = gen_BA(X_realB)
            pred_fake = d_model_A(fakeA)
            adv_loss = mse(pred_fake, y_realA)

            # identity
            idA = gen_BA(X_realA)
            id_loss = l1(idA, X_realA) * 5

            # cycle
            recB = gen_AB(fakeA)
            cycle_loss = l1(recB, X_realB) * 10

            # backward cycle
            fakeB = gen_AB(X_realA)
            recA = gen_BA(fakeB)
            back_loss = l1(recA, X_realA) * 10

            gen_loss_BA = adv_loss + id_loss + cycle_loss + back_loss
            gen_loss_BA.backward()
            opt_g.step()

            # -----------------------
            #  Train Discriminator A
            # -----------------------
            opt_dA.zero_grad()
            loss_real = mse(d_model_A(X_realA), y_realA)
            loss_fake = mse(d_model_A(X_fakeA.detach()), y_fakeA)
            dA_loss = (loss_real + loss_fake) * 0.5
            dA_loss.backward()
            opt_dA.step()

            # -----------------------
            #  Train Generator AB
            # -----------------------
            opt_g.zero_grad()

            fakeB = gen_AB(X_realA)
            pred_fake = d_model_B(fakeB)
            adv_loss = mse(pred_fake, y_realB)

            idB = gen_AB(X_realB)
            id_loss = l1(idB, X_realB) * 5

            recA = gen_BA(fakeB)
            cycle_loss = l1(recA, X_realA) * 10

            fakeA = gen_BA(X_realB)
            recB = gen_AB(fakeA)
            back_loss = l1(recB, X_realB) * 10

            gen_loss_AB = adv_loss + id_loss + cycle_loss + back_loss
            gen_loss_AB.backward()
            opt_g.step()

            # -----------------------
            #  Train Discriminator B
            # -----------------------
            opt_dB.zero_grad()
            loss_real = mse(d_model_B(X_realB), y_realB)
            loss_fake = mse(d_model_B(X_fakeB.detach()), y_fakeB)
            dB_loss = (loss_real + loss_fake) * 0.5
            dB_loss.backward()
            opt_dB.step()

        # -------------------------------------
        #  Visualization & Save
        # -------------------------------------
        if epoch % chunk == 0:
            show_preds(gen_AB, gen_BA, n_images=1)
            torch.save(gen_AB.state_dict(), "GeneratorHtoZ.pt")
            torch.save(gen_BA.state_dict(), "GeneratorZtoH.pt")


In [19]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Generators
g_AB = Generator().to(device)   # Horse → Zebra
g_BA = Generator().to(device)   # Zebra → Horse

# Discriminators
d_A = Discriminator().to(device)   # Judge Horses
d_B = Discriminator().to(device)   # Judge Zebras

In [20]:
lr = 2e-4
beta1 = 0.5

opt_g  = optim.Adam(
    list(g_AB.parameters()) + list(g_BA.parameters()),
    lr=lr, betas=(beta1, 0.999)
)

opt_dA = optim.Adam(d_A.parameters(), lr=lr, betas=(beta1, 0.999))
opt_dB = optim.Adam(d_B.parameters(), lr=lr, betas=(beta1, 0.999))

In [21]:
trainA = horse_images.to(device)   # [N,3,256,256]
trainB = zebra_images.to(device)   # [N,3,256,256]

In [24]:
train(
    d_model_A=d_A,
    d_model_B=d_B,
    gen_AB=g_AB,
    gen_BA=g_BA,
    opt_g=opt_g,
    opt_dA=opt_dA,
    opt_dB=opt_dB,
    trainA=trainA,
    trainB=trainB,
    epochs=10,
    chunk=5
)


Epochs:   0%|          | 0/10 [12:24<?, ?it/s]


KeyboardInterrupt: 