In [15]:
import os
from os.path import join, exists
import numpy as np
import pandas as pd
from tqdm import tqdm
import cv2
from torchvision.datasets import CIFAR10
from sklearn.model_selection import train_test_split
from coreml.utils.io import save_yml, read_yml

In [3]:
train = CIFAR10('/data/CIFAR10/raw',  download=True)

Files already downloaded and verified


In [4]:
train

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: /data/CIFAR10/raw
    Split: Train

In [5]:
test = CIFAR10('/data/CIFAR10/raw', train=False, download=True)

Files already downloaded and verified


In [6]:
processed_dir = '/data/CIFAR10/processed'

In [7]:
os.makedirs(processed_dir, exist_ok=True)

In [8]:
train.data.shape

(50000, 32, 32, 3)

In [9]:
len(train.targets)

50000

In [10]:
all_images = np.append(train.data, test.data, axis=0)
all_targets = np.append(train.targets, test.targets, axis=0)

In [11]:
all_images.shape, all_targets.shape

((60000, 32, 32, 3), (60000,))

In [13]:
image_dir = join(processed_dir, 'images')
annotation_path = join(processed_dir, 'annotation.csv')
version_dir = join(processed_dir, 'versions')
os.makedirs(version_dir, exist_ok=True)
os.makedirs(image_dir, exist_ok=True)
version_path = join(version_dir, 'default.yml')

In [16]:
image_paths = []
for index in tqdm(range(len(all_images))):
    image = all_images[index]
    image_path = join(image_dir, f'{index}.png')
    image_paths.append(image_path)
    if not exists(image_path):
        cv2.imwrite(image_path, image[:, :, ::-1])

100%|██████████| 60000/60000 [00:00<00:00, 139393.27it/s]


In [17]:
splits = ['train'] * len(train.data) + ['test'] * len(test.data)

In [18]:
labels = [{'classification': [all_targets[index].tolist()]} for index in range(len(all_targets))]

In [19]:
annotation = pd.DataFrame({'path': image_paths, 'label': labels, 'split': splits})

In [20]:
annotation.head()

Unnamed: 0,path,label,split
0,/data/CIFAR10/processed/images/0.png,{'classification': [6]},train
1,/data/CIFAR10/processed/images/1.png,{'classification': [9]},train
2,/data/CIFAR10/processed/images/2.png,{'classification': [9]},train
3,/data/CIFAR10/processed/images/3.png,{'classification': [4]},train
4,/data/CIFAR10/processed/images/4.png,{'classification': [1]},train


In [21]:
annotation.to_csv(annotation_path, index=False)

In [22]:
train_indices, val_indices = train_test_split(np.arange(len(train.data)), test_size=0.2)

In [23]:
assert len(train_indices) == 40000
assert len(val_indices) == 10000

In [24]:
version = {}

In [25]:
train_image_paths = [image_path for index, image_path in enumerate(image_paths) if index in train_indices]
val_image_paths = [image_path for index, image_path in enumerate(image_paths) if index in val_indices]

In [26]:
assert len(train_image_paths) == 40000
assert len(val_image_paths) == 10000

In [27]:
train_labels = [label for index, label in enumerate(labels) if index in train_indices]
val_labels = [label for index, label in enumerate(labels) if index in val_indices]

In [28]:
assert len(train_labels) == 40000
assert len(val_labels) == 10000

In [31]:
version['train'] = {
    'file': train_image_paths,
    'label': train_labels
}

In [32]:
version['val'] = {
    'file': val_image_paths,
    'label': val_labels
}

In [33]:
version['test'] = {
    'file': image_paths[len(train.data):],
    'label': labels[len(train.data):]
}

In [35]:
assert len(version['train']['file']) == 40000
assert len(version['val']['file']) == 10000
assert len(version['test']['file']) == 10000

In [36]:
assert isinstance(version['train']['label'], list)
assert isinstance(version['train']['label'][0], dict)
assert isinstance(version['val']['label'], list)
assert isinstance(version['val']['label'][0], dict)
assert isinstance(version['test']['label'], list)
assert isinstance(version['test']['label'][0], dict)

In [81]:
save_yml(version_path, version)

In [82]:
load = read_yml(version_path)