In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os
from tqdm import tqdm
from joblib import Parallel, delayed
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from functools import partial
from sklearn.model_selection import StratifiedKFold

In [None]:
import sys
sys.path.append('../src')

from dataset import CassavaLeafDataModule

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
 
def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=value.reshape(-1)))
 
def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_example(feature1, feature2):
    feature = {
        'image': _float_feature(feature1),
        'target': _float_feature(feature2)
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()
 
 # Read image and resize it
def read_image(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img

def write_file(data, torv, foldi):
    pbar = tqdm(total=len(data))
    len_file = len(data) // 1000 + 1
    for fi in range(len_file):
        path = '../tfrecord/fold{}/{}{:02}.tfrec'.format(foldi, torv, fi)
        with tf.io.TFRecordWriter(path) as writer:
            for i in range(fi * 1000, min((fi+1)*1000, len(data))):
                img, target = data[i]
                example = serialize_example(img, target)
                writer.write(example)
                pbar.update(1)
        

def get_tf_records(df, trn_idx, val_idx, fold):
    os.makedirs(f'../tfrecord/fold{fold}', exist_ok=True)
    dm = CassavaLeafDataModule(df, trn_idx, val_idx)
    dm.setup()
    dt, dv = dm.train_dataset, dm.val_dataset
    write_file(dt, 'train', fold)
    write_file(dv, 'val', fold)
        

In [None]:
import time

In [None]:
df_train = pd.read_csv('../input/train.csv')
folds = StratifiedKFold(
    n_splits=5,
    shuffle=True,
    random_state=1).split(
        np.arange(df_train.shape[0]), df_train.label.values)
for fold, (trn_idx, val_idx) in enumerate(folds):
    dm = CassavaLeafDataModule(df_train, trn_idx, val_idx)
    dm.setup()
    dt, dv = dm.train_dataset, dm.val_dataset
    loader = dm.train_dataloader()
    t = time.time()
    for i, _ in enumerate(loader):
        if i == 100:
            break
    print(time.time()-t)
    break

In [None]:
df_train = pd.read_csv('../input/train.csv')
folds = StratifiedKFold(
    n_splits=5,
    shuffle=True,
    random_state=1).split(
        np.arange(df_train.shape[0]), df_train.label.values)

In [None]:
Parallel(n_jobs=4)([delayed(get_tf_records)(df_train, trn_idx, val_idx, fold) for fold, (trn_idx, val_idx) in enumerate(folds)])

In [7]:
import torch
import glob
from tfrecord.torch.dataset import MultiTFRecordDataset

tfrecord_pattern = "../tfrecord/fold0/train{}.tfrec"
len_file = len(glob.glob(tfrecord_pattern.format('*')))
splits = dict(zip(['{:02}'.format(i) for i in range(len_file)], [1/len_file]*len_file))

In [8]:
description = {"image": "float", "target": "float"}
dataset = MultiTFRecordDataset(tfrecord_pattern, None, splits, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=4)

In [9]:
t = time.time()
for i, _ in enumerate(loader):
    if i == 100:
        break
print(time.time() - t)

35.40824294090271
