# Mean Rotation Error explained with example
In this notebook we compute Mean Rotation Error step by step and show intermediate results in order to verify that the implementation is correct and to improve reader's understanding.

In [74]:
from torchvision import transforms as T
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import matplotlib as mpl

Both here and in `evaluation_mre.py` we use PyTorch's notation when possible:
* $B$ denotes batch size
* $C$ denotes number of input channels ($C=3$ for RGB)
* $H$ and $W$ denote height and width of the image, respectively
Furthermore, $N$ denotes number of rotations we use on every image

In [75]:
N = 4
B = 5
C = 3 # RGB

# dataset specific
H = 96
W = 96
num_classes = 2

Let's create a dummy dataset that returns an image with all values being the same number. Because of that, the image is not affected by rotations (if the rotations are perfect and don't require interpolation i.e. $N=2$ or $N=4$).

In [107]:
from torch.utils.data import Dataset, DataLoader
import random

class DummyDataset(Dataset):
    """
        Dummy dataset for testing an demonstration.
        If rotation_invariant=True, it returns an image with all values equal to its index, so the images are not
        affected by rotations, and a random label.
    """
    def __init__(self, img_height, img_width, num_channels, num_classes, rotation_invariant=True, size=10):
        self.img_height = img_height
        self.img_width = img_width
        self.num_channels = num_channels
        self.num_classes = num_classes
        self.rotation_invariant = rotation_invariant
        self.size = size
    
    def __getitem__(self, idx):
        if self.rotation_invariant:
            img = torch.ones((self.num_channels, self.img_height, self.img_width)) * idx
        else:
            raise NotImplementedError
            
        label = random.randint(0, self.num_classes-1)
        
        return img, label 
    
    def __len__(self):
        return self.size

In [77]:
dataset = DummyDataset(H, W, C, num_classes, rotation_invariant=True)
dataloader = DataLoader(dataset, batch_size=B, shuffle=False)

The operations below are repeated for every batch sampled from the dataset. Let's sample a single dummy batch with `inputs` of shape (B, C, H, W) to demonstrate MRE. 

In [83]:
inputs, labels = next(iter(dataloader))

In [84]:
inputs.shape

torch.Size([5, 3, 96, 96])

The first image contains only zeros.

In [85]:
inputs[0]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

The second image contains only ones.

In [86]:
inputs[1]

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [87]:
labels

tensor([1, 0, 1, 1, 1])

Next we have to rotate the input image by $N$ angles, we do that using a custom transformation.

In [88]:
import torch
from torchvision.transforms.functional import rotate


class DiscreteRotation:
    """Rotate image by one of the given angles.

    Arguments:
        angles: list(ints). List of integer degrees to pick from. E.g. [0, 90, 180, 270] for a random 90-degree-like rotation
    """

    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = self.angles[torch.randperm(len(self.angles))[0]]
        return rotate(x, angle)

    def __repr__(self):
        return f"{self.__class__.__name__}(angles={self.angles})"


In [89]:
angles = np.linspace(start=0, stop=360, num=N, endpoint=False)
rotations = [DiscreteRotation(angles=[angle]) for angle in angles]

Given $N=4$, we will rotate every image by 0, 90, 180 and 270 degrees.

In [108]:
angles

array([  0.,  90., 180., 270.])

In [90]:
rotated_inputs_list = [rotation(inputs) for rotation in rotations]

This yields a list of $N$ tensors of shape (B, C, H, W), one for every rotation. We concatenate this list into a single ($N*B$, $C$, $H$, $W$) tensor in order to pass it all at once through the model.

In [91]:
[rot_inputs.shape for rot_inputs in rotated_inputs_list]

[torch.Size([5, 3, 96, 96]),
 torch.Size([5, 3, 96, 96]),
 torch.Size([5, 3, 96, 96]),
 torch.Size([5, 3, 96, 96])]

In [92]:
rotated_inputs = torch.cat(rotated_inputs_list, dim=0)

In [93]:
rotated_inputs.shape

torch.Size([20, 3, 96, 96])

It is important to note that the first $B$ indices in the 0th dimension relate to rotations by 0 degrees of each image. So the 0th index relates to the first image rotated by 0 degrees, 1st index relates the second image rotated by 0 degrees and the Bth inde relates again to the first image but rotated by 90 degrees (for $N=4$).

In [94]:
rotated_inputs[0,...]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

In [95]:
rotated_inputs[1,...]

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [110]:
rotated_inputs[B,...]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])

Now let's pass these rotated inputs through the model.

In [111]:
from torchvision.models import resnet18

model = resnet18(num_classes=num_classes)
# model.eval()

In [98]:
outputs = model(rotated_inputs)

In [99]:
outputs

tensor([[ 0.0158,  0.0403],
        [ 0.0505, -0.2094],
        [ 0.0853, -0.4592],
        [ 0.1200, -0.7090],
        [ 0.1548, -0.9588],
        [ 0.0158,  0.0403],
        [ 0.0505, -0.2094],
        [ 0.0853, -0.4592],
        [ 0.1200, -0.7090],
        [ 0.1548, -0.9588],
        [ 0.0158,  0.0403],
        [ 0.0505, -0.2094],
        [ 0.0853, -0.4592],
        [ 0.1200, -0.7090],
        [ 0.1548, -0.9588],
        [ 0.0158,  0.0403],
        [ 0.0505, -0.2094],
        [ 0.0853, -0.4592],
        [ 0.1200, -0.7090],
        [ 0.1548, -0.9588]], grad_fn=<AddmmBackward0>)

The model itself doesn't include a softmax layer because it is included inside `torch.nn.CrossEntropyLoss()` during normal training. Therefore, we need to pass the outputs through a softmax to obtain probabilities.

In [123]:
import torch.nn.functional as F

probs = F.softmax(outputs, dim=1)
probs.shape, probs

(torch.Size([20, 2]),
 tensor([[0.4939, 0.5061],
         [0.5646, 0.4354],
         [0.6329, 0.3671],
         [0.6961, 0.3039],
         [0.7528, 0.2472],
         [0.4939, 0.5061],
         [0.5646, 0.4354],
         [0.6329, 0.3671],
         [0.6961, 0.3039],
         [0.7528, 0.2472],
         [0.4939, 0.5061],
         [0.5646, 0.4354],
         [0.6329, 0.3671],
         [0.6961, 0.3039],
         [0.7528, 0.2472],
         [0.4939, 0.5061],
         [0.5646, 0.4354],
         [0.6329, 0.3671],
         [0.6961, 0.3039],
         [0.7528, 0.2472]], grad_fn=<SoftmaxBackward0>))

You can see now that every row contains two probabilities that sum to 1. Another thing we observe is that every row is repeated $N$ times. That's because our images are rotation-invariant and as a consquence so are their model outputs. To make this more clear, we reshape to (B, N, num_classes) in order to have separate dimensions for images and rotations. 

In [103]:
probs = probs.reshape((N, B, -1)).permute(1, 0, 2)
probs.shape, probs

(torch.Size([5, 4, 2]),
 tensor([[[0.4939, 0.5061],
          [0.4939, 0.5061],
          [0.4939, 0.5061],
          [0.4939, 0.5061]],
 
         [[0.5646, 0.4354],
          [0.5646, 0.4354],
          [0.5646, 0.4354],
          [0.5646, 0.4354]],
 
         [[0.6329, 0.3671],
          [0.6329, 0.3671],
          [0.6329, 0.3671],
          [0.6329, 0.3671]],
 
         [[0.6961, 0.3039],
          [0.6961, 0.3039],
          [0.6961, 0.3039],
          [0.6961, 0.3039]],
 
         [[0.7528, 0.2472],
          [0.7528, 0.2472],
          [0.7528, 0.2472],
          [0.7528, 0.2472]]], grad_fn=<PermuteBackward0>))

Next, we want to keep the probabilities of the correct class for every image. We do that with `torch.gather()` which requires us to reshape our labels.

In [104]:
label_indices = labels.unsqueeze(dim=1).unsqueeze(dim=2).repeat(1, N, 1)
label_indices.shape, label_indices

(torch.Size([5, 4, 1]),
 tensor([[[1],
          [1],
          [1],
          [1]],
 
         [[0],
          [0],
          [0],
          [0]],
 
         [[1],
          [1],
          [1],
          [1]],
 
         [[1],
          [1],
          [1],
          [1]],
 
         [[1],
          [1],
          [1],
          [1]]]))

In [105]:
target_probs = torch.gather(probs, 2, label_indices)
target_probs.shape, target_probs

(torch.Size([5, 4, 1]),
 tensor([[[0.5061],
          [0.5061],
          [0.5061],
          [0.5061]],
 
         [[0.5646],
          [0.5646],
          [0.5646],
          [0.5646]],
 
         [[0.3671],
          [0.3671],
          [0.3671],
          [0.3671]],
 
         [[0.3039],
          [0.3039],
          [0.3039],
          [0.3039]],
 
         [[0.2472],
          [0.2472],
          [0.2472],
          [0.2472]]], grad_fn=<GatherBackward0>))

Let's print again the initial `labels`. Looking at them and at `probs` you can see that `target_probs` preserved only the probabilities of the correct class. Having done that, the compute standard deviation between rotations for every image which shows how robust the predictions are to rotations. Given that our dummy images are rotation-invariant, we practically obtain zeros for every image (although for some it's not *exactly* due to machine precision).

In [106]:
stds = torch.std(target_probs, dim=1)
stds.shape, stds

(torch.Size([5, 1]),
 tensor([[0.0000e+00],
         [0.0000e+00],
         [0.0000e+00],
         [2.9802e-08],
         [0.0000e+00]], grad_fn=<StdBackward0>))

This confirms that our metric implementation works as expected. Here we only demonstrate the operations for one batch, but in the actual implementation, the standard deviations are added to a list and once the loop iterated over all batches, the final MRE is calculated by averaging all standard deviations.