# Test dynamic batching in PyTorch

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from torch.autograd import Variable

## loading MNIST

In [None]:
# Download MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transforms.ToTensor())

## models

In [None]:
class NonBatchingNetworkMNIST(nn.Module):
    
    def __init__(self):
        super(NonBatchingNetworkMNIST, self).__init__()
        self.lin1 = nn.Linear(784, 100)
        self.lin2 = nn.Linear(100, 100)
        self.logits = nn.Linear(100, 10)
        
    def forward(self, x):
        x = x.view(-1, 784)  # we make network process one example at a time
        y_pred = F.relu(self.lin1(x))
        if np.random.rand() > 0.5:
            y_pred = F.relu(self.lin2(y_pred))
        y_pred = self.logits(y_pred)
        return y_pred

In [None]:
class BatchingNetworkMNIST(nn.Module):
    
    def __init__(self):
        super(BatchingNetworkMNIST, self).__init__()
        self.lin1 = nn.Linear(784, 100)
        self.lin2 = nn.Linear(100, 100)
        self.logits = nn.Linear(100, 10)
        
    def forward(self, x):
        x = x.view(-1, 784)  # we make network process one example at a time
        y_pred = F.relu(self.lin1(x))
        if np.random.rand() > 0.5:
            y_pred = F.relu(self.lin2(y_pred))
        y_pred = self.logits(y_pred)
        return y_pred

## one epoch with non-batching model

In [None]:
n_steps = 50000
nnb = NonBatchingNetworkMNIST()

In [None]:
loss = nn.CrossEntropyLoss()
opt = optim.SGD(nnb.parameters(), lr = 0.01, momentum=0.9)

In [None]:
%%time
running_loss = 0.0
for i, data in enumerate(trainset):
    inputs, labels = data
    inputs, labels = Variable(inputs), Variable(torch.Tensor(labels))
    outputs = nnb.forward(inputs)
    if i == n_steps:
        break

## one epoch with batching model

In [None]:
from torchfold import Fold

In [None]:
net = BatchingNetworkMNIST()

In [None]:
fold = Fold(cuda=False)

In [None]:
%%time
running_loss = 0.0
all_outputs = []
all_labels = []
for i, data in enumerate(trainset):
    inputs, labels = data
    inputs, labels = Variable(inputs), Variable(torch.Tensor(labels))
    all_outputs.append(fold.add('forward', inputs))
    all_labels.append(labels)
    if i == n_steps:
        break

In [None]:
%%time
res = fold.apply(net, [all_outputs])

## toy example

In [None]:
import torch

from torch.autograd import Variable
from torchfold import Fold

In [None]:
class Oper:
    def __init__(self, x):
        self.x = Variable(torch.FloatTensor([x]))
    def make_oper_1(self, y):
        return self.x + y
    def make_oper_2(self, z):
        return self.x + z

In [None]:
fold = Fold(cuda=False)

In [None]:
oper = Oper(5)

In [None]:
res = fold.add('make_oper_1', Variable(torch.FloatTensor([4])))
res2 = fold.add('make_oper_2', res)

In [None]:
fold

In [None]:
fold.apply(oper, [[res2]])

In [None]:
fold.cached_nodes

# Dynamic batching with PyTorch

Ref to original blog post: https://medium.com/@ilblackdragon/pytorch-dynamic-batching-f4df3dbe09ef

Ref to original implementation: https://github.com/nearai/pytorch-tools/blob/master/pytorch_tools/torchfold.py

Present notebook provides few clarifications and examples of usage. From the most simplest to more practical.

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
import torch
import numpy as np

from torch.autograd import Variable
from torchfold import Fold

# 1. Baby steps

## Example 1.1: simple exponentiation

The simplest example, just to show the interface of the module. Sequential calculation:

![title](img/dyn_batching_simple.png)

Let's calculate value of 2^n. Due to sequential nature of the task, no actual batching will be made.

In [None]:
class Exp:
    def __init__(self, base):
        self.b = base
    def exp(self, x):
        return x * self.b       

In [None]:
fold = Fold(cuda=False)

In [None]:
n = 4   # index
exp = Exp(2)
all_nodes = []
result = 1

In [None]:
for i in range(n):
    result = fold.add('exp',  result)
    all_nodes.append(result)

In [None]:
out = fold.apply(exp, [all_nodes])

In [None]:
print(out)
assert out[0][-1].data[0] == exp.b**n, 'ERROR!!!'

## Example 1.2: exponentiation with two outputs

![title](img/dyn_batching_complex.png)

In [None]:
class Tree:
    def __init__(self):
        # with input a1=1, a2=2, a3=5, expected output d=-1
        pass
    def B(self, x1, x2):
        return (x1 - x2) * 2, (x1 + x2) * 3
    def D(self, x1, x2):
        return x1 + x2
    def C(self, x1, x2):
        return x1 * x2

In [None]:
a1 = 1
a2 = 2
a3 = 5
fold = Fold(cuda=False)
bc, bd = fold.add('B', a1, a2).split(2)
c = fold.add('C', bc, a3)
d = fold.add('D', bd, c)

In [None]:
tree = Tree()
fold.apply(tree, [bc, bd, c, d])

In [None]:
n = 4
exp2 = Exp2(2, 3)
all_nodes = []
result = (1, 1)

In [None]:
for i in range(n):
    result = fold.add('exp',  *result).split(2)
    all_nodes.append(result)

In [None]:
out = fold.apply(exp2, [all_nodes])

In [None]:
fold.steps

In [None]:
def f(x, *args):
    print(x)
    print(args)

In [None]:
f(3,4,5,6)

In [None]:
import torch

In [None]:
torch.cat([torch.Tensor([1]), torch.Tensor([1])], 0)