使用flyai进行文本分类,使用的模型是bert模型

加载参数和配置文件

In [None]:
import argparse
import tensorflow as tf
from flyai.dataset import Dataset
from model import Model
from path import MODEL_PATH, LOG_PATH
from data_helper import *
import bert.modeling as modeling


# 超参
parser = argparse.ArgumentParser()
parser.add_argument("-e", "--EPOCHS", default=10, type=int, help="train epochs")
parser.add_argument("-b", "--BATCH", default=16, type=int, help="batch size")
args = parser.parse_args()
# 数据获取辅助类
dataset = Dataset(epochs=args.EPOCHS, batch=args.BATCH)
# 模型操作辅助类
modelpp = Model(dataset)

path = modelpp.get_remote_date("https://www.flyai.com/m/multi_cased_L-12_H-768_A-12.zip")
print(path)
'''
使用tensorflow实现自己的算法
'''

# 参数
learning_rate = 0.0006     # 学习率
num_labels = 3  #类别数
is_training = False

# ——————————————————配置文件——————————————————
# data_root = os.path.splitext(path)[0]
data_root = r'multi_cased_L-12_H-768_A-12\multi_cased_L-12_H-768_A-12'
bert_config_file = os.path.join(data_root, 'bert_config.json')
bert_config = modeling.BertConfig.from_json_file(bert_config_file)
init_checkpoint = os.path.join(data_root, 'bert_model.ckpt')
bert_vocab_file = os.path.join(data_root, 'vocab.txt')

初始化bert模型

In [None]:
# ——————————————————导入数据——————————————————————
input_ids = tf.placeholder(tf.int32, shape=[None, None], name='input_ids')
input_mask = tf.placeholder(tf.int32, shape=[None, None], name='input_masks')
segment_ids = tf.placeholder(tf.int32, shape=[None, None], name='segment_ids')
labels = tf.placeholder(tf.int32, shape=[None, ], name='labels')

# ——————————————————定义神经网络变量——————————————————
# 初始化BERT
model = modeling.BertModel(
                        config=bert_config,
                        is_training=is_training,
                        input_ids=input_ids,
                        input_mask=input_mask,
                        token_type_ids=segment_ids,
                        use_one_hot_embeddings=False)

# 加载bert模型
tvars = tf.trainable_variables()
(assignment, initialized_variable_names) = modeling.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment)
# 获取最后一层。
# 输出[batch_size, seq_length, embedding_size] 如果做seq2seq 或者ner用这个
# output_layer_seq = model.get_sequence_output()  # 这个获取每个token的output
# 这个获取句子的output
output_layer = tf.identity(model.get_pooled_output(), name='output_layer_pooled')

# 根据输出的句向量计算维度
hidden_size = output_layer.shape[-1].value

# 构建W 和 b
output_weights = tf.get_variable(
                "output_weights", [hidden_size, num_labels],
                initializer=tf.truncated_normal_initializer(stddev=0.02))

output_bias = tf.get_variable(
                "output_bias", [num_labels], initializer=tf.zeros_initializer())

进行预测并计算准确率

In [None]:
with tf.name_scope("accuracy"):
    # 准确率
    correct_pred = tf.equal(labels, tf.cast(pred, tf.int32))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='acc')

with tf.name_scope("optimize"):
    # 将label进行onehot转化
    one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
    # 构建损失函数
    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)

    # 优化器
    train_op = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

进行训练

In [None]:
with tf.name_scope("summary"):
    tf.summary.scalar("loss", loss)
    tf.summary.scalar("acc", accuracy)
    merged_summary = tf.summary.merge_all()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    train_writer = tf.summary.FileWriter(LOG_PATH, sess.graph)

    print('dataset.get_step:', dataset.get_step())
    for step in range(dataset.get_step()):
        x_train, y_train, _, _ = dataset.next_batch(args.BATCH)
        x_input_ids = x_train[0]
        x_input_mask = x_train[1]
        x_segment_ids = x_train[2]

        fetches = [train_op, loss, accuracy]
        feed_dict = {input_ids: x_input_ids, input_mask: x_input_mask, segment_ids: x_segment_ids, labels: y_train}
        _, loss_, acc_ = sess.run(fetches, feed_dict=feed_dict)

        summary = sess.run(merged_summary, feed_dict=feed_dict)
        train_writer.add_summary(summary, step)
        cur_step = str(step + 1) + "/" + str(dataset.get_step())
        print('The Current step per total: {} | The Current loss: {} | The current acc: {}'.
              format(cur_step, loss_, acc_))
        if step % 50 == 0:
            modelpp.save_model(sess, MODEL_PATH, overwrite=True)
    modelpp.save_model(sess, MODEL_PATH, overwrite=True)
