<a href="https://colab.research.google.com/github/navseducation/Gen-AI-Purdue-Course/blob/main/pix2pix_satellite_to_map.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Pix2Pix: Satellite → Map Translation (PyTorch)

This notebook trains a **Pix2Pix** model to translate **satellite images → map renderings** using the original **maps** dataset from the Pix2Pix paper.

**What you'll learn**
- How a **U-Net generator** + **PatchGAN discriminator** work
- How to balance **GAN loss** with **L1 reconstruction loss**
- Tips for stability, visualization, and evaluation

> Run this in a GPU runtime for best results. The dataset is ~255MB.


## 1) Setup

In [None]:

# If needed, install deps in your environment (uncomment below).
# !pip -q install torch torchvision pillow matplotlib tqdm


In [None]:

import os, random, math, time, glob, tarfile, pathlib
from io import BytesIO
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
from tqdm import tqdm

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(42)
device


## 2) Download the **maps** dataset

In [None]:

import urllib.request

root = "./data/maps"
os.makedirs(root, exist_ok=True)
archive_path = os.path.join(root, "maps.tar.gz")

if not (os.path.exists(os.path.join(root, "train")) and os.path.exists(os.path.join(root, "val"))):
    url = "http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz"
    print("Downloading:", url)
    urllib.request.urlretrieve(url, archive_path)
    print("Extracting...")
    with tarfile.open(archive_path, "r:gz") as tar:
        tar.extractall(path="./data")
    print("Done.")
else:
    print("Dataset already present.")



## 3) Dataset Loader (paired images)
Each file is a **side-by-side pair**: left half is one domain, right half is the other.  
For the **maps** dataset: **left = aerial (satellite)**, **right = map**.

If your preview looks flipped, set `flip_pair=True` below to swap halves.


In [None]:

class PairedImageDataset(Dataset):
    def __init__(self, folder, image_size=256, flip_pair=False, augment=True):
        self.paths = sorted(glob.glob(os.path.join(folder, "*.jpg")))
        self.size = image_size
        self.flip_pair = flip_pair
        self.augment = augment

        self.to_tensor = T.ToTensor()
        self.normalize = T.Normalize(mean=[0.5,0.5,0.5], std=[0.5,0.5,0.5])
        self.resize = T.Resize((self.size, self.size), interpolation=T.InterpolationMode.BICUBIC)

    def __len__(self): return len(self.paths)

    def _split_pair(self, img):
        w, h = img.size
        w2 = w // 2
        left = img.crop((0, 0, w2, h))
        right = img.crop((w2, 0, w, h))
        if self.flip_pair:
            left, right = right, left
        return left, right

    def _augment_pair(self, a, b):
        # Random horizontal flip keeps correspondence
        if self.augment and random.random() < 0.5:
            a = a.transpose(Image.FLIP_LEFT_RIGHT)
            b = b.transpose(Image.FLIP_LEFT_RIGHT)
        return a, b

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        a, b = self._split_pair(img)           # a: satellite, b: map (default for maps dataset)
        a, b = self._augment_pair(a, b)
        a = self.normalize(self.to_tensor(self.resize(a)))
        b = self.normalize(self.to_tensor(self.resize(b)))
        return a, b  # (input, target)


### Preview a few pairs

In [None]:

train_dir = "./data/maps/train"
val_dir = "./data/maps/val"

train_ds = PairedImageDataset(train_dir, image_size=256, flip_pair=False, augment=True)
val_ds = PairedImageDataset(val_dir, image_size=256, flip_pair=False, augment=False)

def denorm(x):
    return (x * 0.5 + 0.5).clamp(0,1)

def show_batch(ds, n=4):
    idxs = random.sample(range(len(ds)), n)
    rows = []
    for i in idxs:
        a, b = ds[i]
        rows.append(torch.stack([a, b]))
    grid = make_grid(torch.cat(rows, dim=0), nrow=2)
    plt.figure(figsize=(6, 6))
    plt.axis("off")
    plt.imshow(denorm(grid).permute(1,2,0))
    plt.show()

show_batch(train_ds, n=4)


## 4) Models — U-Net Generator & PatchGAN Discriminator

In [None]:

# --- Building blocks ---
class Down(nn.Module):
    def __init__(self, in_c, out_c, norm=True):
        super().__init__()
        layers = [nn.Conv2d(in_c, out_c, 4, 2, 1, bias=not norm)]
        if norm: layers += [nn.BatchNorm2d(out_c)]
        layers += [nn.LeakyReLU(0.2, inplace=True)]
        self.net = nn.Sequential(*layers)
    def forward(self, x): return self.net(x)

class Up(nn.Module):
    def __init__(self, in_c, out_c, dropout=False):
        super().__init__()
        layers = [nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False),
                  nn.BatchNorm2d(out_c),
                  nn.ReLU(True)]
        if dropout: layers += [nn.Dropout(0.5)]
        self.net = nn.Sequential(*layers)
    def forward(self, x): return self.net(x)

class UNetGenerator(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, features=64):
        super().__init__()
        # Encoder
        self.d1 = Down(in_ch, features, norm=False)   # 128
        self.d2 = Down(features, features*2)          # 64
        self.d3 = Down(features*2, features*4)        # 32
        self.d4 = Down(features*4, features*8)        # 16
        self.d5 = Down(features*8, features*8)        # 8
        self.d6 = Down(features*8, features*8)        # 4
        self.d7 = Down(features*8, features*8)        # 2
        self.d8 = Down(features*8, features*8, norm=False) # 1

        # Decoder
        self.u1 = Up(features*8, features*8, dropout=True)
        self.u2 = Up(features*16, features*8, dropout=True)
        self.u3 = Up(features*16, features*8, dropout=True)
        self.u4 = Up(features*16, features*8)
        self.u5 = Up(features*16, features*4)
        self.u6 = Up(features*8, features*2)
        self.u7 = Up(features*4, features)
        self.outc = nn.Sequential(
            nn.ConvTranspose2d(features*2, out_ch, 4, 2, 1),
            nn.Tanh()
        )
    def forward(self, x):
        d1 = self.d1(x); d2 = self.d2(d1); d3 = self.d3(d2); d4 = self.d4(d3)
        d5 = self.d5(d4); d6 = self.d6(d5); d7 = self.d7(d6); bottleneck = self.d8(d7)
        u1 = self.u1(bottleneck)
        u2 = self.u2(torch.cat([u1, d7], dim=1))
        u3 = self.u3(torch.cat([u2, d6], dim=1))
        u4 = self.u4(torch.cat([u3, d5], dim=1))
        u5 = self.u5(torch.cat([u4, d4], dim=1))
        u6 = self.u6(torch.cat([u5, d3], dim=1))
        u7 = self.u7(torch.cat([u6, d2], dim=1))
        out = self.outc(torch.cat([u7, d1], dim=1))
        return out

class PatchDiscriminator(nn.Module):
    def __init__(self, in_ch=3, cond_ch=3, features=64):
        super().__init__()
        # Input is concatenation of condition (satellite) and output/target (map)
        ch = in_ch + cond_ch
        self.net = nn.Sequential(
            nn.Conv2d(ch, features, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True),        # 128
            nn.Conv2d(features, features*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*2), nn.LeakyReLU(0.2, inplace=True),              # 64
            nn.Conv2d(features*2, features*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*4), nn.LeakyReLU(0.2, inplace=True),              # 32
            nn.Conv2d(features*4, features*8, 4, 1, 1, bias=False),
            nn.BatchNorm2d(features*8), nn.LeakyReLU(0.2, inplace=True),              # 31x31
            nn.Conv2d(features*8, 1, 4, 1, 1)  # Patch logits
        )
    def forward(self, x, cond):
        return self.net(torch.cat([x, cond], dim=1))


## 5) Training Setup

In [None]:

G = UNetGenerator().to(device)
D = PatchDiscriminator().to(device)

def init_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
        try:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        except Exception:
            pass
        if hasattr(m, "bias") and m.bias is not None:
            nn.init.constant_(m.bias.data, 0)

G.apply(init_weights); D.apply(init_weights);

lr = 2e-4
beta1, beta2 = 0.5, 0.999
lambda_L1 = 100.0
epochs = 40
batch_size = 8

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=4, shuffle=False)

bce = nn.BCEWithLogitsLoss()
l1 = nn.L1Loss()
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))


## 6) Training Loop

In [None]:

def sample_and_show(G, ds, n=4, title='Samples'):
    G.eval()
    with torch.no_grad():
        idxs = random.sample(range(len(ds)), n)
        fakes = []
        rows = []
        for i in idxs:
            a, b = ds[i]
            a_in = a.unsqueeze(0).to(device)
            fake = G(a_in).cpu().squeeze(0)
            rows.append(torch.stack([a, fake, b]))
        grid = make_grid(torch.cat(rows, dim=0), nrow=3)
        plt.figure(figsize=(9, 9))
        plt.axis('off')
        plt.title(title)
        plt.imshow((grid*0.5+0.5).clamp(0,1).permute(1,2,0))
        plt.show()
    G.train()

for epoch in range(1, epochs+1):
    G.train(); D.train()
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")
    for a, b in pbar:  # a: satellite (cond), b: map (target)
        a, b = a.to(device), b.to(device)
        # --- Train Discriminator ---
        with torch.no_grad():
            fake_b = G(a)
        real_logits = D(b, a)
        fake_logits = D(fake_b.detach(), a)
        loss_D = bce(real_logits, torch.ones_like(real_logits)) +                  bce(fake_logits, torch.zeros_like(fake_logits))
        opt_D.zero_grad(); loss_D.backward(); opt_D.step()

        # --- Train Generator ---
        fake_b = G(a)
        fake_logits = D(fake_b, a)
        loss_G_adv = bce(fake_logits, torch.ones_like(fake_logits))
        loss_G_l1 = l1(fake_b, b) * lambda_L1
        loss_G = loss_G_adv + loss_G_l1
        opt_G.zero_grad(); loss_G.backward(); opt_G.step()

        pbar.set_postfix(loss_D=float(loss_D.item()), loss_G=float(loss_G.item()))

    # Show progress each epoch
    sample_and_show(G, val_ds, n=4, title=f"Epoch {epoch}")

# Save models
torch.save(G.state_dict(), "pix2pix_sat2map_G.pt")
torch.save(D.state_dict(), "pix2pix_sat2map_D.pt")
print("Saved pix2pix models.")



## 7) Tips, Troubleshooting, and Extensions

- **If outputs look flipped**: set `flip_pair=True` in `PairedImageDataset`.
- **Stability**:
  - Reduce learning rate to `1e-4` if training oscillates.
  - Add instance noise (tiny Gaussian) to inputs of D.
  - Try one-sided label smoothing for real labels.
- **Higher resolution**: train on 512×512 tiles (requires more VRAM).
- **Metrics**: Compute FID on held-out tiles for objective comparison.
- **Other domains**: Swap dataset for your own paired tiles from QGIS/ArcGIS.
