In [11]:
import os

import tensorflow as tf
import tensorflow_datasets as tfds

import torch
from torch.utils.data import Dataset, IterableDataset, DataLoader

import tqdm as notebook_tqdm

from pprint import pprint
import random
import time
from itertools import cycle, islice, chain

from google.cloud import storage

In [12]:
PROJECT_ID='hybrid-vertex'

# storage client
storage_client = storage.Client(
    project=PROJECT_ID
)

train_files = [
    'gs://imagenet-jt/train/train-00000-of-01024',
    # 'gs://imagenet-jt/train/train-00001-of-01024',
    # 'gs://imagenet-jt/train/train-00002-of-01024',
    # 'gs://imagenet-jt/train/train-00003-of-01024',
    # 'gs://imagenet-jt/train/train-00004-of-01024',
]

AUTOTUNE = tf.data.AUTOTUNE
options = tf.data.Options()
options.experimental_deterministic = False
# options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # FILE | DATA | AUTO

## MyIterable dataset

> Partition data into groups, feed each group into a single stream

In [6]:
class MyIterableDataset(IterableDataset):
    
    def __init__(self, data_list, batch_size):
        self.data_list = data_list
        self.batch_size = batch_size
    
    @property
    def shuffled_data_list(self):
        return random.sample(self.data_list, len(self.data_list))
    
    def process_data(self, data):
        for x in data:
            yield x
    
    def get_stream(self, data_list):
        return chain.from_iterable(map(self.process_data, cycle(data_list)))
    
    def get_streams(self):
        return zip(*[self.get_stream(self.shuffled_data_list) for _ in range(self.batch_size)])
    
    def __iter__(self):
        return self.get_streams()

In [7]:
synthetic_data = [
    [12, 13, 14, 15, 16, 17],
    [27, 28, 29],
    [31, 32, 33, 34, 35, 36, 37, 38, 39],
    [40, 41, 42, 43],
]

In [8]:
iterable_dataset = MyIterableDataset(synthetic_data, batch_size=4)

loader = DataLoader(iterable_dataset, batch_size=None)

for batch in islice(loader, 12):
    print(batch)

[40, 40, 31, 31]
[41, 41, 32, 32]
[42, 42, 33, 33]
[43, 43, 34, 34]
[31, 27, 35, 35]
[32, 28, 36, 36]
[33, 29, 37, 37]
[34, 12, 38, 38]
[35, 13, 39, 39]
[36, 14, 40, 27]
[37, 15, 41, 28]
[38, 16, 42, 29]


## MyIterable ds TFRecords to torch_tensors

In [None]:
class MyIterable_tf_Dataset(IterableDataset):
    
    def __init__(self, filenames, batch_size, length):
        self.filenames = filenames
        self.batch_size = batch_size
        self.length = length
        
    def __len__(self, length):
        '''
        Needed for torch dataloader
        '''
        return self.length
    
    def process_data(self, data):
        for x in data:
            yield x
    
    def identity(x):
        return x
    
    # @classmethod
    def split_tfrecords_per_node(filenames):
        """
        Split TFRecords correctly per accelerator node
        :param filenames:
        :return: slice of filenames
        """
        rank=xm.get_ordinal()
        num_replicas=xm.xrt_world_size()
        filenames_this = filenames[rank::num_replicas]
        
        return filenames_this
    
    def tfrecords_per_worker(filenames):
        """
        Split filenames per worker
        Selects a subset of filenames based on Torch get_worker_info.
        Used as a shard selection function in Dataset.
        """

        filenames = [file for file in filenames]

        assert isinstance(filenames, list)

        worker_info = torch.utils.data.get_worker_info()
        if worker_info is not None:
            wid = worker_info.id
            num_workers = worker_info.num_workers

            return filenames[wid::num_workers]
        else:
            return filenames
    
    def tfrecord_dataset(ds):
        buffer_size = 8 * 1024 * 1024 # 8 MiB per file
        return tf.data.TFRecordDataset(ds, buffer_size=buffer_size)
        
    def parse_tfrecord(self, example):
        
        feature_map = {
            'image/class/label': tf.io.FixedLenFeature([], tf.int64),
            'image/encoded': tf.io.FixedLenFeature([], tf.string),
        }
        
        parsed = tf.io.parse_example(example, feature_map)
    
        label = parsed['image/class/label']  
        raw_img = parsed['image/encoded']

        # TODO: make this a function (?)
        img = tf.io.decode_jpeg(raw_img)                        # (240, 320, 3)
        img = tf.image.convert_image_dtype(img, tf.float32)
        img = tf.image.resize(img, [224, 224])                  # (224, 224, 3)

        transposed_tf_tensor = tf.transpose(img, perm=[2, 0, 1])  # (224, 224, 3) -> (3, 224, 224)
        
        return label, transposed_tf_tensor
    
    ds = ds.interleave(
        lambda x: tf.data.TFRecordDataset(x),
        cycle_length=tf.data.AUTOTUNE, 
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=False,
    ).map(
        parse_tfrecord,
        num_parallel_calls=tf.data.AUTOTUNE, # parallelize across many cores
    )

In [None]:
# glob filename from GCS
# for blob in storage_client.list_blobs(f'{self.train_dir}', prefix=f'{self.train_dir_prefix}', delimiter="/"):
    # train_files.append(blob.public_url.replace("https://storage.googleapis.com/", "gs://")) #"/gcs/"

# for testing
train_files = [
    'gs://imagenet-jt/train/train-00000-of-01024',
    # 'gs://imagenet-jt/train/train-00001-of-01024',
    # 'gs://imagenet-jt/train/train-00002-of-01024',
    # 'gs://imagenet-jt/train/train-00003-of-01024',
    # 'gs://imagenet-jt/train/train-00004-of-01024',
]