In [None]:
import os
from pathlib import Path

try:
    import google.colab

    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    from google.colab import userdata

    repo_name = "dgcnz/dl2"
    url = f"https://{userdata.get('gh_pat')}@github.com/{repo_name}.git"
    !git clone {url}
    print("\nCurrent Directory:")
    %cd dl2
    #!pip install torch torchvision numpy matplotlib git+https://github.com/AMLab-Amsterdam/lie_learn escnn scipy
    !pip install torchvision git+https://github.com/AMLab-Amsterdam/lie_learn escnn lightning wandb
    #!pip install -r requirements.txt


else:  # automatically checks if the current directory is 'repo name'
    curdir = Path.cwd()
    print("Current Directory", curdir)
    assert (
        curdir.name == "dl2" or curdir.parent.name == "dl2"
    ), "Notebook cwd has to be on the project root"
    if curdir.name == "notebooks":
        %cd ..
        print("New Current Directory:", curdir)

In [None]:
import wandb

wandb.login(key=userdata.get("wandb_key"))

In [None]:
wandb.init(settings=wandb.Settings(start_method="fork"))

In [None]:
import sys

sys.path.append("../")

import os
from typing import Any, Dict, Tuple

import lightning as L
import torch
from escnn import gspaces, nn
from lightning.pytorch.loggers import WandbLogger
from torch import Tensor, optim, utils
from torch.utils.data import Dataset
from torchmetrics import MaxMetric, MeanMetric
from torchmetrics.classification.accuracy import Accuracy

# from torchvision.datasets import MNIST
from torchvision.transforms import (
    Compose,
    InterpolationMode,
    Pad,
    RandomRotation,
    Resize,
    ToTensor,
)

from src.data.rotated_mnist_datamodule import MnistRotDataset
from src.models.image_module import ImageLightningModule

device = "cuda" if torch.cuda.is_available() else "cpu"

###Set up pytorch lightning stuff


In [None]:
# define the LightningModule
class PlModule(L.LightningModule):
    def __init__(self, net):
        super().__init__()
        self.save_hyperparameters()

        self.net = net

        loss_fn = torch.nn.CrossEntropyLoss()

        # num classes
        num_classes = 10  # net.layers[-1].out_features #if this break hardcode to one

        # metric objects for calculating and averaging accuracy across batches
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.test_acc = Accuracy(task="multiclass", num_classes=num_classes)

        self.train_loss = MeanMetric()
        self.val_loss = MeanMetric()
        self.test_loss = MeanMetric()

        # for averaging loss across batches
        self.train_acc_mean = MeanMetric()
        self.val_acc_mean = MeanMetric()
        self.test_acc_mean = MeanMetric()

        # for tracking best so far validation accuracy
        self.val_acc_best = MaxMetric()

    def model_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Perform a single model step on a batch of data.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target labels.

        :return: A tuple containing (in order):
            - A tensor of losses.
            - A tensor of predictions.
            - A tensor of target labels.
        """
        x, y = batch
        logits = self.net(x)
        loss = torch.nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single test step on a batch of data from the test set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        # print(batch_index)
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.test_loss(loss)
        self.test_acc_mean(self.test_acc(preds, targets))
        self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("test/acc", self.test_acc_mean, on_step=False, on_epoch=True, prog_bar=True)

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        """Perform a single validation step on a batch of data from the validation set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        """
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.val_loss(loss)
        self.val_acc_mean(self.val_acc(preds, targets))

        self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("val/acc", self.val_acc_mean, on_step=False, on_epoch=True, prog_bar=True)

    def on_validation_epoch_end(self) -> None:
        self.val_acc_mean.reset()

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:
        """Perform a single training step on a batch of data from the training set.

        :param batch: A batch of data (a tuple) containing the input tensor of images and target
            labels.
        :param batch_idx: The index of the current batch.
        :return: A tensor of losses between model predictions and targets.
        """
        loss, preds, targets = self.model_step(batch)

        # update and log metrics
        self.train_loss(loss)
        self.train_acc_mean(self.train_acc(preds, targets))

        self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/acc", self.train_acc_mean, on_step=False, on_epoch=True, prog_bar=True)

        # return loss or backpropagation will fail
        return loss

    def on_train_epoch_end(self) -> None:
        "Lightning hook that is called when a training epoch ends."
        self.train_acc_mean.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.net.parameters(), lr=5e-5, weight_decay=1e-5)
        return optimizer

###C8 steerable cnn


In [None]:
class C8SteerableCNN(torch.nn.Module):

    def __init__(self, n_classes=10):

        super(C8SteerableCNN, self).__init__()

        # the model is equivariant under rotations by 45 degrees, modelled by C8
        self.r2_act = gspaces.rot2dOnR2(N=8)

        # the input image is a scalar field, corresponding to the trivial representation
        in_type = nn.FieldType(self.r2_act, [self.r2_act.trivial_repr])

        # we store the input type for wrapping the images into a geometric tensor during the forward pass
        self.input_type = in_type

        # convolution 1
        # first specify the output type of the convolutional layer
        # we choose 24 feature fields, each transforming under the regular representation of C8
        out_type = nn.FieldType(self.r2_act, 24 * [self.r2_act.regular_repr])
        self.block1 = nn.SequentialModule(
            nn.MaskModule(in_type, 29, margin=1),
            nn.R2Conv(in_type, out_type, kernel_size=7, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True),
        )

        # convolution 2
        # the old output type is the input type to the next layer
        in_type = self.block1.out_type
        # the output type of the second convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block2 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True),
        )
        self.pool1 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )

        # convolution 3
        # the old output type is the input type to the next layer
        in_type = self.block2.out_type
        # the output type of the third convolution layer are 48 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 48 * [self.r2_act.regular_repr])
        self.block3 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True),
        )

        # convolution 4
        # the old output type is the input type to the next layer
        in_type = self.block3.out_type
        # the output type of the fourth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96 * [self.r2_act.regular_repr])
        self.block4 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True),
        )
        self.pool2 = nn.SequentialModule(
            nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=2)
        )

        # convolution 5
        # the old output type is the input type to the next layer
        in_type = self.block4.out_type
        # the output type of the fifth convolution layer are 96 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 96 * [self.r2_act.regular_repr])
        self.block5 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=2, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True),
        )

        # convolution 6
        # the old output type is the input type to the next layer
        in_type = self.block5.out_type
        # the output type of the sixth convolution layer are 64 regular feature fields of C8
        out_type = nn.FieldType(self.r2_act, 64 * [self.r2_act.regular_repr])
        self.block6 = nn.SequentialModule(
            nn.R2Conv(in_type, out_type, kernel_size=5, padding=1, bias=False),
            nn.InnerBatchNorm(out_type),
            nn.ReLU(out_type, inplace=True),
        )
        self.pool3 = nn.PointwiseAvgPoolAntialiased(out_type, sigma=0.66, stride=1, padding=0)

        self.gpool = nn.GroupPooling(out_type)

        # number of output channels
        c = self.gpool.out_type.size

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(c, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )

    def forward(self, input: torch.Tensor):
        # wrap the input tensor in a GeometricTensor
        # (associate it with the input type)
        x = nn.GeometricTensor(input, self.input_type)

        # apply each equivariant block

        # Each layer has an input and an output type
        # A layer takes a GeometricTensor in input.
        # This tensor needs to be associated with the same representation of the layer's input type
        #
        # The Layer outputs a new GeometricTensor, associated with the layer's output type.
        # As a result, consecutive layers need to have matching input/output types
        x = self.block1(x)
        x = self.block2(x)
        x = self.pool1(x)

        x = self.block3(x)
        x = self.block4(x)
        x = self.pool2(x)

        x = self.block5(x)
        x = self.block6(x)

        # pool over the spatial dimensions
        x = self.pool3(x)

        # pool over the group
        x = self.gpool(x)

        # unwrap the output GeometricTensor
        # (take the Pytorch tensor and discard the associated representation)
        x = x.tensor

        # classify with the final fully connected layers)
        x = self.fully_net(x.reshape(x.shape[0], -1))

        return x

In [None]:
# init model
net = C8SteerableCNN().to(device)
equivariantmodel = PlModule(net)

# define transforms
pad = Pad((0, 0, 1, 1), fill=0)

resize1 = Resize(87)
resize2 = Resize(29)

totensor = ToTensor()

train_transform = Compose(
    [
        pad,
        resize1,
        RandomRotation(180.0, interpolation=InterpolationMode.BILINEAR, expand=False),
        resize2,
        totensor,
    ]
)

test_transform = Compose(
    [
        pad,
        totensor,
    ]
)


dataset = MnistRotDataset("data/mnist/", download=True, transform=train_transform)
test_dataset = MnistRotDataset(
    "data/mnist/", download=False, train=False, transform=test_transform
)

train_loader1 = utils.data.DataLoader(dataset, batch_size=64, num_workers=7)
test_loader1 = utils.data.DataLoader(test_dataset, batch_size=64, num_workers=7)

In [None]:
# create logger
wandb.finish()
c8steerable_logger = WandbLogger(
    project="C8steerable_rotMNIST", log_model="all", name="first_epoch"
)

checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
    every_n_epochs=2, save_top_k=-1, filename="c8model"
)  # every_n_epochs =2

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(
    max_epochs=30, logger=c8steerable_logger, callbacks=[checkpoint_callback]
)  # for some reason error if you put max_epoch to 1 idk man
trainer.fit(
    model=equivariantmodel, train_dataloaders=train_loader1
)  # , val_dataloaders=test_loader1) #this is fine because the validation set is not used for anything.

In [None]:
wandb.finish()

###Now create a similar sized non-equivariant CNN


In [None]:
import torch
import torch.nn.functional as F


class BasicInvertedBottleneckBlock(torch.nn.Module):
    def __init__(self, Cin, N, Cout, downsample=False, first_block=False):
        super(BasicInvertedBottleneckBlock, self).__init__()

        kernel_size = 3
        padding = 1

        if first_block:
            kernel_size = 7
            padding = 3

        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(Cin, N, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(N),  # idk which BN to use
            torch.nn.ELU(inplace=True),
            # dit is de enige logische interpretatie van appendix H4 die ik kan bedenken. of miss dit of downsample maar nooit beide? dat is denk ik logischer
            # torch.nn.Conv2d(N, Cout, kernel_size=kernel_size, stride=1, padding=padding)
        )

        self.one_by_one: bool = Cin != Cout

        if self.one_by_one:
            self.conv1x1 = torch.nn.Conv2d(Cin, Cout, kernel_size=1, stride=1, padding=0)

        self.downsample = downsample
        if self.downsample:
            # for this one I'm guessing
            self.second_conv = torch.nn.Conv2d(
                N, Cout, kernel_size=kernel_size, stride=2, padding=padding
            )
            self.avg_pool = torch.nn.AvgPool2d(kernel_size=kernel_size, stride=2, padding=padding)
        else:
            self.second_conv = torch.nn.Conv2d(
                N, Cout, kernel_size=kernel_size, stride=1, padding=padding
            )  # idk setting kernel size to one here does massively reduce the number of parameters

    def forward(self, x):
        out = self.block(x)
        print("before downsample")
        print(out.shape)
        print(x.shape)

        out = self.second_conv(out)
        if self.downsample:
            # print('downsampled')
            x = self.avg_pool(x)

        print("afterdownsample")
        print(out.shape)
        print(x.shape)

        if self.one_by_one:
            x = self.conv1x1(x)

        print("afteronebyone")
        print(out.shape)
        print(x.shape)
        skip_connection = out + x
        return skip_connection


class CNN(torch.nn.Module):
    def __init__(self, backbone_channels, residual_channels, n_classes):
        super(CNN, self).__init__()
        self.backbone_channels = backbone_channels
        self.residual_channels = residual_channels
        self.blocks = self._make_blocks()
        self.max_pool = torch.nn.MaxPool2d(kernel_size=3)

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(128, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )

    def _make_blocks(self):
        blocks = []
        for i in range(len(self.backbone_channels)):
            Cin = self.backbone_channels[i]
            N = self.residual_channels[i]
            # this next part is because their explanation doesnt seem to include the channel output size of the last block. So we set it to be the same as the input
            if i < len(self.backbone_channels) - 1:
                Cout = self.backbone_channels[i + 1]
            else:
                Cout = 128  # this is also very vague from the og paper

            # print('hello', i)
            if i == 0:
                # print('i = 0 -------------------------')
                blocks.append(BasicInvertedBottleneckBlock(Cin, N, Cout, first_block=True))

            elif (i + 1) % 2 == 0:  # every two layers this is true
                # print('True-------------------------------------------------------------')
                blocks.append(BasicInvertedBottleneckBlock(Cin, N, Cout, downsample=True))

            else:
                # print('normal case =======================')
                blocks.append(BasicInvertedBottleneckBlock(Cin, N, Cout))
        return torch.nn.Sequential(*blocks)

    def forward(self, x):
        out = self.blocks(x)
        out = self.max_pool(out)
        print(out.shape)
        out = self.fully_net(out.flatten(start_dim=1))

        return out


# Example usage:
backbone_channels = [1, 21, 54, 72, 108, 168]  # These are the C_in's
residual_channels = [96, 192, 288, 288, 576, 576]  # These are the upsampled N's
model = CNN(backbone_channels, residual_channels, n_classes=10)

In [None]:
import torch
import torch.nn.functional as F

# V2


class BasicInvertedBottleneckBlockV2(torch.nn.Module):
    def __init__(self, Cin, N, Cout, downsample=False, first_block=False):
        super(BasicInvertedBottleneckBlockV2, self).__init__()

        kernel_size = 3
        padding = 1

        if first_block:
            kernel_size = 7
            padding = 3

        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(Cin, N, kernel_size=kernel_size, stride=1, padding=padding),
            torch.nn.BatchNorm2d(N),  # idk which BN to use
            torch.nn.ELU(inplace=True),
            # dit is de enige logische interpretatie van appendix H4 die ik kan bedenken. of miss dit of downsample maar nooit beide? dat is denk ik logischer
            # torch.nn.Conv2d(N, Cout, kernel_size=kernel_size, stride=1, padding=padding)
        )

        self.one_by_one: bool = Cin != Cout

        if self.one_by_one:
            self.conv1x1 = torch.nn.Conv2d(Cin, Cout, kernel_size=1, stride=1, padding=0)

        self.downsample = downsample
        if self.downsample:
            # for this one I'm guessing
            self.second_conv = torch.nn.Conv2d(
                N, Cin, kernel_size=kernel_size, stride=2, padding=padding
            )
            self.avg_pool = torch.nn.AvgPool2d(kernel_size=kernel_size, stride=2, padding=padding)
        else:
            self.second_conv = torch.nn.Conv2d(
                N, Cin, kernel_size=kernel_size, stride=1, padding=padding
            )  # idk setting kernel size to one here does massively reduce the number of parameters

    def forward(self, x):
        out = self.block(x)
        # N Channels
        # print('before downsample')
        # print(out.shape)
        # print(x.shape)

        out = self.second_conv(out)
        # C_in Channels

        if self.downsample:
            # if the second convolution downsamples
            # then the size of the image changes
            # so we must decrease x in size
            # to be able to add it up
            x = self.avg_pool(x)

        # print('afterdownsample')
        # print(out.shape)
        # print(x.shape)

        # print('afteronebyone')
        # print(out.shape)
        # print(x.shape)
        skip_connection = out + x
        if self.one_by_one:
            skip_connection = self.conv1x1(skip_connection)
            # C_out

        return skip_connection


class CNN(torch.nn.Module):
    def __init__(self, backbone_channels, residual_channels, n_classes):
        super(CNN, self).__init__()
        self.backbone_channels = backbone_channels
        self.residual_channels = residual_channels
        self.blocks = self._make_blocks()
        self.max_pool = torch.nn.MaxPool2d(kernel_size=3, stride=1)  # This stride is sus

        # Fully Connected
        self.fully_net = torch.nn.Sequential(
            torch.nn.Linear(128, 64),
            torch.nn.BatchNorm1d(64),
            torch.nn.ELU(inplace=True),
            torch.nn.Linear(64, n_classes),
        )

    def _make_blocks(self):
        blocks = []
        for i in range(len(self.backbone_channels)):
            Cin = self.backbone_channels[i]
            N = self.residual_channels[i]
            # this next part is because their explanation doesnt seem to include the channel output size of the last block. So we set it to be the same as the input
            if i < len(self.backbone_channels) - 1:
                Cout = self.backbone_channels[i + 1]
            else:
                Cout = 32  # We put this to 32, so it is 128/4, since we have a 2x2 output

            # print('hello', i)
            if i == 0:
                # print('i = 0 -------------------------')
                blocks.append(BasicInvertedBottleneckBlockV2(Cin, N, Cout, first_block=True))

            elif (i + 1) % 2 == 0:  # every two layers this is true
                # print('True-------------------------------------------------------------')
                blocks.append(BasicInvertedBottleneckBlockV2(Cin, N, Cout, downsample=True))

            else:
                # print('normal case =======================')
                blocks.append(BasicInvertedBottleneckBlockV2(Cin, N, Cout))
        return torch.nn.Sequential(*blocks)

    def forward(self, x):
        out = self.blocks(x)
        out = self.max_pool(out)
        # print(out.shape)
        out = self.fully_net(out.flatten(start_dim=1))

        return out


# Example usage:
backbone_channels = [1, 21, 54, 72, 108, 168]  # These are the C_in's
residual_channels = [96, 192, 288, 288, 576, 576]  # These are the upsampled N's
model = CNN(backbone_channels, residual_channels, n_classes=10)

In [None]:
normal_cnn = PlModule(model)

train_loader2 = utils.data.DataLoader(dataset, batch_size=64, num_workers=7)

wandb.finish()
cnn_logger = WandbLogger(project="CNN_rotMNIST", log_model="all", name="first_cnn_run")

checkpoint_callback = L.pytorch.callbacks.ModelCheckpoint(
    every_n_epochs=2, save_top_k=-1, filename="cnn_model"
)

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
cnn_trainer = L.Trainer(
    max_epochs=30, logger=cnn_logger, callbacks=[checkpoint_callback]
)  # for some reason error if you put max_epoch to 1 idk man
cnn_trainer.fit(
    model=normal_cnn, train_dataloaders=train_loader1
)  # , val_dataloaders=test_loader1) #this is fine because the validation set is not used for anything.

In [None]:
wandb.finish()