In [None]:
# Install PyTorch and Torchvision (if not pre-installed)
!pip install torch torchvision

# Install additional libraries
!pip install matplotlib Pillow scikit-learn


In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [None]:
# Dataset Definition
class TreeDataset(Dataset):
    def __init__(self, normal_dir, decorated_dir, indices=None, transform=None):
        self.normal_paths = sorted([os.path.join(normal_dir, img) for img in os.listdir(normal_dir)])
        self.decorated_paths = sorted([os.path.join(decorated_dir, img) for img in os.listdir(decorated_dir)])
        if indices is not None:
            self.normal_paths = [self.normal_paths[i] for i in indices]
            self.decorated_paths = [self.decorated_paths[i] for i in indices]
        self.transform = transform

    def __len__(self):
        return len(self.normal_paths)

    def __getitem__(self, idx):
        normal_image = Image.open(self.normal_paths[idx]).convert('RGB')
        decorated_image = Image.open(self.decorated_paths[idx]).convert('RGB')
        if self.transform:
            normal_image = self.transform(normal_image)
            decorated_image = self.transform(decorated_image)
        return normal_image, decorated_image


In [None]:
# Residual Block Definition
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(channels),
        )

    def forward(self, x):
        return x + self.block(x)

In [None]:
# Generator with Residual Blocks
class GeneratorResNet(nn.Module):
    def __init__(self, input_channels=3, output_channels=3, num_residuals=9):
        super(GeneratorResNet, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.downsampling = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True),
        )
        self.residuals = nn.Sequential(*[ResidualBlock(256) for _ in range(num_residuals)])
        self.upsampling = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        )
        self.output = nn.Sequential(
            nn.Conv2d(64, output_channels, kernel_size=7, stride=1, padding=3),
            nn.Tanh(),
        )

    def forward(self, x):
        x = self.initial(x)
        x = self.downsampling(x)
        x = self.residuals(x)
        x = self.upsampling(x)
        return self.output(x)



In [None]:
# Discriminator Definition (PatchGAN)
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1),
        )

    def forward(self, x):
        return self.model(x)


In [None]:
def train_cyclegan(generator_g, generator_f, discriminator_x, discriminator_y, train_loader, val_loader, device, num_epochs=100):
    opt_g = torch.optim.Adam(list(generator_g.parameters()) + list(generator_f.parameters()), lr=1e-4, betas=(0.5, 0.999))
    opt_d_x = torch.optim.Adam(discriminator_x.parameters(), lr=1e-4, betas=(0.5, 0.999))
    opt_d_y = torch.optim.Adam(discriminator_y.parameters(), lr=1e-4, betas=(0.5, 0.999))

    mse_loss = nn.MSELoss()
    l1_loss = nn.L1Loss()

    for epoch in range(num_epochs):
        generator_g.train()
        generator_f.train()

        total_loss_g, total_loss_d_x, total_loss_d_y = 0, 0, 0

        for normal, decorated in train_loader:
            normal = normal.to(device)
            decorated = decorated.to(device)

            # Train Generators
            opt_g.zero_grad()
            fake_decorated = generator_g(normal)
            fake_normal = generator_f(decorated)

            loss_g_adv = mse_loss(discriminator_y(fake_decorated), torch.ones_like(discriminator_y(fake_decorated)))
            cycle_normal = generator_f(fake_decorated)
            cycle_decorated = generator_g(fake_normal)
            loss_cycle = l1_loss(cycle_normal, normal) + l1_loss(cycle_decorated, decorated)

            loss_g = loss_g_adv + 10 * loss_cycle
            loss_g.backward()
            opt_g.step()

            # Train Discriminators
            opt_d_x.zero_grad()
            opt_d_y.zero_grad()

            loss_d_x_real = mse_loss(discriminator_x(decorated), torch.ones_like(discriminator_x(decorated)))
            loss_d_x_fake = mse_loss(discriminator_x(fake_decorated.detach()), torch.zeros_like(discriminator_x(fake_decorated.detach())))
            loss_d_x = (loss_d_x_real + loss_d_x_fake) * 0.5

            loss_d_y_real = mse_loss(discriminator_y(normal), torch.ones_like(discriminator_y(normal)))
            loss_d_y_fake = mse_loss(discriminator_y(fake_normal.detach()), torch.zeros_like(discriminator_y(fake_normal.detach())))
            loss_d_y = (loss_d_y_real + loss_d_y_fake) * 0.5

            loss_d_x.backward()
            loss_d_y.backward()
            opt_d_x.step()
            opt_d_y.step()

            total_loss_g += loss_g.item()
            total_loss_d_x += loss_d_x.item()
            total_loss_d_y += loss_d_y.item()

        # Show Images
        if (epoch + 1) % 10 == 0:
            generator_g.eval()
            generator_f.eval()
            with torch.no_grad():
                for i, (normal, decorated) in enumerate(val_loader):
                    if i >= 3:
                        break

                    normal = normal.to(device)
                    decorated = decorated.to(device)

                    fake_decorated = generator_g(normal)
                    fake_normal = generator_f(decorated)

                    normal_np = normal[0].permute(1, 2, 0).cpu().numpy()
                    decorated_np = decorated[0].permute(1, 2, 0).cpu().numpy()
                    fake_decorated_np = fake_decorated[0].permute(1, 2, 0).cpu().numpy()
                    fake_normal_np = fake_normal[0].permute(1, 2, 0).cpu().numpy()

                    fig, axs = plt.subplots(1, 4, figsize=(15, 5))
                    axs[0].imshow((normal_np + 1) / 2)
                    axs[0].set_title("Normal Tree")
                    axs[1].imshow((decorated_np + 1) / 2)
                    axs[1].set_title("Decorated Tree")
                    axs[2].imshow((fake_decorated_np + 1) / 2)
                    axs[2].set_title("Fake Decorated Tree")
                    axs[3].imshow((fake_normal_np + 1) / 2)
                    axs[3].set_title("Fake Normal Tree")
                    plt.show()

        print(f"Epoch [{epoch+1}/{num_epochs}] Loss G: {total_loss_g:.4f}, Loss D_x: {total_loss_d_x:.4f}, Loss D_y: {total_loss_d_y:.4f}")

In [None]:
# Parameters
batch_size = 4
image_size = 128
num_epochs = 300
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Transformations for the dataset
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [None]:
# Paths to the dataset directories
normal_tree_dir = "/kaggle/input/christmas-trees/undecorated"
decorated_tree_dir = "/kaggle/input/christmas-trees/decorated"  

# Dataset Preparation
normal_images = sorted(os.listdir(normal_tree_dir))
decorated_images = sorted(os.listdir(decorated_tree_dir))

# Ensure datasets have the same length (to avoid imbalance issues)
min_dataset_size = min(len(normal_images), len(decorated_images))

normal_indices = list(range(len(normal_images)))[:min_dataset_size]
decorated_indices = list(range(len(decorated_images)))[:min_dataset_size]

In [None]:
# Train-validation split
train_indices, val_indices = train_test_split(
    normal_indices, test_size=0.2, random_state=42
)

# Create dataset objects
train_dataset = TreeDataset(
    normal_dir=normal_tree_dir,
    decorated_dir=decorated_tree_dir,
    indices=train_indices,
    transform=transform,
)

val_dataset = TreeDataset(
    normal_dir=normal_tree_dir,
    decorated_dir=decorated_tree_dir,
    indices=val_indices,
    transform=transform,
)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
# Initialize CycleGAN Models
generator_g = GeneratorResNet().to(device)  # Normal → Decorated
generator_f = GeneratorResNet().to(device)  # Decorated → Normal
discriminator_x = Discriminator().to(device)  # For Decorated Trees
discriminator_y = Discriminator().to(device)  # For Normal Trees

# Training the CycleGAN
train_cyclegan(
    generator_g=generator_g,
    generator_f=generator_f,
    discriminator_x=discriminator_x,
    discriminator_y=discriminator_y,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    num_epochs=num_epochs,
)
