In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

import os
import re
import glob
import PIL
import time
import datetime
import gc
from urllib import request
import pathlib
import random
import gzip
import struct

from IPython import display

In [2]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

# Clear any logs from previous runs
!rm -rf ./logs/ 

seed = 3939
np.random.seed(seed)
tf.random.set_seed(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

In [3]:
IMG_SHAPE = (28, 28, 1)
LABEL_NUM = 10

In [4]:
dataset_dir = os.environ["HOME"] + "/Workspace/Dataset/mnist/"

In [5]:
!ls $dataset_dir

jpg			   TFRecord
png			   train-images-idx3-ubyte.gz
t10k-images-idx3-ubyte.gz  train-labels-idx1-ubyte.gz
t10k-labels-idx1-ubyte.gz


In [6]:
def load_mnist(x_filename, y_filename, dataset_dir):
    x_path = dataset_dir + x_filename
    y_path = dataset_dir + y_filename
    with gzip.open(x_path) as fx, gzip.open(y_path) as fy:
        fx.read(4)
        fy.read(4)
        N, = struct.unpack('>i', fy.read(4))
        if N != struct.unpack('>i', fx.read(4))[0]:
            raise RuntimeError('wrong pair of MNIST images and labels')
        fx.read(8)

        images = np.empty((N, 784), dtype=np.uint8)
        labels = np.empty(N, dtype=np.uint8)

        for i in range(N):
            labels[i] = ord(fy.read(1))
            for j in range(784):
                images[i, j] = ord(fx.read(1))
    return images.reshape(-1, *IMG_SHAPE), np.eye(LABEL_NUM)[labels]

In [7]:
X_train_valid, y_train_valid = load_mnist("train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", dataset_dir)
X_test, y_test = load_mnist("t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz", dataset_dir)

In [13]:
def get_example_proto(img, label):
    return tf.train.Example(
                features=tf.train.Features(
                    feature={
                        "img": tf.train.Feature(
                            float_list=tf.train.FloatList(
                                value=list(img.flatten())
                            )
                        ),

                        "img_shape": tf.train.Feature(
                            int64_list=tf.train.Int64List(
                                value=list(img.shape)
                            )
                        ),

                        "label": tf.train.Feature(
                            int64_list=tf.train.Int64List(
                                value=[label]
                            )
                        ),
                    }
                )
            )


feature_description = {
    "img": tf.io.FixedLenFeature([28*28], tf.float32),
    "img_shape": tf.io.VarLenFeature(tf.int64),
    "label": tf.io.FixedLenFeature([], tf.int64),
}

def _parse_example_proto(example_proto):
    return tf.io.parse_single_example(example_proto, feature_description)

In [14]:
i = 0
example_proto = get_example_proto(X_train_valid[i], label=np.argmax(y_train_valid[i]))
example_protos = [example_proto]

In [15]:
record_file = 'test.tfrecords'
with tf.io.TFRecordWriter(record_file) as writer:
  for example_proto in example_protos:
    writer.write(example_proto.SerializeToString())

In [16]:
raw_dataset = tf.data.TFRecordDataset([record_file])

parsed_dataset = raw_dataset.map(_parse_example_proto)
parsed_dataset

2022-10-07 13:50:27.968859: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-10-07 13:50:28.426965: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 900 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3090, pci bus id: 0000:18:00.0, compute capability: 8.6


<MapDataset shapes: {img_shape: (None,), img: (784,), label: ()}, types: {img_shape: tf.int64, img: tf.float32, label: tf.int64}>

In [18]:
for image_features in parsed_dataset:
  feature = image_features['img']
  print(feature)

tf.Tensor(
[  0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   0.   3.  18.
  18.  18. 126. 136. 175.  26. 166. 255. 247. 127.   0.   0.   0.   0.
   0.   0.   0.   0.   0.   0.   0.   0.  30.  36.  94. 154. 170. 253.
 253. 253. 253. 253. 225. 172. 253. 242. 195.  64.   0.   0.   0. 