### Importing Libraries

In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import VOCSegmentation
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

### U-Net Code

In [2]:
def crop(image, new_shape):
    '''
    Function for cropping an image tensor: Given an image tensor and the new shape,
    crops to the center pixels (assumes that the input's size and the new size are
    even numbers).
    Parameters:
        image: image tensor of shape (batch size, channels, height, width)
        new_shape: a torch.Size object with the shape you want x to have
    '''
    middle_height = image.shape[2] // 2
    middle_width = image.shape[3] // 2
    starting_height = middle_height - new_shape[2] // 2
    final_height = starting_height + new_shape[2]
    starting_width = middle_width - new_shape[3] // 2
    final_width = starting_width + new_shape[3]
    
    cropped_image = image[:, :, starting_height: final_height, starting_width: final_width]
    
    return cropped_image

class ContractingBlock(nn.Module):
    
    def __init__(self, input_channels, use_dropout=False, use_bn=True):
        
        super(ContractingBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(input_channels * 2, input_channels * 2, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        if use_bn:
            self.batchnorm = nn.BatchNorm2d(input_channels * 2)
        self.use_bn = use_bn
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout
        
    def forward(self, x):
        
        x = self.conv1(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.maxpool(x)
        return x
    
class ExpandingBlock(nn.Module):
    
    def __init__(self, input_channels, use_dropout=False, use_bn=False):
        
        super(ExpandingBlock, self).__init__()
        
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv1 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=2)
        self.conv2 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=3)
        self.conv3 = nn.Conv2d(input_channels // 2, input_channels //2, kernel_size=3)
        if use_bn:
            self.batchnorm = nn.BatchNorm2d(input_channels // 2)
        self.use_bn = use_bn
        self.activation = nn.ReLU()
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout
        
    def forward(self, x, skip_con_x):
        
        x = self.upsample(x)
        x = self.conv1(x)
        skip_con_x = crop(x, skip_con_x)
        x = torch.cat([x, skip_con_x], axis=1)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x
    
class FeatureMapBlock(nn.Module):
    
    def __init__(self, input_channels, output_channels):
        
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)
        
    def forward(self, x):
        
        x = self.conv(x)
        return x
    
class UNet(nn.Module):
    
    def __init__(self, input_channels, output_channels, hidden_channels=32):
        
        super(UNet, self).__init__()
        
        self.upfeature = self.FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = self.ContractingBlock(hidden_channels, use_dropout=True)
        self.contract2 = self.ContractingBlock(hidden_channels * 2, use_dropout=True)
        self.contract3 = self.ContractingBlock(hidden_channels * 4, use_dropout=True)
        self.contract4 = self.ContractingBlock(hidden_channels * 8)
        self.contract5 = self.ContractingBlock(hidden_channels * 16)
        self.contract6 = self.ContractingBlock(hidden_channels * 32)
        self.expand0 = self.ExpandingBlock(hidden_channels * 64)
        self.expand1 = self.ExpandingBlock(hidden_channels * 32)
        self.expand2 = self.ExpandingBlock(hidden_channels * 16)
        self.expand3 = self.ExpandingBlock(hidden_channels * 8)
        self.expand4 = self.ExpandingBlock(hidden_channels * 4)
        self.expand5 = self.ExpandingBlock(hidden_channels * 2)
        self.downfeature = self.FeatureMapBlock(hidden_channels, output_channels)
        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self, x):
        
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        x5 = self.contract5(x4)
        x6 = self.contract6(x5)
        x7 = self.expand0(x6, x5)
        x8 = self.expand1(x7, x4)
        x9 = self.expand2(x8, x3)
        x10 = self.expand3(x9, x2)
        x11 = self.expand4(x10, x1)
        x12 = self.expand5(x11, x0)
        xn = self.downfeature(x12)
        return self.sigmoid(xn)

### PatchGAN Discriminator

In [3]:
class Discriminator(nn.Module):
    
    '''
    Discriminator Class
    Structured like the contracting path of the U-Net, the discriminator will
    output a matrix of values classifying corresponding portions of the image as real or fake. 
    Parameters:
        input_channels: the number of image input channels
        hidden_channels: the initial number of discriminator convolutional filters
    '''
    
    def __init__(self, input_channels, hidden_channels=8):
        super(Discriminator, self).__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_bn=False)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        self.contract3 = ContractingBlock(hidden_channels * 4)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        self.final = nn.Conv2d(hidden_channels * 16, 1, kernel_size=1)
        
    def forward(self, x, y):
        
        x = torch.cat([x, y], axis=1)
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        xn = self.final(x4)
        return xn

In [7]:
# UNIT TEST
test_discriminator = Discriminator(10, 1)
assert tuple(test_discriminator(
    torch.randn(1, 5, 256, 256), 
    torch.randn(1, 5, 256, 256)
).shape) == (1, 1, 16, 16)
print("Success!")

Success!


### Training Preparation

In [5]:
import torch.nn.functional as F

adv_criterion = nn.BCEWithLogitsLoss()
recon_criterion = nn.L1Loss()
lambda_recon = 200

n_epochs = 20
input_dim = 3
real_dim = 3
display_step = 200
batch_size = 4
lr = 0.0002
target_shape = 256
device = 'cuda'

In [10]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

import torchvision

dataset = torchvision.datasets.ImageFolder(root="maps")

FileNotFoundError: [WinError 3] The system cannot find the path specified: 'maps'