# Graph Attention Network

\[[paper](https://arxiv.org/abs/1710.10903)\] , \[[original code](https://github.com/PetarV-/GAT)\] , \[[all other implementations](https://paperswithcode.com/paper/graph-attention-networks)\]

From [Graph Convolutional Network (GCN)](https://arxiv.org/abs/1609.02907), we learned that combining local graph structure and node-level features yields good performance on node classification task. Hwever, the way GCN aggregates is structure-dependent, which may hurt its generalizability.

One workaround is to simply average over all neighbor node features as in GraphSAGE. Graph Attention Network proposes an alternative way by weighting neighbor features with feature dependent and structure free normalization, in the style of attention.

The goal of this tutorial:

- Explain what is Graph Attention Network.
- Understand the attentions learnt.
- Introduce to inductive learning.

Introducing Attention to GCN
----------------------------

The key difference between GAT and GCN is how the information from the one-hop neighborhood is aggregated.

For GCN, a graph convolution operation produces the normalized sum of the node features of neighbors:


$$h_i^{(l+1)}=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\frac{1}{c_{ij}} W^{(l)}h^{(l)}_j}\right)$$

where $\mathcal{N}(i)$ is the set of its one-hop neighbors (to include $v_i$ in the set, simply add a self-loop to each node),
$c_{ij}=\sqrt{|\mathcal{N}(i)|}\sqrt{|\mathcal{N}(j)|}$ is a normalization constant based on graph structure, $\sigma$ is an activation function (GCN uses ReLU), and $W^{(l)}$ is a shared weight matrix for node-wise feature transformation. Another model proposed in
[GraphSAGE](https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf)
employs the same update rule except that they set
$c_{ij}=|\mathcal{N}(i)|$.

GAT introduces the attention mechanism as a substitute for the statically
normalized convolution operation. Below are the equations to compute the node
embedding $h_i^{(l+1)}$ of layer $l+1$ from the embeddings of
layer $l$:

<img src="https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/gat.png" height="350" width="450" align="center">

In [2]:
import os,sys,inspect
import os
import joblib
import tensorflow as tf
import numpy as np
import h5py
import scipy.sparse.linalg as la
import scipy.sparse as sp
import scipy
import time
import pickle

import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
%matplotlib inline

import scipy.io as sio

import process_data

In [13]:
def count_no_weights(self):
    total_parameters = 0
    for variable in tf.trainable_variables():
        # shape is an array of tf.Dimension
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        total_parameters += variable_parameters
    print('#weights in the model: %d' % (total_parameters,))

def frobenius_norm(self, tensor):
    square_tensor = tf.square(tensor)
    tensor_sum = tf.reduce_sum(square_tensor)
    frobenius_norm = tf.sqrt(tensor_sum)
    return frobenius_norm


In [14]:
class GAT:
    
    """
    The neural network model.
    """
    def __init__(self, idx_rows, idx_cols, A_shape, X, Y, num_hidden_feat, n_heads, learning_rate=5e-2, gamma=1e-3, idx_gpu = '/gpu:3'):
        
        self.num_hidden_feat = num_hidden_feat
        self.learning_rate = learning_rate
        self.gamma=gamma
        with tf.Graph().as_default() as g:
                self.graph = g
                
                with tf.device(idx_gpu):
                            
                        # list of weights' tensors l2-loss 
                        self.regularizers = []
                            
                        #definition of constant matrices
                        self.X = tf.constant(X, dtype=tf.float32) 
                        self.Y = tf.constant(Y, dtype=tf.float32)
                        
                        #placeholder definition
                        self.idx_nodes = tf.placeholder(tf.int32)
                        self.keep_prob = tf.placeholder(tf.float32)
                        
                        #model definition
                        
                        self.X0 = []
                        for k in range(n_heads):
                            with tf.variable_scope('GCL_1_{}'.format(k+1)):
                                self.X0.append(self.GAT_layer(self.X, num_hidden_feat, idx_rows, idx_cols, A_shape, tf.nn.elu))
                        self.X0 = tf.concat(self.X0, 1)
                        
                        with tf.variable_scope('GCL_2'):
                            self.logits = self.GAT_layer(self.X0, Y.shape[1], idx_rows, idx_cols, A_shape, tf.identity)
                        
                        self.l_out = tf.gather(self.logits, self.idx_nodes)
                        self.c_Y = tf.gather(self.Y, self.idx_nodes)
                        
                        #loss function definition
                        self.l2_reg = tf.reduce_sum(self.regularizers)
                        self.data_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.l_out, labels=self.c_Y)) 
                        
                        self.loss = self.data_loss + self.gamma*self.l2_reg
                        
                        #solver definition
                        self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
                        self.opt_step = self.optimizer.minimize(self.loss)
                        
                        #predictions and accuracy extraction
                        self.c_predictions = tf.argmax(tf.nn.softmax(self.l_out), 1)
                        self.accuracy = tf.contrib.metrics.accuracy(self.c_predictions, tf.argmax(self.c_Y, 1))
                        
                        #gradients computation
                        self.trainable_variables = tf.trainable_variables()
                        self.var_grad = tf.gradients(self.loss, tf.trainable_variables())
                        self.norm_grad = frobenius_norm(tf.concat([tf.reshape(g, [-1]) for g in self.var_grad], 0))
                        
                        #session creation
                        config = tf.ConfigProto(allow_soft_placement = True)
                        config.gpu_options.allow_growth = True
                        self.session = tf.Session(config=config)

                        #session initialization
                        init = tf.global_variables_initializer()
                        self.session.run(init)
                        
                        count_no_weights()

    def GAT_layer(self, X, Fout, idx_rows, idx_cols, A_shape, activation):
        X = tf.nn.dropout(X,  self.keep_prob)
        
        W = tf.get_variable("W", shape=[X.shape[1], Fout], initializer=tf.glorot_uniform_initializer())
        self.regularizers.append(tf.nn.l2_loss(W))
        X_w = tf.matmul(X, W)

        # simplest possible attention mechanism
        W_att1 = tf.get_variable("W_att1", shape=[X_w.shape[1], 1], initializer=tf.glorot_uniform_initializer())
        b_att1 = tf.get_variable("b_att1", shape=[1,], initializer=tf.zeros_initializer())
        self.regularizers.append(tf.nn.l2_loss(W_att1))
        W_att2 = tf.get_variable("W_att2", shape=[X_w.shape[1], 1], initializer=tf.glorot_uniform_initializer())
        b_att2 = tf.get_variable("b_att2", shape=[1,], initializer=tf.zeros_initializer())
        self.regularizers.append(tf.nn.l2_loss(W_att2))
                            
        X_att_1 = tf.squeeze(tf.matmul(X_w, W_att1)) + b_att1
        X_att_2 = tf.squeeze(tf.matmul(X_w, W_att2)) + b_att2
        
        logits = tf.gather(X_att_1, idx_rows) +  tf.gather(X_att_2, idx_cols)
                            
        A_att = tf.SparseTensor(indices=np.vstack([idx_rows, idx_cols]).T, 
                                values=tf.nn.leaky_relu(logits), 
                                dense_shape=A_shape)
        A_att = tf.sparse_reorder(A_att)
        A_att = tf.sparse_softmax(A_att)
        
        # apply dropout
        A_att = tf.SparseTensor(indices=A_att.indices,
                                values=tf.nn.dropout(A_att.values, self.keep_prob),
                                dense_shape=A_shape)
        A_att = tf.sparse_reorder(A_att)

        X_w = tf.nn.dropout(X_w, self.keep_prob)
        res = tf.sparse_tensor_dense_matmul(A_att, X_w)
        res = tf.contrib.layers.bias_add(res)

        return activation(res)
     

Multi-head Attention
^^^^^^^^^^^^^^^^^^^^

Analogous to multiple channels in ConvNet, GAT introduces **multi-head
attention** to enrich the model capacity and to stabilize the learning
process. Each attention head has its own parameters and their outputs can be
merged in two ways:

\begin{align}\text{concatenation}: h^{(l+1)}_{i} =||_{k=1}^{K}\sigma\left(\sum_{j\in \mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)\end{align}

or

\begin{align}\text{average}: h_{i}^{(l+1)}=\sigma\left(\frac{1}{K}\sum_{k=1}^{K}\sum_{j\in\mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)\end{align}

where $K$ is the number of heads. The authors suggest using
concatenation for intermediary layers and average for the final layer.


In [4]:
#learning parameters and path dataset

learning_rate = 5e-3
val_test_interval = 1
num_hidden_feat = 8
n_heads = 8
gamma = 5e-4
patience = 100
path_dataset = './CORA/dataset.pickle'
    
#dataset loading
#ds = Dataset(path_dataset, normalize_feat=1)

A, X, Y, train_idx, val_idx, test_idx = process_data.load_data("cora")
X = process_data.preprocess_features(X)

(2708, 2708)
(2708, 1433)


In [5]:
# extracts rows and cols of adjacency matrix
A = sp.csr_matrix(A)
A.setdiag(1)

idx_rows, idx_cols = A.nonzero()

  self._set_arrayXarray(i, j, x)


In [10]:
from tqdm import tqdm

In [7]:
# num_exp = 10 #number of times training GCN over the given dataset
num_exp = 1 #number of times training GCN over the given dataset

list_all_acc = []
list_all_cost_val_avg  = []
list_all_data_cost_val_avg = []
list_all_acc_val_avg   = []
list_all_cost_test_avg = []
list_all_acc_test_avg  = []

num_done = 0

In [12]:
num_total_iter_training = int(10e4)

GCNN = GAT(idx_rows, idx_cols, A.shape, X, Y, num_hidden_feat, n_heads, learning_rate=learning_rate, gamma=gamma)

cost_train_avg      = []
grad_norm_train_avg = []
acc_train_avg       = []
cost_test_avg       = []
grad_norm_test_avg  = []
acc_test_avg        = []
cost_val_avg        = []
data_cost_val_avg   = []
acc_val_avg         = []
iter_test           = []
list_training_time = list()

max_val_acc = 0
min_val_loss = np.inf

#Training code
for i in tqdm(range(num_total_iter_training)):
    if (len(cost_train_avg) % val_test_interval) == 0:
        #Print last training performance
        if (len(cost_train_avg)>0):
            tqdm.write("[TRN] epoch = %03i, cost = %3.2e, |grad| = %.2e, acc = %3.2e (%03.2fs)" % \
            (len(cost_train_avg), cost_train_avg[-1], grad_norm_train_avg[-1], acc_train_avg[-1], time.time() - tic))

        #Validate the model
        tic = time.time()

        feed_dict = {GCNN.idx_nodes: val_idx, GCNN.keep_prob:1.0}
        acc_val, cost_val, data_cost_val = GCNN.session.run([GCNN.accuracy, GCNN.loss, GCNN.data_loss], feed_dict)

        data_cost_val_avg.append(data_cost_val)
        cost_val_avg.append(cost_val)
        acc_val_avg.append(acc_val)
        tqdm.write("[VAL] epoch = %03i, data_cost = %3.2e, cost = %3.2e, acc = %3.2e (%03.2fs)" % \
            (len(cost_train_avg), data_cost_val_avg[-1], cost_val_avg[-1], acc_val_avg[-1],  time.time() - tic))

        #Test the model
        tic = time.time()

        feed_dict = {GCNN.idx_nodes: test_idx, GCNN.keep_prob:1.0}
        acc_test, cost_test = GCNN.session.run([GCNN.accuracy, GCNN.loss], feed_dict)

        cost_test_avg.append(cost_test)
        acc_test_avg.append(acc_test)
        tqdm.write("[TST] epoch = %03i, cost = %3.2e, acc = %3.2e (%03.2fs)" % \
            (len(cost_train_avg), cost_test_avg[-1], acc_test_avg[-1],  time.time() - tic))
        iter_test.append(len(cost_train_avg))


        if acc_val_avg[-1] >= max_val_acc or data_cost_val_avg[-1] <= min_val_loss:
            max_val_acc = np.maximum(acc_val_avg[-1], max_val_acc)
            min_val_loss = np.minimum(data_cost_val_avg[-1], min_val_loss)
            if acc_val_avg[-1] >= max_val_acc and data_cost_val_avg[-1] <= min_val_loss:
                best_model_test_acc = acc_test_avg[-1]
            curr_step = 0
        else:
            curr_step += 1
            if curr_step == patience:
                tqdm.write('Early stop! Min loss: ', min_val_loss, ', Max accuracy: ', max_val_acc)
                break

    tic = time.time()
    feed_dict = {GCNN.idx_nodes: train_idx, GCNN.keep_prob: 0.4}

    _, current_training_loss, norm_grad, current_acc_training = GCNN.session.run([GCNN.opt_step, GCNN.loss, GCNN.norm_grad, GCNN.accuracy], feed_dict) 

    training_time = time.time() - tic   

    cost_train_avg.append(current_training_loss)
    grad_norm_train_avg.append(norm_grad)
    acc_train_avg.append(current_acc_training)


#Compute and print statistics of the last realized experiment
list_all_acc.append(100*best_model_test_acc)
list_all_cost_val_avg.append(cost_val_avg)
list_all_data_cost_val_avg.append(data_cost_val_avg)
list_all_acc_val_avg.append(acc_val_avg)
list_all_cost_test_avg.append(cost_test_avg)
list_all_acc_test_avg.append(acc_test_avg)

print('Num done: %d' % num_done)
print('Max accuracy on test set achieved: %f%%' % np.max(np.asarray(acc_test_avg)*100))
print('Max suggested accuracy: %f%%' % (100*best_model_test_acc))#(np.asarray(acc_test_avg)[np.asarray(data_cost_val_avg)==np.min(data_cost_val_avg)]),))
print('Current mean: %f%%' % np.mean(list_all_acc))
print('Current std: %f' % np.std(list_all_acc))

num_done += 1


  0%|          | 0/100000 [00:00<?, ?it/s][A

#weights in the model: 92391


                                                                
  0%|          | 77/10000000000 [00:50<1147570:33:28,  2.42it/s]
  0%|          | 0/100000 [00:00<?, ?it/s][A

[VAL] epoch = 000, data_cost = 1.95e+00, cost = 1.99e+00, acc = 1.46e-01 (0.45s)


                                                                
  0%|          | 77/10000000000 [00:51<1147570:33:28,  2.42it/s]
  0%|          | 0/100000 [00:00<?, ?it/s][A

[TST] epoch = 000, cost = 1.99e+00, acc = 1.28e-01 (0.43s)



                                                                
  0%|          | 77/10000000000 [00:52<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:53<1147570:33:28,  2.42it/s]
  0%|          | 1/100000 [00:02<74:23:27,  2.68s/it][A

[TRN] epoch = 001, cost = 1.99e+00, |grad| = 2.57e-01, acc = 1.50e-01 (1.78s)
[VAL] epoch = 001, data_cost = 1.95e+00, cost = 1.98e+00, acc = 5.80e-02 (0.09s)


                                                                
  0%|          | 77/10000000000 [00:53<1147570:33:28,  2.42it/s]
  0%|          | 1/100000 [00:02<74:23:27,  2.68s/it][A
                                                                
  0%|          | 77/10000000000 [00:53<1147570:33:28,  2.42it/s]
  0%|          | 2/100000 [00:03<55:10:19,  1.99s/it][A

[TST] epoch = 001, cost = 1.98e+00, acc = 6.70e-02 (0.10s)
[TRN] epoch = 002, cost = 1.98e+00, |grad| = 2.91e-01, acc = 1.71e-01 (0.15s)


                                                                
  0%|          | 77/10000000000 [00:53<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:53<1147570:33:28,  2.42it/s]
  0%|          | 2/100000 [00:03<55:10:19,  1.99s/it][A

[VAL] epoch = 002, data_cost = 1.95e+00, cost = 1.98e+00, acc = 5.80e-02 (0.10s)
[TST] epoch = 002, cost = 1.98e+00, acc = 6.50e-02 (0.10s)



                                                                
  0%|          | 77/10000000000 [00:53<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:53<1147570:33:28,  2.42it/s]
  0%|          | 3/100000 [00:03<42:10:09,  1.52s/it][A

[TRN] epoch = 003, cost = 1.97e+00, |grad| = 2.14e-01, acc = 1.79e-01 (0.21s)
[VAL] epoch = 003, data_cost = 1.95e+00, cost = 1.98e+00, acc = 5.80e-02 (0.10s)


                                                                
  0%|          | 77/10000000000 [00:53<1147570:33:28,  2.42it/s]
  0%|          | 3/100000 [00:03<42:10:09,  1.52s/it][A
                                                                
  0%|          | 77/10000000000 [00:54<1147570:33:28,  2.42it/s]
  0%|          | 4/100000 [00:03<32:59:55,  1.19s/it][A

[TST] epoch = 003, cost = 1.97e+00, acc = 6.40e-02 (0.12s)
[TRN] epoch = 004, cost = 1.98e+00, |grad| = 2.69e-01, acc = 1.07e-01 (0.18s)


                                                                
  0%|          | 77/10000000000 [00:54<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:54<1147570:33:28,  2.42it/s]
  0%|          | 4/100000 [00:04<32:59:55,  1.19s/it][A

[VAL] epoch = 004, data_cost = 1.94e+00, cost = 1.97e+00, acc = 5.80e-02 (0.11s)
[TST] epoch = 004, cost = 1.97e+00, acc = 6.50e-02 (0.11s)



                                                                
  0%|          | 77/10000000000 [00:54<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:54<1147570:33:28,  2.42it/s]
  0%|          | 5/100000 [00:04<26:38:06,  1.04it/s][A

[TRN] epoch = 005, cost = 1.97e+00, |grad| = 2.06e-01, acc = 1.29e-01 (0.18s)
[VAL] epoch = 005, data_cost = 1.94e+00, cost = 1.97e+00, acc = 6.60e-02 (0.09s)


                                                                
  0%|          | 77/10000000000 [00:54<1147570:33:28,  2.42it/s]
  0%|          | 5/100000 [00:04<26:38:06,  1.04it/s][A
                                                                
  0%|          | 77/10000000000 [00:54<1147570:33:28,  2.42it/s]
  0%|          | 6/100000 [00:04<21:44:21,  1.28it/s][A

[TST] epoch = 005, cost = 1.96e+00, acc = 6.90e-02 (0.09s)
[TRN] epoch = 006, cost = 1.96e+00, |grad| = 2.32e-01, acc = 1.64e-01 (0.16s)


                                                                
  0%|          | 77/10000000000 [00:55<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:55<1147570:33:28,  2.42it/s]
  0%|          | 6/100000 [00:04<21:44:21,  1.28it/s][A

[VAL] epoch = 006, data_cost = 1.94e+00, cost = 1.96e+00, acc = 9.00e-02 (0.10s)
[TST] epoch = 006, cost = 1.96e+00, acc = 1.00e-01 (0.09s)



                                                                
  0%|          | 77/10000000000 [00:55<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:55<1147570:33:28,  2.42it/s]
  0%|          | 7/100000 [00:05<18:20:54,  1.51it/s][A

[TRN] epoch = 007, cost = 1.96e+00, |grad| = 2.26e-01, acc = 2.07e-01 (0.17s)
[VAL] epoch = 007, data_cost = 1.93e+00, cost = 1.96e+00, acc = 1.58e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [00:55<1147570:33:28,  2.42it/s]
  0%|          | 7/100000 [00:05<18:20:54,  1.51it/s][A
                                                                
  0%|          | 77/10000000000 [00:55<1147570:33:28,  2.42it/s]
  0%|          | 8/100000 [00:05<16:04:54,  1.73it/s][A

[TST] epoch = 007, cost = 1.96e+00, acc = 1.93e-01 (0.09s)
[TRN] epoch = 008, cost = 1.95e+00, |grad| = 2.06e-01, acc = 2.86e-01 (0.17s)


                                                                
  0%|          | 77/10000000000 [00:55<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:55<1147570:33:28,  2.42it/s]
  0%|          | 8/100000 [00:05<16:04:54,  1.73it/s][A

[VAL] epoch = 008, data_cost = 1.93e+00, cost = 1.96e+00, acc = 3.28e-01 (0.09s)
[TST] epoch = 008, cost = 1.95e+00, acc = 4.17e-01 (0.10s)



                                                                
  0%|          | 77/10000000000 [00:56<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:56<1147570:33:28,  2.42it/s]
  0%|          | 9/100000 [00:05<14:15:26,  1.95it/s][A

[TRN] epoch = 009, cost = 1.96e+00, |grad| = 2.62e-01, acc = 1.79e-01 (0.15s)
[VAL] epoch = 009, data_cost = 1.93e+00, cost = 1.95e+00, acc = 4.68e-01 (0.11s)


                                                                
  0%|          | 77/10000000000 [00:56<1147570:33:28,  2.42it/s]
  0%|          | 9/100000 [00:06<14:15:26,  1.95it/s][A
                                                                
  0%|          | 77/10000000000 [00:56<1147570:33:28,  2.42it/s]
  0%|          | 10/100000 [00:06<13:01:18,  2.13it/s][A

[TST] epoch = 009, cost = 1.95e+00, acc = 5.49e-01 (0.09s)
[TRN] epoch = 010, cost = 1.95e+00, |grad| = 2.40e-01, acc = 2.71e-01 (0.16s)


                                                                
  0%|          | 77/10000000000 [00:56<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:56<1147570:33:28,  2.42it/s]
  0%|          | 10/100000 [00:06<13:01:18,  2.13it/s][A

[VAL] epoch = 010, data_cost = 1.93e+00, cost = 1.95e+00, acc = 5.58e-01 (0.10s)
[TST] epoch = 010, cost = 1.95e+00, acc = 6.22e-01 (0.10s)



                                                                
  0%|          | 77/10000000000 [00:56<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:56<1147570:33:28,  2.42it/s]
  0%|          | 11/100000 [00:06<12:27:51,  2.23it/s][A

[TRN] epoch = 011, cost = 1.95e+00, |grad| = 2.33e-01, acc = 2.29e-01 (0.18s)
[VAL] epoch = 011, data_cost = 1.92e+00, cost = 1.95e+00, acc = 6.14e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [00:57<1147570:33:28,  2.42it/s]
  0%|          | 11/100000 [00:06<12:27:51,  2.23it/s][A
                                                                
  0%|          | 77/10000000000 [00:57<1147570:33:28,  2.42it/s]
  0%|          | 12/100000 [00:07<12:09:49,  2.28it/s][A

[TST] epoch = 011, cost = 1.95e+00, acc = 6.54e-01 (0.10s)
[TRN] epoch = 012, cost = 1.93e+00, |grad| = 2.50e-01, acc = 3.07e-01 (0.18s)


                                                                
  0%|          | 77/10000000000 [00:57<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:57<1147570:33:28,  2.42it/s]
  0%|          | 12/100000 [00:07<12:09:49,  2.28it/s][A

[VAL] epoch = 012, data_cost = 1.92e+00, cost = 1.94e+00, acc = 6.42e-01 (0.10s)
[TST] epoch = 012, cost = 1.94e+00, acc = 6.58e-01 (0.10s)



                                                                
  0%|          | 77/10000000000 [00:57<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:57<1147570:33:28,  2.42it/s]
  0%|          | 13/100000 [00:07<11:53:58,  2.33it/s][A

[TRN] epoch = 013, cost = 1.94e+00, |grad| = 2.48e-01, acc = 3.29e-01 (0.18s)
[VAL] epoch = 013, data_cost = 1.92e+00, cost = 1.94e+00, acc = 6.20e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [00:57<1147570:33:28,  2.42it/s]
  0%|          | 13/100000 [00:07<11:53:58,  2.33it/s][A
                                                                
  0%|          | 77/10000000000 [00:58<1147570:33:28,  2.42it/s]
  0%|          | 14/100000 [00:07<11:48:13,  2.35it/s][A

[TST] epoch = 013, cost = 1.94e+00, acc = 6.21e-01 (0.09s)
[TRN] epoch = 014, cost = 1.94e+00, |grad| = 2.26e-01, acc = 3.07e-01 (0.17s)


                                                                
  0%|          | 77/10000000000 [00:58<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:58<1147570:33:28,  2.42it/s]
  0%|          | 14/100000 [00:08<11:48:13,  2.35it/s][A

[VAL] epoch = 014, data_cost = 1.91e+00, cost = 1.94e+00, acc = 5.96e-01 (0.09s)
[TST] epoch = 014, cost = 1.94e+00, acc = 6.05e-01 (0.10s)



                                                                
  0%|          | 77/10000000000 [00:58<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:58<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:58<1147570:33:28,  2.42it/s]
  0%|          | 15/100000 [00:08<11:33:13,  2.40it/s]

[TRN] epoch = 015, cost = 1.93e+00, |grad| = 2.37e-01, acc = 2.93e-01 (0.18s)
[VAL] epoch = 015, data_cost = 1.91e+00, cost = 1.94e+00, acc = 5.86e-01 (0.09s)
[TST] epoch = 015, cost = 1.93e+00, acc = 5.79e-01 (0.08s)


[A
                                                                
  0%|          | 77/10000000000 [00:58<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:58<1147570:33:28,  2.42it/s]
                                                                

[TRN] epoch = 016, cost = 1.92e+00, |grad| = 2.10e-01, acc = 3.86e-01 (0.17s)
[VAL] epoch = 016, data_cost = 1.91e+00, cost = 1.93e+00, acc = 5.72e-01 (0.09s)



  0%|          | 77/10000000000 [00:59<1147570:33:28,  2.42it/s]
  0%|          | 16/100000 [00:08<11:08:57,  2.49it/s][A
                                                                
  0%|          | 77/10000000000 [00:59<1147570:33:28,  2.42it/s]
  0%|          | 17/100000 [00:08<11:00:34,  2.52it/s][A

[TST] epoch = 016, cost = 1.93e+00, acc = 5.68e-01 (0.09s)
[TRN] epoch = 017, cost = 1.92e+00, |grad| = 2.30e-01, acc = 3.21e-01 (0.17s)


                                                                
  0%|          | 77/10000000000 [00:59<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:59<1147570:33:28,  2.42it/s]
  0%|          | 17/100000 [00:09<11:00:34,  2.52it/s][A

[VAL] epoch = 017, data_cost = 1.91e+00, cost = 1.93e+00, acc = 5.82e-01 (0.10s)
[TST] epoch = 017, cost = 1.93e+00, acc = 5.70e-01 (0.09s)



                                                                
  0%|          | 77/10000000000 [00:59<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [00:59<1147570:33:28,  2.42it/s]
  0%|          | 18/100000 [00:09<11:22:34,  2.44it/s][A

[TRN] epoch = 018, cost = 1.92e+00, |grad| = 3.10e-01, acc = 3.00e-01 (0.23s)
[VAL] epoch = 018, data_cost = 1.90e+00, cost = 1.93e+00, acc = 5.74e-01 (0.11s)


                                                                
  0%|          | 77/10000000000 [00:59<1147570:33:28,  2.42it/s]
  0%|          | 18/100000 [00:09<11:22:34,  2.44it/s][A

[TST] epoch = 018, cost = 1.93e+00, acc = 5.53e-01 (0.11s)



                                                                
  0%|          | 77/10000000000 [01:00<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:00<1147570:33:28,  2.42it/s]
  0%|          | 19/100000 [00:09<11:46:03,  2.36it/s][A

[TRN] epoch = 019, cost = 1.92e+00, |grad| = 2.32e-01, acc = 3.21e-01 (0.21s)
[VAL] epoch = 019, data_cost = 1.90e+00, cost = 1.93e+00, acc = 5.82e-01 (0.09s)


                                                                
  0%|          | 77/10000000000 [01:00<1147570:33:28,  2.42it/s]
  0%|          | 19/100000 [00:10<11:46:03,  2.36it/s][A
                                                                
  0%|          | 77/10000000000 [01:00<1147570:33:28,  2.42it/s]
  0%|          | 20/100000 [00:10<11:34:09,  2.40it/s][A

[TST] epoch = 019, cost = 1.93e+00, acc = 5.82e-01 (0.10s)
[TRN] epoch = 020, cost = 1.92e+00, |grad| = 2.33e-01, acc = 3.07e-01 (0.18s)


                                                                
  0%|          | 77/10000000000 [01:00<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:00<1147570:33:28,  2.42it/s]
  0%|          | 20/100000 [00:10<11:34:09,  2.40it/s][A

[VAL] epoch = 020, data_cost = 1.90e+00, cost = 1.93e+00, acc = 5.66e-01 (0.11s)
[TST] epoch = 020, cost = 1.92e+00, acc = 5.61e-01 (0.11s)



                                                                
  0%|          | 77/10000000000 [01:00<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:01<1147570:33:28,  2.42it/s]
  0%|          | 21/100000 [00:10<11:43:50,  2.37it/s][A

[TRN] epoch = 021, cost = 1.91e+00, |grad| = 2.45e-01, acc = 3.71e-01 (0.19s)
[VAL] epoch = 021, data_cost = 1.90e+00, cost = 1.92e+00, acc = 5.64e-01 (0.11s)


                                                                
  0%|          | 77/10000000000 [01:01<1147570:33:28,  2.42it/s]
  0%|          | 21/100000 [00:10<11:43:50,  2.37it/s][A
                                                                
  0%|          | 77/10000000000 [01:01<1147570:33:28,  2.42it/s]
  0%|          | 22/100000 [00:11<11:37:36,  2.39it/s][A

[TST] epoch = 021, cost = 1.92e+00, acc = 5.60e-01 (0.10s)
[TRN] epoch = 022, cost = 1.89e+00, |grad| = 2.39e-01, acc = 3.64e-01 (0.18s)


                                                                
  0%|          | 77/10000000000 [01:01<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:01<1147570:33:28,  2.42it/s]
  0%|          | 22/100000 [00:11<11:37:36,  2.39it/s][A

[VAL] epoch = 022, data_cost = 1.90e+00, cost = 1.92e+00, acc = 5.72e-01 (0.10s)
[TST] epoch = 022, cost = 1.92e+00, acc = 5.54e-01 (0.12s)



                                                                
  0%|          | 77/10000000000 [01:01<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:01<1147570:33:28,  2.42it/s]
  0%|          | 23/100000 [00:11<11:38:30,  2.39it/s][A

[TRN] epoch = 023, cost = 1.91e+00, |grad| = 2.41e-01, acc = 3.36e-01 (0.18s)
[VAL] epoch = 023, data_cost = 1.90e+00, cost = 1.92e+00, acc = 5.50e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [01:02<1147570:33:28,  2.42it/s]
  0%|          | 23/100000 [00:11<11:38:30,  2.39it/s][A

[TST] epoch = 023, cost = 1.92e+00, acc = 5.48e-01 (0.10s)



                                                                
  0%|          | 77/10000000000 [01:02<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:02<1147570:33:28,  2.42it/s]
  0%|          | 24/100000 [00:12<12:01:07,  2.31it/s][A

[TRN] epoch = 024, cost = 1.90e+00, |grad| = 2.80e-01, acc = 4.14e-01 (0.21s)
[VAL] epoch = 024, data_cost = 1.90e+00, cost = 1.92e+00, acc = 5.34e-01 (0.12s)


                                                                
  0%|          | 77/10000000000 [01:02<1147570:33:28,  2.42it/s]
  0%|          | 24/100000 [00:12<12:01:07,  2.31it/s][A
                                                                
  0%|          | 77/10000000000 [01:02<1147570:33:28,  2.42it/s]
  0%|          | 25/100000 [00:12<11:53:03,  2.34it/s][A

[TST] epoch = 024, cost = 1.92e+00, acc = 5.55e-01 (0.10s)
[TRN] epoch = 025, cost = 1.90e+00, |grad| = 2.60e-01, acc = 3.64e-01 (0.18s)


                                                                
  0%|          | 77/10000000000 [01:02<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:02<1147570:33:28,  2.42it/s]
  0%|          | 25/100000 [00:12<11:53:03,  2.34it/s][A

[VAL] epoch = 025, data_cost = 1.89e+00, cost = 1.92e+00, acc = 5.42e-01 (0.10s)
[TST] epoch = 025, cost = 1.92e+00, acc = 5.58e-01 (0.09s)



                                                                
  0%|          | 77/10000000000 [01:03<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:03<1147570:33:28,  2.42it/s]
  0%|          | 26/100000 [00:12<11:35:39,  2.40it/s][A

[TRN] epoch = 026, cost = 1.90e+00, |grad| = 2.40e-01, acc = 3.57e-01 (0.18s)
[VAL] epoch = 026, data_cost = 1.89e+00, cost = 1.92e+00, acc = 5.42e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [01:03<1147570:33:28,  2.42it/s]
  0%|          | 26/100000 [00:13<11:35:39,  2.40it/s][A
                                                                
  0%|          | 77/10000000000 [01:03<1147570:33:28,  2.42it/s]
  0%|          | 27/100000 [00:13<11:16:11,  2.46it/s][A

[TST] epoch = 026, cost = 1.91e+00, acc = 5.62e-01 (0.10s)
[TRN] epoch = 027, cost = 1.89e+00, |grad| = 3.49e-01, acc = 3.36e-01 (0.16s)


                                                                
  0%|          | 77/10000000000 [01:03<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:03<1147570:33:28,  2.42it/s]
  0%|          | 27/100000 [00:13<11:16:11,  2.46it/s][A

[VAL] epoch = 027, data_cost = 1.89e+00, cost = 1.92e+00, acc = 5.44e-01 (0.10s)
[TST] epoch = 027, cost = 1.91e+00, acc = 5.65e-01 (0.10s)



                                                                
  0%|          | 77/10000000000 [01:03<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:03<1147570:33:28,  2.42it/s]
  0%|          | 28/100000 [00:13<11:27:50,  2.42it/s][A

[TRN] epoch = 028, cost = 1.89e+00, |grad| = 2.14e-01, acc = 3.07e-01 (0.18s)
[VAL] epoch = 028, data_cost = 1.89e+00, cost = 1.91e+00, acc = 5.46e-01 (0.11s)


                                                                
  0%|          | 77/10000000000 [01:04<1147570:33:28,  2.42it/s]
  0%|          | 28/100000 [00:13<11:27:50,  2.42it/s][A
                                                                
  0%|          | 77/10000000000 [01:04<1147570:33:28,  2.42it/s]
  0%|          | 29/100000 [00:14<11:22:19,  2.44it/s][A

[TST] epoch = 028, cost = 1.91e+00, acc = 5.73e-01 (0.10s)
[TRN] epoch = 029, cost = 1.89e+00, |grad| = 2.48e-01, acc = 3.57e-01 (0.18s)


                                                                
  0%|          | 77/10000000000 [01:04<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:04<1147570:33:28,  2.42it/s]
  0%|          | 29/100000 [00:14<11:22:19,  2.44it/s][A

[VAL] epoch = 029, data_cost = 1.88e+00, cost = 1.91e+00, acc = 5.58e-01 (0.10s)
[TST] epoch = 029, cost = 1.91e+00, acc = 5.65e-01 (0.10s)



                                                                
  0%|          | 77/10000000000 [01:04<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:04<1147570:33:28,  2.42it/s]
  0%|          | 30/100000 [00:14<11:15:04,  2.47it/s][A

[TRN] epoch = 030, cost = 1.87e+00, |grad| = 2.47e-01, acc = 3.64e-01 (0.18s)
[VAL] epoch = 030, data_cost = 1.88e+00, cost = 1.91e+00, acc = 5.60e-01 (0.09s)


                                                                
  0%|          | 77/10000000000 [01:04<1147570:33:28,  2.42it/s]
  0%|          | 30/100000 [00:14<11:15:04,  2.47it/s][A
                                                                
  0%|          | 77/10000000000 [01:05<1147570:33:28,  2.42it/s]
  0%|          | 31/100000 [00:14<11:34:48,  2.40it/s][A

[TST] epoch = 030, cost = 1.91e+00, acc = 5.60e-01 (0.12s)
[TRN] epoch = 031, cost = 1.86e+00, |grad| = 2.51e-01, acc = 4.36e-01 (0.18s)


                                                                
  0%|          | 77/10000000000 [01:05<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:05<1147570:33:28,  2.42it/s]
  0%|          | 31/100000 [00:15<11:34:48,  2.40it/s][A

[VAL] epoch = 031, data_cost = 1.88e+00, cost = 1.91e+00, acc = 5.52e-01 (0.10s)
[TST] epoch = 031, cost = 1.90e+00, acc = 5.54e-01 (0.10s)



                                                                
  0%|          | 77/10000000000 [01:05<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:05<1147570:33:28,  2.42it/s]
  0%|          | 32/100000 [00:15<11:22:44,  2.44it/s][A

[TRN] epoch = 032, cost = 1.85e+00, |grad| = 3.16e-01, acc = 4.86e-01 (0.19s)
[VAL] epoch = 032, data_cost = 1.88e+00, cost = 1.91e+00, acc = 5.54e-01 (0.11s)


                                                                
  0%|          | 77/10000000000 [01:05<1147570:33:28,  2.42it/s]
  0%|          | 32/100000 [00:15<11:22:44,  2.44it/s][A
                                                                
  0%|          | 77/10000000000 [01:05<1147570:33:28,  2.42it/s]
  0%|          | 33/100000 [00:15<11:36:59,  2.39it/s][A

[TST] epoch = 032, cost = 1.90e+00, acc = 5.58e-01 (0.13s)
[TRN] epoch = 033, cost = 1.87e+00, |grad| = 3.03e-01, acc = 4.50e-01 (0.17s)


                                                                
  0%|          | 77/10000000000 [01:06<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:06<1147570:33:28,  2.42it/s]
  0%|          | 33/100000 [00:15<11:36:59,  2.39it/s][A

[VAL] epoch = 033, data_cost = 1.88e+00, cost = 1.91e+00, acc = 5.86e-01 (0.12s)
[TST] epoch = 033, cost = 1.90e+00, acc = 5.90e-01 (0.12s)



                                                                
  0%|          | 77/10000000000 [01:06<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:06<1147570:33:28,  2.42it/s]
  0%|          | 34/100000 [00:16<12:10:09,  2.28it/s][A

[TRN] epoch = 034, cost = 1.84e+00, |grad| = 2.49e-01, acc = 4.50e-01 (0.21s)
[VAL] epoch = 034, data_cost = 1.87e+00, cost = 1.90e+00, acc = 6.18e-01 (0.12s)


                                                                
  0%|          | 77/10000000000 [01:06<1147570:33:28,  2.42it/s]
  0%|          | 34/100000 [00:16<12:10:09,  2.28it/s][A
                                                                


[TST] epoch = 034, cost = 1.90e+00, acc = 6.19e-01 (0.10s)


  0%|          | 77/10000000000 [01:06<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:06<1147570:33:28,  2.42it/s]
  0%|          | 35/100000 [00:16<12:06:53,  2.29it/s][A

[TRN] epoch = 035, cost = 1.89e+00, |grad| = 2.25e-01, acc = 3.21e-01 (0.18s)
[VAL] epoch = 035, data_cost = 1.87e+00, cost = 1.90e+00, acc = 6.52e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [01:07<1147570:33:28,  2.42it/s]
  0%|          | 35/100000 [00:16<12:06:53,  2.29it/s][A
                                                                
  0%|          | 77/10000000000 [01:07<1147570:33:28,  2.42it/s]
  0%|          | 36/100000 [00:17<11:57:35,  2.32it/s][A

[TST] epoch = 035, cost = 1.89e+00, acc = 6.59e-01 (0.09s)
[TRN] epoch = 036, cost = 1.87e+00, |grad| = 2.40e-01, acc = 3.86e-01 (0.17s)


                                                                
  0%|          | 77/10000000000 [01:07<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:07<1147570:33:28,  2.42it/s]
  0%|          | 36/100000 [00:17<11:57:35,  2.32it/s][A

[VAL] epoch = 036, data_cost = 1.86e+00, cost = 1.90e+00, acc = 6.78e-01 (0.11s)
[TST] epoch = 036, cost = 1.89e+00, acc = 6.92e-01 (0.11s)



                                                                
  0%|          | 77/10000000000 [01:07<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:07<1147570:33:28,  2.42it/s]
  0%|          | 37/100000 [00:17<11:55:36,  2.33it/s][A

[TRN] epoch = 037, cost = 1.83e+00, |grad| = 2.89e-01, acc = 4.57e-01 (0.17s)
[VAL] epoch = 037, data_cost = 1.86e+00, cost = 1.89e+00, acc = 7.18e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [01:07<1147570:33:28,  2.42it/s]
  0%|          | 37/100000 [00:17<11:55:36,  2.33it/s][A
                                                                
[A                                                   

[TST] epoch = 037, cost = 1.88e+00, acc = 7.23e-01 (0.10s)


  0%|          | 77/10000000000 [01:08<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:08<1147570:33:28,  2.42it/s]
  0%|          | 38/100000 [00:17<11:45:09,  2.36it/s][A

[TRN] epoch = 038, cost = 1.91e+00, |grad| = 2.92e-01, acc = 3.43e-01 (0.19s)
[VAL] epoch = 038, data_cost = 1.85e+00, cost = 1.89e+00, acc = 7.52e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [01:08<1147570:33:28,  2.42it/s]
  0%|          | 38/100000 [00:18<11:45:09,  2.36it/s][A
                                                                
  0%|          | 77/10000000000 [01:08<1147570:33:28,  2.42it/s]
  0%|          | 39/100000 [00:18<11:50:44,  2.34it/s][A

[TST] epoch = 038, cost = 1.88e+00, acc = 7.73e-01 (0.10s)
[TRN] epoch = 039, cost = 1.80e+00, |grad| = 3.42e-01, acc = 5.50e-01 (0.19s)


                                                                
  0%|          | 77/10000000000 [01:08<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:08<1147570:33:28,  2.42it/s]
  0%|          | 39/100000 [00:18<11:50:44,  2.34it/s][A

[VAL] epoch = 039, data_cost = 1.85e+00, cost = 1.88e+00, acc = 7.64e-01 (0.11s)
[TST] epoch = 039, cost = 1.88e+00, acc = 7.89e-01 (0.11s)



                                                                
  0%|          | 77/10000000000 [01:08<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:09<1147570:33:28,  2.42it/s]
  0%|          | 40/100000 [00:18<11:35:18,  2.40it/s][A

[TRN] epoch = 040, cost = 1.86e+00, |grad| = 3.27e-01, acc = 4.57e-01 (0.17s)
[VAL] epoch = 040, data_cost = 1.84e+00, cost = 1.88e+00, acc = 7.80e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [01:09<1147570:33:28,  2.42it/s]
  0%|          | 40/100000 [00:18<11:35:18,  2.40it/s][A
                                                                
  0%|          | 77/10000000000 [01:09<1147570:33:28,  2.42it/s]
  0%|          | 41/100000 [00:19<11:35:59,  2.39it/s][A

[TST] epoch = 040, cost = 1.87e+00, acc = 8.02e-01 (0.10s)
[TRN] epoch = 041, cost = 1.86e+00, |grad| = 3.15e-01, acc = 4.43e-01 (0.18s)


                                                                
  0%|          | 77/10000000000 [01:09<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:09<1147570:33:28,  2.42it/s]
  0%|          | 41/100000 [00:19<11:35:59,  2.39it/s][A

[VAL] epoch = 041, data_cost = 1.84e+00, cost = 1.87e+00, acc = 7.70e-01 (0.11s)
[TST] epoch = 041, cost = 1.87e+00, acc = 8.09e-01 (0.09s)



                                                                
  0%|          | 77/10000000000 [01:09<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:09<1147570:33:28,  2.42it/s]
  0%|          | 42/100000 [00:19<11:23:10,  2.44it/s][A

[TRN] epoch = 042, cost = 1.86e+00, |grad| = 2.77e-01, acc = 4.07e-01 (0.15s)
[VAL] epoch = 042, data_cost = 1.83e+00, cost = 1.87e+00, acc = 7.52e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [01:09<1147570:33:28,  2.42it/s]
  0%|          | 42/100000 [00:19<11:23:10,  2.44it/s][A
                                                                
  0%|          | 77/10000000000 [01:10<1147570:33:28,  2.42it/s]
  0%|          | 43/100000 [00:19<11:14:52,  2.47it/s][A

[TST] epoch = 042, cost = 1.86e+00, acc = 7.87e-01 (0.11s)
[TRN] epoch = 043, cost = 1.85e+00, |grad| = 2.96e-01, acc = 4.14e-01 (0.18s)


                                                                
  0%|          | 77/10000000000 [01:10<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:10<1147570:33:28,  2.42it/s]
  0%|          | 43/100000 [00:20<11:14:52,  2.47it/s][A

[VAL] epoch = 043, data_cost = 1.83e+00, cost = 1.87e+00, acc = 7.28e-01 (0.11s)
[TST] epoch = 043, cost = 1.86e+00, acc = 7.48e-01 (0.10s)



                                                                
  0%|          | 77/10000000000 [01:10<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:10<1147570:33:28,  2.42it/s]
  0%|          | 44/100000 [00:20<11:19:08,  2.45it/s][A

[TRN] epoch = 044, cost = 1.83e+00, |grad| = 3.11e-01, acc = 4.36e-01 (0.18s)
[VAL] epoch = 044, data_cost = 1.83e+00, cost = 1.86e+00, acc = 7.16e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [01:10<1147570:33:28,  2.42it/s]
  0%|          | 44/100000 [00:20<11:19:08,  2.45it/s][A
                                                                
  0%|          | 77/10000000000 [01:10<1147570:33:28,  2.42it/s]
  0%|          | 45/100000 [00:20<11:15:05,  2.47it/s][A

[TST] epoch = 044, cost = 1.86e+00, acc = 7.35e-01 (0.11s)
[TRN] epoch = 045, cost = 1.84e+00, |grad| = 2.77e-01, acc = 4.07e-01 (0.17s)


                                                                
  0%|          | 77/10000000000 [01:11<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:11<1147570:33:28,  2.42it/s]
  0%|          | 45/100000 [00:20<11:15:05,  2.47it/s][A

[VAL] epoch = 045, data_cost = 1.82e+00, cost = 1.86e+00, acc = 7.16e-01 (0.11s)
[TST] epoch = 045, cost = 1.85e+00, acc = 7.35e-01 (0.11s)



                                                                
  0%|          | 77/10000000000 [01:11<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:11<1147570:33:28,  2.42it/s]
  0%|          | 46/100000 [00:21<11:28:44,  2.42it/s][A

[TRN] epoch = 046, cost = 1.83e+00, |grad| = 3.31e-01, acc = 4.93e-01 (0.19s)
[VAL] epoch = 046, data_cost = 1.82e+00, cost = 1.86e+00, acc = 7.06e-01 (0.10s)


                                                                
  0%|          | 77/10000000000 [01:11<1147570:33:28,  2.42it/s]
  0%|          | 46/100000 [00:21<11:28:44,  2.42it/s][A

[TST] epoch = 046, cost = 1.85e+00, acc = 7.34e-01 (0.11s)



                                                                
  0%|          | 77/10000000000 [01:11<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:11<1147570:33:28,  2.42it/s]
  0%|          | 47/100000 [00:21<12:03:04,  2.30it/s][A

[TRN] epoch = 047, cost = 1.73e+00, |grad| = 3.42e-01, acc = 5.29e-01 (0.22s)
[VAL] epoch = 047, data_cost = 1.82e+00, cost = 1.86e+00, acc = 7.02e-01 (0.11s)


                                                                
  0%|          | 77/10000000000 [01:12<1147570:33:28,  2.42it/s]
  0%|          | 47/100000 [00:21<12:03:04,  2.30it/s][A

[TST] epoch = 047, cost = 1.85e+00, acc = 7.33e-01 (0.14s)



                                                                
  0%|          | 77/10000000000 [01:12<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:12<1147570:33:28,  2.42it/s]
  0%|          | 48/100000 [00:22<12:58:31,  2.14it/s][A

[TRN] epoch = 048, cost = 1.80e+00, |grad| = 2.73e-01, acc = 4.79e-01 (0.23s)
[VAL] epoch = 048, data_cost = 1.82e+00, cost = 1.86e+00, acc = 6.92e-01 (0.12s)


                                                                
  0%|          | 77/10000000000 [01:12<1147570:33:28,  2.42it/s]
  0%|          | 48/100000 [00:22<12:58:31,  2.14it/s][A

[TST] epoch = 048, cost = 1.85e+00, acc = 7.34e-01 (0.11s)



                                                                
  0%|          | 77/10000000000 [01:12<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:13<1147570:33:28,  2.42it/s]
  0%|          | 49/100000 [00:22<13:12:38,  2.10it/s][A

[TRN] epoch = 049, cost = 1.80e+00, |grad| = 3.07e-01, acc = 4.79e-01 (0.26s)
[VAL] epoch = 049, data_cost = 1.81e+00, cost = 1.85e+00, acc = 6.98e-01 (0.12s)


                                                                
  0%|          | 77/10000000000 [01:13<1147570:33:28,  2.42it/s]
  0%|          | 49/100000 [00:22<13:12:38,  2.10it/s][A


[TST] epoch = 049, cost = 1.85e+00, acc = 7.38e-01 (0.11s)


                                                                
  0%|          | 77/10000000000 [01:13<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:13<1147570:33:28,  2.42it/s]
  0%|          | 50/100000 [00:23<13:09:21,  2.11it/s][A

[TRN] epoch = 050, cost = 1.81e+00, |grad| = 2.74e-01, acc = 4.36e-01 (0.19s)
[VAL] epoch = 050, data_cost = 1.81e+00, cost = 1.85e+00, acc = 7.04e-01 (0.11s)


                                                                
  0%|          | 77/10000000000 [01:13<1147570:33:28,  2.42it/s]
  0%|          | 50/100000 [00:23<13:09:21,  2.11it/s][A
                                                                
  0%|          | 77/10000000000 [01:13<1147570:33:28,  2.42it/s]
  0%|          | 51/100000 [00:23<12:42:02,  2.19it/s][A

[TST] epoch = 050, cost = 1.84e+00, acc = 7.36e-01 (0.10s)
[TRN] epoch = 051, cost = 1.77e+00, |grad| = 3.87e-01, acc = 5.21e-01 (0.18s)


                                                                
  0%|          | 77/10000000000 [01:13<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:13<1147570:33:28,  2.42it/s]
  0%|          | 51/100000 [00:23<12:42:02,  2.19it/s][A

[VAL] epoch = 051, data_cost = 1.81e+00, cost = 1.85e+00, acc = 7.22e-01 (0.11s)
[TST] epoch = 051, cost = 1.84e+00, acc = 7.63e-01 (0.11s)



                                                                
  0%|          | 77/10000000000 [01:14<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:14<1147570:33:28,  2.42it/s]
  0%|          | 52/100000 [00:24<12:31:45,  2.22it/s][A

[TRN] epoch = 052, cost = 1.74e+00, |grad| = 3.21e-01, acc = 5.79e-01 (0.20s)
[VAL] epoch = 052, data_cost = 1.80e+00, cost = 1.85e+00, acc = 7.36e-01 (0.12s)


                                                                
  0%|          | 77/10000000000 [01:14<1147570:33:28,  2.42it/s]
  0%|          | 52/100000 [00:24<12:31:45,  2.22it/s][A
                                                                

[TST] epoch = 052, cost = 1.84e+00, acc = 7.85e-01 (0.11s)



  0%|          | 77/10000000000 [01:14<1147570:33:28,  2.42it/s]
                                                                
  0%|          | 77/10000000000 [01:14<1147570:33:28,  2.42it/s]
  0%|          | 53/100000 [00:24<12:23:09,  2.24it/s][A

[TRN] epoch = 053, cost = 1.79e+00, |grad| = 3.38e-01, acc = 4.86e-01 (0.18s)
[VAL] epoch = 053, data_cost = 1.80e+00, cost = 1.84e+00, acc = 7.68e-01 (0.11s)


                                                                
  0%|          | 77/10000000000 [01:14<1147570:33:28,  2.42it/s]
  0%|          | 53/100000 [00:24<12:23:09,  2.24it/s][A

[TST] epoch = 053, cost = 1.83e+00, acc = 8.09e-01 (0.09s)


KeyboardInterrupt: 

In [7]:
#Print average performance
print(np.mean(list_all_acc))
print(np.std(list_all_acc))

83.25999975204468
0.5765412469708043
