In [20]:
import torch
from torchvision import transforms
from PIL import Image
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image
import itertools
import os

# Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories for dataset and saving checkpoints
dataset_path = "./flowers"
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Hyperparameters
batch_size = 8  # Increased batch size for more stable training if GPU allows
learning_rate = 0.0002
num_epochs = 200
image_size = 256

# Dataset
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

dataset_A = datasets.ImageFolder(root=os.path.join(dataset_path, 'A'), transform=transform)
dataset_B = datasets.ImageFolder(root=os.path.join(dataset_path, 'B'), transform=transform)
loader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True)
loader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True)

# Generator Model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            # Downsampling
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(True),
            # Residual Blocks
            *[ResidualBlock(256) for _ in range(12)],  # Increased number of residual blocks for a larger model
            # Upsampling
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )

    def forward(self, x):
        return self.main(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels)
        )

    def forward(self, x):
        return x + self.block(x)

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.main(x)

# Initialize models
G_A2B = Generator().to(device)
G_B2A = Generator().to(device)
D_A = Discriminator().to(device)
D_B = Discriminator().to(device)

# Optimizers
optimizer_G = optim.Adam(itertools.chain(G_A2B.parameters(), G_B2A.parameters()), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=learning_rate, betas=(0.5, 0.999))

# Loss functions
criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
# Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories for inference
inference_dir = "./inference"
inference_output_dir = "./inference_colour"
os.makedirs(inference_output_dir, exist_ok=True)

# Load pre-trained generator
G_A2B = Generator().to(device)
G_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'G_A2B_50.pth')))
G_A2B.eval()

# Inference Transformation
inference_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Inference Loop
for img_name in os.listdir(inference_dir):
    img_path = os.path.join(inference_dir, img_name)
    image = Image.open(img_path).convert('RGB')
    input_tensor = inference_transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        colored_image = G_A2B(input_tensor).cpu().squeeze(0)
    
    # Denormalize and save image
    colored_image = transforms.ToPILImage()(colored_image * 0.5 + 0.5)  # Denormalize
    colored_image.save(os.path.join(inference_output_dir, img_name))

print("Inference Complete!")


  G_A2B.load_state_dict(torch.load(os.path.join(checkpoint_dir, 'G_A2B_62.pth')))


Inference Complete!
