In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import os
import glob
import cv2 # 专用于处理计算机视觉场景文件的包 pip install opencv-python 执行安装
import json
import time
# 查询系统可用的 GPU
physical_devices = tf.config.experimental.list_physical_devices('GPU')
# 确保有可用的 GPU 如果没有, 则会报错
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
# 设置参数,该段务必在运行jupyter的第一段代码执行，否则会无法初始化成功
# 仅在需要时申请显存空间（程序初始运行时消耗很少的显存，随着程序的运行而动态申请显存）
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [17]:
DP_DIR = '/data/python/tensorflow/shuffle_data_gzip/'
BASE_SIZE = 256
NCSVS = 100
NCATS = 340
STEPS = 800
EPOCHS = 16
size = 64
batchsize = 680
np.random.seed(seed = 1987)


# def f2cat(filename :str) -> str:
#     return filename.split('.')[0]

# def list_all_categories():
#     files = os.listdir('/data/python/tensorflow/shuffle_data_gzip/')
#     return sorted([f2cat(f) for f in files], key = str.lower)

# def preds2catids(predictions):
#     return pd.DataFrame(np.argsort(-predictions, axis = 1)[:, :3], columns = ['a','b','c'])

# def top_3_accuracy(y_true, y_pred):
#     return tf.keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k = 3)

In [18]:
def parse_csv(line):
    column_default = [
        tf.constant('0', dtype = tf.string),
        tf.constant(0, dtype = tf.int32)
    ]
    columns = tf.io.decode_csv(line, column_default, select_cols = [1,5])
    features = columns[0]
    label = columns[1]
    return features, label

In [19]:
def draw_cv2(raw_strokes, size = 64, lw = 6):
    raw_strokes = eval(raw_strokes.numpy())
    img = np.zeros((256,256), np.uint8)
    for stroke in raw_strokes:
        for i in range(len(stroke[0]) - 1):
            _ = cv2.line(img, (stroke[0][i],stroke[1][i]),
                (stroke[0][i+1],stroke[1][i+1]), 255, lw
            )
    return cv2.resize(img, (size, size))

In [20]:
def tf_draw_cv2(image, label):
    [image] = tf.py_function(draw_cv2, [image], [tf.float32])
    image = tf.reshape(image, (64,64,1))
    label = tf.one_hot(label, depth = NCATS)
    image.set_shape((64,64,1))
    label.set_shape((340,))

    return image, label

In [21]:
fileList = glob.glob('/data/python/tensorflow/shuffle_data_gzip/*.csv.gz')

In [22]:
fileList[0]

'/data/python/tensorflow/shuffle_data_gzip/train_k22.csv.gz'

In [23]:
train_ds = tf.data.Dataset.from_tensor_slices(fileList[:-1])

train_ds = train_ds.interleave(
        lambda x: tf.data.TextLineDataset(x, compression_type= 'GZIP').skip(1).map(parse_csv,num_parallel_calls = tf.data.experimental.AUTOTUNE),
        cycle_length = 4, block_length = 16, num_parallel_calls = tf.data.experimental.AUTOTUNE
    )

# 返回tensor
train_ds = train_ds.map(tf_draw_cv2, num_parallel_calls = tf.data.experimental.AUTOTUNE)

train_ds = train_ds.prefetch(buffer_size = tf.data.experimental.AUTOTUNE).shuffle(3000).batch(1024)

In [24]:
class MobileNetModel(tf.keras.models.Model):
    def __init__(self, size, n_labels, **kwargs):
        super(MobileNetModel, self).__init__(**kwargs)
        self.base_model = tf.keras.applications.MobileNet(
                input_shape = (size, size, 1),
                include_top = False,
                weights=None,
                classes = n_labels
            )
        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(1024, activation='relu')
        self.outputs = tf.keras.layers.Dense(n_labels, activation='softmax')


    def call(self, inputs):
        x = self.base_model(inputs)
        x = self.flatten(x)
        x = self.dense(x)
        output_ = self.outputs(x)
        return output_

In [26]:
model = MobileNetModel(size = 64, n_labels = NCATS)

loss_object = tf.keras.losses.CategoricalCrossentropy()

learning_rate = 0.002
optimizer = tf.keras.optimizers.Adam(learning_rate= learning_rate)

train_loss = tf.keras.metrics.Mean(name = 'train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy('train_accuracy')
train_top3_accuracy = tf.keras.metrics.TopKCategoricalAccuracy(k =3 , name = 'train_top_3_categorical_accuracy')

test_loss = tf.keras.metrics.Mean(name = 'test_loss')
test_accuracy = tf.keras.metrics.CategoricalAccuracy('test_accuracy')
test_top3_accuracy = tf.keras.metrics.TopKCategoricalAccuracy(k =3 , name = 'test_top_3_categorical_accuracy')

def train_one_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)
    train_top3_accuracy(labels, predictions)

def val_one_step(images, labels):
    predictions = model(images)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)
    test_top3_accuracy(labels, predictions)

In [27]:
for a, b in train_ds.take(1):
    print(a.shape, b.shape)

(1024, 64, 64, 1) (1024, 340)


In [28]:
val_ds = tf.data.TextLineDataset(fileList[-1], compression_type= 'GZIP').skip(1).map(parse_csv,num_parallel_calls = tf.data.experimental.AUTOTUNE)
val_ds = val_ds.map(tf_draw_cv2, num_parallel_calls = tf.data.experimental.AUTOTUNE)
val_ds = val_ds.batch(1024)  

In [None]:
EPOCHS = 10
for epoch in range(EPOCHS):
    # 在下一个epoch开始的时候，重置评估指标
    train_loss.reset_states()
    train_accuracy.reset_states()
    train_top3_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()
    test_top3_accuracy.reset_states()

    for step,(images, labels) in enumerate(train_ds):
        train_one_step(images, labels)

        if step %200 == 0:
             print('step :{0}, Samples: {1}, Train Loss: {2}, Train Accuracy: {3}, Train Top3 Accuracy: {4}'.format(
                 step, (step + 1) * 1024, train_loss.result(), train_accuracy.result() * 100,
                 train_top3_accuracy.result() * 100
             ))

        if step > 1000:
            break

        
    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(
        epoch + 1,
        train_loss.result(),
        train_accuracy.result() * 100,
        train_top3_accuracy.result() * 100,
        test_loss.result(),
        test_accuracy.result() * 100,
        test_top3_accuracy.result() * 100
    ))

##### 总结处理数据的两个关键步骤
- 通过parse_csv 解析csv数据
- 通过draw_cv2 转换为图片数据