In [None]:
from typing import List
import pytorch_lightning as pl

import torch
import torchvision
import torch.nn as nn

# GLOM

In [None]:
class ConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, *args):
        super(ConvLayer, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, *args)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

In [None]:
class GlomLayer(nn.Module):
    def __init__(self, config: List[int], *args):
        super(GlomLayer, self).__init__()
        assert len(config) == 3, ValueError("Each layers config length must be 3.")

        self.top = ConvLayer(config[0], config[1], config[2], *args)
        self.medium = ConvLayer(config[0], config[1], config[2], *args)
        self.bottom = ConvLayer(config[0], config[1], config[2], *args)

    def forward(self, top, medium, bottom):
        inp_top = torch.sum(torch.stack([top, medium]), dim=0)
        inp_med = torch.sum(torch.stack([top, medium, bottom]), dim=0)
        inp_bottom = torch.sum(torch.stack([medium, bottom]), dim=0)

        out_top = self.top(inp_top)
        out_medium = self.medium(inp_med)
        out_bottom = self.bottom(inp_bottom)

        return out_top, out_medium, out_bottom

In [None]:
class GLOM(pl.LightningModule):
    def __init__(self, num_classes, lr: float = 2e-4, layer_depth: int = 3):
        super().__init__()

        self.layer_depth = layer_depth
        self.lr = lr
        self.nc = num_classes

        # glom layers
        self.gl1 = GlomLayer([1, 64, 3])
        self.gl2 = GlomLayer([64, 128, 3])
        self.gl3 = GlomLayer([128, 256, 3])

        # fc dense network
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256 * 22 * 22, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, num_classes),
        )

        self.loss = (
            nn.BCEWithLogitsLoss() if num_classes == 2 else nn.CrossEntropyLoss()
        )

    def forward(self, inp):
        glom_inp = [inp] * 3
        out1 = self.gl1(*glom_inp)
        out2 = self.gl2(*out1)
        out3 = self.gl3(*out2)

        fc_inp = torch.sum(torch.stack(out3), dim=0)

        out = self.fc(fc_inp)

        return torch.sigmoid(out)

    def training_step(self, xb, batch_idx):
        inps, labels = xb
        logits = self(inps)

        loss = self.loss(logits, labels)
        acc = (logits.argmax(-1) == labels.float()).mean()

        self.log(
            "train_loss", loss, on_step=True, prog_bar=True, on_epoch=True, logger=True
        )
        self.log("train_acc", acc, on_step=True, prog_bar=True, on_epoch=True)

        return loss

    def validation_step(self, xb, batch_idx):
        inps, labels = xb
        logits = self(inps)

        loss = self.loss(logits, labels)
        acc = (logits.argmax(-1) == labels.float()).mean()

        self.log("val_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log(
            "val_acc", acc, on_step=True, prog_bar=True, on_epoch=True, logger=True
        )

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(),
            lr=self.lr,
        )

# Execution

In [None]:
epochs = 20
BS = 32

model = GLOM(2)

train_dl = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        "./data/",
        train=True,
        download=True,
        transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
    ),
    batch_size=BS,
    shuffle=True,
)

val_dl = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST(
        "./data/",
        train=False,
        download=True,
        transform=torchvision.transforms.Compose(
            [
                torchvision.transforms.ToTensor(),
                torchvision.transforms.Normalize((0.1307,), (0.3081,)),
            ]
        ),
    ),
    batch_size=BS * 2,
    shuffle=True,
)

trainer = pl.Trainer(
    default_root_dir="logs",
    gpus=(1 if torch.cuda.is_available() else 0),
    max_epochs=epochs,
    logger=pl.loggers.TensorBoardLogger("logs/", name="imdb", version=0),
)

trainer.fit(model, train_dataloader=train_dl, val_dataloaders=val_dl)