# Dynamic batching with PyTorch (all credits to @ilblackdragon)

Ref to original blog post: https://medium.com/@ilblackdragon/pytorch-dynamic-batching-f4df3dbe09ef

Ref to original implementation: https://github.com/nearai/pytorch-tools/blob/master/pytorch_tools/torchfold.py

Present notebook provides few examples of usage. It starts from really basic stuff. It doesn't cover technical details, for that please refer to the original article (ref above). I highly recommend also read article with DyNet implementation of the same idea (ref is in the original blog post).

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import torch
import tqdm
import numpy as np
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from torch.autograd import Variable
from torchfold import Fold

# 1. Baby steps

## Example 1.1: simple exponentiation

The simplest example, just to show the interface of the module. Sequential calculation:

![title](img/dyn_batching_simple.png)

Let's calculate value of 2^n. Due to sequential nature of the task, no actual batching will be made.

In [3]:
class Exp:
    def __init__(self, base):
        self.b = base
    def exp(self, x):
        return x * self.b       

In [4]:
fold = Fold(cuda=False)

In [5]:
n = 4   # index
exp = Exp(2)   # creating our "network"
all_nodes = []  # all output nodes that we want to calculate
result = 1   # initial value (A)

In [6]:
# building graph
for i in range(n):
    result = fold.add('exp',  result)
    all_nodes.append(result)

In [7]:
# executing graph
out = fold.apply(exp, [all_nodes])

In [8]:
print(out)
assert out[0][-1].data[0] == exp.b**n, 'ERROR!!!'

[Variable containing:
  2
  4
  8
 16
[torch.LongTensor of size 4]
]


## Example 1.2: exponentiation with two outputs

![title](img/dyn_batching_complex.png)

In [9]:
class Tree:
    def __init__(self):
        # with input a1=1, a2=2, a3=5, expected output d=-1
        pass
    def B(self, x1, x2):
        return (x1 - x2) * 2, (x1 + x2) * 3
    def D(self, x1, x2):
        return x1 + x2
    def C(self, x1, x2):
        return x1 * x2

In [10]:
a1 = 1
a2 = 2
a3 = 5
fold = Fold(cuda=False)
bc, bd = fold.add('B', a1, a2).split(2)
c = fold.add('C', bc, a3)
d = fold.add('D', bd, c)

In [11]:
tree = Tree()
result = fold.apply(tree, [[d]])
print('result: {}'.format(result[0].data[0]))

result: -1


# 2. MNIST example

## Example 2.1: Simple MNIST network -- batching only forward pass

Artificial task: we set batch size to 1 perform forward pass, then use TorchFold to make "dynamic" batching and compare times.

Results on CPU: 
* 2.5s [non-batching, 10,000 examples]
* 1.0s [batching with TorchFold, 10,000 examples]

In [12]:
# Download MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())

In [13]:
class BatchingNetworkForwardMNIST(nn.Module):
    def __init__(self):
        super(BatchingNetworkForwardMNIST, self).__init__()
        self.lin1 = nn.Linear(784, 200)
        self.lin2 = nn.Linear(200, 100)
        self.logits = nn.Linear(100, 10)    
    def forward(self, x):
        x = x.view(-1, 784)  # reshape from (batch_size, 28, 28) to (batch_size, 784)
        y_pred = F.relu(self.lin1(x))
        y_pred = F.relu(self.lin2(y_pred))
        y_pred = self.logits(y_pred)
        return y_pred

In [14]:
simple_mnist = BatchingNetworkForwardMNIST()
n_steps = 5000

### non-batching example

In [15]:
%%time
ma = 0.005
running_acc = 0.0
for i, data in enumerate(trainset):
    inputs, labels = Variable(data[0]), Variable(torch.LongTensor([data[1]]))
    outputs = simple_mnist(inputs)   
    if i == n_steps:
        break

CPU times: user 15.2 s, sys: 276 ms, total: 15.5 s
Wall time: 2.59 s


### batching example

In [16]:
%%time
bs = 32
fold = Fold()
batch_outputs, batch_labels = [], []
for i, data in enumerate(trainset, 0):
    inputs, labels = Variable(data[0]), Variable(torch.Tensor([data[1]]))
    if len(batch_outputs) < bs:
        batch_outputs.append(fold.add('forward', inputs))        
        batch_labels.append(labels)
    else:
        results = fold.apply(simple_mnist, [batch_outputs, batch_labels])
        fold = Fold()
        batch_outputs, batch_labels = [fold.add('forward', inputs)], [labels]
    if i == n_steps:
        break

CPU times: user 5.68 s, sys: 136 ms, total: 5.81 s
Wall time: 970 ms


## Example 2.2: Simple MNIST network -- batching forward and backward passes

Artificial task: we set batch size to 1 perform training pass, then use TorchFold to make "dynamic" batching and compare times.

Results on CPU: 
* 37.6s [non-batching, 10,000 examples]
* 2.5s [batching with TorchFold, 10,000 examples]

In [17]:
# Download MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())

In [18]:
class BatchingNetworkMNIST(nn.Module):
    def __init__(self):
        super(BatchingNetworkMNIST, self).__init__()
        self.lin1 = nn.Linear(784, 200)
        self.lin2 = nn.Linear(200, 100)
        self.logits = nn.Linear(100, 10)    
    def forward(self, x):
        x = x.view(-1, 784)   # reshape from (batch_size, 28, 28) to (batch_size, 784)
        y_pred = F.relu(self.lin1(x))
        y_pred = F.relu(self.lin2(y_pred))
        y_pred = self.logits(y_pred)
        return y_pred

### non-batching example

In [19]:
training_mnist = BatchingNetworkMNIST()
criterion = nn.CrossEntropyLoss()
opt = optim.Adam(training_mnist.parameters())
n_steps = 5000

In [20]:
%%time
ma = 0.001
running_acc = 0.0
for i, data in enumerate(trainset):
    inputs, labels = Variable(data[0]), Variable(torch.LongTensor([data[1]]))
    opt.zero_grad()
    outputs = training_mnist(inputs)
    loss = criterion(outputs, labels)
    correct = (outputs.max(1)[1].data == labels.data).sum() / labels.data.size()[0]
    loss.backward()
    opt.step()
    running_acc = (1 - ma) * running_acc + ma * correct
    if i % 1000 == 999:  # Print every 4000 mini-batches
        print('running acc: {:.2f}%'.format(running_acc * 100))    
    if i == n_steps:
        break

running acc: 47.47%
running acc: 72.40%
running acc: 82.94%
running acc: 87.99%
running acc: 89.49%
CPU times: user 6min 49s, sys: 9.46 s, total: 6min 58s
Wall time: 38.9 s


### batching example (opt step every batch_size examples, so it's not completely fair. still..)

In [21]:
training_mnist = BatchingNetworkMNIST()
criterion = nn.CrossEntropyLoss()
opt = optim.Adam(training_mnist.parameters())
n_steps = 5000

In [22]:
%%time
bs = 32
ma = 0.001 * bs
running_acc = 0.0
fold = Fold()
batch_outputs, batch_labels = [], []
for i, data in enumerate(trainset, 0):
    inputs, labels = Variable(data[0]), Variable(torch.LongTensor([data[1]]))
    if len(batch_outputs) < bs:
        batch_outputs.append(fold.add('forward', inputs))        
        batch_labels.append(labels)
    else:
        opt.zero_grad()
        results = fold.apply(training_mnist, [batch_outputs, batch_labels])
        loss = criterion(results[0], results[1])
        correct = (results[0].max(1)[1].data == results[1].data).sum() / results[1].data.size()[0]
        loss.backward()
        opt.step()
        running_acc = (1 - ma) * running_acc + ma * correct
        if i % 1000 < bs:  # Print every 4000 mini-batches
            print('running acc: {:.2f}%'.format(running_acc * 100))    
        fold = Fold()
        batch_outputs, batch_labels = [fold.add('forward', inputs)], [labels]
    if i == n_steps:
        break

running acc: 37.04%
running acc: 64.02%
running acc: 77.65%
running acc: 85.35%
CPU times: user 26.9 s, sys: 748 ms, total: 27.7 s
Wall time: 2.64 s
