In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import os

In [2]:
if os.getcwd().split('/')[-1] == 'notebooks':
    os.chdir('..')

In [3]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

batch_size = 128

trainset = torchvision.datasets.CIFAR100(root='/home/josegfer/datasets/cifar100', train=True,
                                        download=True, transform=data_transforms['train'])
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR100(root='/home/josegfer/datasets/cifar100', train=False,
                                       download=True, transform=data_transforms['val'])
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=4)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /home/josegfer/datasets/cifar100/cifar-100-python.tar.gz


  0%|          | 0/169001437 [00:00<?, ?it/s]

Extracting /home/josegfer/datasets/cifar100/cifar-100-python.tar.gz to /home/josegfer/datasets/cifar100
Files already downloaded and verified


In [4]:
for i, batch in enumerate(trainloader):
    inputs, labels = batch
    break
inputs.shape, labels.shape

(torch.Size([128, 3, 224, 224]), torch.Size([128]))

In [8]:
for i, batch in enumerate(testloader):
    inputs, labels = batch
    break
inputs.shape, labels.shape

(torch.Size([128, 3, 224, 224]), torch.Size([128]))

In [9]:
labels

tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6, 7, 0, 4, 9,
        5, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 2, 4, 1, 9, 5, 4, 6, 5, 6, 0, 9, 3, 9,
        7, 6, 9, 8, 0, 3, 8, 8, 7, 7, 4, 6, 7, 3, 6, 3, 6, 2, 1, 2, 3, 7, 2, 6,
        8, 8, 0, 2, 9, 3, 3, 8, 8, 1, 1, 7, 2, 5, 2, 7, 8, 9, 0, 3, 8, 6, 4, 6,
        6, 0, 0, 7, 4, 5, 6, 3, 1, 1, 3, 6, 8, 7, 4, 0, 6, 2, 1, 3, 0, 4, 2, 7,
        8, 3, 1, 2, 8, 0, 8, 3])

# loader

In [41]:
from torch.utils.data import Dataset

TRANSFORM = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

class Cifar10(Dataset):
    def __init__(self, split, embedding_path = None):
        self.transform = TRANSFORM[split]
        self.data = torchvision.datasets.CIFAR10(root='/home/josegfer/datasets/cifar', 
                                                 train = (split == 'train'), transform = self.transform)
        self.embedding = False
        if embedding_path is not None:
            self.H = torch.load(embedding_path)
            self.embedding = True
    
    def __len__(self):
        return self.data.__len__()
    
    def __getitem__(self, idx):
        x = self.data.__getitem__(idx)[0]
        y = self.data.__getitem__(idx)[1]
        if self.embedding:
            return {'image': x, 'label': y, 'embedding': self.H[idx, :, 0, 0]}
        return {'image': x, 'label': y}

In [19]:
split = 'train'
data = torchvision.datasets.CIFAR10(root='/home/josegfer/datasets/cifar', 
                                                 train = (split == 'train'), transform = TRANSFORM[split])

In [26]:
data.__getitem__(0)[0].shape, data.__getitem__(0)[1]

(torch.Size([3, 224, 224]), 6)

In [18]:
trainset.__getitem__(0)

(tensor([[[ 0.1083,  0.1083,  0.1083,  ...,  0.0912,  0.0912,  0.0912],
          [ 0.1083,  0.1083,  0.1083,  ...,  0.0912,  0.0912,  0.0912],
          [ 0.1083,  0.1083,  0.1083,  ...,  0.0912,  0.0912,  0.0912],
          ...,
          [ 0.0569,  0.0569,  0.0569,  ...,  0.3138,  0.3138,  0.3138],
          [ 0.0569,  0.0569,  0.0569,  ...,  0.3138,  0.3138,  0.3138],
          [ 0.0569,  0.0569,  0.0569,  ...,  0.3138,  0.3138,  0.3138]],
 
         [[-0.4251, -0.4251, -0.4251,  ..., -0.4251, -0.4251, -0.4251],
          [-0.4251, -0.4251, -0.4251,  ..., -0.4251, -0.4251, -0.4251],
          [-0.4251, -0.4251, -0.4251,  ..., -0.4251, -0.4251, -0.4251],
          ...,
          [-0.2675, -0.2675, -0.2675,  ..., -0.2325, -0.2325, -0.2325],
          [-0.2675, -0.2675, -0.2675,  ..., -0.2325, -0.2325, -0.2325],
          [-0.2675, -0.2675, -0.2675,  ..., -0.2325, -0.2325, -0.2325]],
 
         [[-0.8807, -0.8807, -0.8807,  ..., -0.8633, -0.8633, -0.8633],
          [-0.8807, -0.8807,

In [42]:
batch_size = 128

# trainset = Cifar10(split = 'train')
trainset = Cifar10(split = 'train', embedding_path = 'output/rn18_H_train.pt')
# trainset = torchvision.datasets.CIFAR10(root='/home/josegfer/datasets/cifar', train=True,
#                                         download=True, transform=data_transforms['train'])
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4)

trainset = Cifar10(split = 'val')
# testset = torchvision.datasets.CIFAR10(root='/home/josegfer/datasets/cifar', train=False,
#                                        download=True, transform=data_transforms['val'])
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=4)

In [43]:
for i, batch in enumerate(trainloader):
    # inputs, labels, h = batch
    inputs = batch['image']
    labels = batch['label']
    h = batch['embedding']
    break
inputs.shape, labels.shape, h.shape

(torch.Size([128, 3, 224, 224]), torch.Size([128]), torch.Size([128, 512]))

In [32]:
for i, batch in enumerate(testloader):
    inputs, labels = batch
    break
inputs.shape, labels.shape

(torch.Size([128, 3, 224, 224]), torch.Size([128]))

In [33]:
labels

tensor([3, 8, 8, 0, 6, 6, 1, 6, 3, 1, 0, 9, 5, 7, 9, 8, 5, 7, 8, 6, 7, 0, 4, 9,
        5, 2, 4, 0, 9, 6, 6, 5, 4, 5, 9, 2, 4, 1, 9, 5, 4, 6, 5, 6, 0, 9, 3, 9,
        7, 6, 9, 8, 0, 3, 8, 8, 7, 7, 4, 6, 7, 3, 6, 3, 6, 2, 1, 2, 3, 7, 2, 6,
        8, 8, 0, 2, 9, 3, 3, 8, 8, 1, 1, 7, 2, 5, 2, 7, 8, 9, 0, 3, 8, 6, 4, 6,
        6, 0, 0, 7, 4, 5, 6, 3, 1, 1, 3, 6, 8, 7, 4, 0, 6, 2, 1, 3, 0, 4, 2, 7,
        8, 3, 1, 2, 8, 0, 8, 3])