# Test dynamic batching in PyTorch

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from torch.autograd import Variable

## loading MNIST

In [3]:
# Download MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transforms.ToTensor())

## models

In [4]:
class NonBatchingNetworkMNIST(nn.Module):
    
    def __init__(self):
        super(NonBatchingNetworkMNIST, self).__init__()
        self.lin1 = nn.Linear(784, 100)
        self.lin2 = nn.Linear(100, 100)
        self.logits = nn.Linear(100, 10)
        
    def forward(self, x):
        x = x.view(-1, 784)  # we make network process one example at a time
        y_pred = F.relu(self.lin1(x))
        if np.random.rand() > 0.5:
            y_pred = F.relu(self.lin2(y_pred))
        y_pred = self.logits(y_pred)
        return y_pred

In [51]:
class BatchingNetworkMNIST(nn.Module):
    
    def __init__(self):
        super(BatchingNetworkMNIST, self).__init__()
        self.lin1 = nn.Linear(784, 100)
        self.lin2 = nn.Linear(100, 100)
        self.logits = nn.Linear(100, 10)
        
    def forward(self, x):
        x = x.view(-1, 784)  # we make network process one example at a time
        y_pred = F.relu(self.lin1(x))
        if np.random.rand() > 0.5:
            y_pred = F.relu(self.lin2(y_pred))
        y_pred = self.logits(y_pred)
        return y_pred

## one epoch with non-batching model

In [61]:
n_steps = 50000
nnb = NonBatchingNetworkMNIST()

In [62]:
loss = nn.CrossEntropyLoss()
opt = optim.SGD(nnb.parameters(), lr = 0.01, momentum=0.9)

In [63]:
%%time
running_loss = 0.0
for i, data in enumerate(trainset):
    inputs, labels = data
    inputs, labels = Variable(inputs), Variable(torch.Tensor(labels))
    outputs = nnb.forward(inputs)
    if i == n_steps:
        break

CPU times: user 1min 6s, sys: 1.06 s, total: 1min 7s
Wall time: 17.6 s


## one epoch with batching model

In [64]:
from torchfold import Fold

In [65]:
net = BatchingNetworkMNIST()

In [66]:
fold = Fold(cuda=False)

In [67]:
%%time
running_loss = 0.0
all_outputs = []
all_labels = []
for i, data in enumerate(trainset):
    inputs, labels = data
    inputs, labels = Variable(inputs), Variable(torch.Tensor(labels))
    all_outputs.append(fold.add('forward', inputs))
    all_labels.append(labels)
    if i == n_steps:
        break

CPU times: user 5.39 s, sys: 80.3 ms, total: 5.47 s
Wall time: 5.43 s


In [68]:
%%time
res = fold.apply(net, [all_outputs])

CPU times: user 2.33 s, sys: 72 ms, total: 2.4 s
Wall time: 1.42 s
