In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import elu, instance_norm
from torch.utils.data import Dataset
from torch import optim, utils

from torchvision.datasets import Places365
from torchvision import transforms

import pytorch_lightning as pl



In [2]:
LAMBDA_ADV = 1e-2
LR_G = 1e-4
LR_D = 1e-3
B1 = 0.5
B2 = 0.9

## Model

In [3]:
class SkipConnection(nn.Module):
    def forward(self, out, old_out):
        return torch.cat([out, old_out], dim=1)

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class GatedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,):
        super(GatedConv, self).__init__()
        self.conv2d = nn.Conv2d(in_channels, 
                                out_channels, 
                                kernel_size, 
                                stride, 
                                padding, 
                                dilation, 
                                groups, 
                                bias)
        self.mask_conv2d = nn.Conv2d(in_channels, 
                                     out_channels, 
                                     kernel_size, 
                                     stride, 
                                     padding, 
                                     dilation, 
                                     groups,
                                     bias)
        self.sigmoid = nn.Sigmoid()


    def gated(self, mask):
        return self.sigmoid(mask)

    def forward(self, inp):
        x = self.conv2d(inp)
        mask = self.mask_conv2d(inp)
        x = elu(x) * self.gated(mask)
        x = instance_norm(x)
        return x

In [4]:
class Generator(pl.LightningModule):
    def __init__(self):
        super(Generator, self).__init__()
        self.skip = SkipConnection()
        self.upsampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        
        self.cache_GC_1 = GatedConv(4, 32, 5, 1, 2)
        self.cache_GC_2 = GatedConv(32, 64, 3, 2, 1)
        self.cache_GC_3 = GatedConv(64, 64, 3, 1, 1)
        self.cache_GC_4 = GatedConv(64, 128, 3, 2, 1)
        self.cache_GC_5 = GatedConv(128, 128, 3, 1, 1)

        self.mid_pile = nn.Sequential(
            GatedConv(128, 128, 3, 1, 1),
            GatedConv(128, 128, 3, 1, 2, dilation=2),
            GatedConv(128, 128, 3, 1, 4, dilation=4),
            GatedConv(128, 128, 3, 1, 8, dilation=8), 
            GatedConv(128, 128, 3, 1, 16, dilation=16),
            GatedConv(128, 128, 3, 1, 1)
        )

        self.GC_1 = GatedConv(256, 128, 3, 1, 1)
        self.GC_2 = GatedConv(256, 64, 3, 1, 1)
        self.GC_3 = GatedConv(128, 64, 3, 1, 1)
        self.GC_4 = GatedConv(128, 32, 3, 1, 1)
        self.GC_5 = GatedConv(64, 16, 3, 1, 1)

        self.final_conv = nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        out1 = self.cache_GC_1(x)
        out2 = self.cache_GC_2(out1)
        out3 = self.cache_GC_3(out2)
        out4 = self.cache_GC_4(out3)
        out5 = self.cache_GC_5(out4)

        out = self.mid_pile(out5)

        out = self.skip(out, out5)
        out = self.GC_1(out)

        out = self.skip(out, out4)
        out = self.upsampling(out)
        out = self.GC_2(out)

        out = self.skip(out, out3)
        out = self.GC_3(out)

        out = self.skip(out, out2)
        out = self.upsampling(out)
        out = self.GC_4(out)

        out = self.skip(out, out1)
        out = self.GC_5(out)
        
        out = self.final_conv(out)
        return out
      

In [5]:
class Discriminator(pl.LightningModule):
    def __init__(self):
        super(Discriminator, self).__init__()

        layers = []
        in_channels, out_channels = 4, 64
        for i in range(6):
            layers.append(nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=2, padding=2)))
            layers.append(nn.LeakyReLU())
            in_channels = out_channels
            out_channels = 2 * out_channels if out_channels < 256 else out_channels
        layers.append(nn.utils.spectral_norm(nn.Conv2d(256, 256, kernel_size=4, stride=1, padding=0)))
        layers.append(nn.LeakyReLU())

        self.pile = nn.Sequential(*layers)
        
        self.flatten = Flatten()
        self.cond_linear = nn.utils.spectral_norm(nn.Linear(1000, 256, bias=False))
        self.final_linear = nn.utils.spectral_norm(nn.Linear(256, 1, bias=False))

    def forward(self, x, y, z):
        out = torch.cat([x, y], dim=1)
        out = self.pile(out)
        out = self.flatten(out)
        out_t = self.final_linear(out)

        z = self.cond_linear(z)
        out = (out * z).sum(1, keepdim=True)
        out = torch.add(out, out_t)
        return out

In [6]:
class GAN(pl.LightningModule):
    def __init__(self,):
        super().__init__()

        self.generator = Generator()
        self.discriminator = Discriminator()
        self.automatic_optimization = False

    def forward(self, z):
        return self.generator(z)

    def generator_step(self, input_tensor, real_image, masked_image, mask, cond):

        gen_image = self.generator(input_tensor)
        loss_rec = torch.nn.L1Loss()(gen_image, real_image)

        substituted_gen_image = gen_image * mask + masked_image
        loss_adv = -self.discriminator(substituted_gen_image, mask, cond).mean()

        loss_G = LAMBDA_ADV * loss_adv + loss_rec

        return {
            'loss_G': loss_G,
            "loss_adv": loss_adv,
            'loss_rec':  loss_rec
        }

    def discriminator_step(self, real_image, fake_image, mask, cond):
        pred_real = self.discriminator(real_image, mask, cond)
        pred_fake = self.discriminator(fake_image.detach(), mask, cond)
        loss_D = nn.ReLU()(1.0 - pred_real).mean() + nn.ReLU()(1.0 + pred_fake).mean()

        return {
            'loss_D': loss_D
        }

    def training_step(self, batch, batch_idx):
        optimizer_g, optimizer_d = self.optimizers()

        input_tensor = torch.Tensor(batch["input_tensor"]).to(torch.float32)
        cond = torch.Tensor(batch["inception_embeds"])
        
        real_image = input_tensor[:, :3, :, :]
        mask = input_tensor[:, 3:, :, :]
        masked_image = mask * real_image

        
        self.toggle_optimizer(optimizer_g, 0)

        G_output = self.generator_step(input_tensor, real_image, masked_image, mask, cond)
        loss_g =  G_output['loss_G']
        self.log('Generator loss', loss_g)
        self.manual_backward(loss_g)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)


        self.toggle_optimizer(optimizer_d, 1)

        gen_image = self.generator(input_tensor)
        fake_image = gen_image * mask + masked_image

        D_output = self.discriminator_step(real_image, fake_image, mask, cond)
        loss_d =  D_output['loss_D']
        self.log('Discriminator loss', loss_d)
        self.manual_backward(loss_d)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def configure_optimizers(self):
        optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=LR_G, betas=(B1, B2))
        optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=LR_D, betas=(B1, B2))
        return [optimizer_G, optimizer_D], []

In [7]:
boundless_gan = GAN()

## Data

In [11]:
from torchvision.datasets.places365 import Places365
from pathlib import Path

import numpy as np
import torch
import torch
import torchvision
import numpy as np
from functools import partial
import torchvision.transforms.functional as F


class ExampleDataset(Dataset):
    def __init__(self):
        self.data = torch.rand((3, 256, 256))
        self.mask = torch.ones((256, 256))
        self.mask = self.mask[None, :, :]

        self.input_tensor = torch.cat([self.data, self.mask])
        self.inception_embeds = torch.zeros((1000))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return {
            "input_tensor": self.input_tensor,
            "inception_embeds": self.inception_embeds,
        }


class Places365Embedding(Places365):
    def __init__(self, embeddings_path: Path, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.embeddings = torch.from_numpy(np.load(str(embeddings_path)))

    #         self.transform = transform

    def __getitem__(self, item):
        img, _ = super(Places365Embedding, self).__getitem__(item)
        print(img.shape)
        embedding = self.embeddings[item, :]
        return {"input_tensor": img, "inception_embeds": embedding}


def add_mask(image, mask_percentage, inpainting=False):
    """Add a mask to an image.
    Args:
        image (PIL.Image): Image to add the mask to.
        mask_percentage (float): Percentage of the image to mask.
        inpainting (bool): If True, the mask will be a square in the center of the image.
    Returns:
            torch.Tensor: Image with the mask."""
    image = F.pil_to_tensor(image)
    num_channels, height, width = image.shape
    mask = torch.ones((height, width), dtype=torch.bool)
    mask_width = int(width * mask_percentage)
    random_delta = np.random.randint(-4, 5)
    if inpainting:
        mask_size = int(np.sqrt(np.prod(image.shape[1:]) * mask_percentage))
        start_h = (image.shape[1] - mask_size) // 2
        start_w = (image.shape[2] - mask_size) // 2
        mask[start_h : start_h + mask_size, start_w : start_w + mask_size] = 0
    else:
        mask_width = mask_width + random_delta
        mask[:, width - mask_width :] = 0
    return torch.cat([image, mask.reshape(1, height, width)], dim=0)


# Usage example
# places365_dataset = torchvision.datasets.Places365(
#     "./data/pt_dataset/",
#     small=True,
#     download=False,
#     transform=partial(add_mask, mask_percentage=0.25, inpainting=False),
# )

places365_dataset = Places365Embedding(
    root="../data/pt_dataset/",
    small=True,
    download=False,
    transform=partial(add_mask, mask_percentage=0.25, inpainting=False),
    embeddings_path="../embeddings_inceptionv3_places365.npy"
)

In [12]:
# dataset = ExampleDataset()
train_loader = utils.data.DataLoader(places365_dataset, batch_size=8)

In [13]:
for i in train_loader:
    print(i)
    break

torch.Size([4, 256, 256])
torch.Size([4, 256, 256])
torch.Size([4, 256, 256])
torch.Size([4, 256, 256])
torch.Size([4, 256, 256])
torch.Size([4, 256, 256])
torch.Size([4, 256, 256])
torch.Size([4, 256, 256])
{'input_tensor': tensor([[[[187, 187, 187,  ..., 173, 172, 172],
          [187, 187, 188,  ..., 173, 172, 172],
          [187, 188, 188,  ..., 173, 173, 172],
          ...,
          [187, 188, 188,  ..., 175, 174, 174],
          [187, 187, 188,  ..., 175, 174, 174],
          [187, 187, 187,  ..., 175, 174, 174]],

         [[188, 188, 188,  ..., 174, 173, 173],
          [188, 188, 189,  ..., 174, 173, 173],
          [188, 189, 189,  ..., 174, 174, 173],
          ...,
          [189, 190, 190,  ..., 179, 178, 178],
          [189, 189, 190,  ..., 179, 178, 178],
          [189, 189, 189,  ..., 179, 178, 178]],

         [[190, 190, 190,  ..., 178, 177, 177],
          [190, 190, 191,  ..., 178, 177, 177],
          [190, 191, 191,  ..., 178, 178, 177],
          ...,
      

In [10]:
trainer = pl.Trainer(limit_train_batches=100, max_epochs=5, accelerator="auto", devices="auto", strategy="auto")
trainer.fit(model=boundless_gan, train_dataloaders=train_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/8
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/8
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/8
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/8
Initializing distributed: GLOBAL_RANK: 4, MEMBER: 5/8
Initializing distributed: GLOBAL_RANK: 5, MEMBER: 6/8
Initializing distributed: GLOBAL_RANK: 6, MEMBER: 7/8
Initializing distributed: GLOBAL_RANK: 7, MEMBER: 8/8
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 8 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]
LOCAL_RANK: 

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [15]:
some_a = torch.ones([32, 3, 256, 256])
some_a[0].shape

torch.Size([3, 256, 256])