In [None]:
# Install dependencies
!pip install torch torchvision matplotlib

# Download dataset
!mkdir -p data && cd data && curl -O http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz
!cd data && tar -xvzf facades.tar.gz

# Import libraries
import os
import torch
import random
from PIL import Image
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Dataset class
class Pix2PixDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.files = os.listdir(root_dir)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.files[idx])
        image = Image.open(img_path)
        w, h = image.size
        input_image = image.crop((0, 0, w//2, h))
        target_image = image.crop((w//2, 0, w, h))
        if self.transform:
            input_image = self.transform(input_image)
            target_image = self.transform(target_image)
        return input_image, target_image

# Generator (U-Net style, simplified)
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UNetGenerator, self).__init__()

        def block(in_c, out_c, down=True, act='relu', use_dropout=False):
            layers = []
            if down:
                layers.append(nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False))
            else:
                layers.append(nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False))
            layers.append(nn.BatchNorm2d(out_c))
            layers.append(nn.ReLU(True) if act == 'relu' else nn.LeakyReLU(0.2, True))
            if use_dropout:
                layers.append(nn.Dropout(0.5))
            return nn.Sequential(*layers)

        self.encoder = nn.Sequential(
            block(in_channels, 64, down=True, act='leaky'),
            block(64, 128, down=True, act='leaky'),
            block(128, 256, down=True, act='leaky'),
            block(256, 512, down=True, act='leaky'),
        )

        self.middle = block(512, 512, down=False, act='relu')

        self.decoder = nn.Sequential(
            block(512, 256, down=False, act='relu'),
            block(256, 128, down=False, act='relu'),
            block(128, 64, down=False, act='relu'),
            nn.ConvTranspose2d(64, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.middle(x)
        x = self.decoder(x)
        return x

# Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = Pix2PixDataset("data/facades/train", transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
generator = UNetGenerator().to(device)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002)

# Training loop
epochs = 5
for epoch in range(epochs):
    for i, (input_img, target_img) in enumerate(dataloader):
        input_img = input_img.to(device)
        target_img = target_img.to(device)
        optimizer.zero_grad()
        output = generator(input_img)
        loss = criterion(output, target_img)
        loss.backward()
        optimizer.step()
        if i % 50 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Step [{i}/{len(dataloader)}], Loss: {loss.item():.4f}")

# Visualize outputs
def show_images(inputs, outputs, targets):
    inputs = inputs.cpu().permute(0, 2, 3, 1)
    outputs = outputs.cpu().permute(0, 2, 3, 1)
    targets = targets.cpu().permute(0, 2, 3, 1)
    for i in range(inputs.shape[0]):
        fig, axs = plt.subplots(1, 3, figsize=(12, 4))
        axs[0].imshow((inputs[i] * 0.5 + 0.5).numpy())
        axs[0].set_title("Input")
        axs[1].imshow((outputs[i] * 0.5 + 0.5).numpy())
        axs[1].set_title("Generated")
        axs[2].imshow((targets[i] * 0.5 + 0.5).numpy())
        axs[2].set_title("Target")
        for ax in axs:
            ax.axis('off')
        plt.show()

# Inference
test_input, test_target = next(iter(dataloader))
test_input = test_input.to(device)
with torch.no_grad():
    test_output = generator(test_input)
show_images(test_input, test_output, test_target)
