# `015` Residual connections

Requirements: 008 Batch normalization, 011 Convolutional Neural Networks

⚠️⚠️⚠️⚠️⚠️⚠️⚠️⚠️ WIP

Residual connections are a technique proposed by [He et al. 2015](https://arxiv.org/pdf/1512.03385) to train deep neural networks without the vanishing gradient problem. The idea is to add the input of a layer to its output. For instance, if we have a linear layer $L$ that receives an input $x$, instead of passing the signal $L(x)$ to the next layer, we pass $L(x) + x$. This simple change has a few implications:
* Now $L$ is not a layer that just transforms $x$ into something else, but one that finds a "patch" to be applied to $x$.
* If we stack many of these operations, the gradients won't become small, because $x$ is passed on to the next layer all the time.
* If we see if the other way around, as $x + L(x)$, we can think of the network as a main branch that is $x$, that keeps flowing through the network, with forks that apply patches to it. This is the intuition behind the name "residual".

Let's compare its impact over a deep network with and without residual connections. We will use the CIFAR-10 dataset and a simple convolutional neural network.

In [1]:
from time import time
import torch
import torchvision

In [2]:
train = torchvision.datasets.MNIST(root='data', download=True)
test = torchvision.datasets.MNIST(root='data', download=True, train=False)
x = train.data.float()
xt = test.data.float()
y = torch.nn.functional.one_hot(train.targets).float()
yt = torch.nn.functional.one_hot(test.targets).float()
print(f'Train: {x.shape}, test: {xt.shape}')

Train: torch.Size([60000, 28, 28]), test: torch.Size([10000, 28, 28])


In [3]:
class DigitRecognizer(torch.nn.Module):
	def __init__(self, use_residual=False):
		super().__init__()
		self.use_residual = use_residual
		self.conv1_1 = torch.nn.Conv2d(1, 32, kernel_size=3, padding=1)
		self.conv1_2 = torch.nn.Conv2d(32, 32, kernel_size=3, padding=1)
		self.conv1_3 = torch.nn.Conv2d(32, 32, kernel_size=3, padding=1)
		self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
		self.conv2_1 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
		self.conv2_2 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
		self.conv2_3 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
		self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
		self.conv3_1 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
		self.conv3_2 = torch.nn.Conv2d(128, 128, kernel_size=3, padding=1)
		self.conv3_3 = torch.nn.Conv2d(128, 128, kernel_size=3, padding=1)
		self.pool3 = torch.nn.MaxPool2d(kernel_size=2)
		self.fc = torch.nn.Linear(128 * 3 * 3, 10)

	def forward(self, x):
		x = x.unsqueeze(1)
		residual = x
		x = self.conv1_1(x).relu()
		x = self.conv1_2(x).relu()
		x = self.conv1_3(x).relu()
		if self.use_residual: x += residual
		x = self.pool1(x)
		x = self.conv2_1(x).relu()
		residual = x
		x = self.conv2_2(x).relu()
		x = self.conv2_3(x).relu()
		if self.use_residual: x += residual
		x = self.pool2(x)
		x = self.conv3_1(x).relu()
		residual = x
		x = self.conv3_2(x).relu()
		x = self.conv3_3(x).relu()
		if self.use_residual: x += residual
		x = self.pool3(x)
		x = x.view(-1, 128 * 3 * 3)
		x = self.fc(x)
		return x

model = DigitRecognizer(use_residual=False)
model_with_residual = DigitRecognizer(use_residual=True)
print(f'There are {sum(p.numel() for p in model.parameters())} parameters in the model')

There are 491722 parameters in the model


In [4]:
def train(model, lr=0.01, num_iterations=5000, batch_size=32, print_each=500):
	model.train()
	optimizer = torch.optim.SGD(model.parameters(), lr=lr)
	start = time()
	for i in range(num_iterations):
		ix = torch.randint(0, x.shape[0], (batch_size,))
		xb, yb = x[ix], y[ix]
		logits = model(xb)
		loss = torch.nn.functional.cross_entropy(logits, yb.argmax(dim=1))
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		if i % print_each == 0:
			remaining = (time() - start) * (num_iterations - i) / (i + 1)
			print(f'Iteration {i:5}, loss: {loss.item():.6f}, remaining: {remaining//60:2.0f}m{remaining%60:2.0f}s')
	model.eval()
	accuracy = (model(xt).argmax(dim=1) == yt.argmax(dim=1)).float().mean().item()
	print(f'Accuracy on the test set: {100 * accuracy:.2f}%')

In [5]:
print('Training the model without residual connections')
train(model)

Training the model without residual connections
Iteration     0, loss: 2.298797, remaining: 10m15s
Iteration   500, loss: 0.121160, remaining:  4m23s
Iteration  1000, loss: 0.002111, remaining:  3m54s
Iteration  1500, loss: 0.069511, remaining:  3m26s
Iteration  2000, loss: 0.001473, remaining:  2m57s
Iteration  2500, loss: 0.022585, remaining:  2m27s
Iteration  3000, loss: 0.020457, remaining:  1m58s
Iteration  3500, loss: 0.000480, remaining:  1m28s
Iteration  4000, loss: 0.066322, remaining:  0m59s
Iteration  4500, loss: 0.002706, remaining:  0m29s
Accuracy on the test set: 98.93%


In [6]:
print('Training the model with residual connections')
train(model_with_residual)

Training the model with residual connections


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [32, 128, 7, 7]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).