In [1]:
import tensorflow as tf
import numpy as np
import cv2 as cv
import os
import random
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import time

In [2]:
def _image_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(
        bytes_list=tf.train.BytesList(value=[tf.io.encode_png(value,compression=9).numpy()])
    )

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))

In [3]:
base = "./Quark_Gluon_Mini/Quark_Gluon_Mini/" 
os.listdir(base)
TRAIN = base + "Train/"
TEST =  base + "Test/" 

In [4]:
def lister(path):
    zero = path+"0/"
    one = path+"1/"
    list_ = []
    list_.extend(list(map(lambda x,y:x+y, [zero]*len(os.listdir(zero)) , os.listdir(zero))))
    list_.extend(list(map(lambda x,y:x+y, [one]*len(os.listdir(one)) , os.listdir(one))))
    random.shuffle(list_)
    return list_

def investigator(path):
    label = int(path.split("/")[-2])
    name = path.split("/")[-1]
    return label, name

def parser(path):
    image = 255*plt.imread(path)
    image = image.astype(np.uint8)
    label,name = investigator(path)
    
    data = {
        'raw_image' : _image_feature(image),
        'label' : _int64_feature(label),
        'name' : _bytes_feature(name)
    }
    out = tf.train.Example(features=tf.train.Features(feature=data))
    return out


In [5]:
def shard_calc(path, no_per_shard = 128):
    tot =  len(lister(path))
    shards = int(tot/no_per_shard)+1
    return shards
def convert(path, no_per_shard, shard_name, ini_path = ""):
    no_shards = shard_calc(path,no_per_shard)
    list_ = lister(path)
    for i in range(no_shards):
        writer = tf.io.TFRecordWriter(f"./{ini_path}/{shard_name}_shard_{i+1}_of{no_shards}.tfrecords") 
        if i != no_shards - 1:
            shard_paths = list_[i*no_per_shard:(i+1)*no_per_shard]
        else:
            shard_paths = list_[i*no_per_shard:]
        for path in tqdm(shard_paths):
            item = parser(path)
            writer.write(item.SerializeToString())
        
        writer.close()
        print(f"Shard: {shard_name}_shard_{(i+1)}/{no_shards} done, moving to next")
    

In [6]:

shard_calc("./Test/Test/", no_per_shard=10240)
convert(path="./Test/Test/", no_per_shard=10240,shard_name= "test",ini_path = "Test1")

  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_1/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_2/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_3/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_4/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_5/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_6/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_7/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_8/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_9/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_10/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_11/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_12/14 done, moving to next


  0%|          | 0/10240 [00:00<?, ?it/s]

Shard: test_shard_13/14 done, moving to next


  0%|          | 0/6186 [00:00<?, ?it/s]

Shard: test_shard_14/14 done, moving to next
