Instantiate train, valid, and test dataloaders for each of the three datasets

To turn the torch datasets into a validation split as well, following the instructions at
https://medium.com/@sergioalves94/deep-learning-in-pytorch-with-cifar-10-dataset-858b504a6b54

In [1]:

from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#from linformer import Linformer
from PIL import Image
from torchvision.utils import make_grid

from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from torch.utils.data import random_split


import sys
sys.path.append('../vit_pytorch/')
sys.path.append('../')

from vit import ViT
from recorder import Recorder # import the Recorder and instantiate

#from vit_pytorch.efficient import ViT

# CIFAR 10 first

In [4]:
def get_CIFAR_data(number='10',
                   val_size = 5000,
                   batch_size = 64,
                   transforms=transforms.Compose([
                           transforms.ToTensor()
                                   ])):

    if number == '10':
        dataset = datasets.CIFAR10(root='../data/', download=True, transform=transforms)
        test_dataset = datasets.CIFAR10(root='../data/', train=False, transform=transforms)
    elif number == '100': 
        dataset = datasets.CIFAR100(root='../data/', download=True, transform=transforms)
        test_dataset = datasets.CIFAR100(root='../data/', train=False, transform=transforms)
        
    else:
        print("Must select 10 or 100")
        sys.exit()
        
        

    train_size = len(dataset) - val_size 

    train_ds, val_ds = random_split(dataset, [train_size, val_size])
    train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size, num_workers=4, pin_memory=True)
    
    
    return train_loader, val_loader, test_loader

In [5]:
train_loader, val_loader, test_loader = get_CIFAR_data()

Files already downloaded and verified




In [None]:



dataset = datasets.CIFAR10(root='../data/', download=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR10(root='../data/', train=False, transform=transforms.ToTensor())

In [None]:
classes = dataset.classes
classes

In [None]:
class_count = {}
for _, index in dataset:
    label = classes[index]
    if label not in class_count:
        class_count[label] = 0
    class_count[label] += 1
class_count

In [None]:
val_size = 5000
train_size = len(dataset) - val_size

In [None]:
train_ds, val_ds = random_split(dataset, [train_size, val_size])
len(train_ds), len(val_ds)

In [None]:
batch_size=128
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size*2, num_workers=4, pin_memory=True)

In [None]:
for images, _ in train_loader:
    print('images.shape:', images.shape)
    plt.figure(figsize=(16,8))
    plt.axis('off')
    plt.imshow(make_grid(images, nrow=16).permute((1, 2, 0)))
    break

And this is basically all we need to do for these torchvision datasets. 



# CIFAR100


Basically copy and pasting above...

In [None]:

dataset = datasets.CIFAR100(root='../data/', download=True, transform=transforms.ToTensor())
test_dataset = datasets.CIFAR100(root='../data/', train=False, transform=transforms.ToTensor())

In [None]:
classes = dataset.classes
classes

In [None]:
class_count = {}
for _, index in dataset:
    label = classes[index]
    if label not in class_count:
        class_count[label] = 0
    class_count[label] += 1
class_count

In [None]:
val_size = 5000
train_size = len(dataset) - val_size

In [None]:
train_ds, val_ds = random_split(dataset, [train_size, val_size])
len(train_ds), len(val_ds)

In [None]:
print("Proof still balanced after splitting.")
class_count = {}
for _, index in val_ds:
    label = classes[index]
    if label not in class_count:
        class_count[label] = 0
    class_count[label] += 1
class_count

In [None]:
batch_size=128
train_loader = DataLoader(train_ds, batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size*2, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size*2, num_workers=4, pin_memory=True)

In [None]:
for images, _ in train_loader:
    print('images.shape:', images.shape)
    plt.figure(figsize=(16,8))
    plt.axis('off')
    plt.imshow(make_grid(images, nrow=16).permute((1, 2, 0)))
    break

# Pets

dataset found here: https://www.robots.ox.ac.uk/~vgg/data/pets/

following code for loading data here: https://github.com/benihime91/pytorch_examples/blob/master/image_classification.ipynb

In [1]:
# !pip install --upgrade albumentations

In [2]:
import pandas as pd
import os
import shutil
import re
from tqdm.notebook import tqdm
from pathlib import Path
from sklearn import preprocessing, model_selection
import cv2
import seaborn as sns
import matplotlib.pyplot as plt

pd.set_option("display.max_colwidth", None, "display.max_row", None)
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import torch
from torch.utils.data import DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [3]:
def download_pets(root_dir = '../data/oxford_iiit_pet2'):
    
    data_url = 'https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz'
    annotations_url = 'https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz'
        
    #download dataset
    if not os.path.exists(root_dir + '/images'):
        !wget {data_url} --directory-prefix {root_dir}
        !tar -xzf {root_dir + '/images.tar.gz'} --directory {root_dir}
    
    #download annotations
    if not os.path.exists(root_dir + '/annotations'):
        !wget {annotations_url} --directory-prefix {root_dir}
        !tar -xzf {root_dir + '/annotations.tar.gz'} --directory {root_dir}

In [4]:
h = 128 #@param{type:"integer"}
w = 128 #@param{type:"integer"}

def transforms(trn:bool=False):
    if trn: 
        tfms = [A.CLAHE(), A.IAAPerspective(), A.IAASharpen(), A.RandomBrightness(),
                A.Rotate(limit=60), A.HorizontalFlip()]
    else: tfms = []
    tfms.append(A.Resize(h,w, always_apply=True))
    tfms.append(A.Normalize(always_apply=True))
    tfms.append(ToTensorV2(always_apply=True))
    tfs = A.Compose(tfms)
    return tfs


class ParseData(Dataset):
    def __init__(self, pth, tfms_fn):
        self.df = pd.read_csv(pth)
        self.tfms = tfms_fn
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        try: pth = self.df.fnames[idx]
        except Exception: print(idx)
        im = cv2.cvtColor(cv2.imread(pth), cv2.COLOR_BGR2RGB)
        im = self.tfms(image=im)["image"]
        lbl = self.df.targets[idx]
        return im, lbl

In [5]:
def get_PETS_data(root_dir ='../data/oxford_iiit_pet',
                 test_size = 0.20,
                val_size = 0.20):
    
    download_pets(root_dir)
    
    pat = r'/([^/]+)_\d+.jpg$'
    pat = re.compile(pat)
    
    #collect list of images
    desc = Path(root_dir + "/images")
    ims = list(desc.iterdir())
    im_list = []

    for im in ims:
        if str(im).split(os.path.sep)[-1].split(".")[-1] == "jpg": im_list.append(str(im))
    
    #check for and remove corrupted images
    print('checking for corrupted images')
    for im in tqdm(im_list):
        try: _ = cv2.cvtColor(cv2.imread(im), cv2.COLOR_BGR2RGB)
        except:
            im_list.remove(im)
            print(f"[INFO] Corrupted Image: {im}")
        
    df = pd.DataFrame()
    df["fnames"] = im_list
    df["labels"] = [ pat.search(fname).group(1).lower() for fname in df.fnames]
    df["targets"] = preprocessing.LabelEncoder().fit_transform(df.labels.values)
    df = df.sample(frac=1).reset_index(drop=True)
    
    y = df.labels.values
    
    #train, val, test split 
    X_train, X_test, y_train, y_test = model_selection.train_test_split(df, y, test_size=test_size, random_state=42)
    X_train, X_val, y_train, y_val = model_selection.train_test_split(X_train, y_train, test_size=val_size / (1-test_size), random_state=42)

    X_train.to_csv(root_dir + '/train_images.csv', index = False)
    X_val.to_csv(root_dir + '/val_images.csv', index = False)
    X_test.to_csv(root_dir + '/test_images.csv', index = False)
    
    train_loader = DataLoader(ParseData(root_dir + '/train_images.csv', transforms(True)),batch_size=128,shuffle=True, pin_memory=True)
    val_loader = DataLoader(ParseData(root_dir + '/val_images.csv', transforms(False)),batch_size=128,shuffle=False, pin_memory=True)
    test_loader = DataLoader(ParseData(root_dir + '/test_images.csv', transforms(False)),batch_size=128,shuffle=False, pin_memory=True)
    
    return train_loader, val_loader, test_loader

In [6]:
# import matplotlib.pyplot as plt
# from torchvision.utils import make_grid

# # Extract and plot 1 batch for sanity-check
# batch = next(iter(train_loader))

# im, _ = batch
# grid = make_grid(im[:64], normalize=True, padding=True).permute(1, 2, 0)

# _, ax = plt.subplots(1, 1, figsize=(22, 15))
# ax.imshow(grid.numpy())
# ax.set_axis_off()