# `014` Residual connections

Requirements: 008 Batch normalization, 011 Convolutional Neural Networks

Residual connections are a technique proposed by [He et al. 2015](https://arxiv.org/pdf/1512.03385) to train deep neural networks without the vanishing gradient problem. The idea is to add the input of a layer to its output. For instance, if we have a linear layer $L$ that receives an input $x$, instead of passing the signal $L(x)$ to the next layer, we pass $L(x) + x$. This simple change has a few implications:
* Now $L$ is not a layer that just transforms $x$ into something else, but one that finds a "patch" to be applied to $x$.
* If we stack many of these operations, the gradients won't become small, because $x$ is passed on to the next layer all the time.
* If we see if the other way around, as $x + L(x)$, we can think of the network as a main branch that is $x$, that keeps flowing through the network, with forks that apply patches to it. This is the intuition behind the name "residual".

Let's compare its impact over a deep network with and without residual connections. We will use the CIFAR-10 dataset as in the CNN notebook.

In [1]:
from matplotlib import pyplot as plt
from time import time
import torch
import torchvision

In [2]:
train = torchvision.datasets.MNIST(root='data', download=True)
test = torchvision.datasets.MNIST(root='data', download=True, train=False)
x = train.data.float()
xt = test.data.float()
y = torch.nn.functional.one_hot(train.targets).float()
yt = torch.nn.functional.one_hot(test.targets).float()
print(f'Train: {x.shape}, test: {xt.shape}')

Train: torch.Size([60000, 28, 28]), test: torch.Size([10000, 28, 28])


Now, let's define a CNN like the one in the previous notebook but with many more layers, so that we can create a need for residual layers by making the gradients vanish too fast.

In [3]:
class DigitRecognizer(torch.nn.Module):
	def __init__(self, use_residual=False):
		super().__init__()
		self.use_residual = use_residual
		self.conv1_1 = torch.nn.Conv2d(1, 32, kernel_size=3, padding=1)
		self.conv1_2 = torch.nn.Conv2d(32, 32, kernel_size=3, padding=1)
		self.conv1_3 = torch.nn.Conv2d(32, 32, kernel_size=3, padding=1)
		self.conv1_4 = torch.nn.Conv2d(32, 32, kernel_size=3, padding=1)
		self.conv1_5 = torch.nn.Conv2d(32, 32, kernel_size=3, padding=1)
		self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
		self.conv2_1 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
		self.conv2_2 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
		self.conv2_3 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
		self.conv2_4 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
		self.conv2_5 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
		self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
		self.conv3_1 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
		self.conv3_2 = torch.nn.Conv2d(128, 128, kernel_size=3, padding=1)
		self.conv3_3 = torch.nn.Conv2d(128, 128, kernel_size=3, padding=1)
		self.conv3_4 = torch.nn.Conv2d(128, 128, kernel_size=3, padding=1)
		self.conv3_5 = torch.nn.Conv2d(128, 128, kernel_size=3, padding=1)
		self.pool3 = torch.nn.MaxPool2d(kernel_size=2)
		self.fc = torch.nn.Linear(128 * 3 * 3, 10)

	def forward(self, x):
		x = x.unsqueeze(1)
		residual = x
		x = self.conv1_1(x).tanh()
		x = self.conv1_2(x).tanh()
		x = self.conv1_3(x).tanh()
		x = self.conv1_4(x).tanh()
		x = self.conv1_5(x).tanh()
		if self.use_residual: x = x + residual
		x = self.pool1(x)
		x = self.conv2_1(x).tanh()
		residual = x
		x = self.conv2_2(x).tanh()
		x = self.conv2_3(x).tanh()
		x = self.conv2_4(x).tanh()
		x = self.conv2_5(x).tanh()
		if self.use_residual: x = x + residual
		x = self.pool2(x)
		x = self.conv3_1(x).tanh()
		residual = x
		x = self.conv3_2(x).tanh()
		x = self.conv3_3(x).tanh()
		x = self.conv3_4(x).tanh()
		x = self.conv3_5(x).tanh()
		if self.use_residual: x = x + residual
		x = self.pool3(x)
		x = x.view(-1, 128 * 3 * 3)
		x = self.fc(x)
		return x

torch.manual_seed(1234)
model = DigitRecognizer(use_residual=False)
torch.manual_seed(1234)
model_with_residual = DigitRecognizer(use_residual=True)
print(f'There are {sum(p.numel() for p in model.parameters())} parameters in the model')

There are 879242 parameters in the model


Note that the residual connections are not any special kind of layers, but rather an operation performed at the forward pass. In the previous architecture:
* We define three convolutional blocks that compress the image more on each of them.
* Each block has 5 convolutional layers followed by a max pooling, so that each block processes images half the size of the previous one.
* We have a total of 19 layers, making a network that starts being big.

In [4]:
def train(model, lr=0.01, num_iterations=2000, batch_size=32, print_each=200):
	torch.manual_seed(1234)
	model.train()
	optimizer = torch.optim.SGD(model.parameters(), lr=lr)
	start = time()
	losses = []
	for i in range(num_iterations):
		ix = torch.randint(0, x.shape[0], (batch_size,))
		xb, yb = x[ix], y[ix]
		logits = model(xb)
		loss = torch.nn.functional.cross_entropy(logits, yb.argmax(dim=1))
		losses.append(loss)
		optimizer.zero_grad()
		loss.backward()
		optimizer.step()
		if i % print_each == 0 and i > 0 or i == 20:
			remaining = (time() - start) * (num_iterations - i) / (i + 1)
			print(f'Iteration {i:5}, loss: {loss.item():.6f}, remaining: {remaining//60:2.0f}m{remaining%60:2.0f}s')
	model.eval()
	accuracy = (model(xt).argmax(dim=1) == yt.argmax(dim=1)).float().mean().item()
	print(f'Accuracy on the test set: {100 * accuracy:.2f}%')
	return losses

In [5]:
print('Training the model without residual connections')
losses = train(model)

Training the model without residual connections
Iteration    20, loss: 2.305718, remaining:  7m30s
Iteration   200, loss: 2.287104, remaining:  6m31s
Iteration   400, loss: 2.304785, remaining:  5m46s
Iteration   600, loss: 2.302989, remaining:  5m 8s
Iteration   800, loss: 2.304319, remaining:  4m26s
Iteration  1000, loss: 2.289454, remaining:  3m42s
Iteration  1200, loss: 2.306574, remaining:  2m58s
Iteration  1400, loss: 2.292930, remaining:  2m20s
Iteration  1600, loss: 2.307819, remaining:  1m36s
Iteration  1800, loss: 2.304099, remaining:  0m47s
Accuracy on the test set: 11.35%


In [6]:
print('Training the model with residual connections')
losses_with_residual = train(model_with_residual)

Training the model with residual connections
Iteration    20, loss: 2.156486, remaining:  7m28s
Iteration   200, loss: 0.602912, remaining:  6m40s
Iteration   400, loss: 0.399495, remaining:  5m37s
Iteration   600, loss: 0.372218, remaining:  4m47s
Iteration   800, loss: 0.497465, remaining:  4m 2s
Iteration  1000, loss: 0.195034, remaining:  3m21s
Iteration  1200, loss: 0.078198, remaining:  2m40s
Iteration  1400, loss: 0.277008, remaining:  1m60s
Iteration  1600, loss: 0.402492, remaining:  1m20s
Iteration  1800, loss: 0.197989, remaining:  0m40s
Accuracy on the test set: 95.55%


Just by adding residual connections the accuracy goes really higher, and we are able to accurately train models with many layers. This opens the door for the design of arbitrarily large neural networks, in which the amount of layers we can add is no longer our bottleneck.

This and other improvements that allowed for the creation of extremely deep neural networks with hundreds of layers, changed the paradigm of artificial intelligence. Now, the biggest bottleneck is just the hardware, which puts big players with resources to train really large networks in a priviledged position. This is a sad reality, as it closes the gates for small companies to push the frontiers of the field, at least in the current state. But hey, at least we get to have really smart models!