# Maximum batch size to fit the GPU

It is a common practice to use the largest batch size that fits the GPU memory for training. This example code shows how to automatically determine the largest batch size.

In [None]:
!pip install labml

Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
from torchvision import datasets, transforms

from labml import lab, tracker, experiment, monit, logger
from labml.logger import Text

VGG Net for CIFAR-10

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        layers = []
        in_channels = 3
        for block in [[64, 64], [128, 128], [256, 256, 256], [512, 512, 512], [512, 512, 512]]:
            for channels in block:
                layers += [nn.Conv2d(in_channels, channels, kernel_size=3, padding=1),
                           nn.BatchNorm2d(channels),
                           nn.ReLU(inplace=True)]
                in_channels = channels
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        self.layers = nn.Sequential(*layers)
        self.fc = nn.Linear(512, 10)

    def forward(self, x):
        x = self.layers(x)
        x = x.view(x.shape[0], -1)
        return self.fc(x)

### Create data loaders with a given batch size

In [4]:
class DataLoaderFactory:
    def __init__(self):
        data_transform =  transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.dataset = [
                        datasets.CIFAR10(str(lab.get_data_path()),
                            train=False,
                            download=True,
                            transform=data_transform),
                        datasets.CIFAR10(str(lab.get_data_path()),
                            train=True,
                            download=True,
                            transform=data_transform),
        ]
     
    def __call__(self, train, batch_size):
        return torch.utils.data.DataLoader(self.dataset[train],
                                           batch_size=batch_size, shuffle=True)

### Determine if a given batch size can fit the GPU memory.

It runs the model with the given batch size and does an optimization step. If the GPU runs out of memory it will crash with a `RuntimeError('CUDA out of memory.')`. We check for this and detemine if the batch size is too large.
Note that we do `torch.cuda.empty_cache()` at the beginning to make sure all the caches (from previous tries) are cleared before we try to allocate new memory.

In [5]:
def check_batch_size(model, optimizer, batch_size):
    data_loader = dl_factory(True, batch_size)
    torch.cuda.empty_cache()

    try:
        data, target = next(iter(data_loader))
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        logger.log(f"batch_size: {batch_size}", Text.success)
        return True
    except RuntimeError as e:
        if len(e.args) != 1:
            raise e
        msg: str = e.args[0]
        if not isinstance(msg, str):
            raise e
        if not msg.startswith('CUDA out of memory.'):
            raise e
        logger.log(f"batch_size: {batch_size}", Text.danger)
        return False

### Find the largest batch size

Run a simple binary search to find the highest possible batch size.

In [6]:
def find_batch_size(dl_factory, device, max_bs = 2 ** 14):
    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    model.train()

    hi = max_bs
    lo = 0

    while lo < hi:
        m = (hi + lo + 1) // 2

        if check_batch_size(model, optimizer, m):
            lo = m
        else:
            hi = m - 1

    return lo

Train the model for an epoch

In [7]:
def train(model, optimizer, train_loader, device):
    model.train()
    for batch_idx, (data, target) in monit.enum("Train", train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        tracker.add_global_step(data.shape[0])
        tracker.save({'loss.train': loss})

Get model validation loss and accuracy

In [8]:
def validate(model, valid_loader, device):
    model.eval()
    valid_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in monit.iterate("valid", valid_loader):
            data, target = data.to(device), target.to(device)

            output = model(data)
            valid_loss += F.cross_entropy(output, target,
                                          reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    valid_loss /= len(valid_loader.dataset)
    valid_accuracy = 100. * correct / len(valid_loader.dataset)

    tracker.save({'loss.valid': valid_loss, 'accuracy.valid': valid_accuracy})

Configurations

In [9]:
configs = {
    'epochs': 50,
    'learning_rate': 2.5e-4,
    'device': "cuda:0" if torch.cuda.is_available() else "cpu",
}

device = torch.device(configs['device'])
dl_factory = DataLoaderFactory()

Files already downloaded and verified
Files already downloaded and verified


Find optimal batch size

In [10]:
batch_size = find_batch_size(dl_factory, device)
batch_size

5424

Create data loaders

In [11]:
configs['batch_size'] = batch_size
train_loader = dl_factory(True, batch_size)
valid_loader = dl_factory(False, batch_size)

Create the model and optimizer

In [12]:
model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=configs['learning_rate'])

Run the training loop

In [13]:
experiment.create(name='cifar10')
experiment.configs(configs)

with experiment.start():
    for _ in monit.loop(range(1, configs['epochs'] + 1)):
        torch.cuda.empty_cache()
        train(model, optimizer, train_loader, device)
        validate(model, valid_loader, device)
        logger.log()

KeyboardInterrupt: ignored