<a href="https://colab.research.google.com/github/fudw/satellite-imagery-to-maps/blob/main/satellite-2-maps-cycleGAN-pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Satellite2Maps
### Automated creation of maps from satellite imagery and aerial sensor data with CycleGAN
<br/>

In this project, I develop a data pipeline that takes in satellite images and outputs maps using a CycleGAN model based on [*Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks*](https://arxiv.org/abs/1703.10593) by Zhu et al. (2017). 

The model will learn to convert satellite images to maps and vice versa, by training on public datasets in the two domains.

In [None]:
from google.colab import output

# download and unzip data
!gdown --id 1GSNhusWi-GXn4bOkymluer1Un7_YDBc9
!mkdir data
!unzip satellite-2-map-dataset-kaggle.zip -d data
output.clear()
print('Data downloaded!')

In [None]:
!pip install wandb

# import libraries
import os
import numpy as np
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
import seaborn as sns
import wandb

torch.manual_seed(9)

output.clear()
print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))

In [None]:
wandb.login()

## II. Build CycleGAN Model

In [None]:
class ResidualBlock(nn.Module):
    '''
    ResidualBlock class.
    '''
    def __init__(self, input_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, padding_mode='reflect')
        self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, padding_mode='reflect')
        self.instancenorm = nn.InstanceNorm2d(input_channels)
        self.activation = nn.ReLU()

    def forward(self, x):
        original_x = x.clone()
        x = self.conv1(x)
        x = self.instancenorm(x)
        x = self.activation(x)
        x = self.conv2(x)
        x = self.instancenorm(x)
        return original_x + x

In [None]:
class ContractingBlock(nn.Module):
    '''
    ContractingBlock class.
    '''
    def __init__(self, input_channels, use_in=True, kernel_size=3, activation='relu'):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=kernel_size, padding=1, stride=2, padding_mode='reflect')
        self.activation = nn.ReLU()
        if use_in:
            self.instancenorm = nn.InstanceNorm2d(input_channels * 2)
        self.use_in = use_in

    def forward(self, x):
        x = self.conv1(x)
        if self.use_in:
            x = self.instancenorm(x)
        x = self.activation(x)
        return x

In [None]:
class ExpandingBlock(nn.Module):
    '''
    ExpandingBlock class.
    '''
    def __init__(self, input_channels, use_in=True):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(input_channels, input_channels // 2, kernel_size=3, padding=1, stride=2, output_padding=1)
        if use_in:
            self.instancenorm = nn.InstanceNorm2d(input_channels // 2)
        self.use_in = use_in
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        if self.use_in:
            x = self.instancenorm(x)
        x = self.activation(x)
        return x

In [None]:
class FeatureMapBlock(nn.Module):
    '''
    FeatureMapBlock class.
    '''
    def __init__(self, input_channels, output_channels):
        super().__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=7, padding=3, padding_mode='reflect')
    
    def forward(self, x):
        x = self.conv(x)
        return x

In [None]:
class Generator(nn.Module):
    '''
    Generator class.
    '''
    def __init__(self, input_channels, output_channels, hidden_channels=64):
        super().__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        res_mult = 4
        self.res1 = ResidualBlock(hidden_channels * res_mult)
        self.res2 = ResidualBlock(hidden_channels * res_mult)
        self.res3 = ResidualBlock(hidden_channels * res_mult)
        self.res4 = ResidualBlock(hidden_channels * res_mult)
        self.res5 = ResidualBlock(hidden_channels * res_mult)
        self.res6 = ResidualBlock(hidden_channels * res_mult)
        self.res7 = ResidualBlock(hidden_channels * res_mult)
        self.res8 = ResidualBlock(hidden_channels * res_mult)
        self.res9 = ResidualBlock(hidden_channels * res_mult)
        self.expand1 = ExpandingBlock(hidden_channels * res_mult)
        self.expand2 = ExpandingBlock(hidden_channels * res_mult // 2)
        self.downfeature = FeatureMapBlock(hidden_channels * res_mult // 4, output_channels)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.upfeature(x)
        x = self.contract1(x)
        x = self.contract2(x)
        x = self.res1(x)
        x = self.res2(x)
        x = self.res3(x)
        x = self.res4(x)
        x = self.res5(x)
        x = self.res6(x)
        x = self.res7(x)
        x = self.res8(x)
        x = self.res9(x)
        x = self.expand1(x)
        x = self.expand2(x)
        x = self.downfeature(x)
        x = self.tanh(x)
        return x

In [None]:
class Discriminator(nn.Module):
    '''
    Discriminator class.
    '''
    def __init__(self, input_channels, hidden_channels=64):
        super().__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_in=False, kernel_size=4, activation='lrelu')
        self.contract2 = ContractingBlock(hidden_channels * 2, kernel_size=4, activation='lrelu')
        self.contract3 = ContractingBlock(hidden_channels * 4, kernel_size=4, activation='lrelu')
        
