In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
import sys
sys.path.append('/content/gdrive/MyDrive/stylegan2_tobigs/')

In [3]:
import os
import glob
import numpy as np
import tensorflow as tf

from tqdm import tqdm
# from tf_utils import allow_memory_growth

# TFRecords
# http://solarisailab.com/archives/2603
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy()  # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def _int64_list_feature(values):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))


def _float_list_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value))

# serialize 직렬화? byte 형태로 변환 -> TFRecords
def serialize_example(image_str, label):
    feature = {
        'label': _float_list_feature(label),
        'image': _bytes_feature(image_str),
    }

    # Create a Features message using tf.train.Example.
    # tf.train.Example을 이용해서 Feature messeage를 생성합니다.
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString() # SerializeToString 함수를 이용해서 binary string으로 변환


def tf_serialize_example(image_str, label):
    tf_string = tf.py_function(
        serialize_example,
        (image_str, label),
        tf.string)
    return tf.reshape(tf_string, ())  # The result is a scalar


def parse_fn(image_fn, labels, res):
    # load image: [res, res, 3]
    image = tf.io.read_file(image_fn)
    # for decode_bmp, decode_gif, decode_jpeg, and decode_png
    # to convert the input bytes string into a Tensor of type dtype.
    image = tf.io.decode_image(image, channels=3, dtype=tf.uint8)
    image.set_shape([None, None, 3]) # 3차원 RGB
    image = tf.image.resize(image, size=[res, res]) # resize!
    image.set_shape([res, res, 3])
    image = tf.cast(image, dtype=tf.uint8)
    image_str = tf.image.encode_png(image)

    # set labels
    return image_str, labels


def create_tfrecord_data(input_data, output_fn, res):
    image_fns = input_data['image_fns']
    labels = input_data['labels']

    dataset = tf.data.Dataset.from_tensor_slices((image_fns, labels))
    dataset = dataset.map(lambda f, l: parse_fn(f, l, res), num_parallel_calls=8) # image_fn -> image_str
    dataset = dataset.map(lambda s, l: tf_serialize_example(s, l), num_parallel_calls=8)
    # TFRecords 쓰기
    writer = tf.data.experimental.TFRecordWriter(output_fn)
    writer.write(dataset)
    return

def raw_data_to_npy(txt_path, parse_first_line=False):
    with open(txt_path) as f:
        raw_data = f.readlines()

    raw_data = sorted(raw_data, key=lambda x:x.split(', ')[0])

    if parse_first_line:
        raw_data = [rd.split(', ') for rd in raw_data]
    else:
        raw_data = [rd.split(', ')[1:] for rd in raw_data]
    
    npy_data = np.array(raw_data).astype(np.float)
    return npy_data



def main():
    # allow_memory_growth() # 메모리 증가를 허용

    # prepare variables
    res = 512
    divide = 1000
    data_base_dir = '/content/gdrive/MyDrive/stylegan2_tobigs/data/REAL0610'
    dst_tfrecord_dir = os.path.join(data_base_dir, 'tfrecords')
    if not os.path.exists(dst_tfrecord_dir):
        os.makedirs(dst_tfrecord_dir)

    # load image file names
    image_fns = glob.glob(os.path.join(data_base_dir, 'Resize_Image', '**/*.png')) # .png 파일 다
    image_fns = sorted(image_fns)

    # load labels
    label_fn = os.path.join(data_base_dir, 'label.txt')
    label_fn = raw_data_to_npy(label_fn)

    # start converting
    n_total = len(image_fns)
    print(n_total)
    interval_list = [(v * divide, (v + 1) * divide) for v in range(n_total // divide)]
    for fn_idx, (start_idx, end_idx) in enumerate(tqdm(interval_list)):
        output_fn = os.path.join(dst_tfrecord_dir, f'{fn_idx:04d}.tfrecord') # 1000개씩 0001.tfrecord, 0002.tfrecord 식

        sliced_data = {
            'image_fns': image_fns[start_idx:end_idx],
            'labels': label_fn[start_idx:end_idx],
        }
        create_tfrecord_data(sliced_data, output_fn, res)
    print('DONE!!')
    return


# if __name__ == '__main__':
#     main()

In [None]:
main()

  0%|          | 0/11 [00:00<?, ?it/s]

11661


 36%|███▋      | 4/11 [29:27<51:42, 443.21s/it]

KeyboardInterrupt: ignored

In [None]:
import os
import glob
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt


def parse_tfrecord(raw_record, res):
    feature_description = {
        'label': tf.io.FixedLenFeature([259], tf.float32),
        'image': tf.io.FixedLenFeature([], tf.string, default_value=''),
    }

    # parse feature
    parsed = tf.io.parse_single_example(raw_record, feature_description)

    # labels
    labels = tf.reshape(parsed['label'], shape=[259])

    # image
    image = tf.io.decode_png(parsed['image'])
    image = tf.image.resize(image, size=[res, res])
    image = tf.transpose(image, perm=[2, 0, 1])
    image = tf.cast(image, dtype=tf.dtypes.float32)
    image = image / 127.5 - 1.0
    image.set_shape([3, res, res])
    return image, labels


def input_fn(filenames, res, batch_size, epochs):
    files = tf.data.Dataset.from_tensor_slices(filenames)
    dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x),
                               # cycle_length: number of files to read concurrently
                               # If you want to read from all your files(*.tfrecord) to create a batch,
                               # set this to the number of files
                               cycle_length=len(filenames),
                               # block_length: each time we read from a file, reads block_length elements from this file
                               block_length=1)
    dataset = dataset.map(lambda x: parse_tfrecord(x, res), num_parallel_calls=8)
    dataset = dataset.shuffle(buffer_size=1000, reshuffle_each_iteration=True)
    dataset = dataset.repeat(epochs)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=None)
    return dataset


def get_dataset(tfrecord_dir, res, batch_size, is_train):
    # get dataset
    filenames = glob.glob(os.path.join(tfrecord_dir, '*.tfrecord'))
    print(filenames)
    filenames = sorted(filenames)
    train_filenames = filenames[:-1]
    val_filenames = filenames[-1:]
    filenames = train_filenames if is_train else val_filenames

    # create dataset
    dataset = input_fn(filenames, res, batch_size=batch_size, epochs=None)
    return dataset


def get_label_depth():
    # define label spec
    label_depths = {
        'label_dim': 254,
        'person_dim': 5,
        # 'label_dim': 258,
        # 'person_dim': 4,
        # 'keypoint_dim': 254,
    }
    return label_depths


def main():
    import numpy as np
    from PIL import Image

    # res = 256
    res = 512
    data_base_dir = '/content/gdrive/MyDrive/stylegan2_tobigs/data/REAL0610'
    folder_name = 'tfrecords'
    tfrecord_dir = os.path.join(data_base_dir, folder_name)
    print(tfrecord_dir)

    batch_size = 2
    is_train = True

    # label_depths = get_label_depth()
    dataset = get_dataset(tfrecord_dir, res, batch_size, is_train)

    for ii in range(3): # 이미지, 라벨 3개만 출력
        for real_images, labels in dataset.take(1):
            image_raw = real_images.numpy()
            image_raw = image_raw[0]
            image_raw = np.transpose(image_raw, axes=[1, 2, 0])
            image_raw = (image_raw + 1.0) * 127.5  # -1~1 -> 0~255
            image = Image.fromarray(np.uint8(image_raw))
            image.show()
            plt.figure(figsize=(40,40))
            plt.subplot(1,3,ii+1)
            plt.imshow(image)
            plt.scatter(list(np.array(labels[0][:-5:2])*512), list(np.array(labels[0][1:-5:2])*512), s=0.2)
            # print(labels[0])
            print(len(labels[0]))
    return


# if __name__ == '__main__':
#     main()

In [None]:
main()

Output hidden; open in https://colab.research.google.com to view.