# Pix2Pix for object detection training set synthesis

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import matplotlib.pyplot as plt

import torch
import torchvision

In [None]:
import torchvision.datasets as datasets

mnist_trainset = datasets.MNIST(root='/Users/jcboyd/Data/torch', train=True, download=True)
mnist_testset = datasets.MNIST(root='/Users/jcboyd/Data/torch', train=False, download=True)

x_train = (mnist_trainset.data / 127.5) - 1
x_test = (mnist_testset.data / 127.5) - 1

y_train = mnist_trainset.targets
y_test = mnist_testset.targets

idx = (y_train == 0) | (y_train == 1)

y_train = y_train[idx]
x_train = x_train[idx]

idx = (y_test == 0) | (y_test == 1)

x_test = x_test[idx]
y_test = y_test[idx]

## Assemble training set

In [None]:
from torch.nn import UpsamplingNearest2d

def get_canvas(x_data, y_data, num_samples=16, nb_classes=2, dim=256):

    idx = torch.randint(x_data.shape[0], size=(num_samples,))
    images = x_data[idx]
    labels = y_data[idx]

    h, w = images.shape[1:]

    canvas = -torch.ones((dim, dim))
    mask_img = -torch.ones((nb_classes, dim, dim))
    bboxes = torch.Tensor()

    for i in range(num_samples):

        y, x = (torch.randint(dim - h, size=(1,)).item(), torch.randint(dim - w, size=(1,)).item())
        canvas[y:y+h, x:x+w] = torch.max(canvas[y:y+h, x:x+w], images[i].squeeze())

        s = 4

        binary_noise = (torch.rand(h // s, w // s) > 0.5)[None, None].float()
        binary_noise = (binary_noise - 0.5) / 0.5   # normalise to [-1, 1]
        scaled_sample = UpsamplingNearest2d(scale_factor=(s, s))(binary_noise)

        mask_img[labels[i], y:y+h, x:x+w] = scaled_sample.squeeze()
        bboxes = torch.cat([bboxes, torch.tensor([[x, y, x + w, y + h]]).float()], axis=0)

    canvas = torch.clamp(canvas, -1, 1)

    return canvas[None, None], mask_img[None], bboxes


def gen_canvas(x_data, y_data, batch_size=2):

    while True:

        samples = [get_canvas(x_data, y_data) for _ in range(batch_size)]

        canvas_batch = torch.cat([sample[0] for sample in samples])
        mask_batch = torch.cat([sample[1] for sample in samples])

        bbox_batch = torch.cat([torch.cat([i * torch.ones((16, 1)), sample[2]], axis=1)
                               for i, sample in enumerate(samples)], axis=0)

        yield mask_batch, canvas_batch, bbox_batch

In [None]:
train_gen = gen_canvas(x_train, y_train, batch_size=2)
mask_batch, canvas_batch, bbox_batch = next(train_gen)

In [None]:
fig, axes = plt.subplots(figsize=(15, 5), ncols=3)

axes[0].imshow(canvas_batch[0].squeeze(), cmap='Greys_r')

for i in range(1, 3):

    axes[i].imshow(mask_batch[0, i - 1].squeeze())

## Build model

In [None]:
# from keras.models import Model
# from keras.layers import Input
# from keras.optimizers import Adam
# from src.models import fnet, patch_gan
# from src.utils import set_trainable

# nb_classes = 2

# input_shape = mask_batch.shape[1:]     # "images"
# output_shape = canvas_batch.shape[1:]  # "labels"

# h, w, c = output_shape
# disc_patch = (h // 2 ** 4, w // 2 ** 4, c)

# optimizer = Adam(0.0002, 0.5)

# # Build discriminator
# discriminator = patch_gan(input_shape, output_shape)
# discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

# # Build the generator
# generator = fnet(input_shape, 64, 'tanh')

# # Input images and their conditioning images
# images = Input(shape=input_shape)
# labels = Input(shape=output_shape)

# # By conditioning on B generate a fake version of A
# fake_labels = generator(images)

# set_trainable(discriminator, False)

# # Discriminators determines validity of translated images / condition pairs
# valid = discriminator([fake_labels, images])

# combined = Model(inputs=[images, labels], outputs=[valid, fake_labels])
# combined.compile(loss=['mse', 'mae'], loss_weights=[1, 100], optimizer=optimizer)

## Train model

In [None]:
# epochs = 200
# batch_size = 2

# # Adversarial loss ground truths
# valid = np.ones((batch_size,) + disc_patch)
# fake = np.zeros((batch_size,) + disc_patch)

# train_gen = gen_canvas(x_train, y_train, batch_size)
# steps_per_epoch = 200 # x_train.shape[0] // batch_size // 4

# for epoch in range(epochs):

#     for batch_i in range(steps_per_epoch):

#         images, labels = next(train_gen)

#         # Train discriminator - condition on B and generate a translated version
#         fake_labels = generator.predict(images)

#         set_trainable(discriminator, True)

#         # Train the discriminators (original images = real / generated = Fake)
#         d_loss_real = discriminator.train_on_batch([images, labels], valid)
#         d_loss_fake = discriminator.train_on_batch([images, fake_labels], fake)
#         d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

#         set_trainable(discriminator, False)

#         # Train the generators
#         g_loss = combined.train_on_batch([images, labels], [valid, labels])

#         # Plot the progress
#         print('[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f]' % (
#             epoch, epochs, batch_i, steps_per_epoch, d_loss[0], 100 * d_loss[1], g_loss[0]))

#     generator.save_weights('./weights/pix2pix_mnist.h5')
#     discriminator.save_weights('./weights/patch_gan_mnist.h5')

## Test model

In [None]:
# generator.load_weights('./weights/pix2pix_mnist_199.h5')
# discriminator.load_weights('./weights/patch_gan_mnist_199.h5')

# gen_test = gen_canvas(x_test, y_test, 3)
# images, labels, bboxes, roi_labels = next(gen_test)

# fake_labels = generator.predict((images + 1) / 2)

# titles = ['Original', 'Generated']
# fig, axes = plt.subplots(figsize=(10, 15), nrows=3, ncols=2)

# axes[0, 0].set_title('Original')
# axes[0, 1].set_title('Generated')

# for i in range(3):

#     axes[i, 0].imshow(labels[i].squeeze())
#     axes[i, 1].imshow(fake_labels[i].squeeze())

In [None]:
# fig, axes = plt.subplots(figsize=(10, 5), nrows=4, ncols=4)

# img = fake_labels[0, ..., 0]

# for i in range(16):

#     ax = axes[i // 4][i % 4]

#     left, top, right, bottom = list(map(int, bboxes[i, 1:]))
#     ax.imshow(img[top:bottom, left:right])#[220:250, 170:200])
#     ax.axis('off')

There is a significant degree of mode collapse in the generated objects.

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
def set_trainable(model, trainable):

    for param in model.parameters():

        param.requires_grad = trainable

#         if not trainable:
#             param.grad = None

In [None]:
from torch.nn import Module, Sequential
from torch.nn import Linear, Conv2d, ReLU, LeakyReLU, BatchNorm1d, BatchNorm2d, Tanh
from torch.nn import MaxPool2d, UpsamplingNearest2d, ZeroPad2d

from torchvision.ops import RoIAlign, RoIPool


class Downward(Module):

    def __init__(self, in_ch, out_ch, normalise=True):

        super(Downward, self).__init__()

        self.pool = Sequential(
            ZeroPad2d((1, 2, 1, 2)),
            Conv2d(in_ch, out_ch, kernel_size=4, stride=2),
            LeakyReLU(negative_slope=0.2, inplace=True))

        if normalise:
            self.pool.add_module('batch_norm', BatchNorm2d(out_ch, momentum=0.8))

    def forward(self, x):
        x = self.pool(x)
        return x


class Upward(Module):

    def __init__(self, in_ch, out_ch):

        super(Upward, self).__init__()

        self.depool = Sequential(
            UpsamplingNearest2d(scale_factor=2),
            ZeroPad2d((1, 2, 1, 2)),
            Conv2d(in_ch, out_ch, kernel_size=4, stride=1),
            ReLU(inplace=True),
            BatchNorm2d(out_ch, momentum=0.8))

    def forward(self, x1, x2):

        x = self.depool(x1)
        x = torch.cat((x, x2), dim=1)

        return x


class FNet(Module):

    def __init__(self, base_filters=32):

        super(FNet, self).__init__()

        self.down1 = Downward(2, base_filters, normalise=False)
        self.down2 = Downward(base_filters, 2 * base_filters)
        self.down3 = Downward(2 * base_filters, 4 * base_filters)
        self.down4 = Downward(4 * base_filters, 8 * base_filters)
        self.down5 = Downward(8 * base_filters, 8 * base_filters)
#         self.down6 = Downward(8 * base_filters, 8 * base_filters)
#         self.down7 = Downward(8 * base_filters, 8 * base_filters)

        self.up1 = Upward(8 * base_filters, 8 * base_filters)
        self.up2 = Upward(16 * base_filters, 4 * base_filters)
        self.up3 = Upward(8 * base_filters, 2 * base_filters)
        self.up4 = Upward(4 * base_filters, base_filters)

#         self.up1 = Upward(8 * base_filters, 8 * base_filters)
#         self.up2 = Upward(16 * base_filters, 8 * base_filters)
#         self.up3 = Upward(16 * base_filters, 8 * base_filters)
#         self.up4 = Upward(16 * base_filters, 4 * base_filters)
#         self.up5 = Upward(8 * base_filters, 2 * base_filters)
#         self.up6 = Upward(4 * base_filters, base_filters)

        self.out_conv = Sequential(
            UpsamplingNearest2d(scale_factor=2),
            ZeroPad2d((1, 2, 1, 2)),
            Conv2d(2 * base_filters, 1, kernel_size=4, stride=1),
            Tanh())

    def forward(self, x):

        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
#         x6 = self.down6(x5)
#         x7 = self.down7(x6)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)

#         x = self.up1(x7, x6)
#         x = self.up2(x, x5)
#         x = self.up3(x, x4)
#         x = self.up4(x, x3)
#         x = self.up5(x, x2)
#         x = self.up6(x, x1)

        x = self.out_conv(x)

        return x


class PatchGAN(Module):

    def __init__(self, base_filters=32):

        super(PatchGAN, self).__init__()

        self.down1 = Downward(3, base_filters, normalise=False)
        self.down2 = Downward(base_filters, 2 * base_filters)
        self.down3 = Downward(2 * base_filters, 4 * base_filters)
        self.down4 = Downward(4 * base_filters, 8 * base_filters)

        self.padding = ZeroPad2d((1, 2, 1, 2))
        self.validity = Conv2d(8 * base_filters, 1, kernel_size=4, stride=1)

    def forward(self, x, y):

        x = torch.cat([x, y], axis=1)

        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)

        x = self.padding(x)
        x = self.validity(x)

        return x


class RoIGAN(Module):

    def __init__(self):

        super(RoIGAN, self).__init__()

        self.features = Sequential(
            Conv2d(3, 16, kernel_size=3, padding=1),
            LeakyReLU(negative_slope=0.2, inplace=True),
            Conv2d(16, 16, kernel_size=3, padding=1),
            LeakyReLU(negative_slope=0.2, inplace=True),
            BatchNorm2d(16, momentum=0.8),
            MaxPool2d(kernel_size=2, stride=2),
            Conv2d(16, 32, kernel_size=3, padding=1),
            LeakyReLU(negative_slope=0.2, inplace=True),
            BatchNorm2d(32, momentum=0.8),
            Conv2d(32, 32, kernel_size=3, padding=1),
            LeakyReLU(negative_slope=0.2, inplace=True),
            BatchNorm2d(32, momentum=0.8))

#         self.roi_align = RoIAlign(output_size=(7, 7),
#                                   spatial_scale=1, sampling_ratio=-1)
        self.roi_align = RoIPool(output_size=(7, 7), spatial_scale=1)

        self.classifier = Sequential(
            Linear(1568, 1))

    def forward(self, x, y, boxes):

        x = torch.cat([x, y], axis=1)

        features = self.features(x)
        # N.B. box coordinates correspond to features, not input_img
        roi = self.roi_align(features, boxes)

        roi_flat = roi.view(-1, 1568)
        valid = self.classifier(roi_flat)

        return valid

In [None]:
generator = FNet().to(device)
discriminator = PatchGAN().to(device)
roi_gan = RoIGAN().to(device)

total_params = 0

for params in generator.parameters():
    total_params += torch.prod(torch.tensor(params.size())).item()

print(total_params)

In [None]:
from torch.nn import MSELoss, L1Loss, BCEWithLogitsLoss
from torch.optim import Adam

batch_size = 2

train_gen = gen_canvas(x_train, y_train, batch_size)
mask_batch, canvas_batch, bbox_batch = next(train_gen)

c, h, w = mask_batch.shape[1:]

disc_patch = (1, h // 2 ** 4, w // 2 ** 4)

# Loss weights
lambda_recon = 100

optimiser_discriminator = Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimiser_roi = Adam(roi_gan.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimiser_combined = Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))

epochs = 200

steps_per_epoch = 10 # 200


def train_on_batch_discriminator(discriminator, imgs, optimiser):

    images, labels, fake_labels = imgs

    input_images = torch.cat([images, images], axis=0)
    input_labels = torch.cat([labels, fake_labels], axis=0)

    valid = torch.ones((batch_size,) + disc_patch)
    fake = torch.zeros((batch_size,) + disc_patch)

    targets = torch.cat([valid, fake], axis=0).to(device)

    outputs = discriminator(input_images, input_labels)

    # clear previous gradients
    optimiser.zero_grad()

    # forward pass
    d_loss = MSELoss()(outputs, targets)

    # calculate gradients
    d_loss.backward()

    # descent step
    optimiser.step()

    return d_loss


def train_on_batch_roi_gan(roi_gan, data, optimiser):

    images, labels, fake_labels, bboxes = data

    num_roi = bboxes.shape[0]

    input_images = torch.cat([images, images], axis=0)
    input_labels = torch.cat([labels, fake_labels], axis=0)
    input_bboxes = torch.cat([bboxes, bboxes], axis=0)

    input_bboxes[num_roi:, 0] += batch_size  # increment image id
    input_bboxes[:, 1:] = input_bboxes[:, 1:] / 2  # correct for pooling

#     valid = 0.7 + 0.5 * torch.rand((num_roi, 1))
#     fake = 0.3 * torch.rand((num_roi, 1))

    valid = torch.ones((num_roi, 1))
    fake = torch.zeros((num_roi, 1))

    targets = torch.cat([valid, fake], axis=0).to(device)

    # clear previous gradients
    optimiser.zero_grad()

    # forward pass
    validity = roi_gan(input_images, input_labels, input_bboxes)

    # calculate loss
    d_loss = MSELoss()(validity, targets)

    # backpropagate
    d_loss.backward()

    # descent step
    optimiser.step()

    return d_loss


def train_on_batch_combined(models, data, optimiser):

    # clear previous gradients
    optimiser.zero_grad()

    generator, discriminator, roi_gan = models
    images, labels, bboxes = data
    bboxes[:, 1:] = bboxes[:, 1:] / 2  # correct for pooling

    fake_labels = generator(images)

    # Discriminators determines validity of translated images
    valid = torch.ones((batch_size,) + disc_patch).to(device)
    validity = discriminator(images, fake_labels)

    # Discriminators determines validity of translated images
    num_roi = bboxes.shape[0]
    valid_roi = torch.ones((num_roi, 1)).to(device)
    validity_roi = roi_gan(images, fake_labels, bboxes)

    g_loss = 0 * MSELoss()(validity, valid) + \
             1 * MSELoss()(validity_roi, valid_roi) + \
             10 * L1Loss()(labels, fake_labels)

    # calculate gradients
    g_loss.backward()

    # descent step
    optimiser.step()

    return g_loss


for epoch in range(epochs):

    for batch_i in range(steps_per_epoch):

        data = next(train_gen)
        images, labels, bboxes = data[0].to(device), data[1].to(device), data[2].to(device)

        # Generate fake images
        fake_labels = generator(images).detach()

        # Train discriminator
        set_trainable(discriminator, True)

        d_loss = train_on_batch_discriminator(discriminator,
                                              [images, labels, fake_labels],
                                              optimiser_discriminator).item()

        set_trainable(discriminator, False)

        # Train roi discriminator
        set_trainable(roi_gan, True)

        roi_loss = train_on_batch_roi_gan(roi_gan,
                                          [images, labels, fake_labels, bboxes],
                                          optimiser_roi).item()

        set_trainable(roi_gan, False)

        # Train combined pix2pix
        g_loss = train_on_batch_combined([generator, discriminator, roi_gan],
                                         [images, labels, bboxes],
                                         optimiser_combined).item()

        # Plot the progress
        print('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [R loss: %f] [G loss: %05f]' % (
            epoch, epochs, batch_i, steps_per_epoch, d_loss, roi_loss, g_loss))

    # save model weights
    torch.save(generator.state_dict(), './pix2pix_generator_%d.torch' % epoch)
    torch.save(discriminator.state_dict(), './pix2pix_discriminator_%d.torch' % epoch)
    torch.save(roi_gan.state_dict(), './pix2pix_discriminator_%d.torch' % epoch)

In [None]:
epoch = 199 
generator = FNet().to(device)
generator.load_state_dict(torch.load('./weights//pix2pix_generator_%d.torch' % epoch,
                                     map_location=torch.device('cpu')))
generator = generator.eval()

In [None]:
# visualise progress
gen_test = gen_canvas(x_test, y_test, batch_size=3)
data = next(gen_test)
images, labels, bboxes = data[0].to(device), data[1].to(device), data[2].to(device)

fake_labels = generator(images)

titles = ['Original', 'Generated']
fig, axes = plt.subplots(figsize=(10, 15), nrows=3, ncols=2)

axes[0, 0].set_title('Original')
axes[0, 1].set_title('Generated')

for i in range(3):

    axes[i, 0].imshow(labels.detach().cpu().numpy()[i].squeeze())
    axes[i, 1].imshow(fake_labels.detach().cpu().numpy()[i].squeeze())

# fig.savefig('./outputs/torch_%04d.png' % epoch)
plt.show()