In [71]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import os
import re
import glob
import PIL
import time
import datetime
import gc
from urllib import request
import pathlib
import random
import gzip
import struct

from IPython import display

In [72]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

# Clear any logs from previous runs
!rm -rf ./logs/ 

seed = 3939
np.random.seed(seed)
tf.random.set_seed(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [78]:
IMG_SHAPE = (28, 28, 1)
IMG_SIZE = np.prod(IMG_SHAPE)
LABEL_NUM = 10

In [74]:
dataset_dir = os.environ["HOME"] + "/Workspace/Dataset/mnist/"

In [75]:
!ls $dataset_dir

jpg			   TFRecord
png			   train-images-idx3-ubyte.gz
t10k-images-idx3-ubyte.gz  train-labels-idx1-ubyte.gz
t10k-labels-idx1-ubyte.gz


In [76]:
def load_mnist(x_filename, y_filename, dataset_dir):
    x_path = dataset_dir + x_filename
    y_path = dataset_dir + y_filename
    with gzip.open(x_path) as fx, gzip.open(y_path) as fy:
        fx.read(4)
        fy.read(4)
        N, = struct.unpack('>i', fy.read(4))
        if N != struct.unpack('>i', fx.read(4))[0]:
            raise RuntimeError('wrong pair of MNIST images and labels')
        fx.read(8)

        images = np.empty((N, 784), dtype=np.uint8)
        labels = np.empty(N, dtype=np.uint8)

        for i in range(N):
            labels[i] = ord(fy.read(1))
            for j in range(784):
                images[i, j] = ord(fx.read(1))
    return images.reshape(-1, *IMG_SHAPE), np.eye(LABEL_NUM)[labels]

In [77]:
X_train_valid, y_train_valid = load_mnist("train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", dataset_dir)
X_test, y_test = load_mnist("t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", dataset_dir)

In [107]:
def get_example_proto(img, label):
    return tf.train.Example(
                features=tf.train.Features(
                    feature={
                        "img": tf.train.Feature(
                            float_list=tf.train.FloatList(
                                value=list(img.flatten())
                            )
                        ),

                        "label": tf.train.Feature(
                            int64_list=tf.train.Int64List(
                                value=[label]
                            )
                        ),
                    }
                )
            )


feature_description = {
    "img": tf.io.FixedLenFeature([IMG_SIZE], tf.float32),
    "label": tf.io.FixedLenFeature([], tf.int64),
}

def _parse_example_proto(example_proto):
    parsed = tf.io.parse_single_example(example_proto, feature_description)
    parsed["img"] = tf.reshape(parsed['img'], IMG_SHAPE)
    return parsed

In [108]:
example_protos = []
for i in range(len(X_train_valid)):
    example_proto = get_example_proto(X_train_valid[i], label=np.argmax(y_train_valid[i]))
    example_protos.append(example_proto)
record_file = 'train_valid.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for example_proto in example_protos:
    writer.write(example_proto.SerializeToString())



In [109]:
raw_dataset = tf.data.TFRecordDataset([record_file])
parsed_dataset = raw_dataset.map(_parse_example_proto)
for image_features in parsed_dataset:
  # feature = tf.reshape(image_features['img'], IMG_SHAPE)
  print(image_features['label'])
  print(image_features["img"])
  break
  # print(feature)

tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(
[[[  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]]

 [[  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]]

 [[  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]]

 [[  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  [  0.]
  

: 

In [103]:
example_protos = []
for i in range(len(X_test)):
    example_proto = get_example_proto(X_test[i], label=np.argmax(y_test[i]))
    example_protos.append(example_proto)
record_file = 'test.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for example_proto in example_protos:
    writer.write(example_proto.SerializeToString())
raw_dataset = tf.data.TFRecordDataset([record_file])