In [16]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [52]:
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value=value.numpy()
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [53]:
def _float_features(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

In [54]:
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [6]:
def save_npyTopng(src_path, img_name):
    try:
        img=np.load(src_path+img_name.replace('.png', '.npy'))
        name=img_name.replace('.npy', '.png')
        cv2.imwrite(src_path+name, img)
    except:
        return 0

In [55]:
def make_example(plain, segment, img_name):
    feature = {
        'plain_en': _bytes_feature(plain)
        , 'segment_en': _bytes_feature(segment)
        , 'file_name': _bytes_feature(img_name)
    }
    
    return tf.train.Example(features=tf.train.Features(feature=feature))

In [57]:
import random
from tqdm import tqdm 
import os
import cv2

# train 하나 만들고, vaild 하나 만들자

def valid_main(dataset_path, output_path='/root/Public_Storage/Crop_Seg/valid_crops.tfrecord'):
    samples=[]
    print("Reading data list...")
    for img_name in tqdm(os.listdir(dataset_path+'validation/plain/')):
        if '.ipynb' in img_name:
            continue
        plain_path=dataset_path+'validation/plain/'+img_name
        # save_npyTopng(dataset_path+'validation/label/', img_name)
        seg_path=dataset_path+'validation/label/'+img_name
        # print(plain_path)
        # print(seg_path)
        samples.append((plain_path, seg_path, img_name))
    # print(np.array(samples))
    
    print("Writing tfrecord file...")
    with tf.io.TFRecordWriter(output_path) as writer:
        for plain_path, seg_path, img_name in tqdm(samples):
            if '.ipynb' in img_name:
                continue
            tf_example=make_example(plain=open(plain_path, 'rb').read()
                                   , segment=open(seg_path, 'rb').read()
                                   , img_name=str.encode(img_name))
            writer.write(tf_example.SerializeToString())

In [64]:
import random
from tqdm import tqdm 
import os
import cv2

# train 하나 만들고, vaild 하나 만들자

def train_main(dataset_path, output_path='/root/Public_Storage/Crop_Seg/train_crops.tfrecord'):
    samples=[]
    print("Reading data list...")
    for img_name in tqdm(os.listdir(dataset_path+'training/plain/')):
        if '.ipynb' in img_name:
            continue
        plain_path=dataset_path+'training/plain/'+img_name
        # save_npyTopng(dataset_path+'training/label/', img_name)
        seg_path=dataset_path+'training/label/'+img_name
        # print(plain_path)
        # print(seg_path)
        samples.append((plain_path, seg_path, img_name))
    # print(np.array(samples))
    
    print("Writing tfrecord file...")
    with tf.io.TFRecordWriter(output_path) as writer:
        for plain_path, seg_path, img_name in tqdm(samples):
            if '.ipynb' in img_name:
                continue
            tf_example=make_example(plain=open(plain_path, 'rb').read()
                                   , segment=open(seg_path, 'rb').read()
                                   , img_name=str.encode(img_name))
            writer.write(tf_example.SerializeToString())
        

In [65]:
dataset_path='/root/Public_Storage/Crop_Seg/base_dir_vinyl_and_crop/'

In [59]:
valid_main(dataset_path)

Reading data list...


100% 10001/10001 [00:00<00:00, 505376.19it/s]


Writing tfrecord file...


100% 10000/10000 [02:22<00:00, 70.27it/s]


In [None]:
train_main(dataset_path)

Reading data list...


100% 81442/81442 [00:00<00:00, 942495.75it/s]

Writing tfrecord file...



 88% 71647/81441 [22:36<02:14, 72.76it/s]  

In [103]:
def _parse_tfrecord():
    def parse_tfrecord(tfrecord):
        features = {
            'image/plain_en': tf.io.FixedLenFeature([], tf.string)
            ,'image/segment_en': tf.io.FixedLenFeature([], tf.string)
            ,'image/filename': tf.io.FixedLenFeature([], tf.string)
        }
        x=tf.io.parse_single_example(tfrecord, features)
        x_train=tf.image.decode_png(x['image/plain_en'], channels=3)
        # print(x_train)
        y_train=tf.image.decode_png(x['image/segment_en'], channels=3)
        # print(y_train)
        return x_train, y_train
    return parse_tfrecord

In [108]:
def load_tfrecord_dataset(tfrecord_name, batch_size, shuffle=True, buffer_size=1024):
    raw_dataset=tf.data.TFRecordDataset(tfrecord_name)
    # print(raw_dataset)
    # raw_dataset=raw_dataset.repeat()
    if shuffle:
        raw_dataset=raw_dataset.shuffle(buffer_size=buffer_size)
    dataset=raw_dataset.map(
        _parse_tfrecord()
        , num_parallel_calls=tf.data.experimental.AUTOTUNE
    )
    dataset=dataset.batch(batch_size)
    dataset=dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset

In [109]:
tf_record=load_tfrecord_dataset('/root/Public_Storage/Crop_Seg/train_crops.tfrecord', 32)
tf_record = iter(tf_record)
inputs, labels=next(tf_record)

In [111]:
tf_record=load_tfrecord_dataset('/root/Public_Storage/Crop_Seg/train_crops.tfrecord', 32)

tf_record = iter(tf_record)

import time

start = time.time()

for _ in tqdm(range(2600)):
    inputs, labels=next(tf_record)
    
print(time.time()-start)

 98% 2546/2600 [04:15<00:05,  9.95it/s]


StopIteration: 

In [107]:
next(tf_record)[0]

<tf.Tensor: shape=(32, 512, 512, 3), dtype=uint8, numpy=
array([[[[ 72,  82,  50],
         [ 72,  82,  50],
         [ 75,  86,  53],
         ...,
         [ 70,  74,  46],
         [ 62,  63,  38],
         [ 62,  63,  38]],

        [[ 72,  82,  50],
         [ 72,  82,  50],
         [ 75,  86,  53],
         ...,
         [ 70,  74,  46],
         [ 62,  63,  38],
         [ 62,  63,  38]],

        [[ 71,  82,  51],
         [ 71,  82,  51],
         [ 73,  82,  50],
         ...,
         [ 62,  63,  38],
         [ 64,  65,  39],
         [ 64,  65,  39]],

        ...,

        [[ 98, 113, 104],
         [ 98, 113, 104],
         [100, 114, 104],
         ...,
         [ 66,  70,  40],
         [ 62,  68,  38],
         [ 62,  68,  38]],

        [[101, 115, 105],
         [101, 115, 105],
         [ 99, 114, 103],
         ...,
         [ 68,  70,  43],
         [ 62,  64,  38],
         [ 62,  64,  38]],

        [[101, 115, 105],
         [101, 115, 105],
         [ 99, 11