<a href="https://colab.research.google.com/github/fudw/satellite-imagery-to-maps/blob/main/satellite-2-maps-cycleGAN-pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Satellite2Maps
### Automated creation of maps from satellite imagery and aerial sensor data with CycleGAN
<br/>

In this project, I develop a data pipeline that takes in satellite images and outputs maps using a CycleGAN model based on [*Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks*](https://arxiv.org/abs/1703.10593) by Zhu et al. (2017). 

The model will learn to convert satellite images to maps and vice versa, by training on public datasets in the two domains.

In [None]:
from google.colab import output

# download and unzip data
!gdown --id 1GSNhusWi-GXn4bOkymluer1Un7_YDBc9
!mkdir data
!unzip satellite-2-map-dataset-kaggle.zip -d data
output.clear()
print('Data downloaded!')

In [None]:
!pip install wandb

# import libraries
import os
import numpy as np
import torch
from torch import nn
import torchvision
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
import seaborn as sns
import wandb

torch.manual_seed(9)

output.clear()
print('Setup complete. Using torch %s %s' % (torch.__version__, torch.cuda.get_device_properties(0) if torch.cuda.is_available() else 'CPU'))

In [None]:
wandb.login()

## II. Build CycleGAN Model

In [None]:
class ResidualBlock(nn.Module):
    '''
    ResidualBlock class.
    '''
    def __init__(self, input_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, padding_mode='reflect')
        self.conv2 = nn.Conv2d(input_channels, input_channels, kernel_size=3, padding=1, padding_mode='reflect')
        self.instancenorm = nn.InstanceNorm2d(input_channels)
        self.activation = nn.ReLU()

    def forward(self, x):
        
