## Install requirements

In [None]:
!pip install pytorch_lightning

## Imports

In [None]:
import logging
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics
from torch.optim import Optimizer

## Reproducibility

In [None]:
pl.seed_everything(0)

## Define models

In [None]:
class MLP(nn.Module):
    def __init__(self, input=28 * 28, num_classes=10):
        super().__init__()
        self.input = input
        self.layer0 = nn.Linear(input, 512)
        self.layer1 = nn.Linear(512, 512)
        self.layer2 = nn.Linear(512, 512)
        self.layer3 = nn.Linear(512, 256)
        self.layer4 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = x.view(-1, self.input)

        h0 = nn.functional.relu(self.layer0(x))

        h1 = nn.functional.relu(self.layer1(h0))

        h2 = nn.functional.relu(self.layer2(h1))

        h3 = nn.functional.relu(self.layer3(h2))

        h4 = self.layer4(h3)

        embeddings = [h0, h1, h2, h3, h4]

        return nn.functional.log_softmax(h4, dim=-1), embeddings

In [None]:
class MyLightningModule(pl.LightningModule):
    def __init__(self, model, num_classes: int = None, *args, **kwargs) -> None:
        super().__init__()

        self.num_classes = num_classes

        metric = torchmetrics.Accuracy(task="multiclass", top_k=1, num_classes=num_classes)
        self.train_accuracy = metric.clone()
        self.val_accuracy = metric.clone()
        self.test_accuracy = metric.clone()

        self.model = model

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Method for the forward pass.

        'training_step', 'validation_step' and 'test_step' should call
        this method in order to compute the output predictions and the loss.

        Returns:
            output_dict: forward output containing the predictions (output logits ecc...) and the loss if any.
        """
        return self.model(x)

    def unwrap_batch(self, batch):
        if isinstance(batch, Dict):
            x, y = batch["x"], batch["y"]
        else:
            x, y = batch
        return x, y

    def step(self, x, y) -> Mapping[str, Any]:

        output = self(x)
        if isinstance(output, tuple):
            logits, embeddings = output
        else:
            logits = output
            embeddings = None
        loss = F.cross_entropy(logits, y)

        return {"logits": logits.detach(), "loss": loss, "embeddings": embeddings}

    def training_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:
        x, y = self.unwrap_batch(batch)

        step_out = self.step(x, y)
        self.log_dict(
            {"loss/train": step_out["loss"].cpu().detach()},
            on_step=True,
            on_epoch=True,
            prog_bar=True,
        )

        self.train_accuracy(torch.softmax(step_out["logits"], dim=-1), y)
        self.log_dict(
            {
                "acc/train": self.train_accuracy,
            },
            on_epoch=True,
        )

        return step_out

    def validation_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:
        x, y = self.unwrap_batch(batch)

        step_out = self.step(x, y)

        self.log_dict(
            {"loss/val": step_out["loss"].cpu().detach()},
            on_step=False,
            on_epoch=True,
            prog_bar=True,
        )

        self.val_accuracy(torch.softmax(step_out["logits"], dim=-1), y)
        self.log_dict(
            {
                "acc/val": self.val_accuracy,
            },
            on_epoch=True,
        )

        return step_out

    def test_step(self, batch: Any, batch_idx: int) -> Mapping[str, Any]:
        x, y = self.unwrap_batch(batch)

        step_out = self.step(x, y)

        self.log_dict(
            {"loss/test": step_out["loss"].cpu().detach()},
        )

        self.test_accuracy(torch.softmax(step_out["logits"], dim=-1), y)
        self.log_dict(
            {
                "acc/test": self.test_accuracy,
            },
            on_epoch=True,
        )

        return step_out

    def configure_optimizers(
        self,
    ) -> Union[Optimizer, Tuple[Sequence[Optimizer], Sequence[Any]]]:
        """Choose what optimizers and learning-rate schedulers to use in your optimization.

        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        Return:
            Any of these 6 options.
            - Single optimizer.
            - List or Tuple - List of optimizers.
            - Two lists - The first list has multiple optimizers, the second a list of LR schedulers (or lr_dict).
            - Dictionary, with an 'optimizer' key, and (optionally) a 'lr_scheduler'
              key whose value is a single LR scheduler or lr_dict.
            - Tuple of dictionaries as described, with an optional 'frequency' key.
            - None - Fit will run without any optimizer.
        """
        return [torch.optim.Adam(self.parameters(), lr=1e-3)]

## Load data

In [None]:
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize

transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST(root=".", train=True, download=True, transform=transform)
test_dataset = MNIST(root=".", train=False, download=True, transform=transform)

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=1000, num_workers=8)

test_loader = DataLoader(test_dataset, batch_size=1000, num_workers=8)

## Train a bunch of models

In [None]:
num_classes = 10
model_a = MyLightningModule(MLP(num_classes), num_classes)

In [None]:
from pytorch_lightning import Trainer

trainer = Trainer(enable_progress_bar=True, enable_model_summary=False, max_epochs=50)

trainer.fit(model_a, train_loader)

trainer.test(model_a, test_loader)

## Match the models

In [None]:
from typing import NamedTuple
from collections import defaultdict


class PermutationSpec(NamedTuple):
    # maps permutation matrices to the layers they permute, expliciting the axis they act on
    perm_to_layers_and_axes: dict

    # maps layers to permutations: if a layer has k dimensions, it maps to a permutation matrix (or None) for each dimension
    layer_and_axes_to_perm: dict


class PermutationSpecBuilder:
    def __init__(self) -> None:
        pass

    def create_permutation_spec(self) -> list:
        pass

    def permutation_spec_from_axes_to_perm(self, axes_to_perm: dict) -> PermutationSpec:
        perm_to_axes = defaultdict(list)

        for wk, axis_perms in axes_to_perm.items():
            for axis, perm in enumerate(axis_perms):
                if perm is not None:
                    perm_to_axes[perm].append((wk, axis))

        return PermutationSpec(perm_to_layers_and_axes=dict(perm_to_axes), layer_and_axes_to_perm=axes_to_perm)

In [None]:
class MLPPermutationSpecBuilder(PermutationSpecBuilder):
    def __init__(self, num_hidden_layers: int):
        self.num_hidden_layers = num_hidden_layers

    def create_permutation_spec(self) -> PermutationSpec:
        L = self.num_hidden_layers
        assert L >= 1

        axes_to_perm = {
            "layer0.weight": ("P_0", None),
            **{f"layer{i}.weight": (f"P_{i}", f"P_{i-1}") for i in range(1, L)},
            **{f"layer{i}.bias": (f"P_{i}",) for i in range(L)},
            f"layer{L}.weight": (None, f"P_{L-1}"),
            f"layer{L}.bias": (None,),
        }

        return self.permutation_spec_from_axes_to_perm(axes_to_perm)