In [1]:
import random
import time
import sys
import os
import datetime

from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score
import numpy as np
import tensorflow as tf

In [2]:
# 配置config
class TrainConfig(object):
    epochs = 10
    decay_rate = 0.92
    learning_rate = 0.01
    evaluate_every = 100
    checkpoint_every = 100
    max_grad_norm = 3.0


class ModelConfig(object):
    hidden_layers = [200]
    dropout_keep_prob = 0.6


class Config(object):
    batch_size = 32
    num_skills = 124
    input_size = num_skills * 2

    trainConfig = TrainConfig()
    modelConfig = ModelConfig()
    

# 实例化config
config = Config()

In [3]:
# 生成数据
class DataGenerator(object):
    # 导入的seqs是train_seqs，或者是test_seqs
    def __init__(self, fileName, config):
        self.fileName = fileName
        self.train_seqs = []
        self.test_seqs = []
        self.infer_seqs = []
        self.batch_size = config.batch_size
        self.pos = 0
        self.end = False
        self.num_skills = config.num_skills
        self.skills_to_int = {}  # 知识点到索引的映射
        self.int_to_skills = {}  # 索引到知识点的映射

    def read_file(self):
        # 从文件中读取数据，返回读取出来的数据和知识点个数
        # 保存每个学生的做题信息 {学生id: [[知识点id，答题结果], [知识点id，答题结果], ...]}，用一个二元列表来表示一个学生的答题信息
        seqs_by_student = {}
        skills = []  # 统计知识点的数量，之后输入的向量长度就是两倍的知识点数量
        count = 0
        with open(self.fileName, 'r') as f:
            for line in f:
                fields = line.strip().split(" ")  # 一个列表，[学生id，知识点id，答题结果]
                student, skill, is_correct = int(fields[0]), int(fields[1]), int(fields[2])
                skills.append(skill)  # skill实际上是用该题所属知识点来表示的
                seqs_by_student[student] = seqs_by_student.get(student, []) + [[skill, is_correct]]  # 保存每个学生的做题信息
        return seqs_by_student, list(set(skills))

    def gen_dict(self, unique_skills):
        """
        构建知识点映射表，将知识点id映射到[0, 1, 2...]表示
        :param unique_skills: 无重复的知识点列表
        :return:
        """
        sorted_skills = sorted(unique_skills)
        skills_to_int = {}
        int_to_skills = {}
        for i in range(len(sorted_skills)):
            skills_to_int[sorted_skills[i]] = i
            int_to_skills[i] = sorted_skills[i]

        self.skills_to_int = skills_to_int
        self.int_to_skills = int_to_skills

    def split_dataset(self, seqs_by_student, sample_rate=0.2, random_seed=1):
        # 将数据分割成测试集和训练集
        sorted_keys = sorted(seqs_by_student.keys())  # 得到排好序的学生id的列表

        random.seed(random_seed)
        # 随机抽取学生id，将这部分学生作为测试集
        test_keys = set(random.sample(sorted_keys, int(len(sorted_keys) * sample_rate)))

        # 此时是一个三层的列表来表示的，最外层的列表中的每一个列表表示一个学生的做题信息
        test_seqs = [seqs_by_student[k] for k in seqs_by_student if k in test_keys]
        train_seqs = [seqs_by_student[k] for k in seqs_by_student if k not in test_keys]
        return train_seqs, test_seqs

    def gen_attr(self, is_infer=False):
        """
        生成待处理的数据集
        :param is_infer: 判断当前是训练模型还是利用模型进行预测
        :return:
        """
        if is_infer:
            seqs_by_students, skills = self.read_file()
            self.infer_seqs = seqs_by_students
        else:
            seqs_by_students, skills = self.read_file()
            train_seqs, test_seqs = self.split_dataset(seqs_by_students)
            self.train_seqs = train_seqs
            self.test_seqs = test_seqs

        self.gen_dict(skills)  # 生成知识点到索引的映射字典

    def pad_sequences(self, sequences, maxlen=None, value=0.):
        # 按每个batch中最长的序列进行补全, 传入的sequences是二层列表
        # 统计一个batch中每个序列的长度，其实等于seqs_len
        lengths = [len(s) for s in sequences]
        # 统计下该batch中序列的数量
        nb_samples = len(sequences)
        # 如果没有传入maxlen参数就自动获取最大的序列长度
        if maxlen is None:
            maxlen = np.max(lengths)
        # 构建x矩阵
        x = (np.ones((nb_samples, maxlen)) * value).astype(np.int32)

        # 遍历batch，去除每一个序列
        for idx, s in enumerate(sequences):
            trunc = np.asarray(s, dtype=np.int32)
            x[idx, :len(trunc)] = trunc

        return x

    def num_to_one_hot(self, num, dim):
        # 将题目转换成one-hot的形式， 其中dim=num_skills * 2，前半段表示错误，后半段表示正确
        base = np.zeros(dim)
        if num >= 0:
            base[num] += 1
        return base

    def format_data(self, seqs):
        # 生成输入数据和输出数据，输入数据是每条序列的前n-1个元素，输出数据是每条序列的后n-1个元素

        # 统计一个batch_size中每条序列的长度，在这里不对序列固定长度，通过条用tf.nn.dynamic_rnn让序列长度可以不固定
        seq_len = np.array(list(map(lambda seq: len(seq) - 1, seqs)))
        max_len = max(seq_len)  # 获得一个batch_size中最大的长度
        # i表示第i条数据，j只从0到len(i)-1，x作为输入只取前len(i)-1个，sequences=[j[0] + num_skills * j[1], ....]
        # 此时要将知识点id j[0] 转换成index表示
        x_sequences = np.array([[(self.skills_to_int[j[0]] + self.num_skills * j[1]) for j in i[:-1]] for i in seqs])
        # 将输入的序列用-1进行补全，补全后的长度为当前batch的最大序列长度
        x = self.pad_sequences(x_sequences, maxlen=max_len, value=-1)

        # 构建输入值input_x，x为一个二层列表，i表示一个学生的做题信息，也就是一个序列，j就是一道题的信息
        input_x = np.array([[self.num_to_one_hot(j, self.num_skills * 2) for j in i] for i in x])

        # 遍历batch_size，然后取每条序列的后len(i)-1 个元素中的知识点id为target_id
        target_id_seqs = np.array([[self.skills_to_int[j[0]] for j in i[1:]] for i in seqs])
        target_id = self.pad_sequences(target_id_seqs, maxlen=max_len, value=0)

        # 同target_id
        target_correctness_seqs = np.array([[j[1] for j in i[1:]] for i in seqs])
        target_correctness = self.pad_sequences(target_correctness_seqs, maxlen=max_len, value=0)

        return dict(input_x=input_x, target_id=target_id, target_correctness=target_correctness,
                    seq_len=seq_len, max_len=max_len)

    def next_batch(self, seqs):
        # 接收一个序列，生成batch

        length = len(seqs)
        num_batchs = length // self.batch_size
        start = 0
        for i in range(num_batchs):
            batch_seqs = seqs[start: start + self.batch_size]
            start += self.batch_size
            params = self.format_data(batch_seqs)

            yield params

            
fileName = "./data/assistments.txt"
dataGen = DataGenerator(fileName, config)
dataGen.gen_attr()

In [4]:
train_seqs = dataGen.train_seqs
params = next(dataGen.next_batch(train_seqs))
print("skill num: {}".format(len(dataGen.skills_to_int)))
print("train_seqs length: {}".format(len(dataGen.train_seqs)))
print("test_seqs length: {}".format(len(dataGen.test_seqs)))
print("input_x shape: {}".format(params['input_x'].shape))
print(params["input_x"][1][0])

skill num: 124
train_seqs length: 3374
test_seqs length: 843
input_x shape: (32, 153, 248)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0.]


In [5]:
# 构建模型
class TensorFlowDKT(object):
    def __init__(self, config):
        # 导入配置好的参数
        self.hiddens = hiddens = config.modelConfig.hidden_layers
        self.num_skills = num_skills = config.num_skills
        self.input_size = input_size = config.input_size
        self.batch_size = batch_size = config.batch_size
        self.keep_prob_value = config.modelConfig.dropout_keep_prob

        # 定义需要喂给模型的参数
        self.max_steps = tf.placeholder(tf.int32, name="max_steps")  # 当前batch中最大序列长度
        self.input_data = tf.placeholder(tf.float32, [batch_size, None, input_size], name="input_x")

        self.sequence_len = tf.placeholder(tf.int32, [batch_size], name="sequence_len")
        self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")  # dropout keep prob

        self.target_id = tf.placeholder(tf.int32, [batch_size, None], name="target_id")
        self.target_correctness = tf.placeholder(tf.float32, [batch_size, None], name="target_correctness")
        self.flat_target_correctness = None

        # 构建lstm模型结构
        hidden_layers = []
        for idx, hidden_size in enumerate(hiddens):
            lstm_layer = tf.nn.rnn_cell.LSTMCell(num_units=hidden_size, state_is_tuple=True)
            hidden_layer = tf.nn.rnn_cell.DropoutWrapper(cell=lstm_layer,
                                                         output_keep_prob=self.keep_prob)
            hidden_layers.append(hidden_layer)
        self.hidden_cell = tf.nn.rnn_cell.MultiRNNCell(cells=hidden_layers, state_is_tuple=True)

        # 采用动态rnn，动态输入序列的长度
        outputs, self.current_state = tf.nn.dynamic_rnn(cell=self.hidden_cell,
                                                        inputs=self.input_data,
                                                        sequence_length=self.sequence_len,
                                                        dtype=tf.float32)

        # 隐层到输出层的权重系数[最后隐层的神经元数量，知识点数]
        output_w = tf.get_variable("W", [hiddens[-1], num_skills])
        output_b = tf.get_variable("b", [num_skills])

        self.output = tf.reshape(outputs, [batch_size * self.max_steps, hiddens[-1]])
        # 因为权值共享的原因，对生成的矩阵[batch_size * self.max_steps, num_skills]中的每一行都加上b
        self.logits = tf.matmul(self.output, output_w) + output_b

        self.mat_logits = tf.reshape(self.logits, [batch_size, self.max_steps, num_skills])

        # 对每个batch中每个序列中的每个时间点的输出中的每个值进行sigmoid计算，这里的值表示对某个知识点的掌握情况，
        # 每个时间点都会输出对所有知识点的掌握情况
        self.pred_all = tf.sigmoid(self.mat_logits, name="pred_all")

        # 计算损失loss
        flat_logits = tf.reshape(self.logits, [-1])

        flat_target_correctness = tf.reshape(self.target_correctness, [-1])
        self.flat_target_correctness = flat_target_correctness

        flat_base_target_index = tf.range(batch_size * self.max_steps) * num_skills

        # 因为flat_logits的长度为batch_size * num_steps * num_skills，我们要根据每一步的target_id将其长度变成batch_size * num_steps
        flat_base_target_id = tf.reshape(self.target_id, [-1])

        flat_target_id = flat_base_target_id + flat_base_target_index
        # gather是从一个tensor中切片一个子集
        flat_target_logits = tf.gather(flat_logits, flat_target_id)

        # 对切片后的数据进行sigmoid转换
        self.pred = tf.sigmoid(tf.reshape(flat_target_logits, [batch_size, self.max_steps]), name="pred")
        # 将sigmoid后的值表示为0或1
        self.binary_pred = tf.cast(tf.greater_equal(self.pred, 0.5), tf.float32, name="binary_pred")

        # 定义损失函数
        with tf.name_scope("loss"):
            # flat_target_logits_sigmoid = tf.nn.log_softmax(flat_target_logits)
            # self.loss = -tf.reduce_mean(flat_target_correctness * flat_target_logits_sigmoid)
            self.loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=flat_target_correctness,
                                                                               logits=flat_target_logits))


In [6]:
# 训练模型
def mean(item):
    return sum(item) / len(item)


def gen_metrics(sequence_len, binary_pred, pred, target_correctness):
    """
    生成auc和accuracy的指标值
    :param sequence_len: 每一个batch中各序列的长度组成的列表
    :param binary_pred:
    :param pred:
    :param target_correctness:
    :return:
    """
    binary_preds = []
    preds = []
    target_correctnesses = []
    for seq_idx, seq_len in enumerate(sequence_len):
        binary_preds.append(binary_pred[seq_idx, :seq_len])
        preds.append(pred[seq_idx, :seq_len])
        target_correctnesses.append(target_correctness[seq_idx, :seq_len])

    new_binary_pred = np.concatenate(binary_preds)
    new_pred = np.concatenate(preds)
    new_target_correctness = np.concatenate(target_correctnesses)

    auc = roc_auc_score(new_target_correctness, new_pred)
    accuracy = accuracy_score(new_target_correctness, new_binary_pred)
    precision = precision_score(new_target_correctness, new_binary_pred)
    recall = recall_score(new_target_correctness, new_binary_pred)

    return auc, accuracy, precision, recall

In [7]:
class DKTEngine(object):

    def __init__(self):
        self.config = Config()
        self.train_dkt = None
        self.test_dkt = None
        self.sess = None
        self.global_step = 0

    def add_gradient_noise(self, grad, stddev=1e-3, name=None):
        """
        Adds gradient noise as described in http://arxiv.org/abs/1511.06807 [2].
        """
        with tf.op_scope([grad, stddev], name, "add_gradient_noise") as name:
            grad = tf.convert_to_tensor(grad, name="grad")
            gn = tf.random_normal(tf.shape(grad), stddev=stddev)
            return tf.add(grad, gn, name=name)

    def train_step(self, params, train_op, train_summary_op, train_summary_writer):
        """
        A single training step
        """
        dkt = self.train_dkt
        sess = self.sess
        global_step = self.global_step

        feed_dict = {dkt.input_data: params['input_x'],
                     dkt.target_id: params['target_id'],
                     dkt.target_correctness: params['target_correctness'],
                     dkt.max_steps: params['max_len'],
                     dkt.sequence_len: params['seq_len'],
                     dkt.keep_prob: self.config.modelConfig.dropout_keep_prob}

        _, step, summaries, loss, binary_pred, pred, target_correctness = sess.run(
            [train_op, global_step, train_summary_op, dkt.loss, dkt.binary_pred, dkt.pred, dkt.target_correctness],
            feed_dict)

        auc, accuracy, precision, recall = gen_metrics(params['seq_len'], binary_pred, pred, target_correctness)

        time_str = datetime.datetime.now().isoformat()
        print("train: {}: step {}, loss {}, acc {}, auc: {}, precision: {}, recall: {}".format(time_str, step, loss, accuracy, 
                                                                                               auc, precision, recall))
        train_summary_writer.add_summary(summaries, step)

    def dev_step(self, params, dev_summary_op, writer=None):
        """
        Evaluates model on a dev set
        """
        dkt = self.test_dkt
        sess = self.sess
        global_step = self.global_step

        feed_dict = {dkt.input_data: params['input_x'],
                     dkt.target_id: params['target_id'],
                     dkt.target_correctness: params['target_correctness'],
                     dkt.max_steps: params['max_len'],
                     dkt.sequence_len: params['seq_len'],
                     dkt.keep_prob: 1.0}
        step, summaries, loss, pred, binary_pred, target_correctness = sess.run(
            [global_step, dev_summary_op, dkt.loss, dkt.pred, dkt.binary_pred, dkt.target_correctness],
            feed_dict)

        auc, accuracy, precision, recall = gen_metrics(params['seq_len'], binary_pred, pred, target_correctness)

        if writer:
            writer.add_summary(summaries, step)

        return loss, accuracy, auc, precision, recall

    def run_epoch(self, fileName):
        """
        训练模型
        :param filePath:
        :return:
        """

        # 实例化配置参数对象
        config = Config()

        # 实例化数据生成对象
        dataGen = DataGenerator(fileName, config)
        dataGen.gen_attr()  # 生成训练集和测试集

        train_seqs = dataGen.train_seqs
        test_seqs = dataGen.test_seqs

        session_conf = tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=False
        )
        sess = tf.Session(config=session_conf)
        self.sess = sess

        with sess.as_default():
            # 实例化dkt模型对象
            with tf.name_scope("train"):
                with tf.variable_scope("dkt", reuse=None):
                    train_dkt = TensorFlowDKT(config)

            with tf.name_scope("test"):
                with tf.variable_scope("dkt", reuse=True):
                    test_dkt = TensorFlowDKT(config)

            self.train_dkt = train_dkt
            self.test_dkt = test_dkt

            global_step = tf.Variable(0, name="global_step", trainable=False)
            self.global_step = global_step

            # 定义一个优化器
            optimizer = tf.train.AdamOptimizer(config.trainConfig.learning_rate)
            grads_and_vars = optimizer.compute_gradients(train_dkt.loss)

            # 对梯度进行截断，并且加上梯度噪音
            grads_and_vars = [(tf.clip_by_norm(g, config.trainConfig.max_grad_norm), v)
                              for g, v in grads_and_vars if g is not None]
            # grads_and_vars = [(self.add_gradient_noise(g), v) for g, v in grads_and_vars]

            # 定义图中最后的节点
            train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step, name="train_op")

            # 保存各种变量或结果的值
            grad_summaries = []
            for g, v in grads_and_vars:
                if g is not None:
                    grad_hist_summary = tf.summary.histogram("{}/grad/hist".format(v.name), g)
                    sparsity_summary = tf.summary.scalar("{}/grad/sparsity".format(v.name), tf.nn.zero_fraction(g))
                    grad_summaries.append(grad_hist_summary)
                    grad_summaries.append(sparsity_summary)
            grad_summaries_merged = tf.summary.merge(grad_summaries)

            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs", timestamp))
            print("writing to {}".format(out_dir))

            # 训练时的 Summaries
            train_loss_summary = tf.summary.scalar("loss", train_dkt.loss)
            train_summary_op = tf.summary.merge([train_loss_summary, grad_summaries_merged])
            train_summary_dir = os.path.join(out_dir, "summaries", "train")
            train_summary_writer = tf.summary.FileWriter(train_summary_dir, sess.graph)

            # 测试时的 summaries
            test_loss_summary = tf.summary.scalar("loss", test_dkt.loss)
            dev_summary_op = tf.summary.merge([test_loss_summary])
            dev_summary_dir = os.path.join(out_dir, "summaries", "dev")
            dev_summary_writer = tf.summary.FileWriter(dev_summary_dir, sess.graph)

            saver = tf.train.Saver(tf.global_variables())

            sess.run(tf.global_variables_initializer())

            print("初始化完毕，开始训练")
            for i in range(config.trainConfig.epochs):
                np.random.shuffle(train_seqs)
                for params in dataGen.next_batch(train_seqs):
                    # 批次获得训练集，训练模型
                    self.train_step(params, train_op, train_summary_op, train_summary_writer)

                    current_step = tf.train.global_step(sess, global_step)
                    # train_step.run(feed_dict={x: batch_train[0], y_actual: batch_train[1], keep_prob: 0.5})
                    # 对结果进行记录
                    if current_step % config.trainConfig.evaluate_every == 0:
                        print("\nEvaluation:")
                        # 获得测试数据

                        losses = []
                        accuracys = []
                        aucs = []
                        precisions = []
                        recalls = []
                        for params in dataGen.next_batch(test_seqs):
                            loss, accuracy, auc, precision, recall = self.dev_step(params, dev_summary_op, writer=None)
                            losses.append(loss)
                            accuracys.append(accuracy)
                            aucs.append(auc)
                            precisions.append(precision)
                            recalls.append(recall)

                        time_str = datetime.datetime.now().isoformat()
                        print("dev: {}, step: {}, loss: {}, acc: {}, auc: {}, precision: {}, recall: {}".
                              format(time_str, current_step, mean(losses), mean(accuracys), mean(aucs), mean(precisions), mean(recalls)))

                    if current_step % config.trainConfig.checkpoint_every == 0:
                        path = saver.save(sess, "model/my-model", global_step=current_step)
                        print("Saved model checkpoint to {}\n".format(path))


if __name__ == "__main__":
    fileName = "./data/assistments.txt"
    dktEngine = DKTEngine()
    dktEngine.run_epoch(fileName)


INFO:tensorflow:Summary name dkt/rnn/multi_rnn_cell/cell_0/lstm_cell/kernel:0/grad/hist is illegal; using dkt/rnn/multi_rnn_cell/cell_0/lstm_cell/kernel_0/grad/hist instead.
INFO:tensorflow:Summary name dkt/rnn/multi_rnn_cell/cell_0/lstm_cell/kernel:0/grad/sparsity is illegal; using dkt/rnn/multi_rnn_cell/cell_0/lstm_cell/kernel_0/grad/sparsity instead.
INFO:tensorflow:Summary name dkt/rnn/multi_rnn_cell/cell_0/lstm_cell/bias:0/grad/hist is illegal; using dkt/rnn/multi_rnn_cell/cell_0/lstm_cell/bias_0/grad/hist instead.
INFO:tensorflow:Summary name dkt/rnn/multi_rnn_cell/cell_0/lstm_cell/bias:0/grad/sparsity is illegal; using dkt/rnn/multi_rnn_cell/cell_0/lstm_cell/bias_0/grad/sparsity instead.
INFO:tensorflow:Summary name dkt/W:0/grad/hist is illegal; using dkt/W_0/grad/hist instead.
INFO:tensorflow:Summary name dkt/W:0/grad/sparsity is illegal; using dkt/W_0/grad/sparsity instead.
INFO:tensorflow:Summary name dkt/b:0/grad/hist is illegal; using dkt/b_0/grad/hist instead.
INFO:tensorf

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


writing to /data4T/share/jiangxinyang848/dkt/runs/1546588631
初始化完毕，开始训练
train: 2019-01-04T15:57:12.891540: step 1, loss 0.6649993062019348, acc 0.55418480890179, auc: 0.5592457988797013, precision: 0.8141906873614191, recall: 0.5631901840490797
train: 2019-01-04T15:57:14.277169: step 2, loss 0.6512970328330994, acc 0.7151464782879701, auc: 0.7212137920681752, precision: 0.7187121762285161, recall: 0.934235368156073
train: 2019-01-04T15:57:15.152428: step 3, loss 0.6603468656539917, acc 0.5564334085778782, auc: 0.5053623836658082, precision: 0.6254653760238272, recall: 0.7479964381121995
train: 2019-01-04T15:57:16.714156: step 4, loss 0.6549214720726013, acc 0.5791875855773619, auc: 0.5602000760642198, precision: 0.6704304869442484, recall: 0.6761565836298933
train: 2019-01-04T15:57:17.870086: step 5, loss 0.6498287916183472, acc 0.5918482875742994, auc: 0.5252767258138828, precision: 0.6996792301523657, recall: 0.7157506152584086
train: 2019-01-04T15:57:18.848666: step 6, loss 0.648199

train: 2019-01-04T15:58:18.999843: step 48, loss 0.4762658476829529, acc 0.7151714077315828, auc: 0.7181867295319483, precision: 0.7369353410097431, recall: 0.8989735278227985
train: 2019-01-04T15:58:20.171124: step 49, loss 0.4784357249736786, acc 0.6981175923774111, auc: 0.6922217400003674, precision: 0.7173243772739994, recall: 0.8986676016830295
train: 2019-01-04T15:58:21.560071: step 50, loss 0.46058666706085205, acc 0.7652120467117394, auc: 0.7640798273361131, precision: 0.8028134853262188, recall: 0.9085918199286303
train: 2019-01-04T15:58:22.580115: step 51, loss 0.4674523174762726, acc 0.6918044077134986, auc: 0.7021816831830733, precision: 0.7109621451104101, recall: 0.917557251908397
train: 2019-01-04T15:58:23.534580: step 52, loss 0.46458229422569275, acc 0.723998653651969, auc: 0.7074849397590361, precision: 0.745021475985943, recall: 0.9195180722891566
train: 2019-01-04T15:58:24.884757: step 53, loss 0.4617909789085388, acc 0.7369510232237296, auc: 0.7346573485665343, pre

train: 2019-01-04T15:59:27.301158: step 95, loss 0.33343759179115295, acc 0.9041342121030557, auc: 0.9294695085185142, precision: 0.9158997722095672, recall: 0.9657027572293208
train: 2019-01-04T15:59:28.165821: step 96, loss 0.34603238105773926, acc 0.7722457627118644, auc: 0.7786174408540941, precision: 0.827891156462585, recall: 0.8730272596843616
train: 2019-01-04T15:59:28.430915: step 97, loss 0.3814060688018799, acc 0.7122676579925651, auc: 0.7078177278401999, precision: 0.7487875848690592, recall: 0.8577777777777778
train: 2019-01-04T15:59:29.009505: step 98, loss 0.37157732248306274, acc 0.6977186311787072, auc: 0.7095417506352404, precision: 0.7246184472461845, recall: 0.8316831683168316
train: 2019-01-04T15:59:30.565164: step 99, loss 0.35502901673316956, acc 0.6731681602530311, auc: 0.6582868485511423, precision: 0.6997667444185272, recall: 0.8610086100861009
train: 2019-01-04T15:59:33.259114: step 100, loss 0.31888821721076965, acc 0.8605289421157685, auc: 0.944832832617376

train: 2019-01-04T16:00:47.122343: step 140, loss 0.30807214975357056, acc 0.737490377213241, auc: 0.7613077849756877, precision: 0.7783585990590696, recall: 0.8523182598740698
train: 2019-01-04T16:00:48.152935: step 141, loss 0.2701229453086853, acc 0.812448132780083, auc: 0.8486853836330623, precision: 0.8324118866620595, recall: 0.925826287471176
train: 2019-01-04T16:00:49.259717: step 142, loss 0.2753676474094391, acc 0.6902268760907504, auc: 0.6870523232592197, precision: 0.7072120559741658, recall: 0.8878378378378379
train: 2019-01-04T16:00:50.017256: step 143, loss 0.30024442076683044, acc 0.6879090595935239, auc: 0.7467521731603752, precision: 0.6878048780487804, recall: 0.8412887828162291
train: 2019-01-04T16:00:51.483599: step 144, loss 0.27432698011398315, acc 0.7447161210111893, auc: 0.8295283119680359, precision: 0.7206703910614525, recall: 0.8749058025621703
train: 2019-01-04T16:00:52.341019: step 145, loss 0.27109575271606445, acc 0.7524707996406109, auc: 0.7509244317729

train: 2019-01-04T16:01:48.710415: step 187, loss 0.2195623815059662, acc 0.7514906303236797, auc: 0.8258588127201265, precision: 0.7552570093457944, recall: 0.8871355060034305
train: 2019-01-04T16:01:49.805682: step 188, loss 0.22704535722732544, acc 0.7577068685776095, auc: 0.7984653109586964, precision: 0.7935858068918458, recall: 0.8888039740160489
train: 2019-01-04T16:01:50.937196: step 189, loss 0.23451527953147888, acc 0.753687660969328, auc: 0.8000901514910854, precision: 0.7695924764890282, recall: 0.8856421356421357
train: 2019-01-04T16:01:51.907406: step 190, loss 0.22662858664989471, acc 0.6235616438356164, auc: 0.6789980185437146, precision: 0.622568093385214, recall: 0.7984031936127745
train: 2019-01-04T16:01:53.385565: step 191, loss 0.23115701973438263, acc 0.7609092417454787, auc: 0.8018438973815591, precision: 0.7749793559042114, recall: 0.9144945188794154
train: 2019-01-04T16:01:54.952817: step 192, loss 0.20152547955513, acc 0.873011403632268, auc: 0.92704703132377,

train: 2019-01-04T16:03:17.498152: step 232, loss 0.1609184294939041, acc 0.9012, auc: 0.957236948608708, precision: 0.9081059390048154, recall: 0.9575458392101551
train: 2019-01-04T16:03:19.356662: step 233, loss 0.16590936481952667, acc 0.8884699968583097, auc: 0.93202171687992, precision: 0.899625468164794, recall: 0.9650462032944958
train: 2019-01-04T16:03:20.997836: step 234, loss 0.19495607912540436, acc 0.7104910714285714, auc: 0.7379523776729451, precision: 0.7353025523372526, recall: 0.8727025187202179
train: 2019-01-04T16:03:22.342050: step 235, loss 0.17135384678840637, acc 0.8553430821147356, auc: 0.9070425393892059, precision: 0.8773662551440329, recall: 0.9422510312315852
train: 2019-01-04T16:03:23.204154: step 236, loss 0.19826877117156982, acc 0.726742064939803, auc: 0.755937022075316, precision: 0.7362637362637363, recall: 0.8866513233601842
train: 2019-01-04T16:03:26.418379: step 237, loss 0.15361268818378448, acc 0.8965416178194607, auc: 0.9544023334463989, precision

train: 2019-01-04T16:04:15.267972: step 279, loss 0.17214913666248322, acc 0.7231732776617954, auc: 0.7571358456375641, precision: 0.7191374663072776, recall: 0.9037940379403794
train: 2019-01-04T16:04:16.534305: step 280, loss 0.20870311558246613, acc 0.6751568844592101, auc: 0.681748081472165, precision: 0.6919954904171364, recall: 0.8862258157666763
train: 2019-01-04T16:04:17.786175: step 281, loss 0.15776368975639343, acc 0.8035215804165772, auc: 0.8872905996205274, precision: 0.8091780424048794, recall: 0.9152431011826544
train: 2019-01-04T16:04:19.116546: step 282, loss 0.1532628834247589, acc 0.810786313165647, auc: 0.8826022360248448, precision: 0.8361714621256606, recall: 0.91136
train: 2019-01-04T16:04:20.134731: step 283, loss 0.1580197662115097, acc 0.7068709836875927, auc: 0.6985851041447506, precision: 0.734248284466625, recall: 0.8757440476190477
train: 2019-01-04T16:04:21.352276: step 284, loss 0.17383641004562378, acc 0.737518910741301, auc: 0.6905342000997211, precisi

train: 2019-01-04T16:05:38.178242: step 325, loss 0.14704826474189758, acc 0.7553220156292104, auc: 0.6926004886250685, precision: 0.7615062761506276, recall: 0.9586155003762227
train: 2019-01-04T16:05:39.662493: step 326, loss 0.13872931897640228, acc 0.6762225969645869, auc: 0.7095506337940574, precision: 0.6787797218483625, recall: 0.861126920887877
train: 2019-01-04T16:05:40.651861: step 327, loss 0.13896003365516663, acc 0.8482758620689655, auc: 0.9018457300275482, precision: 0.8436923076923077, recall: 0.9621052631578947
train: 2019-01-04T16:05:41.689321: step 328, loss 0.1326570212841034, acc 0.7964661909616038, auc: 0.8559812843306044, precision: 0.8114193807800563, recall: 0.9394785847299814
train: 2019-01-04T16:05:42.887995: step 329, loss 0.1535102128982544, acc 0.7601776461880089, auc: 0.7619532524455512, precision: 0.7641509433962265, recall: 0.9377713458755427
train: 2019-01-04T16:05:43.402876: step 330, loss 0.1914796233177185, acc 0.7052631578947368, auc: 0.744400963976

train: 2019-01-04T16:06:44.514649: step 372, loss 0.10401125997304916, acc 0.8758260498827543, auc: 0.9387969147443298, precision: 0.876898596846521, recall: 0.9507528230865746
train: 2019-01-04T16:06:45.837220: step 373, loss 0.14622612297534943, acc 0.7320990408208788, auc: 0.6728608235116328, precision: 0.7593020272004106, recall: 0.9183736809435133
train: 2019-01-04T16:06:46.943333: step 374, loss 0.12003070116043091, acc 0.7326687811508835, auc: 0.7477173534319929, precision: 0.7422048997772829, recall: 0.913013698630137
train: 2019-01-04T16:06:47.931747: step 375, loss 0.13405348360538483, acc 0.7488727020464794, auc: 0.738327302276281, precision: 0.7691648822269808, recall: 0.9066128218071681
train: 2019-01-04T16:06:49.190081: step 376, loss 0.14391443133354187, acc 0.7826654240447344, auc: 0.8040704915879224, precision: 0.8043960923623446, recall: 0.9270726714431935
train: 2019-01-04T16:06:50.509070: step 377, loss 0.12722639739513397, acc 0.7418077061577242, auc: 0.69194995627

train: 2019-01-04T16:08:06.519888: step 418, loss 0.15212474763393402, acc 0.7625688073394495, auc: 0.8363868904320719, precision: 0.783342119135477, recall: 0.8629500580720093
train: 2019-01-04T16:08:08.054895: step 419, loss 0.13562509417533875, acc 0.8048657506617926, auc: 0.8336440995841897, precision: 0.8257314974182444, recall: 0.945320197044335
train: 2019-01-04T16:08:09.445121: step 420, loss 0.11923951655626297, acc 0.8229005000862217, auc: 0.885558312209811, precision: 0.8378437363596682, recall: 0.9311181178753335
train: 2019-01-04T16:08:09.987817: step 421, loss 0.14279644191265106, acc 0.7505957108816521, auc: 0.8461965357126568, precision: 0.7473684210526316, recall: 0.8458304134548003
train: 2019-01-04T16:08:11.006368: step 422, loss 0.11060132086277008, acc 0.7975843398583924, auc: 0.8512666939163658, precision: 0.7533068783068783, recall: 0.9097444089456869
train: 2019-01-04T16:08:12.219435: step 423, loss 0.09079939126968384, acc 0.8398617511520737, auc: 0.91362291862

train: 2019-01-04T16:09:07.491276: step 465, loss 0.15112890303134918, acc 0.75229769110065, auc: 0.7723563804094545, precision: 0.7572204746925937, recall: 0.9118457300275482
train: 2019-01-04T16:09:09.126105: step 466, loss 0.0738774910569191, acc 0.9020240539747727, auc: 0.9552838071889912, precision: 0.9150552486187845, recall: 0.9678597516435354
train: 2019-01-04T16:09:10.361605: step 467, loss 0.10749567300081253, acc 0.702401082177883, auc: 0.7573931220305017, precision: 0.7142857142857143, recall: 0.8495821727019499
train: 2019-01-04T16:09:11.622810: step 468, loss 0.11359207332134247, acc 0.7434494673193205, auc: 0.7480004336911285, precision: 0.7513774104683195, recall: 0.9281156954487452
train: 2019-01-04T16:09:12.829440: step 469, loss 0.0987536832690239, acc 0.8367612600885186, auc: 0.8896454964805031, precision: 0.843613077182339, recall: 0.9388597149287322
train: 2019-01-04T16:09:14.204473: step 470, loss 0.10546468198299408, acc 0.7906925934017915, auc: 0.86352968777431

train: 2019-01-04T16:10:27.790135: step 510, loss 0.1201729029417038, acc 0.7036479879654005, auc: 0.6880229070446462, precision: 0.7375790424570913, recall: 0.8875
train: 2019-01-04T16:10:28.964765: step 511, loss 0.11575430631637573, acc 0.7063969382176053, auc: 0.7389344699056776, precision: 0.7280606717226435, recall: 0.8626444159178434
train: 2019-01-04T16:10:30.703275: step 512, loss 0.08741830289363861, acc 0.8337591240875912, auc: 0.9077317987852478, precision: 0.8387171740646061, recall: 0.943282801881861
train: 2019-01-04T16:10:31.871204: step 513, loss 0.08286784589290619, acc 0.8443606444977717, auc: 0.8830525956876841, precision: 0.8501026694045175, recall: 0.9587772116720704
train: 2019-01-04T16:10:33.763876: step 514, loss 0.11231787502765656, acc 0.7368421052631579, auc: 0.8002034794871979, precision: 0.7493355142097731, recall: 0.9062809099901088
train: 2019-01-04T16:10:34.178317: step 515, loss 0.16355931758880615, acc 0.691812865497076, auc: 0.7434144442061781, preci

train: 2019-01-04T16:11:41.728907: step 557, loss 0.094634510576725, acc 0.8387949260042283, auc: 0.8639372930776523, precision: 0.8503732554365466, recall: 0.9461899602744673
train: 2019-01-04T16:11:42.896607: step 558, loss 0.10488149523735046, acc 0.7693817468105987, auc: 0.7984600554297636, precision: 0.7908790879087909, recall: 0.9155956929489406
train: 2019-01-04T16:11:43.652930: step 559, loss 0.11697373539209366, acc 0.7382995319812793, auc: 0.7704309126396554, precision: 0.7615347255949491, recall: 0.897025171624714
train: 2019-01-04T16:11:44.679740: step 560, loss 0.11601527780294418, acc 0.7381552419354839, auc: 0.8211974238858476, precision: 0.741189035243859, recall: 0.8696741854636592
train: 2019-01-04T16:11:45.785460: step 561, loss 0.1248917207121849, acc 0.7366922234392114, auc: 0.7042308349274146, precision: 0.7562003968253969, recall: 0.9329865361077111
train: 2019-01-04T16:11:46.578106: step 562, loss 0.11587654054164886, acc 0.7915567282321899, auc: 0.8093715438879

train: 2019-01-04T16:12:54.875355: step 602, loss 0.05778505653142929, acc 0.8920618731939487, auc: 0.9514959627606965, precision: 0.888677901732732, recall: 0.957544757033248
train: 2019-01-04T16:12:55.891112: step 603, loss 0.09929025918245316, acc 0.7782888684452622, auc: 0.7835574156665208, precision: 0.7913074712643678, recall: 0.9394456289978678
train: 2019-01-04T16:12:57.312550: step 604, loss 0.10787048935890198, acc 0.7307279364397681, auc: 0.7479954839342726, precision: 0.7364072494669509, recall: 0.9124834874504624
train: 2019-01-04T16:12:58.474586: step 605, loss 0.1012277603149414, acc 0.7163790296320417, auc: 0.7498798506944628, precision: 0.7093658126837463, recall: 0.9041755888650964
train: 2019-01-04T16:12:59.460916: step 606, loss 0.09153982996940613, acc 0.7654916512059369, auc: 0.7196645757996424, precision: 0.7857142857142857, recall: 0.9387550200803213
train: 2019-01-04T16:12:59.896631: step 607, loss 0.10750787705183029, acc 0.7292817679558011, auc: 0.74853430208

train: 2019-01-04T16:14:00.496419: step 649, loss 0.0849335789680481, acc 0.7453610141466104, auc: 0.8030819015525215, precision: 0.7575357535753575, recall: 0.9237993023879796
train: 2019-01-04T16:14:01.758787: step 650, loss 0.08920568972826004, acc 0.7366536458333334, auc: 0.7194167154548666, precision: 0.7411373707533235, recall: 0.948936170212766
train: 2019-01-04T16:14:02.602551: step 651, loss 0.08586824685335159, acc 0.7704194260485652, auc: 0.8532384158585602, precision: 0.7761664564943254, recall: 0.8818051575931232
train: 2019-01-04T16:14:03.486535: step 652, loss 0.08262622356414795, acc 0.7471461187214612, auc: 0.7555331670184866, precision: 0.7588703837798697, recall: 0.9050086355785838
train: 2019-01-04T16:14:04.382828: step 653, loss 0.08697233349084854, acc 0.7920059215396003, auc: 0.8521453754920609, precision: 0.8227294803302574, recall: 0.8958223162347964
train: 2019-01-04T16:14:05.549943: step 654, loss 0.10176432132720947, acc 0.7283051834595224, auc: 0.7520925912

train: 2019-01-04T16:15:12.678358: step 696, loss 0.053115569055080414, acc 0.8762729898160815, auc: 0.9240052757392944, precision: 0.8730292200966996, recall: 0.9518679807471923
train: 2019-01-04T16:15:13.138247: step 697, loss 0.10824133455753326, acc 0.776875298614429, auc: 0.854515591788662, precision: 0.715929203539823, recall: 0.8471204188481676
train: 2019-01-04T16:15:14.649805: step 698, loss 0.061638254672288895, acc 0.8140963465890722, auc: 0.8732349007369331, precision: 0.8376518218623482, recall: 0.9224253232278199
train: 2019-01-04T16:15:22.311856: step 699, loss 0.046561870723962784, acc 0.8859633517167764, auc: 0.9408062910738073, precision: 0.8992248062015504, recall: 0.9582315299592786
train: 2019-01-04T16:15:23.256098: step 700, loss 0.11792723089456558, acc 0.6969525959367946, auc: 0.7690924399380599, precision: 0.6899521531100479, recall: 0.771948608137045

Evaluation:
dev: 2019-01-04T16:15:35.966478, step: 700, loss: 0.0887234963190097, acc: 0.7772196787273354, auc

train: 2019-01-04T16:16:25.781532: step 741, loss 0.20013819634914398, acc 0.6851654215581644, auc: 0.7353991785281867, precision: 0.6824902723735409, recall: 0.8281397544853636
train: 2019-01-04T16:16:27.088875: step 742, loss 0.09527431428432465, acc 0.7101262561195568, auc: 0.7363145047270753, precision: 0.7255520504731862, recall: 0.9001956947162426
train: 2019-01-04T16:16:27.461333: step 743, loss 0.1415073573589325, acc 0.730072463768116, auc: 0.8172302744160155, precision: 0.7172464840858623, recall: 0.8191039729501268
train: 2019-01-04T16:16:28.687231: step 744, loss 0.05865265056490898, acc 0.8619444444444444, auc: 0.9317916541752711, precision: 0.884828349944629, recall: 0.9283501161890008
train: 2019-01-04T16:16:30.206195: step 745, loss 0.0765272006392479, acc 0.7141579731743666, auc: 0.7454808236684601, precision: 0.7416342412451362, recall: 0.8659700136301681
train: 2019-01-04T16:16:31.657298: step 746, loss 0.05720669403672218, acc 0.8500823723228995, auc: 0.932091431752

train: 2019-01-04T16:17:20.316434: step 788, loss 0.11867163330316544, acc 0.6962190352020861, auc: 0.7364435063354388, precision: 0.711864406779661, recall: 0.8553370786516854
train: 2019-01-04T16:17:21.459600: step 789, loss 0.05881252512335777, acc 0.8380721220527045, auc: 0.8808150009952619, precision: 0.8532110091743119, recall: 0.9467838963442851
train: 2019-01-04T16:17:22.528148: step 790, loss 0.09521865844726562, acc 0.7824245632609846, auc: 0.768760598521072, precision: 0.7952706907280647, recall: 0.9397058823529412
train: 2019-01-04T16:17:24.496698: step 791, loss 0.04731446132063866, acc 0.8901408450704226, auc: 0.9491315061571528, precision: 0.8988899613899614, recall: 0.9670301142263759
train: 2019-01-04T16:17:25.692740: step 792, loss 0.06791224330663681, acc 0.8488155443172745, auc: 0.8962984530371609, precision: 0.8491228070175438, recall: 0.946051602814699
train: 2019-01-04T16:17:27.180491: step 793, loss 0.10867683589458466, acc 0.7215836526181354, auc: 0.70039367001

train: 2019-01-04T16:18:41.653609: step 833, loss 0.09390085190534592, acc 0.8214463015058746, auc: 0.8468372487567913, precision: 0.8326279338207002, recall: 0.953934317831166
train: 2019-01-04T16:18:44.337540: step 834, loss 0.04628106206655502, acc 0.8767103347889375, auc: 0.9515649300028727, precision: 0.8833138856476079, recall: 0.9484234704531217
train: 2019-01-04T16:18:45.953960: step 835, loss 0.05316237360239029, acc 0.868830591373944, auc: 0.9222106996735184, precision: 0.8880275624461671, recall: 0.9392651078044336
train: 2019-01-04T16:18:47.400288: step 836, loss 0.05335957929491997, acc 0.8373639661426844, auc: 0.8829681333555071, precision: 0.8525807649461568, recall: 0.9421419778416086
train: 2019-01-04T16:18:50.156046: step 837, loss 0.04171999543905258, acc 0.8596715717637022, auc: 0.924725335017378, precision: 0.8577860169491526, recall: 0.9639880952380953
train: 2019-01-04T16:18:51.369911: step 838, loss 0.07035387307405472, acc 0.6941870261162595, auc: 0.69981542336

train: 2019-01-04T16:19:53.468952: step 880, loss 0.09198813140392303, acc 0.7, auc: 0.7059312363528522, precision: 0.7132933375434165, recall: 0.8960729869099564
train: 2019-01-04T16:19:54.822977: step 881, loss 0.06897386163473129, acc 0.7840336134453781, auc: 0.7498033593075297, precision: 0.7964015151515151, recall: 0.9524348810872028
train: 2019-01-04T16:19:55.365102: step 882, loss 0.12395218014717102, acc 0.7472331310246341, auc: 0.8075641122951063, precision: 0.7478019435446552, recall: 0.9083754918493536
train: 2019-01-04T16:19:57.010352: step 883, loss 0.07484752684831619, acc 0.6465807174887892, auc: 0.6923073588834289, precision: 0.6585760517799353, recall: 0.7960880195599022
train: 2019-01-04T16:19:58.615125: step 884, loss 0.06897873431444168, acc 0.744410569105691, auc: 0.7534876513322981, precision: 0.7590546347452425, recall: 0.9179658500371195
train: 2019-01-04T16:20:00.235091: step 885, loss 0.06943605840206146, acc 0.8156171284634761, auc: 0.8961154352540273, precis

train: 2019-01-04T16:21:12.261440: step 925, loss 0.058000918477773666, acc 0.8321318228630278, auc: 0.8574559699453402, precision: 0.8512648582749162, recall: 0.9445383834967873
train: 2019-01-04T16:21:13.391352: step 926, loss 0.08123525232076645, acc 0.7795315682281059, auc: 0.821747579542806, precision: 0.7995018679950187, recall: 0.9204301075268817
train: 2019-01-04T16:21:14.485468: step 927, loss 0.06735634803771973, acc 0.8441265060240963, auc: 0.9131643701295655, precision: 0.8506578947368421, recall: 0.9393389030148929
train: 2019-01-04T16:21:16.115222: step 928, loss 0.06547972559928894, acc 0.7129461584996976, auc: 0.7357643890649509, precision: 0.7288854608561511, recall: 0.8848314606741573
train: 2019-01-04T16:21:16.441881: step 929, loss 0.13795475661754608, acc 0.7408036219581211, auc: 0.7499014876874766, precision: 0.7649074708704592, recall: 0.9065800162469537
train: 2019-01-04T16:21:17.007001: step 930, loss 0.055701471865177155, acc 0.8284904323175053, auc: 0.9032128

train: 2019-01-04T16:22:34.467450: step 972, loss 0.06248552352190018, acc 0.7466460268317854, auc: 0.7832390482417488, precision: 0.747599451303155, recall: 0.8985985160758451
train: 2019-01-04T16:22:35.945790: step 973, loss 0.045455701649188995, acc 0.8411949685534591, auc: 0.9018335769980506, precision: 0.854728370221328, recall: 0.9365079365079365
train: 2019-01-04T16:22:38.017132: step 974, loss 0.039307475090026855, acc 0.8848341232227488, auc: 0.9425289856678517, precision: 0.8875885432707115, recall: 0.9597069597069597
train: 2019-01-04T16:22:41.367998: step 975, loss 0.03757224604487419, acc 0.9028593811202507, auc: 0.9629865108962832, precision: 0.9117382707172389, recall: 0.9524882629107981
train: 2019-01-04T16:22:42.852377: step 976, loss 0.06415774673223495, acc 0.7780952380952381, auc: 0.8278897300407619, precision: 0.8070913461538461, recall: 0.9025537634408602
train: 2019-01-04T16:22:44.566829: step 977, loss 0.06382247805595398, acc 0.8219627873039037, auc: 0.84876471

train: 2019-01-04T16:23:45.731052: step 1017, loss 0.07470189034938812, acc 0.7759320782576596, auc: 0.7969807817544442, precision: 0.7960327727468737, recall: 0.9323232323232323
train: 2019-01-04T16:23:46.221030: step 1018, loss 0.09599720686674118, acc 0.7406593406593407, auc: 0.7998089589511471, precision: 0.7383399209486166, recall: 0.8688372093023256
train: 2019-01-04T16:23:46.901344: step 1019, loss 0.07998091727495193, acc 0.7350383631713555, auc: 0.766487689345848, precision: 0.7540214477211796, recall: 0.8816614420062696
train: 2019-01-04T16:23:47.583115: step 1020, loss 0.11057641357183456, acc 0.7272410184862226, auc: 0.7754951916550938, precision: 0.7384830153559795, recall: 0.87825124515772
train: 2019-01-04T16:23:48.742633: step 1021, loss 0.05726763978600502, acc 0.7428326914848096, auc: 0.7631263646475606, precision: 0.7526132404181185, recall: 0.9356435643564357
train: 2019-01-04T16:23:50.199906: step 1022, loss 0.06466130167245865, acc 0.7991404659579281, auc: 0.82493

In [9]:
# 模型预测
def load_model(fileName):
    # 实例化配置参数对象
    config = Config()

    # 实例化数据生成对象
    dataGen = DataGenerator(fileName, config)
    dataGen.gen_attr()  # 生成训练集和测试集

    test_seqs = dataGen.test_seqs

    with tf.Session() as sess:

        accuracys = []
        aucs = []
        step = 1

        for params in dataGen.next_batch(test_seqs):
            print("step: {}".format(step))

            saver = tf.train.import_meta_graph("model/my-model-800.meta")
            saver.restore(sess, tf.train.latest_checkpoint("model/"))

            # 获得默认的计算图结构
            graph = tf.get_default_graph()

            # 获得需要喂给模型的参数，输出的结果依赖的输入值
            input_x = graph.get_operation_by_name("test/dkt/input_x").outputs[0]
            target_id = graph.get_operation_by_name("test/dkt/target_id").outputs[0]
            keep_prob = graph.get_operation_by_name("test/dkt/keep_prob").outputs[0]
            max_steps = graph.get_operation_by_name("test/dkt/max_steps").outputs[0]
            sequence_len = graph.get_operation_by_name("test/dkt/sequence_len").outputs[0]

            # 获得输出的结果
            pred_all = graph.get_tensor_by_name("test/dkt/pred_all:0")
            pred = graph.get_tensor_by_name("test/dkt/pred:0")
            binary_pred = graph.get_tensor_by_name("test/dkt/binary_pred:0")

            target_correctness = params['target_correctness']
            pred_all, pred, binary_pred = sess.run([pred_all, pred, binary_pred],
                                                   feed_dict={input_x: params["input_x"],
                                                              target_id: params["target_id"],
                                                              keep_prob: 1.0,
                                                              max_steps: params["max_len"],
                                                              sequence_len: params["seq_len"]})

            auc, acc, precision, recall = gen_metrics(params["seq_len"], binary_pred, pred, target_correctness)
            print(auc, acc)
            accuracys.append(acc)
            aucs.append(auc)
            step += 1

        aucMean = mean(aucs)
        accMean = mean(accuracys)

        print("inference  auc: {}  acc: {}".format(aucMean, accMean))


if __name__ == "__main__":
    fileName = "./data/assistments.txt"
    load_model(fileName)

step: 1
INFO:tensorflow:Restoring parameters from model/my-model-1000


ValueError: too many values to unpack (expected 2)