Installing pytorch_lightning package, it is like keras for tensorflow

In [None]:
!pip install pytorch_lightning -qq

[K     |████████████████████████████████| 585 kB 14.4 MB/s 
[K     |████████████████████████████████| 140 kB 77.3 MB/s 
[K     |████████████████████████████████| 596 kB 92.1 MB/s 
[K     |████████████████████████████████| 419 kB 97.4 MB/s 
[K     |████████████████████████████████| 1.1 MB 74.9 MB/s 
[K     |████████████████████████████████| 271 kB 94.2 MB/s 
[K     |████████████████████████████████| 94 kB 1.6 MB/s 
[K     |████████████████████████████████| 144 kB 81.9 MB/s 
[?25h

In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import distance
import random
import torch
import os
import gdown
import torch.nn as nn
import torch.nn.functional as F
from ShConv import ShConv
from utils import LayersHyperParameters
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms as T
from transforms import CutOutRectangles, RandomText, ToTensor
from pytorch_lightning import Trainer, seed_everything

from dataset import ImageInpaintingDataset

For reproducability, fixing the random seeds for all the packages used in this notebook

In [3]:
# fix the seed
seed = 877
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

seed_everything(seed, workers=True)

Global seed set to 877


877

**Dataset**  
I used the thumbnails128x128 version of the [Flickr-Faces-HQ Dataset (FFHQ)](https://github.com/NVlabs/ffhq-dataset), which comprises of 70,000 images of the size 128 * 128 in PNG format.

In [4]:
TEST_SPLIT = 0.15
VALIDATION_SPLIT = 0.1

In [5]:
data_path = r"C:\Users\eyad\Pictures\Images Datasets\Filcker Faces thumbnails 128x128"

dataset = ImageInpaintingDataset(root_dir=data_path, transform=T.Compose([
                                            CutOutRectangles(num_rectangles=2),
                                            # RandomText(text_size=25),
                                            ToTensor()
                                        ]),
                                        extensions=['png'],
                                        nested=True)

dataset_size = len(dataset)

split = int(np.floor(TEST_SPLIT * dataset_size))

train_set, test_set = random_split(dataset, [dataset_size - split, split], generator=torch.Generator().manual_seed(seed))

trainset_size = len(train_set)

split = int(np.floor(VALIDATION_SPLIT * trainset_size))

train_set, validation_set = random_split(train_set, [trainset_size - split, split], generator=torch.Generator().manual_seed(seed))

print(len(train_set), len(validation_set), len(test_set))

train_dataloader = DataLoader(train_set, batch_size=32,
                    shuffle=True, num_workers=7)



validation_dataloader = DataLoader(validation_set, batch_size=32,
                    shuffle=False, num_workers=3)

test_dataloader = DataLoader(test_set, batch_size=32,
                    shuffle=False, num_workers=2)

inside init of CutOutRectangles
53550 5950 10500


In [6]:
LR = 1e-3

In [7]:
class ShepardNet(pl.LightningModule):
    def __init__(self, layers):
        super(ShepardNet, self).__init__()
        self.layers = layers
        the_input_layer = LayersHyperParameters(self.layers[0].layer_type, 3, self.layers[0].kernel_size)
        self.layers = [the_input_layer, *self.layers]
        self.modules_list = nn.ModuleList()
        for i, (input_layer, output_layer) in enumerate(zip(self.layers[:-1], self.layers[1:])):
            if (output_layer.layer_type == 'conv'):
                self.modules_list.append(nn.Conv2d(input_layer.kernels_num, output_layer.kernels_num, output_layer.kernel_size, stride=output_layer.stride, padding=output_layer.padding))
                self.modules_list.append(nn.ReLU())
                if (i != len(self.layers) - 1):
                    self.modules_list.append(nn.BatchNorm2d(output_layer.kernels_num))
            else:
                self.modules_list.append(ShConv(input_layer.kernels_num, output_layer.kernels_num, output_layer.kernel_size, stride=output_layer.stride, padding=output_layer.padding))
                self.modules_list.append(nn.ReLU())
                if (i != len(self.layers) - 1):
                    self.modules_list.append(nn.BatchNorm2d(output_layer.kernels_num))
        # saving the hyperparameters.
        self.save_hyperparameters()

    def forward(self, masks, x):
        for layer in self.modules_list:
            if isinstance(layer, ShConv):
                x, masks = layer(masks, x)
            else:
                x = layer(x)
        return x, masks

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        original, x, masks = batch['original'], batch['corrupted'], batch['mask']
        for layer in self.modules_list:
          if isinstance(layer, ShConv):
              x, masks = layer(masks, x)
          else:
              x = layer(x)
        loss = F.mse_loss(original, x)
        return loss

    def test_step(self, batch, batch_idx):
        # this is the test loop
        original, x, masks = batch['original'], batch['corrupted'], batch['mask']
        for layer in self.modules_list:
          if isinstance(layer, ShConv):
              x, masks = layer(masks, x)
          else:
              x = layer(x)
        test_loss = F.mse_loss(original, x)
        self.log("test_loss", test_loss)

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        original, x, masks = batch['original'], batch['corrupted'], batch['mask']
        for layer in self.modules_list:
          if isinstance(layer, ShConv):
              x, masks = layer(masks, x)
          else:
              x = layer(x)
        val_loss = F.mse_loss(original, x)
        self.log("val_loss", val_loss, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LR)
        return optimizer

In [8]:
layers = [
    LayersHyperParameters('shepard', 8, 7),
    LayersHyperParameters('shepard', 8, 5),
    LayersHyperParameters('conv', 10, 3),
    LayersHyperParameters('conv', 25, 3),
    LayersHyperParameters('conv', 3, 3),
]
net = ShepardNet(layers)

trainer = pl.Trainer(accelerator="gpu", max_epochs=4, deterministic=True)
trainer.fit(net, train_dataloader, validation_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type       | Params
--------------------------------------------
0 | modules_list | ModuleList | 6.6 K 
--------------------------------------------
6.6 K     Trainable params
0         Non-trainable params
6.6 K     Total params
0.026     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:

batch = 13
in_channels = 8
out_channels = 512
kernel_size = 5
stride = 1
padding = 'same'
# TODO: accept both int and string for padding
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
shconv = ShConv(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
shconv.to(device)

masks = torch.randn(batch, in_channels, 32, 32)
x = torch.randn(batch, in_channels, 32, 32)
x, masks = x.to(device), masks.to(device)
output_features_map, output_mask = shconv(masks, x)
print(output_features_map.shape)
print(output_mask.shape)

cuda:0
torch.Size([13, 512, 32, 32])
torch.Size([13, 512, 32, 32])
