In [1]:
import tensorflow as tf

# 單一CSV檔案

In [2]:
titanic_file_path = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")

Downloading data from https://storage.googleapis.com/tf-datasets/titanic/train.csv


In [3]:
titanic_csv_ds = tf.data.experimental.make_csv_dataset(
    titanic_file_path,
    batch_size=5,
    label_name='survived',
    ignore_errors=True,)

#file_pattern：指定的資料來源
#batch_size：單一批次處理的數量
#column_names：欄位的名稱，若無指定則會自動推論名稱
#column_defaults：欄位中指定的格式，如float32, float64, int32, int64, string
#label_name：指定模型推論的欄位
#select_columns：可以挑選指定的欄位資料
#field_delim：預設為csv的","，可用於指定分隔資料的符號
#use_quote_delim：預設值為True，若設定為False則會將雙引號讀取為常規字串
#na_value：可以設定是否識別NA / NaN值
#header：是否包含檔頭
#num_epochs：可以指定重複該數據的次數
#shuffle：隨機抽換資料
#shuffle_buffer_size：隨機抽換資料的大小，數字越大，記憶體消耗越大
#shuffle_seed：隨機種子
#prefetch_buffer_size：預設為自動調整，通常是配合批次量來處理
#num_parallel_reads：指定讀取資料的線程數，預設為1
#sloppy:如果設定為True，則會盡可能高效率的方式給予讀取資料，但不保證排序或資料是否有問題
#num_rows_for_inference：預設為100，如果設定為None則可以讀取所有的資料
#compression_type：預設為無壓縮，支援ZLIB與GZIP格式
#ignore_errors：忽略CSV文件解析過程中的錯誤

In [4]:
for batch, label in titanic_csv_ds.take(1): #資料數
  for key, value in batch.items(): #欄位
    print(f"{key:1s}: {value}")
  print(f"{'label':1s}: {label}")

sex: [b'male' b'male' b'male' b'female' b'female']
age: [27. 21. 32. 21.  2.]
n_siblings_spouses: [0 0 0 2 0]
parch: [2 0 0 2 1]
fare: [211.5      7.925    8.3625 262.375   10.4625]
class: [b'First' b'Third' b'Third' b'First' b'Third']
deck: [b'C' b'unknown' b'unknown' b'B' b'G']
embark_town: [b'Cherbourg' b'Southampton' b'Southampton' b'Cherbourg' b'Southampton']
alone: [b'n' b'y' b'y' b'n' b'n']
label: [0 0 0 1 0]


#單一gz檔案

In [5]:
traffic_volume_csv_gz = tf.keras.utils.get_file(
    'Metro_Interstate_Traffic_Volume.csv.gz', 
    "https://archive.ics.uci.edu/ml/machine-learning-databases/00492/Metro_Interstate_Traffic_Volume.csv.gz",
    cache_dir='.', cache_subdir='traffic')

Downloading data from https://archive.ics.uci.edu/ml/machine-learning-databases/00492/Metro_Interstate_Traffic_Volume.csv.gz


In [6]:
traffic_volume_csv_gz_ds = tf.data.experimental.make_csv_dataset(
    traffic_volume_csv_gz,
    batch_size=256,
    label_name='traffic_volume',
    num_epochs=1,
    compression_type="GZIP")

for batch, label in traffic_volume_csv_gz_ds.take(1):
  for key, value in batch.items():
    print(f"{key:20s}: {value[:5]}")
  print()
  print(f"{'label':20s}: {label[:5]}")

holiday             : [b'None' b'None' b'Labor Day' b'None' b'None']
temp                : [289.18 282.53 288.78 261.8  264.76]
rain_1h             : [0. 0. 0. 0. 0.]
snow_1h             : [0. 0. 0. 0. 0.]
clouds_all          : [ 1 90  0  1 20]
weather_main        : [b'Clear' b'Clouds' b'Clear' b'Mist' b'Clouds']
weather_description : [b'sky is clear' b'overcast clouds' b'Sky is Clear' b'mist' b'few clouds']
date_time           : [b'2012-10-04 10:00:00' b'2012-10-11 23:00:00' b'2013-09-02 00:00:00'
 b'2013-01-06 02:00:00' b'2013-02-17 14:00:00']

label               : [4603 1284 1041  507 4422]


In [7]:
traffic_volume_csv_gz_ds

<PrefetchDataset shapes: (OrderedDict([(holiday, (None,)), (temp, (None,)), (rain_1h, (None,)), (snow_1h, (None,)), (clouds_all, (None,)), (weather_main, (None,)), (weather_description, (None,)), (date_time, (None,))]), (None,)), types: (OrderedDict([(holiday, tf.string), (temp, tf.float32), (rain_1h, tf.float32), (snow_1h, tf.float32), (clouds_all, tf.int32), (weather_main, tf.string), (weather_description, tf.string), (date_time, tf.string)]), tf.int32)>

#透過快取(Caching)或快照(Snapshot)處理數據

In [8]:
%%time
for i, (batch, label) in enumerate(traffic_volume_csv_gz_ds.repeat(20)):
  if i % 40 == 0:
    print('.', end='')
print()

...............................................................................................
CPU times: user 16 s, sys: 2.92 s, total: 18.9 s
Wall time: 13 s


In [9]:
#快取(Caching)將數據在第一次epoch就做快取
%%time
caching = traffic_volume_csv_gz_ds.cache().shuffle(1000)

for i, (batch, label) in enumerate(caching.shuffle(1000).repeat(20)):
  if i % 40 == 0:
    print('.', end='')
print()

...............................................................................................
CPU times: user 1.58 s, sys: 121 ms, total: 1.7 s
Wall time: 1.42 s


In [10]:
#快照(Snapshot)將數據臨時儲存
%%tim
snapshot = tf.data.experimental.snapshot('titanic.tfsnap')
snapshotting = traffic_volume_csv_gz_ds.apply(snapshot).shuffle(1000)

for i, (batch, label) in enumerate(snapshotting.shuffle(1000).repeat(20)):
  if i % 40 == 0:
    print('.', end='')
print()

UsageError: Cell magic `%%tim` not found.
