In [1]:
from torchvision import transforms
from torchvision.datasets import CIFAR10
import pytorch_lightning as pl
import os

# how to set os environment variables
# export PATH_DATASETS="" in terminal
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(
        (32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize([],[]),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([],[]),
])

train_dataset = CIFAR10(
    root=DATASET_PATH, train=True, transform=train_transform, download=True)
# validation set
val_dataset = CIFAR10(
    root=DATASET_PATH, train=True, transform=test_transform, download=True)
test_set = CIFAR10(
    root=DATASET_PATH, train=False, transform=test_transform, download=True)


ModuleNotFoundError: No module named 'torchvision'

In [None]:
import torch
import torch.utils.data as data

# seeding a deep learning model is not trivial
# makes sure you are setting pseudorandom numbers in pytorch, numpy, python, etc.
# sometimes people forget to set the seed in an individual library and that means they can't reproduce their results easily.

# the random split of 45k which we want
pl.seed_everything(42)
train_set, _ = data.random_split(train_dataset, [45000, 5000])
pl.seed_everything(42)
_, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

In [None]:
import matplotlib.pyplot as plt
import torchvision

# Visualize some examples
NUM_IMAGES = 4
CIFAR_images = torch.stack(
    [val_set[idx][0] for idx in range(NUM_IMAGES)], dim=0)
img_grid = torchvision.utils.make_grid(
    CIFAR_images, nrow=4, normalize=True, pad_value=0.9)
img_grid = img_grid.permut(1,2,0)

plt.figure(figsize=(8, 8))
plt.title("Image examples of the CIFAR10 dataset")
plt.imshow(img_grid)
plt.axis("off")
plt.show()
plt.close()

In [None]:
train_loader = data.DataLoader(
    train_set, batch_size=128, 
    shuffle=True, drop_last=True, pin_memory=True, num_workers=4)
val_loader = data.DataLoader(
    val_set, batch_size=128,
    shuffle=False, drop_last=False, num_workers=4)
test_loader = data.DataLoader(
    test_set, batch_size=128,
    shuffle=False, drop_last=False, num_workers=4)
    

In [None]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Converts an image to patches
    Args:
        x: tensor representing image of shape (B, C, H, W)
        patch_size: number of pixels per dimension of the patches (int)
        flatten_channels: whether to flatten the channels 
        as a feature vector instead of an image grid or not
    Returns:
        patches: patches of shape (B, C, patch_size, patch_size)
    """
    B, C, H, W = x.shape
    
    x = x.reshape(
        B, # batch
        C, # channel 
        torch.div(H, patch_size, rounding_mode='trunc'),
        patch_size,
        torch.div(W, patch_size, rounding_mode='floor'),
        patch_size,
    )
    
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1, 2) # [B, H'*W', C, p_H, p_W]
    
    if flatten_channels:
        x = x.flatten(2, 4) # [B, H'*W', C*p_H*p_W]
        
    return x

img_patches = img_to_patch(
    CIFAR_images, patch_size=4, flatten_channels=False)
)