## Classification: CIFAR10

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Data

In [4]:
import os
import numpy as np
import pickle

def unpickle(filename):
    # tar -zxvf cifar-10-python.tar.gz
    with open(filename, 'rb') as f:
        data = pickle.load(f, encoding='bytes')

    x = np.array(data[b'data']).reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)
    y = np.array(data[b'labels'])
    return x, y


def load_cifar10(data_dir):
    batch_files = [os.path.join(data_dir, f"data_batch_{i+1}") for i in range(5)]
    test_file = os.path.join(data_dir, "test_batch")

    images, labels = [], []
    for filename in batch_files:
        x, y = unpickle(filename)
        images.append(x)
        labels.append(y)

    x_train = np.concatenate(images, axis=0)
    y_train = np.concatenate(labels, axis=0)

    x_test, y_test = unpickle(test_file)
    return (x_train, y_train), (x_test, y_test)

# data_dir = r"D:\datasets\cifar10_178M\cifar-10-batches-py"    ## windows
data_dir = "/mnt/d/datasets/cifar10_178M/cifar-10-batches-py"   ## wsl
(x_train, y_train), (x_test, y_test) = load_cifar10(data_dir)

print(f">> Train images: {x_train.shape}, {x_train.dtype}")
print(f">> Train labels: {y_train.shape}, {y_train.dtype}")
print(f">> Test images:  {x_test.shape}, {x_test.dtype}")
print(f">> Test labels:  {y_test.shape}, {y_test.dtype}")

>> Train images: (50000, 32, 32, 3), uint8
>> Train labels: (50000,), int64
>> Test images:  (10000, 32, 32, 3), uint8
>> Test labels:  (10000,), int64


In [6]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

class CIFAR10(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label).long()
        return image, label

transform_train = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomHorizontalFlip(0.3),
    transforms.RandomVerticalFlip(0.3),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])

train_dataset = CIFAR10(x_train, y_train, transform=transform_train)
test_dataset = CIFAR10(x_test, y_test, transform=transform_test)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

x, y = next(iter(train_loader))
print(f">> x: {x.shape}, {x.dtype}, min={x.min()}, max={x.max()}")
print(f">> y: {y.shape}, {y.dtype}, min={y.min()}, max={y.max()}")

>> x: torch.Size([32, 3, 32, 32]), torch.float32, min=-1.0, max=1.0
>> y: torch.Size([32]), torch.int64, min=0, max=8


### Modeling

In [None]:
## Model
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(),
            nn.MaxPool2d(2),
        )
    def forward(self, x):
        x = self.conv_block(x)
        return x

class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv_block1 = ConvBlock(3, 32)
        self.conv_block2 = ConvBlock(32, 64)
        self.conv_block3 = ConvBlock(64, 128)
        self.fc = nn.Linear(128 * 4 * 4, latent_dim)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.conv_block3(x)
        x = x.view(-1, 128 * 4 * 4)
        x = self.fc(x)
        return x

def accuracy(y_pred, y):
    y_pred = y_pred.argmax(dim=1)   # int64 (long)
    return torch.eq(y_pred, y).float().mean()

### Training

In [None]:
import sys
from tqdm import tqdm
from torchvision.utils import save_image

## Hyperparameters
set_seed(42)
n_epochs = 10
learning_rate = 1e-3
step_size = 1

## Modeling
model = Encoder(latent_dim=10).to(device)
loss_fn = nn.CrossEntropyLoss()     # with logits
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [26]:
## Training loop
for epoch in range(1, n_epochs + 1):
    cur_epoch = f"[{epoch:3d}/{n_epochs}]"

    ## Training
    model.train()
    train_loss, train_acc = 0, 0
    with tqdm(train_loader, leave=False, file=sys.stdout, dynamic_ncols=True, ascii=True) as pbar:
        for i, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            acc = accuracy(y_pred, y)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            train_loss += loss.item()
            train_acc += acc.item()

            desc = f"loss: {train_loss/(i + 1):.3f} acc: {train_acc/(i + 1):.3f}"
            pbar.set_description(cur_epoch + " " + desc)

    ## Validation
    model.eval()
    valid_loss, valid_acc = 0, 0
    with tqdm(test_loader, leave=False, file=sys.stdout, dynamic_ncols=True, ascii=True) as pbar:
        for i, (x, y) in enumerate(pbar):
            x, y = x.to(device), y.to(device)
            y_pred = model(x)

            valid_loss += loss_fn(y_pred, y).item()
            valid_acc += accuracy(y_pred, y).item()

            val_desc = f"val_loss: {valid_loss/(i + 1):.3f} val_acc: {valid_acc/(i + 1):.3f}"
            pbar.set_description(cur_epoch + " " + desc + " | " + val_desc)

    if epoch % step_size == 0:
        print(cur_epoch + " " + desc + " | " + val_desc)

[  1/10] loss: 1.340 acc: 0.523 | val_loss: 1.166 val_acc: 0.598                                                    
[  2/10] loss: 1.022 acc: 0.642 | val_loss: 0.935 val_acc: 0.676                                                    
[  3/10] loss: 0.896 acc: 0.685 | val_loss: 0.830 val_acc: 0.708                                                    
[  4/10] loss: 0.824 acc: 0.713 | val_loss: 0.796 val_acc: 0.723                                                    
[  5/10] loss: 0.772 acc: 0.734 | val_loss: 0.709 val_acc: 0.758                                                    
[  6/10] loss: 0.721 acc: 0.749 | val_loss: 0.746 val_acc: 0.746                                                    
[  7/10] loss: 0.687 acc: 0.763 | val_loss: 0.681 val_acc: 0.771                                                    
[  8/10] loss: 0.654 acc: 0.773 | val_loss: 0.697 val_acc: 0.762                                                    
[  9/10] loss: 0.632 acc: 0.781 | val_loss: 0.709 val_acc: 0.752