In [2]:
!pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch_lightning
  Downloading pytorch_lightning-2.0.0-py3-none-any.whl (715 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m715.6/715.6 KB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.11.4-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.2/519.2 KB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.7.0
  Downloading lightning_utilities-0.8.0-py3-none-any.whl (20 kB)
Collecting aiohttp!=4.0.0a0,!=4.0.0a1
  Downloading aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
Collecting aiosignal>=1.1.2
  Downloading aiosignal-1.3.1-py3-none-any.whl (7.6 kB)
Collecting yarl<2.0,>=1.0
  Dow

In [15]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn.functional import elu, instance_norm
from torch.utils.data import Dataset
from torch import optim, utils

from torchvision.datasets import Places365
from torchvision import transforms

import pytorch_lightning as pl

In [4]:
LAMBDA_ADV = 1e-2
LR_G = 1e-4
LR_D = 1e-3
B1 = 0.5
B2 = 0.9

## Model

In [6]:
class SkipConnection(nn.Module):
    def forward(self, out, old_out):
        return torch.cat([out, old_out], dim=1)

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class GatedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,):
        super(GatedConv, self).__init__()
        self.conv2d = nn.Conv2d(in_channels, 
                                out_channels, 
                                kernel_size, 
                                stride, 
                                padding, 
                                dilation, 
                                groups, 
                                bias)
        self.mask_conv2d = nn.Conv2d(in_channels, 
                                     out_channels, 
                                     kernel_size, 
                                     stride, 
                                     padding, 
                                     dilation, 
                                     groups,
                                     bias)
        self.sigmoid = nn.Sigmoid()


    def gated(self, mask):
        return self.sigmoid(mask)

    def forward(self, input):
        x = self.conv2d(input)
        mask = self.mask_conv2d(input)
        x = elu(x) * self.gated(mask)
        x = instance_norm(x)
        return 

In [7]:
class Generator(pl.LightningModule):
    def __init__(self):
        super(Generator, self).__init__()
        self.skip = SkipConnection()
        self.upsampling = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        
        self.cache_GC_1 = GatedConv(5, 32, 5, 1, 2)
        self.cache_GC_2 = GatedConv(32, 64, 3, 2, 1)
        self.cache_GC_3 = GatedConv(64, 64, 3, 1, 1)
        self.cache_GC_4 = GatedConv(64, 128, 3, 2, 1)
        self.cache_GC_5 = GatedConv(128, 128, 3, 1, 1)

        self.mid_pile = nn.Sequential(
            GatedConv(128, 128, 3, 1, 1),
            GatedConv(128, 128, 3, 1, 2, dilation=2),
            GatedConv(128, 128, 3, 1, 4, dilation=4),
            GatedConv(128, 128, 3, 1, 8, dilation=8), 
            GatedConv(128, 128, 3, 1, 16, dilation=16),
            GatedConv(128, 128, 3, 1, 1)
        )

        self.GC_1 = GatedConv(256, 128, 3, 1, 1)
        self.GC_2 = GatedConv(256, 64, 3, 1, 1)
        self.GC_3 = GatedConv(128, 64, 3, 1, 1)
        self.GC_4 = GatedConv(128, 32, 3, 1, 1)
        self.GC_5 = GatedConv(64, 16, 3, 1, 1)

        self.final_conv = nn.Conv2d(16, 3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        out1 = self.cache_GC_1(x)
        out2 = self.cache_GC_2(out1)
        out3 = self.cache_GC_3(out2)
        out4 = self.cache_GC_4(out3)
        out5 = self.cache_GC_5(out4)

        out = self.mid_pile(out5)

        out = self.skip(out, out5)
        out = self.GC_1(out)

        out = self.skip(out, out4)
        out = self.upsampling(out)
        out = self.GC_2(out)

        out = self.skip(out, out3)
        out = self.GC_3(out)

        out = self.skip(out, out2)
        out = self.upsampling(out)
        out = self.GC_4(out)

        out = self.skip(out, out1)
        out = self.GC_5(out)
        
        out = self.final_conv(out)
        return out
      

In [8]:
class Discriminator(pl.LightningModule):
    def __init__(self):
        super(Discriminator, self).__init__()

        layers = []
        in_channels, out_channels = 4, 64
        for i in range(6):
            layers.append(nn.utils.spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=2, padding=2)))
            layers.append(nn.LeakyReLU())
            in_channels = out_channels
            out_channels = 2 * out_channels if out_channels<256 else out_channels
        layers.append(nn.utils.spectral_norm(nn.Conv2d(256, 256, kernel_size=4, stride=1, padding=0)))
        layers.append(nn.LeakyReLU())

        self.pile = nn.Sequential(*layers)
        
        self.flatten = Flatten()
        self.cond_linear = nn.utils.spectral_norm(nn.Linear(1000, 256, bias=False))
        self.final_linear = nn.utils.spectral_norm(nn.Linear(256, 1, bias=False))

    def forward(self, x, y, z):
        out = torch.cat([x, y], dim=1)
        out = self.pile(out)
        out = self.flatten(out)
        out_t = self.final_linear(out)

        z = self.cond_linear(z)
        out = (out * z).sum(1, keepdim=True)
        out = torch.add(out, out_t)
        return 

In [30]:
class GAN(pl.LightningModule):
    def __init__(self,):
        super().__init__()

        self.generator = Generator()
        self.discriminator = Discriminator()
        self.automatic_optimization = False

    def forward(self, z):
        return self.generator(z)

    def adversarial_loss(self, y_hat, y):
        return torch.nn.L1Loss()

    def generator_step(self, real_image, masked_image, mask, cond):

        squeezed_input = torch.cat([masked_image, mask])
        gen_image = self.generator(squeezed_input)
        loss_rec = self.criterion_pixel(gen_image, real_image)

        substituted_gen_image = gen_image * mask + masked_image
        loss_adv = -self.discriminator(substituted_gen_image, mask, cond).mean()

        loss_G = LAMBDA_ADV * loss_adv + loss_rec

        return {
            'loss_G': loss_G,
            "loss_adv": loss_adv,
            'loss_rec':  loss_rec
        }

    def discriminator_step(self, real_image, fake_image, mask, cond):
        pred_real = self.discriminator(real_image, mask, cond)
        pred_fake = self.discriminator(fake_image.detach(), mask, cond)
        loss_D = nn.ReLU()(1.0 - pred_real).mean() + nn.ReLU()(1.0 + pred_fake).mean()

        return {
            'loss_D': loss_D
        }

    def training_step(self, batch, batch_idx):
        optimizer_g, optimizer_d = self.optimizers()

        masked_image = torch.Tensor(batch["masked_image"])
        real_image = torch.Tensor(batch["real_image"]) # [3, 256, 256]
        mask = torch.Tensor(batch["mask"])
        cond = torch.Tensor(0) # inception_embed(real_image)

        
        self.toggle_optimizer(optimizer_g)

        G_output = self.generator_step(real_image, masked_image, mask, cond)
        loss_g =  G_output['loss_G']
        self.log('Generator loss', loss_g)
        self.manual_backward(loss_g)
        optimizer_g.step()
        optimizer_g.zero_grad()
        self.untoggle_optimizer(optimizer_g)


        self.toggle_optimizer(optimizer_d)

        squeezed_input = torch.cat([masked_image, mask])
        gen_image = self.generator(squeezed_input)
        fake_image = gen_image * mask + masked_image

        D_output = self.discriminator_step(real_image, fake_image, mask, cond)
        loss_d =  D_output['loss_D']
        self.log('Discriminator loss', loss_d)
        self.manual_backward(loss_d)
        optimizer_d.step()
        optimizer_d.zero_grad()
        self.untoggle_optimizer(optimizer_d)

    def configure_optimizers(self):
        optimizer_G = torch.optim.Adam(self.generator.parameters(), lr=LR_G, betas=(B1, B2))
        optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=LR_D, betas=(B1, B2))
        return [optimizer_G, optimizer_D], []

In [31]:
boundless_gan = GAN()

## Data

In [32]:
class SmallDataset(Dataset):
    def __init__(self):
        self.data = torch.rand((2, 3, 256, 256))
        self.mask = torch.ones((2, 256, 256))

        all_mask = self.mask[:,None,:,:].repeat(1,3,1, 1)
        self.masked_data = all_mask * self.data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, item):
        return {
            "real_image" : self.data[item],
            "masked_image" : self.masked_data,
            "mask" : self.mask
        }


In [33]:
dataset = SmallDataset()
train_loader = utils.data.DataLoader(dataset)

In [34]:
trainer = pl.Trainer(limit_train_batches=100, max_epochs=1)
trainer.fit(model=boundless_gan, train_dataloaders=train_loader)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 3.5 M 
1 | discriminator | Discriminator | 7.3 M 
------------------------------------------------
10.7 M    Trainable params
0         Non-trainable params
10.7 M    Total params
42.840    Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

RuntimeError: ignored

In [None]:
transform = transforms.Compose([
    transforms.PILToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

data = Places365(root='.', download=True, small=True, split='val', transform=transform)