## Lib

In [1]:
# Pytorch
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# CIFAR-10
from torchvision import datasets, transforms

# Pytorch Lightning
import pytorch_lightning as pl

# WandB
from pytorch_lightning.loggers import WandbLogger

In [3]:
DATA_DIR='./data'

NUM_CLASSES = 10
NUM_WORKERS = 18
BATCH_SIZE = 64
EPOCHS = 50

## CIFAR-10 DataModule

In [4]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor()
])

In [5]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, train_transform, data_dir='./', batch_size=32, num_workers=8):
        super().__init__()
        self.train_transform = train_transform
        self.val_transform = transforms.ToTensor()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
    
    def prepare_data(self):
        datasets.CIFAR10(root=self.data_dir, train=True, download=True)
        datasets.CIFAR10(root=self.data_dir, train=False, download=True)
    
    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dset = datasets.CIFAR10(root=self.data_dir, train=True,
                                               transform=self.train_transform)
            self.val_dset = datasets.CIFAR10(root=self.data_dir, train=False,
                                             transform=self.val_transform)
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dset, batch_size=self.batch_size,
                                           num_workers=self.num_workers, pin_memory=True)


    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dset, batch_size=self.batch_size,
                                           num_workers=self.num_workers, pin_memory=True)

In [7]:
data_module = CIFAR10DataModule(train_transform,
                                data_dir=DATA_DIR,
                                batch_size=BATCH_SIZE,
                                num_workers=NUM_WORKERS
                                )

In [8]:
data_module.prepare_data()

Using downloaded and verified file: ./data/cifar-10-python.tar.gz
Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [9]:
data_module.setup()