In [2]:
from pathlib import Path

import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import tensorflow as tf
import keras
np.set_printoptions(precision=4)

## SketchRNN

In [90]:

DOWNLOAD_ROOT = "http://download.tensorflow.org/data/"
FILENAME = "quickdraw_tutorial_dataset_v1.tar.gz"
filepath = keras.utils.get_file(FILENAME,
                                DOWNLOAD_ROOT + FILENAME,
                                cache_subdir="datasets/quickdraw",
                                extract=True)

Downloading data from http://download.tensorflow.org/data/quickdraw_tutorial_dataset_v1.tar.gz


In [6]:
filepath

'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets/quickdraw\\quickdraw_tutorial_dataset_v1.tar.gz'

In [72]:
quickdraw_dir = Path(filepath).parent
all_files= sorted([str(file) for file in quickdraw_dir.glob("*.*")])
train_files = sorted([str(path) for path in quickdraw_dir.glob("training.tfrecord-*")])
eval_files = sorted([str(path) for path in quickdraw_dir.glob("eval.tfrecord-*")])

In [73]:
all_files

['C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00000-of-00010',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00001-of-00010',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00002-of-00010',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00003-of-00010',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00004-of-00010',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00005-of-00010',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00006-of-00010',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00007-of-00010',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00008-of-00010',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00009-of-00010',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord.classes',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\dat

In [15]:
#读class 文件
with open(quickdraw_dir / "eval.tfrecord.classes") as test_classes_file:
    test_classes = test_classes_file.readlines()
    
with open(quickdraw_dir / "training.tfrecord.classes") as train_classes_file:
    train_classes = train_classes_file.readlines()

In [81]:
assert test_classes==train_classes #验证测试类别标签和训练标签类别一致
class_names = [name.strip().lower() for name in train_classes] #去除类别标签前后空格，转小写
len(class_names) #345个分类

345

In [42]:
def parse(data_batch):
    feature_descriptions = {
        "ink": tf.io.VarLenFeature(dtype=tf.float32),#变长特征
        "shape": tf.io.FixedLenFeature([2], dtype=tf.int64),#定长特征
        "class_index": tf.io.FixedLenFeature([1], dtype=tf.int64)
    }
    examples = tf.io.parse_example(data_batch, feature_descriptions)
    flat_sketches = tf.sparse.to_dense(examples["ink"])
    sketches = tf.reshape(flat_sketches, shape=[tf.size(data_batch), -1, 3])
    lengths = examples["shape"][:, 0]
    labels = examples["class_index"][:, 0]
    return sketches, lengths, labels

In [43]:
def quickdraw_dataset(filepaths, batch_size=32, shuffle_buffer_size=None,
                      n_parse_threads=5, n_read_threads=5, cache=False):
    dataset = tf.data.TFRecordDataset(filepaths,
                                      num_parallel_reads=n_read_threads)
    if cache:
        dataset = dataset.cache()
    if shuffle_buffer_size:
        dataset = dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(parse, num_parallel_calls=n_parse_threads)
    return dataset.prefetch(1)

In [74]:
examplefile='C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\quickdraw\\eval.tfrecord-00004-of-00010'
dataset = tf.data.TFRecordDataset(examplefile)
bytedataset=list(dataset)

In [75]:
feature_descriptions = {
        "ink": tf.io.VarLenFeature(dtype=tf.float32),#变长特征
        "shape": tf.io.FixedLenFeature([2], dtype=tf.int64),#定长特征
        "class_index": tf.io.FixedLenFeature([1], dtype=tf.int64)
    }
dataset_example = tf.io.parse_example(bytedataset, feature_descriptions)

In [78]:
tf.sparse.to_dense(dataset_example["ink"])

<tf.Tensor: shape=(34279, 2040), dtype=float32, numpy=
array([[-0.0863, -0.0354,  0.    , ...,  0.    ,  0.    ,  0.    ],
       [-0.0472, -0.019 ,  0.    , ...,  0.    ,  0.    ,  0.    ],
       [-0.0748, -0.0442,  0.    , ...,  0.    ,  0.    ,  0.    ],
       ...,
       [ 0.1391,  0.0157,  0.    , ...,  0.    ,  0.    ,  0.    ],
       [ 0.0235, -0.0714,  0.    , ...,  0.    ,  0.    ,  0.    ],
       [ 0.0902, -0.0127,  0.    , ...,  0.    ,  0.    ,  0.    ]],
      dtype=float32)>

## jsb_chorales

In [4]:
chorales_file = tf.keras.utils.get_file('jsb_chorales', 
                                       'https://raw.githubusercontent.com/ageron/data/main/jsb_chorales.tgz',
                                        untar=True)

chorales_root = Path(chorales_file)

train_chorales_list=sorted([str(file) for file in chorales_root.glob("train/*.csv")])
valid_chorales_list=sorted([str(file) for file in chorales_root.glob("valid/*.csv")])
test_chorales_list=sorted([str(file) for file in chorales_root.glob("test/*.csv")])
test_chorales_list

['C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_305.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_306.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_307.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_308.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_309.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_310.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_311.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_312.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_313.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_314.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_315.csv',
 'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_31

In [5]:
#看下基本数据结构，四个int64列
csv_example='C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_305.csv'

df = pd.read_csv(csv_example)
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 228 entries, 0 to 227
Data columns (total 4 columns):
 #   Column  Non-Null Count  Dtype
---  ------  --------------  -----
 0   note0   228 non-null    int64
 1   note1   228 non-null    int64
 2   note2   228 non-null    int64
 3   note3   228 non-null    int64
dtypes: int64(4)
memory usage: 7.2 KB


In [6]:
#读取csv文件，最后一行做label，其他行做dataset
#数据集很小，可以直接用tf.data.Dataset.from_tensor_slices(dict(df))读取内存
#experimental.make_csv_dataset 属于高级接口。它支持列类型推断,批处理和重排
#experimental.CsvDataset 属于低级接口,提供粒度更细的控制，不支持列类型推断,须指定每个列的类型

def csvparse(csvfile):
    chorales_notes  = [tf.int64,tf.int64,tf.int64,tf.int64]
    dataset = tf.data.experimental.CsvDataset(csvfile, chorales_notes , header=True) #header必须设置为true否则读取不到列类型
    label = list(dataset)[-1]
    return list(dataset)[:-1],label

In [None]:
#拿两个文件列表map测试
csv_example=['C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_305.csv',
             'C:\\Users\\sec_zca9cu_rw\\.keras\\datasets\\jsb_chorales\\test\\chorale_305.csv']
dataset = tf.data.Dataset.from_tensor_slices(csv_example)
dataset , label = dataset.map(csvparse)

In [None]:
#创建数据集，读取csv文件列表
def create_dataset(csvfiles):
    dataset = tf.data.Dataset.from_tensor_slices(csvfiles)#读取文件列表，转换为tf.data
    dataset , label = dataset.map(csvparse)
    return dataset, label

In [None]:
test_dataset=create_dataset(test_chorales_list)
#list[test_dataset]

In [114]:
for line in test_dataset.take(10):
    print([item.numpy() for item in line])

[65, 60, 57, 53]
[65, 60, 57, 53]
[65, 60, 57, 53]
[65, 60, 57, 53]
[72, 60, 55, 52]
[72, 60, 55, 52]
[70, 60, 55, 52]
[70, 60, 55, 52]
[69, 60, 53, 53]
[69, 60, 53, 53]


In [121]:
len(list(test_dataset))

77