In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
import itertools
import os

# Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories for dataset and saving checkpoints
dataset_path = "./flowers"
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Hyperparameters
batch_size = 8  # Increased batch size for more stable training if GPU allows
learning_rate = 0.0002
num_epochs = 200
image_size = 256

# Dataset
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset_A = datasets.ImageFolder(root=os.path.join(dataset_path, 'A'), transform=transform)
dataset_B = datasets.ImageFolder(root=os.path.join(dataset_path, 'B'), transform=transform)
loader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True)
loader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True)

# Generator Model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            # Downsampling
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(True),
            # Residual Blocks
            *[ResidualBlock(256) for _ in range(12)],  # Increased number of residual blocks for a larger model
            # Upsampling
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.main(x)

# Initialize models
G_A2B = Generator().to(device)
G_B2A = Generator().to(device)
D_A = Discriminator().to(device)
D_B = Discriminator().to(device)

# Optimizers
optimizer_G = optim.Adam(itertools.chain(G_A2B.parameters(), G_B2A.parameters()), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=learning_rate, betas=(0.5, 0.999))

# Loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()

# Training Loop
best_loss = float('inf')
for epoch in range(num_epochs):
    for i, (real_A, real_B) in enumerate(zip(loader_A, loader_B)):
        real_A = real_A[0].to(device)
        real_B = real_B[0].to(device)

        # Train Generators
        optimizer_G.zero_grad()

        # GAN Loss A2B
        fake_B = G_A2B(real_A)
        pred_fake = D_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, torch.ones_like(pred_fake).to(device))

        # GAN Loss B2A
        fake_A = G_B2A(real_B)
        pred_fake = D_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, torch.ones_like(pred_fake).to(device))

        # Cycle Loss
        recovered_A = G_B2A(fake_B)
        recovered_B = G_A2B(fake_A)
        loss_cycle_A = criterion_cycle(recovered_A, real_A)
        loss_cycle_B = criterion_cycle(recovered_B, real_B)

        # Total Generator Loss
        loss_G = loss_GAN_A2B + loss_GAN_B2A + 10.0 * (loss_cycle_A + loss_cycle_B)
        loss_G.backward()
        optimizer_G.step()

        # Train Discriminator A
        optimizer_D_A.zero_grad()
        pred_real = D_A(real_A)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real).to(device))
        pred_fake = D_A(fake_A.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake).to(device))
        loss_D_A = (loss_D_real + loss_D_fake) * 0.5
        loss_D_A.backward()
        optimizer_D_A.step()

        # Train Discriminator B
        optimizer_D_B.zero_grad()
        pred_real = D_B(real_B)
        loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real).to(device))
        pred_fake = D_B(fake_B.detach())
        loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake).to(device))
        loss_D_B = (loss_D_real + loss_D_fake) * 0.5
        loss_D_B.backward()
        optimizer_D_B.step()

        if i % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch [{i}] Loss G: {loss_G.item()} Loss D_A: {loss_D_A.item()} Loss D_B: {loss_D_B.item()}")

    # Save checkpoint
    if (epoch + 1) % 10 == 0 or loss_G.item() < best_loss:
        torch.save(G_A2B.state_dict(), f"{checkpoint_dir}/G_A2B_{epoch+1}.pth")
        torch.save(G_B2A.state_dict(), f"{checkpoint_dir}/G_B2A_{epoch+1}.pth")
        best_loss = loss_G.item()

print("Training Complete!")


Epoch [0/200] Batch [0] Loss G: 12.94000244140625 Loss D_A: 0.4972791075706482 Loss D_B: 0.6664077639579773
Epoch [0/200] Batch [100] Loss G: 7.225565433502197 Loss D_A: 0.204205721616745 Loss D_B: 0.21398532390594482
Epoch [0/200] Batch [200] Loss G: 6.966743469238281 Loss D_A: 0.21361631155014038 Loss D_B: 0.18268409371376038
Epoch [0/200] Batch [300] Loss G: 5.600608825683594 Loss D_A: 0.21613283455371857 Loss D_B: 0.49126988649368286
Epoch [0/200] Batch [400] Loss G: 5.61020040512085 Loss D_A: 0.19709672033786774 Loss D_B: 0.1868860125541687
Epoch [1/200] Batch [0] Loss G: 3.9963855743408203 Loss D_A: 0.1904582679271698 Loss D_B: 0.24450989067554474
Epoch [1/200] Batch [100] Loss G: 4.690371036529541 Loss D_A: 0.2334848940372467 Loss D_B: 0.18060357868671417
Epoch [1/200] Batch [200] Loss G: 4.797698020935059 Loss D_A: 0.2057403326034546 Loss D_B: 0.16414861381053925
Epoch [1/200] Batch [300] Loss G: 4.4426703453063965 Loss D_A: 0.1968402862548828 Loss D_B: 0.16190281510353088
Epoc

In [6]:
import torch
from torchvision import transforms
from PIL import Image
import os

# Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories for inference
inference_dir = "./inference"
inference_output_dir = "./inference_colour"
os.makedirs(inference_output_dir, exist_ok=True)

# Load pre-trained generator
G_A2B = Generator().to(device)
G_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'G_A2B_1.pth')))
G_A2B.eval()

# Inference Transformation
inference_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Inference Loop
for img_name in os.listdir(inference_dir):
    img_path = os.path.join(inference_dir, img_name)
    image = Image.open(img_path).convert('RGB')
    input_tensor = inference_transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        colored_image = G_A2B(input_tensor).cpu().squeeze(0)
    
    # Denormalize and save image
    colored_image = transforms.ToPILImage()(colored_image * 0.5 + 0.5)  # Denormalize
    colored_image.save(os.path.join(inference_output_dir, img_name))

print("Inference Complete!")


Inference Complete!


  G_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'G_A2B_1.pth')))
