[View in Colaboratory](https://colab.research.google.com/github/cococ0j/GAIN/blob/master/GAIN_test.ipynb)

In [0]:
'''
Written by Jinsung Yoon
Date: Jul 9th 2018

Generative Adversarial Imputation Networks (GAIN) Implementation on MNIST

Reference: J. Yoon, J. Jordon, M. van der Schaar, "GAIN: Missing Data Imputation using Generative Adversarial Nets," ICML, 2018.
Paper Link: http://medianetlab.ee.ucla.edu/papers/ICML_GAIN.pdf
Appendix Link: http://medianetlab.ee.ucla.edu/papers/ICML_GAIN_Supp.pdf

Contact: jsyoon0823@g.ucla.edu
'''

#%% Packages
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from tqdm import tqdm




In [2]:
! pip3 install tensorflow



In [4]:
#%% Data Input
# MNIST
mnist = input_data.read_data_sets('../../MNIST_data', one_hot = True)



Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


In [5]:
#%% System Parameters
# 1. Mini batch size
mb_size = 128
# 2. Missing rate
p_miss = 0.5
# 3. Hint rate
p_hint = 0.9
# 4. Loss Hyperparameters
alpha = 10
# 5. Imput Dim (Fixed)
Dim = 784

#%% Necessary Functions
# 1. Xavier Initialization Definition
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
    return tf.random_normal(shape = size, stddev = xavier_stddev)
    
# 2. Plot (4 x 4 subfigures)
def plot(samples):
    fig = plt.figure(figsize = (5,5))
    gs = gridspec.GridSpec(5,5)
    gs.update(wspace=0.05, hspace=0.05)
    
    for i, sample in enumerate(samples):
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
        plt.imshow(sample.reshape(28,28), cmap='Greys_r')
        
    return fig
   
'''
GAIN Consists of 3 Components
- Generator
- Discriminator
- Hint Mechanism
'''   
   
#%% GAIN Architecture   
   
#%% 1. Input Placeholders
# 1.1. Data Vector
X = tf.placeholder(tf.float32, shape = [None, Dim])
# 1.2. Mask Vector 
M = tf.placeholder(tf.float32, shape = [None, Dim])
# 1.3. Hint vector
H = tf.placeholder(tf.float32, shape = [None, Dim])
# 1.4. Random Noise Vector
Z = tf.placeholder(tf.float32, shape = [None, Dim])

#%% 2. Discriminator
D_W1 = tf.Variable(xavier_init([Dim*2, 256]))     # Data + Hint as inputs
D_b1 = tf.Variable(tf.zeros(shape = [256]))

D_W2 = tf.Variable(xavier_init([256, 128]))
D_b2 = tf.Variable(tf.zeros(shape = [128]))

D_W3 = tf.Variable(xavier_init([128, Dim]))
D_b3 = tf.Variable(tf.zeros(shape = [Dim]))       # Output is multi-variate

theta_D = [D_W1, D_W2, D_W3, D_b1, D_b2, D_b3]

#%% 3. Generator
G_W1 = tf.Variable(xavier_init([Dim*2, 256]))     # Data + Mask as inputs (Random Noises are in Missing Components)
G_b1 = tf.Variable(tf.zeros(shape = [256]))

G_W2 = tf.Variable(xavier_init([256, 128]))
G_b2 = tf.Variable(tf.zeros(shape = [128]))

G_W3 = tf.Variable(xavier_init([128, Dim]))
G_b3 = tf.Variable(tf.zeros(shape = [Dim]))

theta_G = [G_W1, G_W2, G_W3, G_b1, G_b2, G_b3]

#%% GAIN Function

#%% 1. Generator
def generator(x,z,m):
    inp = m * x + (1-m) * z  # Fill in random noise on the missing values
    inputs = tf.concat(axis = 1, values = [inp,m])  # Mask + Data Concatenate
    G_h1 = tf.nn.relu(tf.matmul(inputs, G_W1) + G_b1)
    G_h2 = tf.nn.relu(tf.matmul(G_h1, G_W2) + G_b2)
    G_prob = tf.nn.sigmoid(tf.matmul(G_h2, G_W3) + G_b3) # [0,1] normalized Output
    
    return G_prob
    
#%% 2. Discriminator
def discriminator(x, m, g, h):
    inp = m * x + (1-m) * g  # Replace missing values to the imputed values
    inputs = tf.concat(axis = 1, values = [inp,h])  # Hint + Data Concatenate
    D_h1 = tf.nn.relu(tf.matmul(inputs, D_W1) + D_b1)
    D_h2 = tf.nn.relu(tf.matmul(D_h1, D_W2) + D_b2)
    D_logit = tf.matmul(D_h2, D_W3) + D_b3
    D_prob = tf.nn.sigmoid(D_logit)  # [0,1] Probability Output
    
    return D_prob

#%% 3. Others
# Random sample generator for Z
def sample_Z(m, n):
    return np.random.uniform(0., 1., size = [m, n])        

# Mask Vector and Hint Vector Generation
def sample_M(m, n, p):
    A = np.random.uniform(0., 1., size = [m, n])
    B = A > p
    C = 1.*B
    return C

#%% Structure
G_sample = generator(X,Z,M)
D_prob = discriminator(X, M, G_sample, H)

#%% Loss
D_loss1 = -tf.reduce_mean(M * tf.log(D_prob + 1e-8) + (1-M) * tf.log(1. - D_prob + 1e-8)) * 2
G_loss1 = -tf.reduce_mean((1-M) * tf.log(D_prob + 1e-8)) / tf.reduce_mean(1-M)
MSE_train_loss = tf.reduce_mean((M * X - M * G_sample)**2) / tf.reduce_mean(M)

D_loss = D_loss1
G_loss = G_loss1  + alpha * MSE_train_loss 

#%% MSE Performance metric
MSE_test_loss = tf.reduce_mean(((1-M) * X - (1-M)*G_sample)**2) / tf.reduce_mean(1-M)

#%% Solver
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)

# Sessions
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#%%
# Output Initialization
if not os.path.exists('Multiple_Impute_out1/'):
    os.makedirs('Multiple_Impute_out1/')
    
# Iteration Initialization
i = 1

#%% Start Iterations
for it in tqdm(range(10000)):    
    
    #%% Inputs
    X_mb, _ = mnist.train.next_batch(mb_size)    
    Z_mb = sample_Z(mb_size, Dim) 
    M_mb = sample_M(mb_size, Dim, p_miss)
    H_mb1 = sample_M(mb_size, Dim, 1-p_hint)
    H_mb = M_mb * H_mb1
    
    New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb  # Missing Data Introduce
    
    _, D_loss_curr = sess.run([D_solver, D_loss1], feed_dict = {X: X_mb, M: M_mb, Z: New_X_mb, H: H_mb})
    _, G_loss_curr, MSE_train_loss_curr, MSE_test_loss_curr = sess.run([G_solver, G_loss1, MSE_train_loss, MSE_test_loss],
                                                                       feed_dict = {X: X_mb, M: M_mb, Z: New_X_mb, H: H_mb})
            
    #%% Output figure
    if it % 100 == 0:
        X_mb, _ = mnist.train.next_batch(5)    
        Z_mb = sample_Z(5, Dim) 
        M_mb = sample_M(5, Dim, p_miss)
    
        New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb
        
        samples1 = X_mb                
        samples5 = M_mb * X_mb + (1-M_mb) * Z_mb
        
        samples2 = sess.run(G_sample, feed_dict = {X: X_mb, M: M_mb, Z: New_X_mb})
        samples2 = M_mb * X_mb + (1-M_mb) * samples2        
        
        Z_mb = sample_Z(5, Dim) 
        New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb       
        samples3 = sess.run(G_sample, feed_dict = {X: X_mb, M: M_mb, Z: New_X_mb})
        samples3 = M_mb * X_mb + (1-M_mb) * samples3     
        
        Z_mb = sample_Z(5, Dim) 
        New_X_mb = M_mb * X_mb + (1-M_mb) * Z_mb       
        samples4 = sess.run(G_sample, feed_dict = {X: X_mb, M: M_mb, Z: New_X_mb})
        samples4 = M_mb * X_mb + (1-M_mb) * samples4     
        
        samples = np.vstack([samples5, samples2, samples3, samples4, samples1])          
        
        fig = plot(samples)
        plt.savefig('Multiple_Impute_out1/{}.png'.format(str(i).zfill(3)), bbox_inches='tight')
        i += 1
        plt.close(fig)
        
    #%% Intermediate Losses
    if it % 100 == 0:
        print('Iter: {}'.format(it))
        print('Train_loss: {:.4}'.format(MSE_train_loss_curr))
        print('Test_loss: {:.4}'.format(MSE_test_loss_curr))
        print()
    



  0%|          | 4/10000 [00:01<3:17:13,  1.18s/it]

Iter: 0
Train_loss: 0.2546
Test_loss: 0.256
()


  1%|          | 103/10000 [00:08<35:25,  4.66it/s]

Iter: 100
Train_loss: 0.06313
Test_loss: 0.06529
()


  2%|▏         | 203/10000 [00:15<33:00,  4.95it/s]

Iter: 200
Train_loss: 0.04982
Test_loss: 0.05366
()


  3%|▎         | 303/10000 [00:22<36:37,  4.41it/s]

Iter: 300
Train_loss: 0.04373
Test_loss: 0.04752
()


  4%|▍         | 403/10000 [00:30<32:26,  4.93it/s]

Iter: 400
Train_loss: 0.03893
Test_loss: 0.04339
()


  5%|▌         | 503/10000 [00:36<31:26,  5.03it/s]

Iter: 500
Train_loss: 0.03422
Test_loss: 0.0407
()


  6%|▌         | 604/10000 [00:43<35:51,  4.37it/s]

Iter: 600
Train_loss: 0.03337
Test_loss: 0.03904
()


  7%|▋         | 704/10000 [00:50<31:06,  4.98it/s]

Iter: 700
Train_loss: 0.02999
Test_loss: 0.03678
()


  8%|▊         | 804/10000 [00:57<30:22,  5.05it/s]

Iter: 800
Train_loss: 0.02823
Test_loss: 0.03502
()


  9%|▉         | 904/10000 [01:04<36:57,  4.10it/s]

Iter: 900
Train_loss: 0.0281
Test_loss: 0.03608
()


 10%|█         | 1004/10000 [01:12<31:09,  4.81it/s]

Iter: 1000
Train_loss: 0.02621
Test_loss: 0.03319
()


 11%|█         | 1104/10000 [01:20<30:55,  4.80it/s]

Iter: 1100
Train_loss: 0.0251
Test_loss: 0.0337
()


 12%|█▏        | 1204/10000 [01:28<38:06,  3.85it/s]

Iter: 1200
Train_loss: 0.02574
Test_loss: 0.03281
()


 13%|█▎        | 1304/10000 [01:36<29:10,  4.97it/s]

Iter: 1300
Train_loss: 0.02395
Test_loss: 0.03174
()


 14%|█▍        | 1404/10000 [01:43<28:59,  4.94it/s]

Iter: 1400
Train_loss: 0.02604
Test_loss: 0.03507
()


 15%|█▌        | 1504/10000 [01:50<28:15,  5.01it/s]

Iter: 1500
Train_loss: 0.02181
Test_loss: 0.03077
()


 16%|█▌        | 1604/10000 [01:57<36:24,  3.84it/s]

Iter: 1600
Train_loss: 0.022
Test_loss: 0.03163
()


 17%|█▋        | 1704/10000 [02:04<27:22,  5.05it/s]

Iter: 1700
Train_loss: 0.02193
Test_loss: 0.03174
()


 18%|█▊        | 1804/10000 [02:10<27:20,  4.99it/s]

Iter: 1800
Train_loss: 0.02122
Test_loss: 0.03064
()


 19%|█▉        | 1904/10000 [02:18<27:38,  4.88it/s]

Iter: 1900
Train_loss: 0.02114
Test_loss: 0.03132
()


 20%|██        | 2004/10000 [02:25<27:17,  4.88it/s]

Iter: 2000
Train_loss: 0.02218
Test_loss: 0.03138
()


 21%|██        | 2104/10000 [02:33<26:45,  4.92it/s]

Iter: 2100
Train_loss: 0.02137
Test_loss: 0.03106
()


 22%|██▏       | 2204/10000 [02:41<36:30,  3.56it/s]

Iter: 2200
Train_loss: 0.02152
Test_loss: 0.03096
()


 23%|██▎       | 2304/10000 [02:48<25:42,  4.99it/s]

Iter: 2300
Train_loss: 0.02001
Test_loss: 0.02932
()


 24%|██▍       | 2404/10000 [02:55<25:15,  5.01it/s]

Iter: 2400
Train_loss: 0.02114
Test_loss: 0.03074
()


 25%|██▌       | 2504/10000 [03:02<24:50,  5.03it/s]

Iter: 2500
Train_loss: 0.02072
Test_loss: 0.03067
()


 26%|██▌       | 2604/10000 [03:08<24:04,  5.12it/s]

Iter: 2600
Train_loss: 0.0201
Test_loss: 0.03076
()


 27%|██▋       | 2704/10000 [03:15<24:35,  4.94it/s]

Iter: 2700
Train_loss: 0.0189
Test_loss: 0.02849
()


 28%|██▊       | 2802/10000 [03:23<34:35,  3.47it/s]

Iter: 2800
Train_loss: 0.01868
Test_loss: 0.02923
()


 29%|██▉       | 2904/10000 [03:33<37:26,  3.16it/s]

Iter: 2900
Train_loss: 0.01998
Test_loss: 0.02972
()


 30%|███       | 3004/10000 [03:40<24:08,  4.83it/s]

Iter: 3000
Train_loss: 0.01886
Test_loss: 0.02941
()


 31%|███       | 3104/10000 [03:47<22:56,  5.01it/s]

Iter: 3100
Train_loss: 0.01983
Test_loss: 0.02945
()


 32%|███▏      | 3204/10000 [03:55<23:08,  4.90it/s]

Iter: 3200
Train_loss: 0.01924
Test_loss: 0.02907
()


 33%|███▎      | 3304/10000 [04:02<22:48,  4.89it/s]

Iter: 3300
Train_loss: 0.01758
Test_loss: 0.02759
()


 34%|███▍      | 3404/10000 [04:09<22:11,  4.95it/s]

Iter: 3400
Train_loss: 0.01826
Test_loss: 0.02999
()


 35%|███▌      | 3504/10000 [04:16<21:25,  5.05it/s]

Iter: 3500
Train_loss: 0.01811
Test_loss: 0.02924
()


 36%|███▌      | 3604/10000 [04:23<21:19,  5.00it/s]

Iter: 3600
Train_loss: 0.01762
Test_loss: 0.02827
()


 37%|███▋      | 3704/10000 [04:31<34:09,  3.07it/s]

Iter: 3700
Train_loss: 0.01837
Test_loss: 0.02842
()


 38%|███▊      | 3804/10000 [04:38<20:21,  5.07it/s]

Iter: 3800
Train_loss: 0.01754
Test_loss: 0.0281
()


 39%|███▉      | 3903/10000 [04:44<15:22,  6.61it/s]

Iter: 3900
Train_loss: 0.01704
Test_loss: 0.02664
()


 40%|████      | 4003/10000 [04:52<20:58,  4.77it/s]

Iter: 4000
Train_loss: 0.01796
Test_loss: 0.02954
()


 41%|████      | 4103/10000 [05:00<20:41,  4.75it/s]

Iter: 4100
Train_loss: 0.01775
Test_loss: 0.02984
()


 42%|████▏     | 4203/10000 [05:08<20:15,  4.77it/s]

Iter: 4200
Train_loss: 0.01738
Test_loss: 0.02761
()


 43%|████▎     | 4303/10000 [05:16<19:47,  4.80it/s]

Iter: 4300
Train_loss: 0.01759
Test_loss: 0.02905
()


 44%|████▍     | 4403/10000 [05:23<19:14,  4.85it/s]

Iter: 4400
Train_loss: 0.0182
Test_loss: 0.02848
()


 45%|████▌     | 4503/10000 [05:31<18:46,  4.88it/s]

Iter: 4500
Train_loss: 0.01654
Test_loss: 0.02689
()


 46%|████▌     | 4603/10000 [05:38<18:21,  4.90it/s]

Iter: 4600
Train_loss: 0.01726
Test_loss: 0.02814
()


 47%|████▋     | 4703/10000 [05:45<17:50,  4.95it/s]

Iter: 4700
Train_loss: 0.01711
Test_loss: 0.0281
()


 48%|████▊     | 4803/10000 [05:53<31:00,  2.79it/s]

Iter: 4800
Train_loss: 0.01674
Test_loss: 0.02727
()


 49%|████▉     | 4903/10000 [06:00<16:59,  5.00it/s]

Iter: 4900
Train_loss: 0.0168
Test_loss: 0.02755
()


 50%|█████     | 5003/10000 [06:07<16:35,  5.02it/s]

Iter: 5000
Train_loss: 0.0167
Test_loss: 0.02723
()


 51%|█████     | 5103/10000 [06:14<16:08,  5.05it/s]

Iter: 5100
Train_loss: 0.01588
Test_loss: 0.02696
()


 52%|█████▏    | 5204/10000 [06:21<16:03,  4.98it/s]

Iter: 5200
Train_loss: 0.01749
Test_loss: 0.02832
()


 53%|█████▎    | 5304/10000 [06:29<16:23,  4.77it/s]

Iter: 5300
Train_loss: 0.01571
Test_loss: 0.02669
()


 54%|█████▍    | 5404/10000 [06:36<15:52,  4.83it/s]

Iter: 5400
Train_loss: 0.01599
Test_loss: 0.02746
()


 55%|█████▌    | 5504/10000 [06:44<15:29,  4.84it/s]

Iter: 5500
Train_loss: 0.01563
Test_loss: 0.02776
()


 56%|█████▌    | 5602/10000 [06:51<19:25,  3.77it/s]

Iter: 5600
Train_loss: 0.01654
Test_loss: 0.0276
()


 57%|█████▋    | 5702/10000 [07:00<20:04,  3.57it/s]

Iter: 5700
Train_loss: 0.01696
Test_loss: 0.02861
()


 58%|█████▊    | 5802/10000 [07:09<19:24,  3.61it/s]

Iter: 5800
Train_loss: 0.01651
Test_loss: 0.02684
()


 59%|█████▉    | 5902/10000 [07:17<18:35,  3.67it/s]

Iter: 5900
Train_loss: 0.0168
Test_loss: 0.02833
()


 60%|██████    | 6002/10000 [07:28<37:47,  1.76it/s]

Iter: 6000
Train_loss: 0.01616
Test_loss: 0.02786
()


 61%|██████    | 6103/10000 [07:34<13:10,  4.93it/s]

Iter: 6100
Train_loss: 0.01649
Test_loss: 0.02834
()


 62%|██████▏   | 6203/10000 [07:43<13:26,  4.71it/s]

Iter: 6200
Train_loss: 0.01622
Test_loss: 0.02761
()


 63%|██████▎   | 6303/10000 [07:51<12:54,  4.77it/s]

Iter: 6300
Train_loss: 0.01587
Test_loss: 0.02859
()


 64%|██████▍   | 6403/10000 [07:59<12:45,  4.70it/s]

Iter: 6400
Train_loss: 0.01644
Test_loss: 0.02722
()


 65%|██████▌   | 6504/10000 [08:06<11:34,  5.03it/s]

Iter: 6500
Train_loss: 0.01622
Test_loss: 0.02772
()


 66%|██████▌   | 6603/10000 [08:13<10:36,  5.34it/s]

Iter: 6600
Train_loss: 0.01622
Test_loss: 0.0271
()


 67%|██████▋   | 6703/10000 [08:19<08:01,  6.84it/s]

Iter: 6700
Train_loss: 0.01627
Test_loss: 0.02796
()


 68%|██████▊   | 6803/10000 [08:25<10:00,  5.32it/s]

Iter: 6800
Train_loss: 0.01606
Test_loss: 0.02702
()


 69%|██████▉   | 6904/10000 [08:32<10:42,  4.82it/s]

Iter: 6900
Train_loss: 0.01431
Test_loss: 0.02491
()


 70%|███████   | 7004/10000 [08:39<10:16,  4.86it/s]

Iter: 7000
Train_loss: 0.01655
Test_loss: 0.02852
()


 71%|███████   | 7103/10000 [08:45<09:27,  5.10it/s]

Iter: 7100
Train_loss: 0.01535
Test_loss: 0.02737
()


 72%|███████▏  | 7203/10000 [08:52<09:08,  5.10it/s]

Iter: 7200
Train_loss: 0.01595
Test_loss: 0.02715
()


 73%|███████▎  | 7302/10000 [08:58<08:48,  5.10it/s]

Iter: 7300
Train_loss: 0.0154
Test_loss: 0.02774
()


 74%|███████▍  | 7403/10000 [09:05<06:20,  6.83it/s]

Iter: 7400
Train_loss: 0.01648
Test_loss: 0.02943
()


 75%|███████▌  | 7502/10000 [09:11<10:47,  3.86it/s]

Iter: 7500
Train_loss: 0.01618
Test_loss: 0.0281
()


 76%|███████▌  | 7604/10000 [09:20<12:40,  3.15it/s]

Iter: 7600
Train_loss: 0.01556
Test_loss: 0.02709
()


 77%|███████▋  | 7704/10000 [09:28<08:01,  4.76it/s]

Iter: 7700
Train_loss: 0.01546
Test_loss: 0.02716
()


 78%|███████▊  | 7804/10000 [09:35<07:21,  4.97it/s]

Iter: 7800
Train_loss: 0.01542
Test_loss: 0.02719
()


 79%|███████▉  | 7904/10000 [09:42<07:06,  4.91it/s]

Iter: 7900
Train_loss: 0.01436
Test_loss: 0.0266
()


 80%|████████  | 8004/10000 [09:50<06:50,  4.86it/s]

Iter: 8000
Train_loss: 0.01565
Test_loss: 0.02731
()


 81%|████████  | 8104/10000 [09:57<06:26,  4.90it/s]

Iter: 8100
Train_loss: 0.01606
Test_loss: 0.02754
()


 82%|████████▏ | 8204/10000 [10:04<05:57,  5.02it/s]

Iter: 8200
Train_loss: 0.01462
Test_loss: 0.02635
()


 83%|████████▎ | 8304/10000 [10:11<05:38,  5.00it/s]

Iter: 8300
Train_loss: 0.01511
Test_loss: 0.02756
()


 84%|████████▍ | 8404/10000 [10:18<05:19,  4.99it/s]

Iter: 8400
Train_loss: 0.01519
Test_loss: 0.02725
()


 85%|████████▌ | 8504/10000 [10:26<05:02,  4.95it/s]

Iter: 8500
Train_loss: 0.01617
Test_loss: 0.02833
()


 86%|████████▌ | 8603/10000 [10:32<03:30,  6.63it/s]

Iter: 8600
Train_loss: 0.01571
Test_loss: 0.02676
()


 87%|████████▋ | 8703/10000 [10:39<04:17,  5.03it/s]

Iter: 8700
Train_loss: 0.01588
Test_loss: 0.02829
()


 88%|████████▊ | 8803/10000 [10:46<03:58,  5.02it/s]

Iter: 8800
Train_loss: 0.01473
Test_loss: 0.02512
()


 89%|████████▉ | 8903/10000 [10:54<03:47,  4.82it/s]

Iter: 8900
Train_loss: 0.01558
Test_loss: 0.02646
()


 90%|█████████ | 9003/10000 [11:01<03:28,  4.78it/s]

Iter: 9000
Train_loss: 0.01548
Test_loss: 0.02689
()


 91%|█████████ | 9103/10000 [11:09<03:01,  4.95it/s]

Iter: 9100
Train_loss: 0.01488
Test_loss: 0.02591
()


 92%|█████████▏| 9203/10000 [11:16<02:43,  4.87it/s]

Iter: 9200
Train_loss: 0.01467
Test_loss: 0.02731
()


 93%|█████████▎| 9303/10000 [11:24<02:20,  4.95it/s]

Iter: 9300
Train_loss: 0.01475
Test_loss: 0.02681
()


 94%|█████████▍| 9403/10000 [11:31<02:00,  4.94it/s]

Iter: 9400
Train_loss: 0.01569
Test_loss: 0.02781
()


 95%|█████████▌| 9503/10000 [11:38<01:38,  5.05it/s]

Iter: 9500
Train_loss: 0.01494
Test_loss: 0.02664
()


 96%|█████████▌| 9603/10000 [11:48<03:22,  1.96it/s]

Iter: 9600
Train_loss: 0.01462
Test_loss: 0.02664
()


 97%|█████████▋| 9703/10000 [11:55<00:59,  4.97it/s]

Iter: 9700
Train_loss: 0.01494
Test_loss: 0.02661
()


 98%|█████████▊| 9803/10000 [12:02<00:39,  5.01it/s]

Iter: 9800
Train_loss: 0.01382
Test_loss: 0.02504
()


 99%|█████████▉| 9903/10000 [12:08<00:18,  5.26it/s]

Iter: 9900
Train_loss: 0.01574
Test_loss: 0.02856
()


100%|██████████| 10000/10000 [12:14<00:00, 13.62it/s]
