# Assertions about named dimensions in pytorch

[Open this in colab](https://githubtocolab.com/boazbk/named-tensors-asserts/blob/main/cifar10_example.ipynb)

This is a demonstration of a small library to make assertions about tensor and model dimensions in pytorch

__General idea:__
We write: 
```python
T &nt // "batch=1024, channels=3, height=32, width=32" 
```

to assert that `T`'s shape is `(1024,3,32,32)` and to update the global
named dimensions batch, channels,height, width to these values.

We can access these with nt.batch, nt.width etc
In future declarations we can write expressions such as: 
```python
Q &nt // "batch, channels*(height+1), width"
```

To say that a model maps tensors with dimensions `['batch','width','height','channels']` to tensors with dimensions `['batch','output']`

We we write: 

```python
model &nt // "batch, width, height, channels -> batch, output"
```

In [1]:
# Uncomment for colab
!wget https://raw.githubusercontent.com/boazbk/named-tensors-asserts/main/named_asserts.py -O named_asserts.py

--2022-03-16 13:00:55--  https://raw.githubusercontent.com/boazbk/named-tensors-asserts/main/named_asserts.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 15526 (15K) [text/plain]
Saving to: 'named_asserts.py'

     0K .......... .....                                      100% 9.14M=0.002s

2022-03-16 13:00:56 (9.14 MB/s) - 'named_asserts.py' saved [15526/15526]



In [2]:
%run named_asserts.py

We use the Pytorch CIFAR-10 tutorial as our example

In [3]:
import torch
import torchvision
import torchvision.transforms as transforms

In [4]:
#hack https://github.com/pytorch/vision/issues/5039
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


Let us show some of the training images, for fun.



In [5]:
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

# functions to show an image

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()



In [6]:
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next() 

In [9]:
# show images
#imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

frog  ship  plane cat  


In [10]:
images &nt // "(batch=4, channels=3, height=32, width =32)"
nt.dims

{'batch': 4, 'channels': 3, 'height': 32, 'width': 32}

:2. Define a Convolutional Neural Network
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Copy the neural network from the Neural Networks section before and modify it to
take 3-channel images (instead of 1-channel images as it was defined).



In [11]:
import torch.nn as nn
import torch.nn.functional as F

#From https://discuss.pytorch.org/t/utility-function-for-calculating-the-shape-of-a-conv-output/11173/5
def conv_output_shape(h_,w_, kernel_size=1, stride=1, pad=0, dilation=1):
    from math import floor
    if type(kernel_size) is not tuple:
        kernel_size = (kernel_size, kernel_size)
    h = floor( ((h_ + (2 * pad) - ( dilation * (kernel_size[0] - 1) ) - 1 )/ stride) + 1)
    w = floor( ((w_ + (2 * pad) - ( dilation * (kernel_size[1] - 1) ) - 1 )/ stride) + 1)
    return h, w


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        a,b = conv_output_shape(nt.height,nt.width,5)
        self.conv1 = nn.Conv2d(3, 6, 5) &nt // "batch,channels,width,height -> batch, channel1=6 ,conv1h=a, conv1w=b"
        poolh,poolw = conv_output_shape(nt.conv1h,nt.conv1w,2,2)
        self.pool = nn.MaxPool2d(2, 2) &nt // "batch, channel1, conv1h,conv1w -> batch, channel1,pool1h=poolh,pool1w=poolw"
        
        a,b = conv_output_shape(nt.pool1h,nt.pool1w,5)
        self.conv2 = nn.Conv2d(6, 16, 5) &nt // "batch, channel1, poolh,poolw -> batch, channel2=16, conv2h = a, conv2w=b"
        nt.pool2h,nt.pool2w = conv_output_shape(nt.conv2h,nt.conv2w,2,2)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  &nt // "batch, channel2*pool2h*pool2w -> batch, 120"
        self.fc2 = nn.Linear(120, 84) &nt // "batch, 120 -> batch, 84"
        self.fc3 = nn.Linear(84, 10) &nt // "batch, 84 -> batch, 10"

    def forward(self, x):
        x &nt // "batch, channels, height, width"
        x = self.pool(F.relu(self.conv1(x))) &nt // "batch, channel1,pool1h,pool1w" 
        x = self.pool(F.relu(self.conv2(x))) &nt // "batch, channel2,pool2h,pool2w"
        x = torch.flatten(x, 1) &nt // "batch, channel2*pool2h*pool2w" # flatten all dimensions except batch 
        x = F.relu(self.fc1(x)) &nt // "batch, 120"
        x = F.relu(self.fc2(x)) &nt // "batch, 84"
        x = self.fc3(x) &nt // "batch, 10"
        return x


net = Net()
net(images)
nt.dims

Updating channel1 to 6
Updating conv1h to 28
Updating conv1w to 28
Updating pool1h to 14
Updating pool1w to 14
Updating channel2 to 16
Updating conv2h to 10
Updating conv2w to 10


{'batch': 4,
 'channels': 3,
 'height': 32,
 'width': 32,
 'channel1': 6,
 'conv1h': 28,
 'conv1w': 28,
 'pool1h': 14,
 'pool1w': 14,
 'channel2': 16,
 'conv2h': 10,
 'conv2w': 10,
 'pool2h': 5,
 'pool2w': 5}

We can supress the assertions by using `with skip_asserts():` in the code. We can also use `with skip_asserts(flag)` to skip the assertions if and only if the `flag` is True.

In [12]:
x = torch.randn(1, 3, 32, 32)
with skip_asserts():
    x &nt // "batch, channels, height, width+1"

In [13]:
try:
    x &nt // "batch, channels, height, width+1"
except AssertionError as e:
    print(e)

Dimension expression batch, channels, height, width+1 must have the same dimensions as the tensor ([[[[ 1.4626, -1.5704, -0.55.., got (4, 3, 32, 33) vs torch.Size([1, 3, 32, 32])


#### Training   

For fun we can also train the network

In [14]:
import torch.optim as optim

def train(net):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for epoch in range(2):  # loop over the dataset multiple times

        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:    # print every 2000 mini-batches
                print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
                running_loss = 0.0

    print('Finished Training')


In [15]:
net1 , net2 = Net(), Net()

Let's train with and without assertions to see the difference (this is on a CPU in a  rather overloaded laptop)

In [16]:
%%time
train(net1)

[1,  2000] loss: 2.174
[1,  4000] loss: 1.861
[1,  6000] loss: 1.695
[1,  8000] loss: 1.601
[1, 10000] loss: 1.523
[1, 12000] loss: 1.491
[2,  2000] loss: 1.428
[2,  4000] loss: 1.386
[2,  6000] loss: 1.372
[2,  8000] loss: 1.325
[2, 10000] loss: 1.323
[2, 12000] loss: 1.287
Finished Training
Wall time: 5min 35s


In [17]:
%%time
with skip_asserts():
    train(net2)

[1,  2000] loss: 2.188
[1,  4000] loss: 1.884
[1,  6000] loss: 1.663
[1,  8000] loss: 1.578
[1, 10000] loss: 1.521
[1, 12000] loss: 1.478
[2,  2000] loss: 1.408
[2,  4000] loss: 1.370
[2,  6000] loss: 1.340
[2,  8000] loss: 1.301
[2, 10000] loss: 1.297
[2, 12000] loss: 1.281
Finished Training
Wall time: 4min 39s


### Testing

For fun we can also test the network

In [None]:
dataiter = iter(testloader)
images, labels = dataiter.next()

# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

Okay, now let us see what the neural network thinks these examples above are:



In [None]:
outputs = net(images)
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
                              for j in range(4)))