<a href="https://colab.research.google.com/github/fudw/satellite-imagery-to-maps/blob/main/satellite-2-maps-cycleGAN-pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Satellite2Maps
### Automated creation of maps from satellite imagery and aerial sensor data with CycleGAN
<br/>

In this project, I develop a data pipeline that takes in satellite images and outputs maps using a CycleGAN model based on [*Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks*](https://arxiv.org/abs/1703.10593) by Zhu et al. (2017). 

The model will learn to convert satellite images to maps and vice versa, by training on public datasets in the two domains.

In [None]:
from google.colab import output

# download and unzip data
!gdown --id 1GSNhusWi-GXn4bOkymluer1Un7_YDBc9
!mkdir data
!unzip satellite-2-map-dataset-kaggle.zip -d data
output.clear()
print('Data downloaded!')

In [None]:
!pip install wandb

# import libraries
import os
import numpy as np
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
import seaborn as sns
import wandb

torch.manual_seed(9)

output.clear()
print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))

In [None]:
wandb.login()

## II. Build CycleGAN Model

In [None]:
class ResidualBlock(nn.Module):
    '''
    ResidualBlock class.
    '''
    def __init__(self, input_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, padding_mode='reflect')
        self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, padding_mode='reflect')
        self.instancenorm = nn.InstanceNorm2d(input_channels)
        self.activation = nn.ReLU()

    def forward(self, x):
        original_x = x.clone()
        x = self.conv1(x)
        x = self.instancenorm(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.instancenorm(x)
        return original_x + x

In [None]:
class ContractingBlock(nn.Module):
    '''
    ContractingBlock class.
    '''
    def __init__(self, input_channels, use_in=True, kernel_size=3, activation='relu'):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=kernel_size, padding=1, stride=2, padding_mode='reflect')
        self.activation = nn.ReLU()
        if use_in:
            self.instancenorm = nn.InstanceNorm2d(input_channels * 2)
        self.use_in = use_in

    def forward(self, x):
        x = self.conv1(x)
        if self.use_in:
            x = self.instancenorm(x)
        x = self.activation(x)
        return x

In [None]:
class ExpandingBlock(nn.Module):
    '''
    ExpandingBlock class.
    '''
    def __init__(self, input_channels, use_in=True):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(input_channels, input_channels // 2, kernel_size=3, padding=1, stride=2, output_padding=1)
        if use_in:
            self.instancenorm = nn.InstanceNorm2d(input_channels // 2)
        self.use_in = use_in
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        if self.use_in:
            x = self.instancenorm(x)
        x = self.activation(x)
        return x

In [None]:
class FeatureMapBlock(nn.Module):
    '''
    FeatureMapBlock class.
    '''
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=7, padding=3, padding_mode='reflect')
    
    def forward(self, x):
        x = self.conv(x)
        return x

In [None]:
class Generator(nn.Module):
    '''
    Generator class.
    '''
    def __init__(self, input_channels, output_channels, hidden_channels=64):
        super().__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        res_mult = 4
        self.res1 = ResidualBlock(hidden_channels * res_mult)
        self.res2 = ResidualBlock(hidden_channels * res_mult)
        self.res3 = ResidualBlock(hidden_channels * res_mult)
        self.res4 = ResidualBlock(hidden_channels * res_mult)
        self.res5 = ResidualBlock(hidden_channels * res_mult)
        self.res6 = ResidualBlock(hidden_channels * res_mult)
        self.res7 = ResidualBlock(hidden_channels * res_mult)
        self.res8 = ResidualBlock(hidden_channels * res_mult)
        self.res9 = ResidualBlock(hidden_channels * res_mult)
        self.expand1 = ExpandingBlock(hidden_channels * res_mult)
        self.expand2 = ExpandingBlock(hidden_channels * res_mult // 2)
        self.downfeature = FeatureMapBlock(hidden_channels * res_mult // 4, output_channels)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.upfeature(x)
        x = self.contract1(x)
        x = self.contract2(x)
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = self.res5(x)
        x = self.res6(x)
        x = self.res7(x)
        x = self.res8(x)
        x = self.res9(x)
        x = self.expand1(x)
        x = self.expand2(x)
        x = self.downfeature(x)
        x = self.tanh(x)
        return x

In [None]:
class Discriminator(nn.Module):
    '''
    Discriminator class.
    '''
    def __init__(self, input_channels, hidden_channels=64):
        super().__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_in=False, kernel_size=4, activation='lrelu')
        self.contract2 = ContractingBlock(hidden_channels * 2, kernel_size=4, activation='lrelu')
        self.contract3 = ContractingBlock(hidden_channels * 4, kernel_size=4, activation='lrelu')
        self.final = nn.Conv2d(hidden_channels * 8, 1, kernel_size=1)

    def forward(self, x):
        x = self.upfeature(x)
        x = self.contract1(x)
        x = self.contract2(x)
        x = self.contract3(x)
        x = self.final(x)
        return x

In [None]:
adv_criterion = nn.MSELoss()
recon_criterion = nn.L1Loss()

n_epochs = 100
dim_A = 3
dim_B = 3
display_step = 200
batch_size = 1
load_shape = 300
target_shape = 256
device = 'cuda'

In [None]:
transform = transforms.Compose([transforms.Resize(load_shape),
                                transforms.RandomCrop(target_shape),
                                transforms.RandomHorizontalFlip(),
                                transforms.ToTensor(),
                                ])

dataset = ImageDataset('maps', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
gen_AB = Generator(dim_A, dim_B).to(device)
gen_BA = Generator(dim_B, dim_A).to(device)
gen_opt = torch.optim.Adam(list(gen_AB.parameters()) + list(gen_BA.parameters()), lr=lr, betas=(0.5, 0.999))
disc_A = Discriminator(dim_A).to(device)
disc_A_opt = torch.optim.Adam(list(disc_A.parameters()), lr=lr, betas=(0.5, 0.999))
disc_B = Discriminator(dim_B).to(device)
disc_B_opt = torch.optim.Adam(list(disc_B.parameters()), lr=lr, betas=(0.5, 0.999))

In [None]:
def get_disc_loss(real_X, fake_X, disc_X, adv_criterion):
    real_Y = disc_X(real_X)
    real_loss = adv_criterion(real_Y, torch.ones_like(real_Y))
    fake_Y = disc_X(fake_X.detach())
    fake_loss = adv_criterion(fake_Y, torch.zeros_like(fake_Y))
    disc_loss = (real_loss + fake_loss) / 2
    return disc_loss

In [None]:
def get_gen_adversarial_loss(real_X, disc_Y, gen_XY, adv_criterion):
    fake_Y = gen_XY(real_X)
    pred = disc_Y(fake_Y)
    adversarial_loss = adv_criterion(pred, torch.ones_like(pred))
    return adversarial_loss, fake_Y

In [None]:
def get_identity_loss(real_X, gen_YX, identity_criterion):
    identity_X = gen_YX(real_X)
    identity_loss = identity_criterion(identity_X, real_X)
    return identity_loss, identity_X

In [None]:
def get_cycle_consistency_loss(real_X, fake_Y, gen_YX, cycle_criterion):
    cycle_X = gen_YX(fake_Y)
    cycle_loss = cycle_criterion(real_X, cycle_X)
    return cycle_loss, cycle_X

In [None]:
def get_gen_loss(real_A, real_B, gen_AB, gen_BA, disc_A, disc_B, adv_criterion, identity_criterion, cycle_criterion, lambda_identity=0.1, lambda_cycle=10):
    adv_loss_AB, fake_B = get_adversarial_loss(real_A, disc_B, gen_AB, adv_criterion)
    adv_loss_BA, fake_A = get_adversarial_loss(real_B, disc_A, gen_BA, adv_criterion)
    identity_loss_AA, identity_A = get_identity_loss(real_A, gen_BA, identity_criterion)
    identity_loss_BB, identity_B = get_identity_loss(real_B, gen_AB, identity_criterion)
    cycle_loss_AA, cycle_A = get_cycle_consistency_loss(real_A, fake_B, gen_BA, cycle_criterion)
    cycle_loss_BB, cycle_B = get_cycle_consistency_loss(real_B, fake_A, gen_AB, cycle_criterion)
    total_adv_loss = adv_loss_AB + adv_loss_BA
    total_identity_loss = identity_loss_AA + identity_loss_BB
    total_cycle_loss = cycle_loss_AA + cycle_loss_BB
    gen_loss = total_adv_loss + lambda_identity * total_identity_loss + lambda_cycle * total_cycle_loss
    return gen_loss, fake_A, fake_B

In [None]:
def train(save_model=False):
    mean_gen_loss = 0
    mean_disc_loss = 0
    current_step = 0

    for epoch in range(n_epochs):

        for real_A, real_B in tqdm(dataloader):

            real_A = nn.functional.interpolate(real_A, size=target_shape)
            real_B = nn.functional.interpolate(real_B, size=target_shape)
            current_batch_size = len(real_A)
            real_A = real_A.to(device)
            real_B = real_B.to(device)

            disc_A_opt.zero_grad()
            with torch.no_grad():
                fake_A = gen_BA(real_B)
            disc_A_loss = get_disc_loss(real_A, fake_A, disc_A, adv_criterion)
            disc_A_loss.backward(retain_graph=True)
            disc_A_opt.step()

            disc_B_opt.zero_grad()
            with torch.no_grad():
                fake_B = gen_AB(real_A)
            disc_B_loss = gen_disc_loss(real_B, fake_B, disc_B, adv_criterion)
            disc_B_loss.backward(retian_graph=True)
            disc_B_opt.step()