In [0]:
import tensorflow as tf

from googleapiclient.discovery import build
from google.colab import auth

import io
from googleapiclient.http import MediaIoBaseDownload

import os

In [0]:
session = tf.InteractiveSession()

In [0]:
BATCH_SIZE = 8

# LOCAL_DATA_DIR = "./data/diagrams"
# GDRIVE_DATA_FOLDER_ID = "1-TLAH-EpZnCeX8KPVERIlFEQchmozqaA"
# TFRECORD_PATTERN = "diagrams.tf_record-?????-of-?????"
# FEATURE_SIZE = 4

LOCAL_DATA_DIR = "./data/quickdraw"
GDRIVE_DATA_FOLDER_ID = "1vMusQ1HQjrJbKO-ebnn5OLQc4iIm1MFL"
TFRECORD_PATTERN = "training.tfrecord-?????-of-?????"
FEATURE_SIZE = 3

In [0]:
auth.authenticate_user()
drive_service = build('drive', 'v3')
gdrive_query = drive_service.files().list(q="'" + GDRIVE_DATA_FOLDER_ID + "' in parents", includeTeamDriveItems=True, supportsTeamDrives=True).execute()
file_ids = [f['id'] for f in gdrive_query['files']]
file_names = [f['name'] for f in gdrive_query['files']]
print(file_ids)
print(file_names)

['10l498QaMa898uPM4HPaAlHJEY1vJ-FBc', '1jY88LyfvZXtAzHGVj7azAdm3g6yqSctO', '1D597gVPD5VOeMnXFxagQS4XPkv3kpfsZ']
['training.tfrecord-00002-of-00010', 'training.tfrecord-00001-of-00010', 'training.tfrecord-00000-of-00010']


In [0]:
def download_and_save_tf_file(file_id, save_path):
  request = drive_service.files().get_media(fileId=file_id)
  with open(save_path, 'wb') as f:
    downloader = MediaIoBaseDownload(f, request)
    done = False
    while not done:
      # _ is a placeholder for a progress object that we ignore.
      # (Our file is small, so we skip reporting progress.)
      status, done = downloader.next_chunk()
      # print("Download %d%%." % int(status.progress() * 100))

In [0]:
os.makedirs(LOCAL_DATA_DIR, exist_ok=True)

# Download data.
if len(os.listdir(LOCAL_DATA_DIR)) == 0:
  print("Downloading Data...")
  for file_id, file_name in zip(file_ids, file_names):
    download_and_save_tf_file(file_id, os.path.join(LOCAL_DATA_DIR, file_name))

Downloading Data...


In [0]:
def _parse_tfexample_fn(example_proto):
    feature_to_type = {
        "ink": tf.VarLenFeature(dtype=tf.float32),
        "shape": tf.FixedLenFeature([2], dtype=tf.int64),
        }

    parsed_features = tf.parse_single_example(example_proto, feature_to_type)
    parsed_features["ink"] = tf.sparse.to_dense(parsed_features["ink"])
    return parsed_features

In [0]:
dataset = tf.data.TFRecordDataset.list_files(os.path.join(LOCAL_DATA_DIR, TFRECORD_PATTERN))
dataset = dataset.repeat()
dataset = dataset.interleave(tf.data.TFRecordDataset, cycle_length=10, block_length=1)
dataset = dataset.map(_parse_tfexample_fn, num_parallel_calls=4)
dataset = dataset.prefetch(100)
dataset = dataset.padded_batch(BATCH_SIZE, padded_shapes=dataset.output_shapes)
features = dataset.make_one_shot_iterator().get_next()

Instructions for updating:
Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.


In [0]:
op_ink = tf.reshape(features["ink"], [BATCH_SIZE, -1, FEATURE_SIZE])
op_shape = features["shape"]
print(op_shape.eval())
print(op_ink.eval().shape)

[[55  3]
 [80  3]
 [42  3]
 [80  3]
 [42  3]
 [55  3]
 [80  3]
 [42  3]]
(8, 55, 3)


# Instead of downloading the whole data at the beginning, try to stream it. It doesn't work until so far.

In [0]:
def download_tf_file(file_id):
  request = drive_service.files().get_media(fileId=file_id)
  fh = io.BytesIO()
  downloader = MediaIoBaseDownload(fh, request)
  done = False
  while done is False:
      status, done = downloader.next_chunk()
      print("Download %d%%." % int(status.progress() * 100))

  return fh.getvalue()

In [0]:
serialized_example = download_tf_file(file_ids[0])
feature_to_type = {
        "ink": tf.VarLenFeature(dtype=tf.float32),
        "shape": tf.FixedLenFeature([2], dtype=tf.int64),
        }

parsed_features = tf.parse_single_example(serialized_example, features=feature_to_type)
parsed_features["ink"] = tf.sparse.to_dense(parsed_features["ink"])

Download 50%.
Download 100%.
Instructions for updating:
Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.


In [0]:
# ink = tf.sparse_tensor_to_dense(parsed_features["ink"])
# print(ink)
# print(ink.shape)
shape_strokes = parsed_features["shape"]
print(shape_strokes.eval())