In [1]:
import torch
from torch import optim
from torch.utils.data import Subset
import torchvision
from torchvision import datasets
from torchvision import transforms

from sklearn.model_selection import StratifiedShuffleSplit

import os
import collections
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(0)
np.random.seed(0)

In [2]:
transform = transforms.Compose([
    transforms.ToTensor()
])

train_ds = datasets.STL10('./data', split='train', download=True, transform=transform)

train_ds.data.shape

Files already downloaded and verified


(5000, 3, 96, 96)

In [3]:
labels = [labels for _, labels in train_ds]
labels_counter = collections.Counter(labels)
labels_counter

Counter({1: 500,
         5: 500,
         6: 500,
         3: 500,
         9: 500,
         7: 500,
         4: 500,
         8: 500,
         0: 500,
         2: 500})

In [4]:
test_ds = datasets.STL10('./data', split='test', download=True, transform=transform)
test_ds.data.shape

Files already downloaded and verified


(8000, 3, 96, 96)

In [5]:
labels_t = [labels for _, labels in test_ds]
labels_t_counter = collections.Counter(labels)
labels_t_counter

Counter({1: 500,
         5: 500,
         6: 500,
         3: 500,
         9: 500,
         7: 500,
         4: 500,
         8: 500,
         0: 500,
         2: 500})

In [6]:
shuffle = StratifiedShuffleSplit(n_splits=1, test_size=.2, random_state=0)
indicies = list(range(len(test_ds)))

# unpack generator
test_i, val_i = shuffle.split(indicies, labels_t).__next__()
print(test_i, test_i.shape)
print(val_i, val_i.shape)

[2096 4321 2767 ... 3206 3910 2902] (6400,)
[6332 6852 1532 ... 5766 4469 1011] (1600,)


In [7]:
val_ds = Subset(test_ds, val_i)
test_ds = Subset(test_ds, test_i)

print(len(val_ds))
print(len(test_ds))

1600
6400


In [9]:
val_labels = [label for _, label in val_ds]
test_labels = [label for _, label in test_ds]

val_counter = collections.Counter(val_labels)
test_counter = collections.Counter(test_labels)

print(val_counter)
print(test_counter)

Counter({2: 160, 8: 160, 3: 160, 6: 160, 4: 160, 1: 160, 5: 160, 9: 160, 0: 160, 7: 160})
Counter({6: 640, 0: 640, 4: 640, 5: 640, 9: 640, 2: 640, 3: 640, 1: 640, 7: 640, 8: 640})
