<a href="https://colab.research.google.com/github/iskra3138/ImageSr/blob/master/3_DataSet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
%tensorflow_version 2.x

import tensorflow as tf

TensorFlow 2.x selected.


In [0]:
import numpy as np

In [0]:
CLASSES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

In [0]:
import os

# TFRecord Parsing을 위한 함수 정의  
AUTO = tf.data.experimental.AUTOTUNE

IMG_WIDTH = 224
IMG_HEIGHT = 224
IMAGE_SIZE =  [IMG_HEIGHT, IMG_WIDTH]

batch_size = 64

## 본 실험에서는 16개의 tfrecord파일을 train/validation용으로 나눠서 사용합니다.
## train전용, validation전용 tfrecord 파일들이 있으면 특정해서 list 로 넘기시면 됩니다.
BUCKET = "gs://iskra3138_share"
gcs_pattern = os.path.join(BUCKET, '*.tfrec')
validation_split = 0.19
filenames = tf.io.gfile.glob(gcs_pattern)
split = len(filenames) - int(len(filenames) * validation_split)
train_fns = filenames[:split]
validation_fns = filenames[split:]

## TFRecord Parsing 함수 (TFRecord 생성함수를 참고해서 만들어줍니다.)
def parse_tfrecord(example):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),  # tf.string = bytestring (not text string)
        "file_name": tf.io.FixedLenFeature([], tf.string),  # one bytestring
        "label_name": tf.io.FixedLenFeature([], tf.string),  # one bytestring
        "label": tf.io.FixedLenFeature([], tf.int64),  # shape [] means scalar, one integer
    }
    # decode the TFRecord
    example = tf.io.parse_single_example(example, features)
    
    
    label = example['label']
    label = tf.one_hot(indices=label, depth=5)   
    image = tf.image.decode_jpeg(example['image'], channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32) ## make [0,255] to [0,1) resize 앞에 위치할 때만 [0,1), 즉 input이 float32가 아니어야 작동
    image = tf.image.resize(image, IMAGE_SIZE) ## method가 tf.image.ResizeMethod.NEAREST_NEIGHBOR 가 아니면 출력은 무조건 float32
    
    file_name  = example['file_name']
    label_name  = example['label_name']
    
    return image, label, file_name, label_name

def load_dataset(filenames):
  # Read from TFRecords. For optimal performance, we interleave reads from multiple files.
  records = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO)
  return records.map(parse_tfrecord, num_parallel_calls=AUTO)

def get_training_dataset():
  dataset = load_dataset(train_fns)

  # Create some additional training images by randomly flipping and
  # increasing/decreasing the saturation of images in the training set. 
  def data_augment(image, label, file_name, label_name):
    modified = tf.image.random_flip_left_right(image)
    modified = tf.image.random_flip_up_down(modified)
    return modified, label, file_name, label_name
  augmented = dataset.map(data_augment, num_parallel_calls=AUTO)

  # Prefetch the next batch while training (autotune prefetch buffer size).
  return augmented.repeat().shuffle(2048).batch(batch_size).prefetch(AUTO) 

training_dataset = get_training_dataset()
validation_dataset = load_dataset(validation_fns).batch(batch_size).prefetch(AUTO)

위 코드를 실행시키면, training_dataset, validation_dataset 모두 (image, label, file_name, label_name)구조로 element들이 구성됨

training_dataset은 
- augmented = dataset.map(data_augment, num_parallel_calls=AUTO)에서 랜덤으로 좌우플립/상하플립이 적용되고, 
- augmented.repeat().shuffle(2048).batch(batch_size).prefetch(AUTO) 에서 무한 반복으로, 섞고, batch_size만큼 생성됨

validation_dataset은 
- batch_size만큼 뽑히는 데, 전체 element들이 다 뽑히면 끝남

In [16]:
## tfrecord 파일 내 데이터 개수 세기
#### tfrecord 이름이 아니라 직접 데이터 개수를 확인하고 싶다면 아래 코드 실행
#sum([1 for _ in tf.data.TFRecordDataset({TFRecord 파일 이름})])
sum([1 for _ in tf.data.TFRecordDataset(filenames[0])])

230

In [17]:
## tfrecord 파일 내 데이터 개수 세기
#### tfrecord 이름이 아니라 직접 데이터 개수를 확인하고 싶다면 아래 코드 실행
#### 여러 개의 tfrecord를 list로 넘겨도 됨

#sum([1 for _ in tf.data.TFRecordDataset([TFRecord 파일 리스트])])
sum([1 for _ in tf.data.TFRecordDataset(filenames[0:3])])

690

In [21]:
## train_fns list로 traind_data 개수 세기
print (train_fns)
print (sum([1 for _ in tf.data.TFRecordDataset(train_fns)]))

['gs://iskra3138_share/00-230.tfrec', 'gs://iskra3138_share/01-230.tfrec', 'gs://iskra3138_share/02-230.tfrec', 'gs://iskra3138_share/03-230.tfrec', 'gs://iskra3138_share/04-230.tfrec', 'gs://iskra3138_share/05-230.tfrec', 'gs://iskra3138_share/06-230.tfrec', 'gs://iskra3138_share/07-230.tfrec', 'gs://iskra3138_share/08-230.tfrec', 'gs://iskra3138_share/09-230.tfrec', 'gs://iskra3138_share/10-230.tfrec', 'gs://iskra3138_share/11-230.tfrec', 'gs://iskra3138_share/12-230.tfrec']
2990


In [22]:
## validation_fns list로 traind_data 개수 세기
print (validation_fns)
print (sum([1 for _ in tf.data.TFRecordDataset(validation_fns)]))

['gs://iskra3138_share/13-230.tfrec', 'gs://iskra3138_share/14-230.tfrec', 'gs://iskra3138_share/15-220.tfrec']
680


train data는 2990장, validation data는 680장임

In [65]:
## dataset 랜덤으로 1batch 데이터 뽑기 1
### take() 사용
### take(n)을 쓰면 사전에 정의된 배치사이즈 크기의 데이터를 n번 샘플링 해옴
for img, label, file_name, label_name in validation_dataset.take(1):
  label1 = label
  print (label.shape)

(64, 5)


In [66]:
## dataset 랜덤으로 1batch 데이터 뽑기 2
## as_numpy_iterator()를 써서 next로 뽑기
iter = validation_dataset.as_numpy_iterator()
img, label, file_name, label_name = next(iter)
label2 = label
print (label.shape)

(64, 5)


In [68]:
np.where ((label1 == label2)==False)

(array([], dtype=int64), array([], dtype=int64))

- 두 방법 모두 생성되는 순서는 같은 것 같음

In [69]:
## 하지만 두 코드가 연속적으로 실행되면 그 결과값은 달라짐 
for img, label, file_name, label_name in validation_dataset.take(1):
  label1 = label
img2, label2, file_name2, label_name2 = next(iter)

np.where ((label1 == label2)==False)

(array([ 1,  1,  2,  2,  3,  3,  4,  4,  5,  5,  6,  6,  7,  7, 10, 10, 11,
        11, 12, 12, 13, 13, 14, 14, 15, 15, 19, 19, 21, 21, 22, 22, 24, 24,
        25, 25, 26, 26, 27, 27, 29, 29, 30, 30, 31, 31, 32, 32, 33, 33, 34,
        34, 35, 35, 36, 36, 37, 37, 38, 38, 39, 39, 40, 40, 41, 41, 42, 42,
        43, 43, 44, 44, 45, 45, 46, 46, 47, 47, 48, 48, 49, 49, 51, 51, 52,
        52, 54, 54, 56, 56, 57, 57, 58, 58, 59, 59, 60, 60, 61, 61, 62, 62,
        63, 63]),
 array([1, 4, 0, 4, 0, 2, 0, 4, 2, 4, 1, 2, 0, 1, 1, 2, 0, 1, 2, 4, 0, 2,
        1, 3, 1, 3, 2, 3, 0, 1, 2, 3, 2, 3, 2, 4, 3, 4, 1, 3, 3, 4, 3, 4,
        0, 1, 1, 4, 3, 4, 1, 3, 3, 4, 1, 4, 1, 3, 1, 3, 1, 3, 2, 3, 1, 2,
        1, 4, 0, 2, 0, 4, 1, 3, 0, 2, 0, 1, 0, 1, 2, 4, 1, 2, 1, 3, 3, 4,
        0, 3, 2, 4, 1, 4, 1, 4, 0, 3, 0, 1, 0, 3, 1, 2]))

In [75]:
## as_numpy_iterator()를 쓰면 딱 1epoch만큼 샘플링할 수 있음
num_samples = 0
for img, label, file_name, label_name in validation_dataset.as_numpy_iterator():
  num_samples += label.shape[0]
print (num_samples)

680


전체 데이터 개수와 똑같아졌음

In [83]:
## take(n)를 쓰면 계속해서 증가하다가 1epoch 샘플링에서 끝남
for i in range(15) :
  num_samples = 0
  for img, label, file_name, label_name in validation_dataset.take(i):
    num_samples += label.shape[0]
  print (i, num_samples)

0 0
1 64
2 128
3 192
4 256
5 320
6 384
7 448
8 512
9 576
10 640
11 680
12 680
13 680
14 680


계속해서 64개씩 늘어나다 11번째는 남은 40만큼 늘어나고 그 이후는 변화가 없음

In [0]:
# class별 데이터세트 만들기
## 원하는 클래스끼리만 모아 놓고 싶을 때는 fliter를 적용하면 됨
### label이 1인 데이터세트를 만드는 예제

def select_label1 (img, label, file_name, label_name):
  return tf.math.equal(tf.argmax(label), 1) ## label이 one-hot 형태임

label1dataset = validation_dataset.unbatch().filter(select_label1) ## 한 장씩 평가해야 해서 unbatch 해줌

In [96]:
for img, label, file_name, label_name in label1dataset.as_numpy_iterator():
  print (label)

[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0.]
[0. 1. 0. 0. 0

In [92]:
label.numpy().shape

(64, 5)