# Pix2pix Baseline Model

In [1]:
# set GPU 
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Sun Dec  6 21:19:24 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.45.01    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P0    23W / 300W |      0MiB / 16130MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
# Reference: https://zhangruochi.com/Pix2Pix/2020/11/09/
# find path for dataset
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [3]:
cd gdrive/MyDrive/

/content/gdrive/MyDrive


In [4]:
ls

 [0m[01;34mArchive[0m/
 [01;34mclassifier[0m/
 ClassifierExperiments.ipynb
[01;34m'Colab Notebooks'[0m/
 [01;34mData[0m/
[01;34m'Data Outputs'[0m/
[01;34m'Data processing and Augmentation'[0m/
 edges2shoes.tar.gz
 [01;34mExperiments[0m/
 pix2pix_11000.pth
 pix2pix_13750.pth
 pix2pix_16500.pth
 pix2pix_19250.pth
 pix2pix_22000.pth
 pix2pix_24750.pth
 pix2pix_27500.pth
 pix2pix_2750.pth
 pix2pix_5500.pth
 pix2pix_8250.pth
 pix2pixlabelsupervised_11000.pth
 pix2pixlabelsupervised_13750.pth
 pix2pixlabelsupervised_16500.pth
 pix2pixlabelsupervised_19250.pth
 pix2pixlabelsupervised_22000.pth
 pix2pixlabelsupervised_24750.pth
 pix2pixlabelsupervised_27500.pth
 pix2pixlabelsupervised_2750.pth
 pix2pixlabelsupervised_5500.pth
 pix2pixlabelsupervised_8250.pth
 pix2pixlabelsupervisedprogressivetraining_11000.pth
 pix2pixlabelsupervisedprogressivetraining_13750.pth
 pix2pixlabelsupervisedprogressivetraining_16500.pth
 pix2pixlabelsupervisedprogressivetraining_19250.pth
 pix2pixlabe

In [5]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torch.utils.data import random_split, ConcatDataset
import matplotlib.pyplot as plt
torch.manual_seed(0)

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

# Baseline Model (Unet code）

In [6]:
def crop(image, new_shape):
    '''
    Function for cropping an image tensor: Given an image tensor and the new shape,
    crops to the center pixels.
    Parameters:
        image: image tensor of shape (batch size, channels, height, width)
        new_shape: a torch.Size object with the shape you want x to have
    '''
    middle_height = image.shape[2] // 2
    middle_width = image.shape[3] // 2
    starting_height = middle_height - round(new_shape[2] / 2)
    final_height = starting_height + new_shape[2]
    starting_width = middle_width - round(new_shape[3] / 2)
    final_width = starting_width + new_shape[3]
    cropped_image = image[:, :, starting_height:final_height, starting_width:final_width]
    return cropped_image

class ContractingBlock(nn.Module):
    '''
    ContractingBlock Class
    Performs two convolutions followed by a max pool operation.
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels, use_dropout=False, use_bn=True):
        super(ContractingBlock, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, input_channels * 2, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(input_channels * 2, input_channels * 2, kernel_size=3, padding=1)
        self.activation = nn.LeakyReLU(0.2)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        if use_bn:
            self.batchnorm = nn.BatchNorm2d(input_channels * 2)
        self.use_bn = use_bn
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout

    def forward(self, x):
        '''
        Function for completing a forward pass of ContractingBlock: 
        Given an image tensor, completes a contracting block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv1(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.maxpool(x)
        return x

class ExpandingBlock(nn.Module):
    '''
    ExpandingBlock Class:
    Performs an upsampling, a convolution, a concatenation of its two inputs,
    followed by two more convolutions with optional dropout
    Values:
        input_channels: the number of channels to expect from a given input
    '''
    def __init__(self, input_channels, use_dropout=False, use_bn=True):
        super(ExpandingBlock, self).__init__()
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv1 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=2)
        self.conv2 = nn.Conv2d(input_channels, input_channels // 2, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(input_channels // 2, input_channels // 2, kernel_size=2, padding=1)
        if use_bn:
            self.batchnorm = nn.BatchNorm2d(input_channels // 2)
        self.use_bn = use_bn
        self.activation = nn.ReLU()
        if use_dropout:
            self.dropout = nn.Dropout()
        self.use_dropout = use_dropout

    def forward(self, x, skip_con_x):
        '''
        Function for completing a forward pass of ExpandingBlock: 
        Given an image tensor, completes an expanding block and returns the transformed tensor.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
            skip_con_x: the image tensor from the contracting path (from the opposing block of x)
                    for the skip connection
        '''
        x = self.upsample(x)
        x = self.conv1(x)
        skip_con_x = crop(skip_con_x, x.shape)
        x = torch.cat([x, skip_con_x], axis=1)
        x = self.conv2(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        x = self.conv3(x)
        if self.use_bn:
            x = self.batchnorm(x)
        if self.use_dropout:
            x = self.dropout(x)
        x = self.activation(x)
        return x

class FeatureMapBlock(nn.Module):
    '''
    FeatureMapBlock Class
    The final layer of a U-Net - 
    maps each pixel to a pixel with the correct number of output dimensions
    using a 1x1 convolution.
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels):
        super(FeatureMapBlock, self).__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1)

    def forward(self, x):
        '''
        Function for completing a forward pass of FeatureMapBlock: 
        Given an image tensor, returns it mapped to the desired number of channels.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        x = self.conv(x)
        return x

class UNet(nn.Module):
    '''
    UNet Class
    A series of 4 contracting blocks followed by 4 expanding blocks to 
    transform an input image into the corresponding paired image, with an upfeature
    layer at the start and a downfeature layer at the end.
    Values:
        input_channels: the number of channels to expect from a given input
        output_channels: the number of channels to expect for a given output
    '''
    def __init__(self, input_channels, output_channels, hidden_channels=32):
        super(UNet, self).__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_dropout=True)
        self.contract2 = ContractingBlock(hidden_channels * 2, use_dropout=True)
        self.contract3 = ContractingBlock(hidden_channels * 4, use_dropout=True)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        self.contract5 = ContractingBlock(hidden_channels * 16)
        # self.contract6 = ContractingBlock(hidden_channels * 32)
        # self.expand0 = ExpandingBlock(hidden_channels * 64)
        self.expand1 = ExpandingBlock(hidden_channels * 32)
        self.expand2 = ExpandingBlock(hidden_channels * 16)
        self.expand3 = ExpandingBlock(hidden_channels * 8)
        self.expand4 = ExpandingBlock(hidden_channels * 4)
        self.expand5 = ExpandingBlock(hidden_channels * 2)
        self.downfeature = FeatureMapBlock(hidden_channels, output_channels)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        '''
        Function for completing a forward pass of UNet: 
        Given an image tensor, passes it through U-Net and returns the output.
        Parameters:
            x: image tensor of shape (batch size, channels, height, width)
        '''
        # print(x.size())
        x0 = self.upfeature(x)
        # print(x0.size())
        x1 = self.contract1(x0)
        # print(x1.size())
        x2 = self.contract2(x1)
        # print(x2.size())
        x3 = self.contract3(x2)
        # print(x3.size())
        x4 = self.contract4(x3)
        # print(x4.size())
        x5 = self.contract5(x4)
        # print(x5.size())
        # x6 = self.contract6(x5)
        # x7 = self.expand0(x6, x5)
        x6 = self.expand1(x5, x4)
        # print(x6.size())
        x7 = self.expand2(x6, x3)
        # print(x7.size())
        x8 = self.expand3(x7, x2)
        # print(x8.size())
        x9 = self.expand4(x8, x1)
        # print(x9.size())
        x10 = self.expand5(x9, x0)
        # print(x10.size())
        xn = self.downfeature(x10)
        # print(xn.size())
        return self.sigmoid(xn)

Patch GAN discriminator


In [7]:
# GRADED CLASS: Discriminator
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Structured like the contracting path of the U-Net, the discriminator will
    output a matrix of values classifying corresponding portions of the image as real or fake. 
    Parameters:
        input_channels: the number of image input channels
        hidden_channels: the initial number of discriminator convolutional filters
    '''
    def __init__(self, input_channels, hidden_channels=8):
        super(Discriminator, self).__init__()
        self.upfeature = FeatureMapBlock(input_channels, hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels, use_bn=False)
        self.contract2 = ContractingBlock(hidden_channels * 2)
        self.contract3 = ContractingBlock(hidden_channels * 4)
        self.contract4 = ContractingBlock(hidden_channels * 8)
        self.final = nn.Conv2d(hidden_channels * 16, 1, kernel_size=1)

    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.upfeature(x)
        x1 = self.contract1(x0)
        x2 = self.contract2(x1)
        x3 = self.contract3(x2)
        x4 = self.contract4(x3)
        xn = self.final(x4)
        return xn


Training Preparation

In [8]:
import torch.nn.functional as F
# New parameters
adv_criterion = nn.BCEWithLogitsLoss() 
recon_criterion = nn.L1Loss() 
lambda_recon = 200

n_epochs = 50
input_dim = 3
real_dim = 3
display_step = 200
batch_size = 4
lr = 0.0002
target_shape = 96
device = 'cuda'

In [23]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

import torchvision
stl10data =  torchvision.datasets.ImageFolder("Data/STL-10/Final Concatenated/", transform=transform)
sketchydata = torchvision.datasets.ImageFolder("Data/sketchy-database/Final Concatenated/", transform=transform)
stl10train, stl10val, stl10test = random_split(stl10data, [2200, 125, 50], generator=torch.Generator().manual_seed(0)) # change to 2200, 175, 125 after
sketchytrain, sketchyval, sketchytest = random_split(sketchydata, [2200, 175, 125], generator=torch.Generator().manual_seed(0))
trainData = ConcatDataset([stl10train, sketchytrain])
valData = ConcatDataset([stl10val, sketchyval])
testDat = ConcatDataset([stl10test, sketchytest])

In [20]:
gen = UNet(input_dim, real_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(input_dim + real_dim).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

pretrained = False
if pretrained:
    loaded_state = torch.load("pix2pix_35000.pth")
    gen.load_state_dict(loaded_state["gen"])
    gen_opt.load_state_dict(loaded_state["gen_opt"])
    disc.load_state_dict(loaded_state["disc"])
    disc_opt.load_state_dict(loaded_state["disc_opt"])
else:
    gen = gen.apply(weights_init)
    disc = disc.apply(weights_init)
from torchvision.utils import save_image


In [21]:
def get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon):
    '''
    Return the loss of the generator given inputs.
    Parameters:
        gen: the generator; takes the condition and returns potential images
        disc: the discriminator; takes images and the condition and
          returns real/fake prediction matrices
        real: the real images (e.g. maps) to be used to evaluate the reconstruction
        condition: the source images (e.g. satellite imagery) which are used to produce the real images
        adv_criterion: the adversarial loss function; takes the discriminator 
                  predictions and the true labels and returns a adversarial 
                  loss (which you aim to minimize)
        recon_criterion: the reconstruction loss function; takes the generator 
                    outputs and the real images and returns a reconstructuion 
                    loss (which you aim to minimize)
        lambda_recon: the degree to which the reconstruction loss should be weighted in the sum
    '''
    # Steps: 1) Generate the fake images, based on the conditions.
    #        2) Evaluate the fake images and the condition with the discriminator.
    #        3) Calculate the adversarial and reconstruction losses.
    #        4) Add the two losses, weighting the reconstruction loss appropriately.
    gen_img = gen(condition)
    out = disc(gen_img, condition)
    adv_loss = adv_criterion(out, torch.ones_like(out))
    recon_loss = recon_criterion(gen_img, real)
    gen_loss = adv_loss + lambda_recon * recon_loss
    return gen_loss

Pix2Pix Training

In [24]:
from skimage import color
import numpy as np
import matplotlib.pyplot as plt
import math
from torchvision.utils import save_image



def train(save_model=False):
    mean_generator_loss = 0
    mean_discriminator_loss = 0
    val_mean_generator_loss = 0
    val_mean_discriminator_loss = 0
    cur_step = 0
    curStepCount = []
    genLossCount = []
    disLossCount = []
    valGenLossCount = []
    valDisLossCount = []
    accCount = []
    valAccCount = []
    epochCount = []
    dataloader = DataLoader(trainData, batch_size=batch_size, shuffle=True)
    valDataLoader = DataLoader(valData, batch_size=batch_size, shuffle=True)
  
    for epoch in range(n_epochs):
        epochCount += [epoch]
        # Dataloader returns the batches
        acc = 0
        for image, _ in tqdm(dataloader):
            image_width = image.shape[3]
            condition = image[:, :, :, :image_width // 2]
            condition = nn.functional.interpolate(condition, size=target_shape)
            # print(condition.size())
            real = image[:, :, :, image_width // 2:]
            real = nn.functional.interpolate(real, size=target_shape)
            cur_batch_size = len(condition)
            condition = condition.to(device)
            real = real.to(device)

            ### Update discriminator ###
            disc_opt.zero_grad() # Zero out the gradient before backpropagation
            with torch.no_grad():
                fake = gen(condition)
            # print(fake.size())
            disc_fake_hat = disc(fake.detach(), condition) # Detach generator
            disc_fake_loss = adv_criterion(disc_fake_hat, torch.zeros_like(disc_fake_hat))
            disc_real_hat = disc(real, condition)
            # print(disc_real_hat)
            # print(disc_real_hat.size())
            # print(torch.mean(disc_real_hat,[1,2,3]).cpu().numpy())

                
            disc_real_loss = adv_criterion(disc_real_hat, torch.ones_like(disc_real_hat))
            disc_loss = (disc_fake_loss + disc_real_loss) / 2
            disc_loss.backward(retain_graph=True) # Update gradients
            disc_opt.step() # Update optimizer

            ### Update generator ###
            gen_opt.zero_grad()
            gen_loss = get_gen_loss(gen, disc, real, condition, adv_criterion, recon_criterion, lambda_recon)
            gen_loss.backward() # Update gradients
            gen_opt.step() # Update optimizer

            ### Evaluate discriminator accuracy
            for i in torch.mean(disc_real_hat,[1,2,3]):
              if i > 0.5:
                acc += 1
            for i in torch.mean(disc_fake_hat,[1,2,3]):
              if i < 0.5:
                acc += 1

            # Keep track of the average discriminator loss
            mean_discriminator_loss += disc_loss.item() / display_step
            # Keep track of the average generator loss
            mean_generator_loss += gen_loss.item() / display_step

            ### Visualization code ###
            if cur_step % display_step == 0:
                if cur_step > 0:
                    print(f"Epoch {epoch}: Step {cur_step}: G train loss: {mean_generator_loss}, D train loss: {mean_discriminator_loss}")
                else:
                    print("Pretrained initial state")
                genLossCount += [mean_generator_loss]
                disLossCount += [mean_discriminator_loss]
                curStepCount += [cur_step]          
                mean_generator_loss = 0
                mean_discriminator_loss = 0
                show_tensor_images(condition, size=(input_dim, target_shape, target_shape))
                show_tensor_images(real, size=(real_dim, target_shape, target_shape))
                show_tensor_images(fake, size=(real_dim, target_shape, target_shape))
            cur_step += 1

        accCount += [acc/8800]
        print(f"D training acc: {acc /8800}")
        acc = 0

        with torch.no_grad():
          for valImage, _ in valDataLoader:
                    valCondition = valImage[:, :, :, :image_width // 2]
                    valCondition = nn.functional.interpolate(valCondition, size=target_shape)
                    valReal = valImage[:, :, :, image_width // 2:]
                    valReal = nn.functional.interpolate(valReal, size=target_shape)
                    cur_batch_size = len(valCondition)
                    valCondition = valCondition.to(device)
                    valReal = valReal.to(device)

                    ### Update discriminator ###
                    valFake = gen(valCondition)
                    valdisc_fake_hat = disc(valFake.detach(), valCondition) # Detach generator
                    valdisc_real_hat = disc(valReal, valCondition)

                    ### Evaluate discriminator accuracy
                    for i in torch.mean(valdisc_fake_hat,[1,2,3]):
                      if i < 0.5:
                        acc += 1
                    for i in torch.mean(valdisc_real_hat,[1,2,3]):
                      if i > 0.5:
                        acc += 1
        print(f"D val acc: {acc /600}")
        valAccCount += [acc/600]

        if epoch % 5 == 4:
          if epoch > 0:
            if save_model:
              torch.save({'gen': gen.state_dict(),
                          'gen_opt': gen_opt.state_dict(),
                          'disc': disc.state_dict(),
                          'disc_opt': disc_opt.state_dict()
                          }, f"pix2pix_{cur_step}.pth")
          valIndices = [0,5,22,41,53]
          dset = torch.utils.data.Subset(stl10val, valIndices)
          sampleDataLoader = DataLoader(dset)
          count = 0
          for image, _ in sampleDataLoader:
            image_width = image.shape[3]
            image = image.to(device)
            condition = image[:, :, :, :image_width // 2]
            condition = nn.functional.interpolate(condition, size=target_shape)
            real = image[:, :, :, image_width // 2:]
            real = nn.functional.interpolate(real, size=target_shape)
            cur_batch_size = len(condition)
            condition = condition.to(device)
            real = real.to(device)

            with torch.no_grad():
              fake = gen(condition)
            fake = F.interpolate(fake, size=96)
            result = torch.cat((image, fake),3).to(device)
            save_image(result,'pytoch-pix2pix/final/baseline-val-results/'+str(cur_step)+str(count)+'.png')
            count += 1
    plt.title("Generator Loss Curve")
    plt.plot(curStepCount, genLossCount, label = "Training")
    # plt.plot(curStepCount, valGenLossCount, label = "Validation")
    plt.xlabel("Number of Steps")
    plt.ylabel("Loss")
    plt.legend(loc='best')
    plt.show()

    plt.title("Discriminator Loss Curve")
    plt.plot(curStepCount, disLossCount, label = "Training")
    # plt.plot(curStepCount, valDisLossCount, label = "Validation")
    plt.xlabel("Number of Steps")
    plt.ylabel("Loss")
    plt.legend(loc='best')
    plt.show() 

    plt.title("Discriminator Accuracy Curve")
    plt.plot(epochCount, accCount, label = "Training")
    plt.plot(epochCount, valAccCount, label = "Validation")
    plt.xlabel("Number of Epochs")
    plt.ylabel("Accuracy")
    plt.legend(loc='best')
    plt.show() 

train(save_model = True)

Output hidden; open in https://colab.research.google.com to view.

In [26]:
from torchvision.utils import save_image

testDataloader = DataLoader(testDat)
count = 1
for image, _ in testDataloader:
  image_width = image.shape[3]
  image = image.to(device)
  condition = image[:, :, :, :image_width // 2]
  condition = nn.functional.interpolate(condition, size=target_shape)
  real = image[:, :, :, image_width // 2:]
  real = nn.functional.interpolate(real, size=target_shape)
  cur_batch_size = len(condition)
  condition = condition.to(device)
  real = real.to(device)

  with torch.no_grad():
    fake = gen(condition)
  fake = F.interpolate(fake, size=96)
  result = torch.cat((image, fake),3).to(device)
  save_image(result,'pytoch-pix2pix/final/baseline-test-results/'+str(count)+'.png')
  count += 1