In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tqdm import tqdm

In [2]:
labels_df = pd.read_csv('data/labels.csv')

In [3]:
import skimage.io as sio
import os.path as op

In [4]:
cell_images = []
for cell_file in labels_df['file']:
    cell_images.append(sio.imread(op.join('data/cells/', cell_file)).ravel())

img_data = np.array(cell_images)

In [5]:
labels_data = labels_df['label']

In [6]:
one_hot_labels = np.zeros((labels_data.shape[0], len(np.unique(labels_data))))
one_hot_labels[np.arange(labels_data.shape[0]), labels_data] = 1

In [7]:
import os.path as op
tfrecords_train_file = op.join('data', 'cells_train.tfrecords')
tfrecords_test_file = op.join('data', 'cells_test.tfrecords')

In [8]:
idx = np.arange(img_data.shape[0])

# one MUST randomly shuffle data before putting it into one of these
# formats. Without this, one cannot make use of tensorflow's great
# out of core shuffling.

np.random.shuffle(idx)

prop_train = 0.8

train_idx = idx[:int(prop_train*idx.shape[0])]
test_idx = idx[int(prop_train*idx.shape[0]):]

In [9]:
def write_tfrecords(img_data, labels_data, fname, idx):
    writer = tf.python_io.TFRecordWriter(fname)
    # iterate over each example
    # wrap with tqdm for a progress bar
    for example_idx in tqdm(idx):
        features = img_data[example_idx]
        label = labels_data[example_idx]

        # construct the Example proto boject
        example = tf.train.Example(
            # Example contains a Features proto object
            features=tf.train.Features(
              # Features contains a map of string to Feature proto objects
              feature={
                # A Feature contains one of either a int64_list,
                # float_list, or bytes_list
                'label': tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[int(label)])),
                'image': tf.train.Feature(
                    int64_list=tf.train.Int64List(value=[int(f) for f in features.astype("int64")])),
        }))
        # use the proto object to serialize the example to a string
        serialized = example.SerializeToString()
        # write the serialized object to disk
        writer.write(serialized)

    writer.close()


In [10]:
if not op.exists(tfrecords_train_file):
    write_tfrecords(img_data, labels_data, tfrecords_train_file, train_idx)

100%|██████████| 675/675 [07:37<00:00,  1.65it/s]


In [11]:
if not op.exists(tfrecords_test_file):
    write_tfrecords(img_data, labels_data, tfrecords_test_file, test_idx)

100%|██████████| 169/169 [01:51<00:00,  1.33it/s]
