In [2]:
import random
import time
import sys

from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score
import numpy as np
import tensorflow as tf

  _nan_object_mask = _nan_object_array != _nan_object_array


In [1]:
# 配置config
class TrainConfig(object):
    epochs = 15
    decay_rate = 0.92
    learning_rate = 0.01
    evaluate_every = 100
    checkpoint_every = 100
    max_grad_norm = 3.0


class ModelConfig(object):
    hidden_layers = [200]
    dropout_keep_prob = 0.6


class Config(object):
    batch_size = 10
    num_skills = 267
    input_size = num_skills * 2

    trainConfig = TrainConfig()
    modelConfig = ModelConfig()

In [3]:
# 生成数据
class DataGenerator(object):
    # 导入的seqs是train_seqs，或者是test_seqs
    def __init__(self, fileName, config):
        self.fileName = fileName
        self.train_seqs = []
        self.test_seqs = []
        self.infer_seqs = []
        self.batch_size = config.batch_size
        self.pos = 0
        self.end = False
        self.num_skills = config.num_skills
        self.skills_to_int = {}  # 知识点到索引的映射
        self.int_to_skills = {}  # 索引到知识点的映射

    def read_file(self):
        # 从文件中读取数据，返回读取出来的数据和知识点个数
        # 保存每个学生的做题信息 {学生id: [[知识点id，答题结果], [知识点id，答题结果], ...]}，用一个二元列表来表示一个学生的答题信息
        seqs_by_student = {}
        skills = []  # 统计知识点的数量，之后输入的向量长度就是两倍的知识点数量
        count = 0
        with open(self.fileName, 'r') as f:
            for line in f:
                count += 1
                if count > 1:
                    fields = line.strip().split(",")  # 一个列表，[学生id，知识点id，答题结果]
                    student, skill, is_correct = int(fields[0]), int(fields[1]), int(fields[2])
                    skills.append(skill)  # skill实际上是用该题所属知识点来表示的
                    seqs_by_student[student] = seqs_by_student.get(student, []) + [[skill, is_correct]]  # 保存每个学生的做题信息
        return seqs_by_student, list(set(skills))

    def gen_dict(self, unique_skills):
        """
        构建知识点映射表，将知识点id映射到[0, 1, 2...]表示
        :param unique_skills: 无重复的知识点列表
        :return:
        """
        sorted_skills = sorted(unique_skills)
        skills_to_int = {}
        int_to_skills = {}
        for i in range(len(sorted_skills)):
            skills_to_int[sorted_skills[i]] = i
            int_to_skills[i] = sorted_skills[i]

        self.skills_to_int = skills_to_int
        self.int_to_skills = int_to_skills

    def split_dataset(self, seqs_by_student, sample_rate=0.2, random_seed=1):
        # 将数据分割成测试集和训练集
        sorted_keys = sorted(seqs_by_student.keys())  # 得到排好序的学生id的列表

        random.seed(random_seed)
        # 随机抽取学生id，将这部分学生作为测试集
        test_keys = set(random.sample(sorted_keys, int(len(sorted_keys) * sample_rate)))

        # 此时是一个三层的列表来表示的，最外层的列表中的每一个列表表示一个学生的做题信息
        test_seqs = [seqs_by_student[k] for k in seqs_by_student if k in test_keys]
        train_seqs = [seqs_by_student[k] for k in seqs_by_student if k not in test_keys]
        return train_seqs, test_seqs

    def gen_attr(self, is_infer=False):
        """
        生成待处理的数据集
        :param is_infer: 判断当前是训练模型还是利用模型进行预测
        :return:
        """
        if is_infer:
            seqs_by_students, skills = self.read_file()
            self.infer_seqs = seqs_by_students
        else:
            seqs_by_students, skills = self.read_file()
            train_seqs, test_seqs = self.split_dataset(seqs_by_students)
            self.train_seqs = train_seqs
            self.test_seqs = test_seqs

        self.gen_dict(skills)  # 生成知识点到索引的映射字典

    def pad_sequences(self, sequences, maxlen=None, value=0.):
        # 按每个batch中最长的序列进行补全, 传入的sequences是二层列表
        # 统计一个batch中每个序列的长度，其实等于seqs_len
        lengths = [len(s) for s in sequences]
        # 统计下该batch中序列的数量
        nb_samples = len(sequences)
        # 如果没有传入maxlen参数就自动获取最大的序列长度
        if maxlen is None:
            maxlen = np.max(lengths)
        # 构建x矩阵
        x = (np.ones((nb_samples, maxlen)) * value).astype(np.int32)

        # 遍历batch，去除每一个序列
        for idx, s in enumerate(sequences):
            trunc = np.asarray(s, dtype=np.int32)
            x[idx, :len(trunc)] = trunc

        return x

    def num_to_one_hot(self, num, dim):
        # 将题目转换成one-hot的形式， 其中dim=num_skills * 2，前半段表示错误，后半段表示正确
        base = np.zeros(dim)
        if num >= 0:
            base[num] += 1
        return base

    def format_data(self, seqs):
        # 生成输入数据和输出数据，输入数据是每条序列的前n-1个元素，输出数据是每条序列的后n-1个元素

        # 统计一个batch_size中每条序列的长度，在这里不对序列固定长度，通过条用tf.nn.dynamic_rnn让序列长度可以不固定
        seq_len = np.array(list(map(lambda seq: len(seq) - 1, seqs)))
        max_len = max(seq_len)  # 获得一个batch_size中最大的长度
        # i表示第i条数据，j只从0到len(i)-1，x作为输入只取前len(i)-1个，sequences=[j[0] + num_skills * j[1], ....]
        # 此时要将知识点id j[0] 转换成index表示
        x_sequences = np.array([[(self.skills_to_int[j[0]] + self.num_skills * j[1]) for j in i[:-1]] for i in seqs])
        # 将输入的序列用-1进行补全，补全后的长度为当前batch的最大序列长度
        x = self.pad_sequences(x_sequences, maxlen=max_len, value=-1)

        # 构建输入值input_x，x为一个二层列表，i表示一个学生的做题信息，也就是一个序列，j就是一道题的信息
        input_x = np.array([[self.num_to_one_hot(j, self.num_skills * 2) for j in i] for i in x])

        # 遍历batch_size，然后取每条序列的后len(i)-1 个元素中的知识点id为target_id
        target_id_seqs = np.array([[self.skills_to_int[j[0]] for j in i[1:]] for i in seqs])
        target_id = self.pad_sequences(target_id_seqs, maxlen=max_len, value=0)

        # 同target_id
        target_correctness_seqs = np.array([[j[1] for j in i[1:]] for i in seqs])
        target_correctness = self.pad_sequences(target_correctness_seqs, maxlen=max_len, value=0)

        return dict(input_x=input_x, target_id=target_id, target_correctness=target_correctness,
                    seq_len=seq_len, max_len=max_len)

    def next_batch(self, seqs):
        # 接收一个序列，生成batch

        length = len(seqs)
        num_batchs = length // self.batch_size
        start = 0
        for i in range(num_batchs):
            batch_seqs = seqs[start: start + self.batch_size]
            start += self.batch_size
            params = self.format_data(batch_seqs)

            yield params

In [None]:
# 构建模型


In [4]:
def pad_sequences(sequences, maxlen=None, value=0.):
    # 按每个batch中最长的序列进行补全, 传入的sequences是二层列表
    # 统计一个batch中每个序列的长度，其实等于seqs_len
    lengths = [len(s) for s in sequences]
    # 统计下该batch中序列的数量
    nb_samples = len(sequences)
    # 如果没有传入maxlen参数就自动获取最大的序列长度
    if maxlen is None:
        maxlen = np.max(lengths)

    # take the sample shape from the first non empty sequence
    # checking for consistency in the main loop below.

    x = (np.ones((nb_samples, maxlen)) * value).astype(np.int32)
   
    # 遍历batch，去除每一个序列
    for idx, s in enumerate(sequences):
        trunc = np.asarray(s, dtype=np.int32)
        x[idx, :len(trunc)] = trunc

    return x

def num_to_one_hot(num, dim):
    # 将题目转换成one-hot的形式， 其中dim=num_skills * 2，前半段表示错误，后半段表示正确
    base = np.zeros(dim)
    if num >= 0:
        base[num] += 1
    return base


def format_data(seqs, batch_size, num_skills):
    # 生成输入数据和输出数据，输入数据是每条序列的前n-1个元素，输出数据是每条序列的后n-1个元素
    # 
    gap = batch_size - len(seqs)
    # 在生成batch的时候，整个序列不一定能整除，对于最后的长度不够batch_size的batch进行补全长度
    seqs_in = seqs + [[[0, 0]]] * gap 
    
    # 统计一个batch_size中每条序列的长度，在这里不对序列固定长度，通过条用tf.nn.dynamic_rnn让序列长度可以不固定
    seq_len = np.array(list(map(lambda seq: len(seq)-1, seqs_in)))
    max_len = max(seq_len)  # 获得一个batch_size中最大的长度
    # i表示第i条数据，j只从0到len(i)-1，x作为输入只取前len(i)-1个，sequences=[j[0] + num_skills * j[1], ....]
    x_sequences = np.array([[(j[0] + num_skills * j[1]) for j in i[:-1]] for i in seqs_in])
    # 将输入的序列用-1进行补全，补全后的长度为当前batch的最大序列长度
    x = pad_sequences(x_sequences, maxlen=max_len, value=-1)
    
    # 构建输入值input_x，x为一个二层列表，i表示一个学生的做题信息，也就是一个序列，j就是一道题的信息
    input_x = np.array([[num_to_one_hot(j, num_skills * 2) for j in i] for i in x])
    
    # 遍历batch_size，然后取每条序列的后len(i)-1 个元素中的知识点id为target_id
    target_id_seqs = np.array([[j[0] for j in i[1:]] for i in seqs_in])
    target_id = pad_sequences(target_id_seqs, maxlen=max_len, value=0)
    
    # 同target_id
    target_correctness_seqs = np.array([[j[1] for j in i[1:]] for i in seqs_in])
    target_correctness = pad_sequences(target_correctness_seqs, maxlen=max_len, value=0)
    
    return input_x, target_id, target_correctness, seq_len, max_len


class DataGenerator(object):
    # 导入的seqs是train_seqs，或者是test_seqs
    def __init__(self, seqs, batch_size, num_skills):
        self.seqs = seqs
        self.batch_size = batch_size
        self.pos = 0
        self.end = False
        self.size = len(seqs)
        self.num_skills = num_skills

    def next_batch(self):
        batch_size = self.batch_size
        if self.pos + batch_size < self.size:
            batch_seqs = self.seqs[self.pos:self.pos + batch_size]
            self.pos += batch_size
        else:
            batch_seqs = self.seqs[self.pos:]
            self.pos = self.size - 1
        if self.pos >= self.size - 1:
            self.end = True
            
        input_x, target_id, target_correctness, seqs_len, max_len = format_data(batch_seqs, batch_size, self.num_skills)
        return input_x, target_id, target_correctness, seqs_len, max_len

    def shuffle(self):
        # 用来打乱训练集
        self.pos = 0
        self.end = False
        np.random.shuffle(self.seqs)

    def reset(self):
        self.pos = 0
        self.end = False

In [5]:
batch_size = 10
train_generator = DataGenerator(train_seqs, batch_size=batch_size, num_skills=num_skills)
test_generator = DataGenerator(test_seqs, batch_size=batch_size, num_skills=num_skills)

In [8]:
# input_x shape=(batch_size, sequence_len, vector_size)
# target_id shape = (batch_size, sequence_len)
input_x, target_id, target_correctness, seqs_len, max_len = train_generator.next_batch()

print(np.array(input_x)[1, 1])
print(np.array(input_x).shape)

[ 0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0

In [7]:
# 定义模型流程图
class TensorFlowDKT(object):
    def __init__(self, config):
        self.hidden_neurons = hidden_neurons = config["hidden_neurons"]
        self.num_skills = num_skills = config["num_skills"]
        self.input_size = input_size = config["input_size"]
        self.batch_size = batch_size = config["batch_size"]
        self.keep_prob_value = config["keep_prob"]

        self.max_steps = tf.placeholder(tf.int32)  # 当前batch中最大序列长度
        self.input_data = tf.placeholder(tf.float32, [batch_size, None, input_size])
        self.sequence_len = tf.placeholder(tf.int32, [batch_size])
        self.keep_prob = tf.placeholder(tf.float32)  # dropout keep prob

        self.target_id = tf.placeholder(tf.int32, [batch_size, None])
        self.target_correctness = tf.placeholder(tf.float32, [batch_size, None])

        # 构建lstm模型结构
        hidden_layers = []
        for idx, hidden_size in enumerate(hidden_neurons):
            lstm_layer = tf.nn.rnn_cell.LSTMCell(num_units=hidden_size, state_is_tuple=True)
            hidden_layer = tf.nn.rnn_cell.DropoutWrapper(cell=lstm_layer,
                                                         output_keep_prob=self.keep_prob)
            hidden_layers.append(hidden_layer)
        self.hidden_cell = tf.nn.rnn_cell.MultiRNNCell(cells=hidden_layers, state_is_tuple=True)

        # 采用动态rnn，动态输入序列的长度
        outputs, self.current_state = tf.nn.dynamic_rnn(cell=self.hidden_cell,
                                                             inputs=self.input_data,
                                                             sequence_length=self.sequence_len,
                                                             dtype=tf.float32)

        # 隐层到输出层的权重系数[最后隐层的神经元数量，知识点数]
        output_w = tf.get_variable("W", [hidden_neurons[-1], num_skills])
        output_b = tf.get_variable("b", [num_skills])
       
        self.output = tf.reshape(outputs, [batch_size * self.max_steps, hidden_neurons[-1]])
        # 因为权值共享的原因，对生成的矩阵[batch_size * self.max_steps, num_skills]中的每一行都加上b
        self.logits = tf.matmul(self.output, output_w) + output_b

        self.mat_logits = tf.reshape(self.logits, [batch_size, self.max_steps, num_skills])
        # 对每个batch中每个序列中的每个时间点的输出中的每个值进行sigmoid计算，这里的值表示对某个知识点的掌握情况，
        # 每个时间点都会输出对所有知识点的掌握情况
        self.pred_all = tf.sigmoid(self.mat_logits)

        # 计算损失loss
        flat_logits = tf.reshape(self.logits, [-1])
        
        flat_target_correctness = tf.reshape(self.target_correctness, [-1])
        
        flat_base_target_index = tf.range(batch_size * self.max_steps) * num_skills
        
        # 因为flat_logits的长度为batch_size * num_steps * num_skills，我们要根据每一步的target_id将其长度变成batch_size * num_steps
        flat_base_target_id = tf.reshape(self.target_id, [-1])
        flat_base_target_correctness = tf.reshape(self.target_correctness, [-1]) * num_skills
        
        flat_target_id = flat_base_target_id + flat_base_target_index
        # gather是从一个tensor中切片一个子集
        flat_target_logits = tf.gather(flat_logits, flat_target_id)
        
        # 对切片后的数据进行sigmoid转换
        self.pred = tf.sigmoid(tf.reshape(flat_target_logits, [batch_size, self.max_steps]))
        # 将sigmoid后的值表示为0或1
        self.binary_pred = tf.cast(tf.greater_equal(self.pred, 0.5), tf.int32)
        self.loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=flat_target_correctness,
                                                                          logits=flat_target_logits))

        self.lr = tf.Variable(0.0, trainable=False)
        trainable_vars = tf.trainable_variables()
        self.grads, _ = tf.clip_by_global_norm(tf.gradients(self.loss, trainable_vars), 4)

        optimizer = tf.train.GradientDescentOptimizer(self.lr)
        self.train_op = optimizer.apply_gradients(zip(self.grads, trainable_vars))

    # step on batch
    def step(self, sess, input_x, target_id, target_correctness, sequence_len, is_train):
        _, max_steps, _ = input_x.shape
        input_feed = {self.input_data: input_x,
                      self.target_id: target_id,
                      self.target_correctness: target_correctness,
                      self.max_steps: max_steps,
                      self.sequence_len: sequence_len}

        if is_train:
            input_feed[self.keep_prob] = self.keep_prob_value
            train_loss, _, _ = sess.run([self.loss, self.train_op, self.current_state], input_feed)
            return train_loss
        else:
            input_feed[self.keep_prob] = 1
            bin_pred, pred, pred_all = sess.run([self.binary_pred, self.pred, self.pred_all], input_feed)
            return bin_pred, pred, pred_all

    def assign_lr(self, session, lr_value):
        session.run(tf.assign(self.lr, lr_value))

In [8]:
# 训练模型
def run():
    # process data
    seqs_by_student, num_skills = read_file("./data/assistments.txt")
    train_seqs, test_seqs = split_dataset(seqs_by_student)
    batch_size = 10
    train_generator = DataGenerator(train_seqs, batch_size=batch_size, num_skills=num_skills)
    test_generator = DataGenerator(test_seqs, batch_size=batch_size, num_skills=num_skills)

    # config and create model
    config = {"hidden_neurons": [200],
              "batch_size": batch_size,
              "keep_prob": 0.6,
              "num_skills": num_skills,
              "input_size": num_skills * 2}
    model = TensorFlowDKT(config)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    lr = 0.4
    lr_decay = 0.92
    # run epoch
    for epoch in range(10):
        # train
        model.assign_lr(sess, lr * lr_decay ** epoch)
        overall_loss = 0
        train_generator.shuffle()
        st = time.time()
        while not train_generator.end:
            input_x, target_id, target_correctness, seqs_len, max_len = train_generator.next_batch()
            loss = model.step(sess, input_x, target_id, target_correctness, seqs_len, is_train=True)
            print("\r idx:{0}, loss:{1}, time spent:{2}s".format(train_generator.pos, loss,
                                                                         time.time() - st),
            sys.stdout.flush())

        # test
        test_generator.reset()
        preds, binary_preds, targets = list(), list(), list()
        while not test_generator.end:
            input_x, target_id, target_correctness, seqs_len, max_len = test_generator.next_batch()
            binary_pred, pred, _ = model.step(sess, input_x, target_id, target_correctness, seqs_len, is_train=False)
            for seq_idx, seq_len in enumerate(seqs_len):
                preds.append(pred[seq_idx, 0:seq_len])
                binary_preds.append(binary_pred[seq_idx, 0:seq_len])
                targets.append(target_correctness[seq_idx, 0:seq_len])
        # compute metrics
        preds = np.concatenate(preds)
        print(preds.shape)
        binary_preds = np.concatenate(binary_preds)
        print(binary_preds.shape)
        targets = np.concatenate(targets)
        print(targets.shape)
        auc_value = roc_auc_score(targets, preds)
        accuracy = accuracy_score(targets, binary_preds)
        precision, recall, f_score, _ = precision_recall_fscore_support(targets, binary_preds)
        print("\n auc={0}, accuracy={1}, precision={2}, recall={3}".format(auc_value, accuracy, precision, recall))


In [9]:
run()

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "


 idx:10, loss:6411.697265625, time spent:1.2386512756347656s None
 idx:20, loss:2296.4375, time spent:2.053205728530884s None
 idx:30, loss:1869.512451171875, time spent:2.956787109375s None
 idx:40, loss:804.0730590820312, time spent:3.232927083969116s None
 idx:50, loss:1128.8704833984375, time spent:3.7211825847625732s None
 idx:60, loss:694.512939453125, time spent:3.969052791595459s None
 idx:70, loss:384.28076171875, time spent:4.200042247772217s None
 idx:80, loss:552.7459716796875, time spent:4.487602472305298s None
 idx:90, loss:768.07373046875, time spent:4.9169793128967285s None
 idx:100, loss:730.6004638671875, time spent:5.616422653198242s None
 idx:110, loss:350.6320495605469, time spent:5.7621355056762695s None
 idx:120, loss:412.0313415527344, time spent:6.129215478897095s None
 idx:130, loss:430.2217102050781, time spent:6.407047748565674s None
 idx:140, loss:373.2222595214844, time spent:6.646601676940918s None
 idx:150, loss:714.2084350585938, time spent:7.1825590133

 idx:1200, loss:410.6839904785156, time spent:66.99108290672302s None
 idx:1210, loss:768.4559326171875, time spent:67.6207480430603s None
 idx:1220, loss:399.8243103027344, time spent:68.40451049804688s None
 idx:1230, loss:228.4388885498047, time spent:68.55985689163208s None
 idx:1240, loss:842.420654296875, time spent:69.45628333091736s None
 idx:1250, loss:177.9541473388672, time spent:69.54384875297546s None
 idx:1260, loss:598.5245361328125, time spent:70.34012198448181s None
 idx:1270, loss:245.306640625, time spent:70.5792396068573s None
 idx:1280, loss:265.4898681640625, time spent:70.684410572052s None
 idx:1290, loss:618.1990966796875, time spent:71.21656918525696s None
 idx:1300, loss:712.2733764648438, time spent:71.75980758666992s None
 idx:1310, loss:1849.6370849609375, time spent:72.62880754470825s None
 idx:1320, loss:255.61744689941406, time spent:72.8344783782959s None
 idx:1330, loss:1239.06787109375, time spent:73.73756980895996s None
 idx:1340, loss:444.116577148

 idx:2370, loss:261.5081787109375, time spent:142.74173140525818s None
 idx:2380, loss:801.0623779296875, time spent:143.20153522491455s None
 idx:2390, loss:397.4521789550781, time spent:143.46886467933655s None
 idx:2400, loss:514.7828979492188, time spent:143.77644896507263s None
 idx:2410, loss:283.4164733886719, time spent:143.92335081100464s None
 idx:2420, loss:753.822509765625, time spent:144.59182476997375s None
 idx:2430, loss:711.7578125, time spent:145.3080449104309s None
 idx:2440, loss:233.35134887695312, time spent:145.38576292991638s None
 idx:2450, loss:1170.473876953125, time spent:145.90797472000122s None
 idx:2460, loss:473.2379455566406, time spent:146.94353222846985s None
 idx:2470, loss:601.5134887695312, time spent:147.56432175636292s None
 idx:2480, loss:204.04148864746094, time spent:147.68128275871277s None
 idx:2490, loss:292.52874755859375, time spent:147.81913590431213s None
 idx:2500, loss:924.77734375, time spent:148.48419284820557s None
 idx:2510, loss:

 idx:130, loss:519.2550048828125, time spent:6.281514883041382s None
 idx:140, loss:246.31854248046875, time spent:6.562132835388184s None
 idx:150, loss:725.597412109375, time spent:7.282762289047241s None
 idx:160, loss:314.1200256347656, time spent:7.412705898284912s None
 idx:170, loss:638.5130615234375, time spent:8.403774738311768s None
 idx:180, loss:185.73922729492188, time spent:8.476576805114746s None
 idx:190, loss:446.99151611328125, time spent:8.800638198852539s None
 idx:200, loss:899.3628540039062, time spent:9.67365288734436s None
 idx:210, loss:1792.068359375, time spent:11.061240911483765s None
 idx:220, loss:1046.73828125, time spent:15.521221399307251s None
 idx:230, loss:615.936279296875, time spent:15.903348922729492s None
 idx:240, loss:1120.833740234375, time spent:16.607744932174683s None
 idx:250, loss:436.341796875, time spent:17.66783356666565s None
 idx:260, loss:1072.687255859375, time spent:18.462368726730347s None
 idx:270, loss:1021.2777709960938, time 

 idx:1320, loss:682.7471923828125, time spent:86.7038402557373s None
 idx:1330, loss:756.4534301757812, time spent:87.1572585105896s None
 idx:1340, loss:645.091796875, time spent:87.69128370285034s None
 idx:1350, loss:270.00372314453125, time spent:87.88241004943848s None
 idx:1360, loss:467.2586364746094, time spent:88.5405912399292s None
 idx:1370, loss:428.13299560546875, time spent:88.99136281013489s None
 idx:1380, loss:315.8162841796875, time spent:89.19093585014343s None
 idx:1390, loss:462.8202209472656, time spent:89.37779259681702s None
 idx:1400, loss:514.3253173828125, time spent:89.70245242118835s None
 idx:1410, loss:257.3545227050781, time spent:91.86869597434998s None
 idx:1420, loss:230.70523071289062, time spent:92.12386298179626s None
 idx:1430, loss:489.19171142578125, time spent:92.3139054775238s None
 idx:1440, loss:597.9323120117188, time spent:93.69416546821594s None
 idx:1450, loss:194.90664672851562, time spent:93.80499029159546s None
 idx:1460, loss:417.469

 idx:2490, loss:443.92608642578125, time spent:144.3009147644043s None
 idx:2500, loss:330.4812927246094, time spent:144.59301614761353s None
 idx:2510, loss:460.8628845214844, time spent:145.30543565750122s None
 idx:2520, loss:735.3001098632812, time spent:146.29078769683838s None
 idx:2530, loss:665.72802734375, time spent:147.3133909702301s None
 idx:2540, loss:810.57470703125, time spent:148.00355982780457s None
 idx:2550, loss:394.7431640625, time spent:148.17621660232544s None
 idx:2560, loss:945.1940307617188, time spent:149.09766459465027s None
 idx:2570, loss:239.21142578125, time spent:149.3643229007721s None
 idx:2580, loss:232.85250854492188, time spent:149.5208888053894s None
 idx:2590, loss:470.0151672363281, time spent:150.09727478027344s None
 idx:2600, loss:607.6190795898438, time spent:150.70494842529297s None
 idx:2610, loss:703.98291015625, time spent:151.50459384918213s None
 idx:2620, loss:1309.82568359375, time spent:152.33807277679443s None
 idx:2630, loss:844.

 idx:260, loss:366.4255065917969, time spent:16.249080419540405s None
 idx:270, loss:412.31842041015625, time spent:17.073867559432983s None
 idx:280, loss:734.9402465820312, time spent:17.436878204345703s None
 idx:290, loss:565.094970703125, time spent:18.027386903762817s None
 idx:300, loss:339.8606872558594, time spent:18.217916011810303s None
 idx:310, loss:481.4337463378906, time spent:18.904322385787964s None
 idx:320, loss:1556.5628662109375, time spent:19.836849212646484s None
 idx:330, loss:925.9843139648438, time spent:20.55491352081299s None
 idx:340, loss:1117.0235595703125, time spent:21.690000772476196s None
 idx:350, loss:208.7222900390625, time spent:21.815455675125122s None
 idx:360, loss:409.28619384765625, time spent:22.05802631378174s None
 idx:370, loss:266.66619873046875, time spent:23.092935800552368s None
 idx:380, loss:204.6671142578125, time spent:23.25142502784729s None
 idx:390, loss:1064.8917236328125, time spent:24.139179468154907s None
 idx:400, loss:183

 idx:1440, loss:302.18646240234375, time spent:83.60148334503174s None
 idx:1450, loss:1034.006103515625, time spent:84.6109688282013s None
 idx:1460, loss:150.79823303222656, time spent:84.74998068809509s None
 idx:1470, loss:1010.8372802734375, time spent:85.78819274902344s None
 idx:1480, loss:516.1414794921875, time spent:86.06290102005005s None
 idx:1490, loss:62.226829528808594, time spent:86.1006350517273s None
 idx:1500, loss:1206.0478515625, time spent:87.31481122970581s None
 idx:1510, loss:479.6680908203125, time spent:88.4419014453888s None
 idx:1520, loss:256.11981201171875, time spent:88.66863465309143s None
 idx:1530, loss:1081.1768798828125, time spent:89.35572266578674s None
 idx:1540, loss:298.6546630859375, time spent:90.57133865356445s None
 idx:1550, loss:398.951171875, time spent:90.85570478439331s None
 idx:1560, loss:312.5352783203125, time spent:91.08828115463257s None
 idx:1570, loss:393.9125061035156, time spent:91.22119235992432s None
 idx:1580, loss:348.806

 idx:2610, loss:274.4029541015625, time spent:154.79099440574646s None
 idx:2620, loss:320.9650573730469, time spent:155.0629904270172s None
 idx:2630, loss:290.6509094238281, time spent:155.26835346221924s None
 idx:2640, loss:350.2647705078125, time spent:155.5593888759613s None
 idx:2650, loss:194.48257446289062, time spent:155.75509977340698s None
 idx:2660, loss:200.90286254882812, time spent:155.90821647644043s None
 idx:2670, loss:158.04579162597656, time spent:156.10122418403625s None
 idx:2680, loss:855.1045532226562, time spent:156.91870522499084s None
 idx:2690, loss:287.3997802734375, time spent:157.07012248039246s None
 idx:2700, loss:1028.1507568359375, time spent:157.98401474952698s None
 idx:2710, loss:1558.948486328125, time spent:158.94495701789856s None
 idx:2720, loss:1095.6976318359375, time spent:159.99234056472778s None
 idx:2730, loss:435.7896423339844, time spent:160.3836579322815s None
 idx:2740, loss:1475.985595703125, time spent:161.54235243797302s None
 idx

 idx:380, loss:687.5408325195312, time spent:21.9456045627594s None
 idx:390, loss:286.572509765625, time spent:22.76778221130371s None
 idx:400, loss:157.32102966308594, time spent:22.848662853240967s None
 idx:410, loss:356.65032958984375, time spent:23.007899522781372s None
 idx:420, loss:771.5635986328125, time spent:23.738478422164917s None
 idx:430, loss:1010.8618774414062, time spent:24.78722906112671s None
 idx:440, loss:198.801513671875, time spent:24.929698944091797s None
 idx:450, loss:328.73675537109375, time spent:25.521129846572876s None
 idx:460, loss:355.1948547363281, time spent:25.70844316482544s None
 idx:470, loss:159.5861053466797, time spent:25.765715837478638s None
 idx:480, loss:1689.6185302734375, time spent:26.88099503517151s None
 idx:490, loss:417.24835205078125, time spent:27.18049693107605s None
 idx:500, loss:239.5913848876953, time spent:27.3864004611969s None
 idx:510, loss:679.0472412109375, time spent:28.068891763687134s None
 idx:520, loss:113.273406

 idx:1560, loss:499.41094970703125, time spent:92.32875776290894s None
 idx:1570, loss:637.1234741210938, time spent:93.03007650375366s None
 idx:1580, loss:584.3298950195312, time spent:93.63779091835022s None
 idx:1590, loss:1268.532470703125, time spent:94.50250363349915s None
 idx:1600, loss:301.98779296875, time spent:94.68514132499695s None
 idx:1610, loss:387.2406311035156, time spent:94.91223120689392s None
 idx:1620, loss:254.57418823242188, time spent:95.40738344192505s None
 idx:1630, loss:1321.1787109375, time spent:96.19177079200745s None
 idx:1640, loss:412.93292236328125, time spent:96.38940119743347s None
 idx:1650, loss:341.48455810546875, time spent:96.81954288482666s None
 idx:1660, loss:360.3186950683594, time spent:96.97041535377502s None
 idx:1670, loss:727.785888671875, time spent:97.72850322723389s None
 idx:1680, loss:536.11767578125, time spent:98.36602878570557s None
 idx:1690, loss:762.5078125, time spent:98.93793320655823s None
 idx:1700, loss:587.853149414

 idx:2730, loss:379.95477294921875, time spent:152.41681337356567s None
 idx:2740, loss:606.5303344726562, time spent:152.66905212402344s None
 idx:2750, loss:201.6486358642578, time spent:152.74929976463318s None
 idx:2760, loss:339.46453857421875, time spent:152.99645519256592s None
 idx:2770, loss:1347.778076171875, time spent:153.72659015655518s None
 idx:2780, loss:289.2571716308594, time spent:153.88097524642944s None
 idx:2790, loss:260.98046875, time spent:154.59164786338806s None
 idx:2800, loss:670.0189208984375, time spent:155.58402919769287s None
 idx:2810, loss:214.95309448242188, time spent:155.8772099018097s None
 idx:2820, loss:205.2705841064453, time spent:156.60144090652466s None
 idx:2830, loss:560.436279296875, time spent:156.82201838493347s None
 idx:2840, loss:157.8976287841797, time spent:156.8965368270874s None
 idx:2850, loss:556.0476684570312, time spent:157.58346843719482s None
 idx:2860, loss:1287.6324462890625, time spent:158.5136525630951s None
 idx:2870, 

 idx:500, loss:429.8442077636719, time spent:26.094717025756836s None
 idx:510, loss:377.2572021484375, time spent:26.35595202445984s None
 idx:520, loss:827.9777221679688, time spent:27.458606719970703s None
 idx:530, loss:219.4903564453125, time spent:27.705005884170532s None
 idx:540, loss:760.0774536132812, time spent:28.305418729782104s None
 idx:550, loss:1140.6829833984375, time spent:29.342132091522217s None
 idx:560, loss:537.9119873046875, time spent:29.64002013206482s None
 idx:570, loss:648.543701171875, time spent:30.51952338218689s None
 idx:580, loss:279.8776550292969, time spent:30.6615948677063s None
 idx:590, loss:846.8961181640625, time spent:31.617297887802124s None
 idx:600, loss:745.362060546875, time spent:32.37601661682129s None
 idx:610, loss:294.7541198730469, time spent:32.482248306274414s None
 idx:620, loss:277.0596923828125, time spent:33.66624569892883s None
 idx:630, loss:1161.827392578125, time spent:34.438621044158936s None
 idx:640, loss:260.268524169

 idx:1680, loss:431.89935302734375, time spent:96.98751950263977s None
 idx:1690, loss:260.83056640625, time spent:97.27112913131714s None
 idx:1700, loss:900.7666625976562, time spent:98.08158731460571s None
 idx:1710, loss:626.664794921875, time spent:98.41043424606323s None
 idx:1720, loss:330.6615905761719, time spent:98.72637486457825s None
 idx:1730, loss:258.7769470214844, time spent:98.90488791465759s None
 idx:1740, loss:340.2421569824219, time spent:99.14391493797302s None
 idx:1750, loss:1029.47216796875, time spent:99.85184478759766s None


KeyboardInterrupt: 