Instantiate train, valid, and test dataloaders for each of the three datasets

To turn the torch datasets into a validation split as well, following the instructions at
https://medium.com/@sergioalves94/deep-learning-in-pytorch-with-cifar-10-dataset-858b504a6b54

In [1]:

from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#from linformer import Linformer
from PIL import Image
from torchvision.utils import make_grid

from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from torch.utils.data import random_split


import sys
sys.path.append('../vit_pytorch/')
sys.path.append('../')

from vit import ViT
from recorder import Recorder # import the Recorder and instantiate

#from vit_pytorch.efficient import ViT

# CIFAR 10 first

In [4]:
def get_CIFAR_data(number='10',
                   val_size = 5000,
                   batch_size = 64,
                   transforms=transforms.Compose([
                           transforms.ToTensor()
                                   ])):

    if number == '10':
        dataset = datasets.CIFAR10(root='../data/', download=True, transform=transforms)
        test_dataset = datasets.CIFAR10(root='../data/', train=False, transform=transforms)
    elif number == '100': 
        dataset = datasets.CIFAR100(root='../data/', download=True, transform=transforms)
        test_dataset = datasets.CIFAR100(root='../data/', train=False, transform=transforms)
        
    else:
        print("Must select 10 or 100")
        sys.exit()
        
        

    train_size = len(dataset) - val_size 

    train_ds, val_ds = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size, num_workers=4, pin_memory=True)
    
    
    return train_loader, val_loader, test_loader

In [5]:
train_loader, val_loader, test_loader = get_CIFAR_data()

Files already downloaded and verified




In [None]:



dataset = datasets.CIFAR10(root='../data/', download=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10(root='../data/', train=False, transform=transforms.ToTensor())

In [None]:
classes = dataset.classes
classes

In [None]:
class_count = {}
for _, index in dataset:
    label = classes[index]
    if label not in class_count:
        class_count[label] = 0
    class_count[label] += 1
class_count

In [None]:
val_size = 5000
train_size = len(dataset) - val_size

In [None]:
train_ds, val_ds = random_split(dataset, [train_size, val_size])
len(train_ds), len(val_ds)

In [None]:
batch_size=128
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size*2, num_workers=4, pin_memory=True)

In [None]:
for images, _ in train_loader:
    print('images.shape:', images.shape)
    plt.figure(figsize=(16,8))
    plt.axis('off')
    plt.imshow(make_grid(images, nrow=16).permute((1, 2, 0)))
    break

And this is basically all we need to do for these torchvision datasets. 



# CIFAR100


Basically copy and pasting above...

In [None]:

dataset = datasets.CIFAR100(root='../data/', download=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR100(root='../data/', train=False, transform=transforms.ToTensor())

In [None]:
classes = dataset.classes
classes

In [None]:
class_count = {}
for _, index in dataset:
    label = classes[index]
    if label not in class_count:
        class_count[label] = 0
    class_count[label] += 1
class_count

In [None]:
val_size = 5000
train_size = len(dataset) - val_size

In [None]:
train_ds, val_ds = random_split(dataset, [train_size, val_size])
len(train_ds), len(val_ds)

In [None]:
print("Proof still balanced after splitting.")
class_count = {}
for _, index in val_ds:
    label = classes[index]
    if label not in class_count:
        class_count[label] = 0
    class_count[label] += 1
class_count

In [None]:
batch_size=128
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size*2, num_workers=4, pin_memory=True)

In [None]:
for images, _ in train_loader:
    print('images.shape:', images.shape)
    plt.figure(figsize=(16,8))
    plt.axis('off')
    plt.imshow(make_grid(images, nrow=16).permute((1, 2, 0)))
    break

# Pets