# Data Augmentation

There are many forms of augmentation available for image tasks in particular.
Rotating, translating, and scaling images are the most common.
Additionally applying random crops can further augment the dataset.

The original dataset may only include samples of a class that have similar lighting.
Color jitter is an effective way of including a broader range of hue or brightness and usually leads to a model that is robust to such changes.

This notebook will demonstrate `torchvision`'s API for data augmentation.

In [2]:
import time
import torch
import torch.nn as nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torchvision import transforms
from tqdm.notebook import trange, tqdm
from torch.utils.tensorboard import SummaryWriter

%matplotlib widget

In [3]:
root_path = "/home/alex/Data/CIFAR10"
# root_path = "/Users/ajd/Data/CIFAR10"

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = transforms.functional.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
        
        
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

# Model Functions

In [4]:
def train_loop(dataloader, model, loss_fn, optimizer, logger=None):
    losses = AverageMeter()
    top1 = AverageMeter()

    model.train()

    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (input, target) in pbar:
        
        input = input.cuda()
        target = target.cuda()

        output = model(input)
        loss = loss_fn(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()

        prec = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.shape[0])
        top1.update(prec.item(), input.shape[0])
        
        if i % print_frequency == 0:
            pbar.set_description("Epoch [%d]\t Loss %.2f\t Prec@1 %.3f (%.3f)" % (epoch, losses.avg, top1.val, top1.avg))
            if logger:
                logger.add_scalar("training loss",
                                  loss.item(),
                                  epoch * len(dataloader) + i)
           
        
def val_loop(dataloader, model, loss_fn, logger=None):
    losses = AverageMeter()
    top1 = AverageMeter()

    model.eval()

    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (input, target) in pbar:
        
        input = input.cuda()
        target = target.cuda()

        output = model(input)
        loss = loss_fn(output, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        output = output.float()
        loss = loss.float()
        
        prec = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.shape[0])
        top1.update(prec.item(), input.shape[0])

        if i % print_frequency == 0:
            pbar.set_description("Epoch [%d]\t Loss %.2f\t Prec@1 %.3f (%.3f)" % (epoch, losses.avg, top1.val, top1.avg))
            if logger:
                logger.add_scalar("validation loss",
                                  loss.item(),
                                  epoch * len(dataloader) + i)
    
    if logger:
        logger.add_scalar("validation accuracy",
                          top1.avg,
                          epoch)
            

def test_loop(dataloader, model, loss_fn):
    losses = AverageMeter()
    top1 = AverageMeter()
    
    model.eval()
    
    pbar = tqdm(enumerate(dataloader), total=len(dataloader))
    for i, (input, target) in pbar:
        
        input = input.cuda()
        target = target.cuda()

        output = model(input)
        loss = criterion(output, target)

        output = output.float()
        loss = loss.float()

        prec = accuracy(output.data, target)[0]
        losses.update(loss.item(), input.shape[0])
        top1.update(prec.item(), input.shape[0])

    # Print result
    print(f"Average Loss: {losses.avg:>8f}\nAccuracy: {top1.avg}\n")

# Model Definition

In [5]:
# Model Parameters
batch_size = 256
learning_rate = 1e-3
epochs = 50
print_frequency = 100

In [15]:
model = nn.Sequential(
    nn.Dropout(0.2),
    nn.Linear(1024, 512),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Dropout(),
    nn.Linear(128, 10)
) # dropout(0.2) want 80% of the nodes for the original input layer to be activate at all times; 0.2 - probability of nodes thats going to be deactivated

# model = nn.Sequential(
#     nn.Linear(1024, 512),
#     nn.ReLU(),
#     nn.Linear(512, 256),
#     nn.ReLU(),
#     nn.Linear(256, 128),
#     nn.ReLU(),
#     nn.Linear(128, 10)
# )

model.cuda()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), learning_rate)

# Data Loader

For batch sampling of our dataset, we wrap the dataset object in a `DataLoader` object.

When splitting the training data into a train and validation set, we want to make sure that no augmentations are performed on the validation set.

In [16]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_dataset = torchvision.datasets.CIFAR10(root_path, train=True, transform=transforms.Compose([
    # transforms.RandAugment(),
    transforms.ToTensor(),
    normalize,
    transforms.Grayscale(),
    torch.flatten
]), download=True)

val_dataset = torchvision.datasets.CIFAR10(root_path, train=True, transform=transforms.Compose([
    transforms.ToTensor(),
    normalize,
    transforms.Grayscale(),
    torch.flatten
]))

# Split the data into training and validation
dataset_size = len(train_dataset)
indices = list(range(dataset_size))

train_indices, val_indices = train_test_split(indices, test_size=0.05)

train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
val_dataset = torch.utils.data.Subset(val_dataset, val_indices)

# Prepare the dataloaders for training and evaluation
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size, shuffle=True,
    num_workers=8, pin_memory=False)


val_dataloader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size, shuffle=False,
    num_workers=8, pin_memory=False)


test_dataloader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10(root_path, train=False, transform=transforms.Compose([                                                                     
        transforms.ToTensor(),
        normalize,
        transforms.Grayscale(),
        torch.flatten
    ])),
    batch_size=batch_size, shuffle=False,
    num_workers=8, pin_memory=False)

Files already downloaded and verified


In [17]:
logger = SummaryWriter("runs/baseline")

for epoch in range(epochs):
    train_loop(train_dataloader, model, criterion, optimizer, logger)
    val_loop(val_dataloader, model, criterion, logger)

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/186 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

# Evaluate The Model

In [18]:
test_loop(test_dataloader, model, criterion)

  0%|          | 0/40 [00:00<?, ?it/s]

Average Loss: 1.509141
Accuracy: 46.73



In [27]:
torch.save(model, "saved_models/augmentation.pth")