# Tensorflow Dataset 에서 png 이미지 저장

In [None]:
import os

import tensorflow as tf
import tensorflow_datasets as tfds
from PIL import Image

## [Tensorflow Dataset Catalog](https://www.tensorflow.org/datasets/catalog/overview)

### [CIFAR10](https://www.tensorflow.org/datasets/catalog/cifar10)

In [None]:
target = 'cifar10'
ds, ds_info = tfds.load(target, with_info=True)
ds_train = ds['train']
ds_test = ds['test']
labels = ds_info.features['label'].names
print(f'학습 이미지: {len(ds_train)}')
print(f'평가 이미지: {len(ds_test)}')
print(f'레이블: {labels}')

In [None]:
output = os.path.abspath(os.path.expanduser(f'datasets/{target}/train'))
for data in ds_train:
    img = Image.fromarray(data['image'].numpy())
    name = data['id'].numpy().decode('utf-8')
    label = labels[data['label'].numpy()]
    os.makedirs(os.path.join(output, label), exist_ok=True)
    path = os.path.join(output, label, f'{name}.png')
    img.save(path)
output = os.path.abspath(os.path.expanduser(f'datasets/{target}/test'))
for data in ds_test:
    img = Image.fromarray(data['image'].numpy())
    name = data['id'].numpy().decode('utf-8')
    label = labels[data['label'].numpy()]
    os.makedirs(os.path.join(output, label), exist_ok=True)
    path = os.path.join(output, label, f'{name}.png')
    img.save(path)

### [Fashion MNIST](https://www.tensorflow.org/datasets/catalog/fashion_mnist)

In [None]:
target = 'fashion_mnist'
ds, ds_info = tfds.load(target, with_info=True)
ds_train = ds['train']
ds_test = ds['test']
labels = []
for l in ds_info.features['label'].names:
    labels.append(l.split('/')[0])
print(f'학습 이미지: {len(ds_train)}')
print(f'평가 이미지: {len(ds_test)}')
print(f'레이블: {labels}')

In [None]:
output = os.path.abspath(os.path.expanduser(f'datasets/{target}/train'))
counter = {}
for data in ds_train:
    img = Image.fromarray(data['image'].numpy().reshape(data['image'].shape[:-1]))
    label = labels[data['label'].numpy()]
    name = counter.get(label, 0)
    os.makedirs(os.path.join(output, label), exist_ok=True)
    path = os.path.join(output, label, f'{name:04d}.png')
    img.save(path)
    counter[label] = name + 1
output = os.path.abspath(os.path.expanduser(f'datasets/{target}/test'))
counter = {}
for data in ds_test:
    img = Image.fromarray(data['image'].numpy().reshape(data['image'].shape[:-1]))
    label = labels[data['label'].numpy()]
    name = counter.get(label, 0)
    os.makedirs(os.path.join(output, label), exist_ok=True)
    path = os.path.join(output, label, f'{name:04d}.png')
    img.save(path)
    counter[label] = name + 1

### [Caltech-UCSD Birds 200](https://www.tensorflow.org/datasets/catalog/caltech_birds2010)

In [None]:
target = 'caltech_birds2010'
ds, ds_info = tfds.load(target, with_info=True)
ds_train = ds['train']
ds_test = ds['test']
labels = ds_info.features['label'].names
print(f'학습 이미지: {len(ds_train)}')
print(f'평가 이미지: {len(ds_test)}')
print(f'레이블: {labels}')

In [None]:
target = 'caltech_birds'
output = os.path.abspath(os.path.expanduser(f'datasets/{target}/train'))
for data in ds_train:
    img = Image.fromarray(data['image'].numpy())
    name = data['image/filename'].numpy().decode('utf-8').split('/')[-1].split('.')[0]
    label = data['label_name'].numpy().decode('utf-8')
    os.makedirs(os.path.join(output, label), exist_ok=True)
    path = os.path.join(output, label, f'{name}.png')
    img.save(path)
output = os.path.abspath(os.path.expanduser(f'datasets/{target}/test'))
for data in ds_test:
    img = Image.fromarray(data['image'].numpy())
    name = data['image/filename'].numpy().decode('utf-8').split('/')[-1].split('.')[0]
    label = data['label_name'].numpy().decode('utf-8')
    os.makedirs(os.path.join(output, label), exist_ok=True)
    path = os.path.join(output, label, f'{name}.png')
    img.save(path)

In [None]:
target = 'caltech_birds2011'
ds, ds_info = tfds.load(target, with_info=True)
ds_train = ds['train']
ds_test = ds['test']
labels = ds_info.features['label'].names
print(f'학습 이미지: {len(ds_train)}')
print(f'평가 이미지: {len(ds_test)}')
print(f'레이블: {labels}')

In [None]:
target = 'caltech_birds'
output = os.path.abspath(os.path.expanduser(f'datasets/{target}/train'))
for data in ds_train:
    img = Image.fromarray(data['image'].numpy())
    name = data['image/filename'].numpy().decode('utf-8').split('/')[-1].split('.')[0]
    label = data['label_name'].numpy().decode('utf-8')
    os.makedirs(os.path.join(output, label), exist_ok=True)
    path = os.path.join(output, label, f'{name}.png')
    img.save(path)
output = os.path.abspath(os.path.expanduser(f'datasets/{target}/test'))
for data in ds_test:
    img = Image.fromarray(data['image'].numpy())
    name = data['image/filename'].numpy().decode('utf-8').split('/')[-1].split('.')[0]
    label = data['label_name'].numpy().decode('utf-8')
    os.makedirs(os.path.join(output, label), exist_ok=True)
    path = os.path.join(output, label, f'{name}.png')
    img.save(path)