In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST

import pytorch_lightning as pl

### WanbB to visualize the training and validation curves

In [2]:
import wandb
from pytorch_lightning.loggers import WandbLogger

wandb.login(key='PUT_YOUR_KEY_HERE')

[34m[1mwandb[0m: Currently logged in as: [33mblurry-mood[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

### Downloading and reading the MNIST dataset

In [3]:
class MNISTDataModule(pl.LightningDataModule):

    def __init__(self, data_dir: str = './', batch_size: int = 64, num_workers: int = 8):
        super().__init__()
        
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose([
            transforms.Resize((28,28)),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])

        # self.dims is returned when you call dm.size()
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)
        self.num_classes = 10

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size, num_workers=self.num_workers)

### The definition of the Layer
Alter the values, add and remove layers to explore the best layer architecture

In [4]:
class Hidden(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels, drop_pb, kernel_size):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, stride=1, padding=1)
        self.adapt = nn.AdaptiveMaxPool2d(kernel_size)
        self.drop  = nn.Dropout(drop_pb)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, stride=1, padding=1)

        self.conv_block = nn.Sequential(self.conv1, self.adapt, self.drop, self.conv2)
        
        self.linears = nn.Conv2d(out_channels, out_channels*(kernel_size**2), kernel_size, padding=0, groups=out_channels )

    def forward(self, z):
        # Extract kernels
        z = self.conv_block(z)
        
        # Apply 'linears[i]' to each kernel in the i^th channel.
        b, c, h, w = z.size()
        z = self.linears(z)
        # Eliminate width & height, since they are always equal to 1.
        z = z.squeeze(-1).squeeze(-1)
        # unpack channel dimension back to: channel, height, width
        z = z.unflatten(-1, (c, h, w))
        
        return z

In [5]:
hidden = Hidden(3,10,5,.1,3)
a = torch.rand(2,3, 256,256)
y = hidden(a)
y.shape

torch.Size([2, 10, 3, 3])

In [39]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels, drop_pb, kernel_size=3, stride=1, padding=1):
        super().__init__()
        
        self.in_channels, self.out_channels = in_channels, out_channels
        self.stride, self.padding = stride, padding
        self.hidden = Hidden(in_channels, in_channels*out_channels , mid_channels, drop_pb, kernel_size)
        self.bias = nn.Parameter(torch.randn(out_channels))

    def forward(self, x):
        batch = x.size(0)
        kernels = self.hidden(x)  
        
        # Reshape a batch of feature maps to be an image with a looot of channels (batch x in_channels)
        x = x.flatten(0,1).unsqueeze(0)
        
        # Reshape the kernels to have the following shape:
        # (batch x out_channels), in_channels, kernel_size, kernel_size
        kernels = kernels.unflatten(1, (self.out_channels, self.in_channels)).flatten(0,1)

        # Apply 2d Conv in an image-wise manner, first in_channels kernels applied to first images, 
        # 2nd kernels to 2nd image, etc.
        # Then separate the out_channels of each image, yielding a batch dimension.
        z = F.conv2d(x, kernels, padding=1, stride=1, groups=batch).unflatten(1, (batch, self.out_channels)).squeeze(0)
        return z

In [40]:
block = ConvBlock(3, 10, 20, .1, ).to('cuda')
a = torch.rand(1,3, 256,256).to('cuda')
y = block(a)
y.shape

torch.Size([1, 10, 256, 256])

### Compare the performance with and without the use of the InputAware Layer.
Model: uses the InputAware Layer, Model2: is a typical CNN architecture.

In [41]:
class Model(nn.Module):
    def __init__(self,):
        super().__init__()

        self.sequence = nn.Sequential(
            nn.Conv2d(1, 8, 3, 1, 1),
            nn.PReLU(),
            nn.Conv2d(8, 16, 3, 1, 1),
            nn.ReLU(),
            ConvBlock(16, 32, 32, .1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(32,10),            
        )

    def forward(self, img):
        return self.sequence(img)

class Model2(nn.Module):
    def __init__(self,):
        super().__init__()

        self.sequence = nn.Sequential(
            nn.Conv2d(1,4,3,1,1),
            nn.PReLU(),
            nn.Conv2d(4,8,3,1,1),
            nn.ReLU(),
            nn.Conv2d(8,16,3,1,1),
            nn.ReLU(),
            nn.Conv2d(16,64,3,1,1),
            nn.ReLU(),
            nn.Conv2d(64,128,3,1,1),
            nn.ReLU(),
            nn.Conv2d(128,256,3,1,1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(256,10),            
        )

    def forward(self, img):
        return self.sequence(img)

In [42]:
block = Model2()
a = torch.rand(1,1, 32,32)
y = block(a)
y.shape

torch.Size([1, 10])

### LightningModule for the training
Modify the learning rate, weight decay, lr_scheduler, etc.

In [43]:
 class Learner(pl.LightningModule):

    def __init__(self, model):
        super().__init__()
        
        self.save_hyperparameters()

        self.model = Model() if model == 'model1' else Model2()
        self.cost = nn.CrossEntropyLoss()     

    def configure_optimizers(self):
        lr = 1e-3
        b1 = .9
        b2 = .99
        
        opt = torch.optim.Adam(self.parameters(), lr=lr, betas=(b1, b2))
        scheduler = torch.optim.lr_scheduler.StepLR(opt, 2, gamma=0.5, )
                                        
        return [opt], [scheduler]

    def forward(self, z):
        return self.model(z)

    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        cost = self.cost(logits, labels)
        self.log('train_loss', cost, prog_bar=True)
        return cost

    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        cost = self.cost(logits, labels)
        self.log('val_loss', cost, prog_bar=True)
        return cost

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        logits = self(imgs)
        cost = self.cost(logits, labels)
        self.log('test_loss', cost, prog_bar=True)
        return cost

### Train with different configuration: MODEL, num_epochs, device, accumulated gradients...

In [None]:
wandb_logger = WandbLogger(project='xxxxxxxxxxx', entity='xxxxxxx')

dm = MNISTDataModule()
learner = Learner('model1')
trainer = pl.Trainer(
                    logger = wandb_logger,
                    gpus=-1 if torch.cuda.is_available() else 0, 
                     max_epochs=20, 
                     accumulate_grad_batches=2,
                     progress_bar_refresh_rate=1)
trainer.fit(learner, dm)
trainer.test(learner)

wandb.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



  | Name  | Type             | Params
-------------------------------------------
0 | model | Model            | 200 K 
1 | cost  | CrossEntropyLoss | 0     
-------------------------------------------
200 K     Trainable params
0         Non-trainable params
200 K     Total params
0.801     Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…