In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold
import statsmodels.api as sm
from sklearn.linear_model import LogisticRegression
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import matthews_corrcoef

## Summary of Results:

$\hat Q$ is the outcome estimator, $\hat G$ is the propensity score estimator. Their respective columns tell us whether the model is misspecified $\times$ or correctly specified $\checkmark$. 

'Reduction' is the relative percent error reduction when compared against the plug-in estimator using the outcome model alone. The results are averages over 1000 simulations.

|Exp. No. | Method | $\hat Q$ | $\hat G$ | Reduction $\%$ | Rel. Error $\%$ |
| ---| --- | --- | --- | --- |--- |
| 1| TMLE | $\times$ | $\times$ | -0.128 | 1.212 |



## Problem Setup:

This example is taken from https://arxiv.org/abs/2107.00681 by Hines, Dukes, Diaz-Ordaz, and Vansteelandt (2021) and the empirical evaluation follows https://onlinelibrary.wiley.com/doi/full/10.1002/sim.7628 by Miguel Angel Luque-Fernandez, Michael Schomaker, Bernard Rachet, Mireille E. Schnitzer (2018).


The following experiments are very similar to the ones in ATE.ipynb, but this time we will fit the estimators using a neural network.

## 1. Define the DGP and some helper functions:

In [4]:

def sigm(x):
    return 1/(1 + np.exp(-x))

def inv_sigm(x):
    return np.log(x/(1-x))

def generate_data(N, seed):
    np.random.seed(seed=seed)
    z1 = np.random.binomial(1, 0.5, (N,1))
    z2 = np.random.binomial(1, 0.65, (N,1))
    z3 = np.round(np.random.uniform(0, 4, (N,1)),3)
    z4 = np.round(np.random.uniform(0, 5, (N,1)),3)
    X = np.random.binomial(1, sigm(-0.4 + 0.2*z2 + 0.15*z3 + 0.2*z4 + 0.15*z2*z4), (N,1))
    Y1 = np.random.binomial(1, sigm(-1 + 1 - 0.1*z1 + 0.3*z2 + 0.25*z3 + 0.2*z4 + 0.15*z2*z4), (N,1))
    Y0 = np.random.binomial(1, sigm(-1 + 0 - 0.1*z1 + 0.3*z2 + 0.25*z3 + 0.2*z4 + 0.15*z2*z4), (N,1))
    Y = Y1 * X + Y0 * (1-X)
    Z = np.concatenate([z1,z2,z3,z4],1)
    return Z, X, Y, Y1, Y0

## 2. Define the Neural Network Objects/Classes

In [5]:


def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_normal_(m.weight)
        m.bias.data.fill_(0.01)     

class QNet(nn.Module):
    def __init__(self, input_size, num_layers, layers_size, output_size, output_type, dropout):
        super(QNet, self).__init__()      
        
        layers = []
        layers.extend([nn.Linear(input_size, layers_size), nn.LeakyReLU()])
        for i in range(num_layers-1):
            layers.extend([nn.Linear(layers_size, layers_size), nn.LeakyReLU(), nn.Dropout(p=dropout)])
        self.net = nn.Sequential(*layers)
        
        pos_arm = []
        pos_arm.extend([nn.Linear(layers_size, layers_size), nn.LeakyReLU()])
        pos_arm.extend([nn.Linear(layers_size, output_size)])     
        
        neg_arm = []
        neg_arm.extend([nn.Linear(layers_size, layers_size), nn.LeakyReLU()])
        neg_arm.extend([nn.Linear(layers_size, output_size)])    
        
        if output_type == 'categorical':
            pos_arm.append(nn.Sigmoid())
            neg_arm.append(nn.Sigmoid())
        elif output_type == 'continuous':
            pass
        self.pos_arm = nn.Sequential(*pos_arm)
        self.neg_arm = nn.Sequential(*neg_arm)
    
        self.net.apply(init_weights) 
        self.neg_arm.apply(init_weights) 
        self.pos_arm.apply(init_weights) 


    def forward(self, X, Z):
        
        out = self.net(torch.cat([X,Z],1))
        out0 = self.neg_arm(out)
        out1 = self.pos_arm(out)
        cond = X.bool()
        return torch.where(cond, out1, out0)

    
    
class GNet(nn.Module):
    def __init__(self, input_size, num_layers, layers_size, output_size, output_type, dropout):
        super(GNet, self).__init__()      
        
        layers = []
        layers.extend([nn.Linear(input_size, layers_size), nn.LeakyReLU()])
        for i in range(num_layers-1):
            layers.extend([nn.Linear(layers_size, layers_size), nn.LeakyReLU(), nn.Dropout(p=dropout)])
        layers.extend([nn.Linear(layers_size, output_size)])

        if output_type == 'categorical':
            layers.append(nn.Sigmoid())
        elif output_type == 'continuous':
            pass
        self.net = nn.Sequential(*layers)
        self.net.apply(init_weights) 
        
    def forward(self, Z):
        return self.net(Z)

## 3. Create a Neural Network training class

In [99]:
class Trainer(object):
    def __init__(self, qnet, gnet, iterations=None, batch_size=None, test_iter=None, lr=None):
        self.qnet = qnet
        self.gnet = gnet
        self.iterations = iterations
        self.batch_size = batch_size
        self.test_iter = test_iter
        
        if lr is not None:
            self.q_optimizer = optim.Adam(qnet.parameters(), lr=lr)
            self.g_optimizer = optim.Adam(gnet.parameters(), lr=lr)
        self.bce_loss = nn.BCELoss(reduction='none')
        self.mse_loss = nn.MSELoss()
    
        
    def train(self, x, y, z):
        
        # create a small validation set
        indices = np.arange(len(x))
        np.random.shuffle(indices)
        val_inds = indices[:len(x)//8]
        train_inds = indices[len(x)//8:]
        x_val, y_val, z_val = x[val_inds], y[val_inds], z[val_inds]
        x_train, y_train, z_train = x[train_inds], y[train_inds], z[train_inds]
        
        indices = np.arange(len(x_train))
        
        train_losses_q = []
        train_losses_g = []
        test_losses_q = []
        test_losses_g = []
        
        for it in range(self.iterations):
            inds = np.random.choice(indices, self.batch_size)
            x_batch, y_batch, z_batch = x_train[inds], y_train[inds], z_train[inds]
            
            x_pred = self.gnet(z_batch)
            y_pred = self.qnet(x_batch, z_batch)
            
            q_loss = self.mse_loss(y_pred, y_batch)
                       
            weight = torch.tensor([0.7, 0.3])
            weight_ = weight[x_batch.data.view(-1).long()].view_as(x_batch)
            g_loss = (self.bce_loss(x_pred, x_batch) * weight_).mean()
            
            q_loss.backward()
            g_loss.backward()
            
            self.q_optimizer.step()
            self.g_optimizer.step()
            self.q_optimizer.zero_grad()
            self.g_optimizer.zero_grad()
            
            if (it % self.test_iter == 0) or (it == (self.iterations-1)):
                self.qnet.eval()
                self.gnet.eval()
                x_pred = self.gnet(z_train[:800])
                y_pred = self.qnet(x_train[:800], z_train[:800])

                q_loss = self.mse_loss(y_pred, y_train[:800])
                g_loss = self.bce_loss(x_pred, x_train[:800]).mean()
                train_losses_q.append(q_loss.item())
                train_losses_g.append(g_loss.item())
                
                q_loss_test, g_loss_test, _, _ = self.test(x_val, y_val, z_val)
                test_losses_q.append(q_loss_test.item())
                test_losses_g.append(g_loss_test.item())
                print('== Iteration {} =='.format(it))
                print('Test Loss Q:', q_loss_test.item(), '  Test Loss G:', g_loss_test.item())
                
                self.qnet.train()
                self.gnet.train()
        
        return train_losses_q, train_losses_g, test_losses_q, test_losses_g
    
    
    def test(self, x, y, z):
        self.qnet.eval()
        self.gnet.eval()
        
        x_pred = self.gnet(z)
        y_pred = self.qnet(x,z)

        q_loss = self.mse_loss(y_pred, y)
        g_loss = self.bce_loss(x_pred, x).mean()
        
        
        return q_loss, g_loss, x_pred, y_pred
    

## 4. Create a hyperparameter tuning class

In [110]:
class Tuner(object):
    def __init__(self, x, y, z, trials, best_params=None):
        self.best_params = best_params
        self.x = x
        self.y = y
        self.z = z
        self.trials = trials
        self.test_iter = 500
        self.best_params = best_params
        self.qnet = None
        self.gnet = None
        self.best_model_q = None
        self.best_model_g = None
        
    def tune(self):
        
        batch_sizes = [30, 60, 90, 120]
        iterss = [5000, 10000, 15000, 20000, 50000, 100000]
        lrs = [0.0001, 0.0005, 0.001, 0.005, 0.01]
        layerss = [2, 4, 6]
        layer_sizes = [16, 32, 64, 128]
        dropouts = [0.1, 0.2, 0.4]
        output_type_Q = 'categorical'
        output_size_Q = 1
        output_type_G = 'categorical'
        output_size_G = 1
        input_size_Q = z.shape[-1] + 1  # we will concatenate the treatment var inside the qnet class
        input_size_G = z.shape[-1]

        train_loss_q = []
        train_loss_g = []
        val_loss_q = []
        val_loss_g = []
        bs_ = []
        iters_ = []
        lr_ = []
        layers_ = []
        dropout_ = []
        layer_size_ = []
        best_loss = 1e10
        for trial in range(self.trials):
            # sample hyper params and store the history
            bs = np.random.choice(batch_sizes) if self.best_params == None else self.best_params['batch_size']
            bs_.append(bs)
            iters = np.random.choice(iterss) if self.best_params == None else self.best_params['iters']
            iters_.append(iters)
            lr = np.random.choice(lrs) if self.best_params == None else self.best_params['lr']
            lr_.append(lr)
            layers = np.random.choice(layerss) if self.best_params == None else self.best_params['layers']
            layers_.append(layers)
            dropout = np.random.choice(dropouts) if self.best_params == None else self.best_params['dropout']
            dropout_.append(dropout)
            layer_size = np.random.choice(layer_sizes) if self.best_params == None else self.best_params['layer_size']
            layer_size_.append(layer_size)
            print('======== Trial {} of {} ========='.format(trial, self.trials-1))
            print('Batch size', bs, ' Iters', iters, ' Lr', lr, ' Layers', layers,
                 ' Dropout', dropout, ' Layer Size', layer_size)

            

            self.qnet = QNet(input_size=input_size_Q, num_layers=layers,
                      layers_size=layer_size, output_size=output_size_Q,
                     output_type=output_type_Q, dropout=dropout)
        
            self.gnet = GNet(input_size=input_size_G, num_layers=layers,
                      layers_size=layer_size, output_size=output_size_G,
                     output_type=output_type_G, dropout=dropout)


            trainer = Trainer(qnet=self.qnet, gnet=self.gnet, iterations=iters,
                          batch_size=bs, test_iter=self.test_iter, lr=lr)
            train_loss_q_, train_loss_g_, val_loss_q_, val_loss_g_ = trainer.train(self.x,
                                                                                  self.y,
                                                                                  self.z)
            train_loss_q.append(train_loss_q_[-1])
            train_loss_g.append(train_loss_g_[-1])
            val_loss_q.append(val_loss_q_[-1])
            val_loss_g.append(val_loss_g_[-1])
            
            total_val_loss = val_loss_q_[-1] + val_loss_g_[-1]
            
            if total_val_loss < best_loss:
                print('old loss:', best_loss)
                print('new loss:', total_val_loss)
                print('best model updated')
                best_loss = total_val_loss
                self.best_model_q = self.qnet
                self.best_model_g = self.gnet

        tuning_dict = {'batch_size': bs_, 'layers':layers_, 'dropout':dropout_,
                      'layer_size':layer_size_,'lr':lr_, 'iters':iters_,
                      'train_loss_q':train_loss_q, 'train_loss_g':train_loss_g,
                      'val_loss_q':val_loss_q, 'val_loss_g':val_loss_g}
        
        return tuning_dict, self.best_model_q, self.best_model_g
        

## 5. Run Hyperparameter Search

Now we have everything we need, we can initialize the neural networks, run hyperparameter search to identify the best parameters.

In [88]:
# First establish ground truth treatment effect:
N = 5000000
Z, x, y, Y1, Y0 = generate_data(N, seed=0)
true_psi = (Y1-Y0).mean()


# Set some params
N = 10000
seed = 0
num_tuning_trials = 60

# data generation:
z, x, y, _, _ = generate_data(N, 0)
x = torch.tensor(x).type(torch.float32)
z = torch.tensor(z).type(torch.float32)
y = torch.tensor(y).type(torch.float32)
    
tuner = Tuner(x=x,y=y,z=z,trials=num_tuning_trials)
tuning_history, best_q, best_g = tuner.tune()

total_losses = np.asarray(tuning_history['val_loss_g']) + np.asarray(tuning_history['val_loss_q'])
best_index = np.argmin(total_losses)

best_params = {}
for key in tuning_history.keys():
    best_params[key] = tuning_history[key][best_index]

Batch size 90  Iters 5000  Lr 0.0001  Layers 4  Dropout 0.4  Layer Size 128
== Iteration 0 ==
Test Loss Q: 0.23505303263664246   Test Loss G: 0.6583337187767029
== Iteration 500 ==
Test Loss Q: 0.18650640547275543   Test Loss G: 0.6955683827400208
== Iteration 1000 ==
Test Loss Q: 0.18443691730499268   Test Loss G: 0.7021484375
== Iteration 1500 ==
Test Loss Q: 0.1827680766582489   Test Loss G: 0.7017059326171875
== Iteration 2000 ==
Test Loss Q: 0.1817612498998642   Test Loss G: 0.6955196261405945
== Iteration 2500 ==
Test Loss Q: 0.18313805758953094   Test Loss G: 0.6959470510482788
== Iteration 3000 ==
Test Loss Q: 0.1808604896068573   Test Loss G: 0.7102521657943726
== Iteration 3500 ==
Test Loss Q: 0.1827697455883026   Test Loss G: 0.6995244026184082
== Iteration 4000 ==
Test Loss Q: 0.18118160963058472   Test Loss G: 0.7010411620140076
== Iteration 4500 ==
Test Loss Q: 0.18189026415348053   Test Loss G: 0.6989416480064392
== Iteration 4999 ==
Test Loss Q: 0.1820671558380127   Tes

== Iteration 6000 ==
Test Loss Q: 0.17381247878074646   Test Loss G: 0.6827932000160217
== Iteration 6500 ==
Test Loss Q: 0.17513717710971832   Test Loss G: 0.6874383091926575
== Iteration 7000 ==
Test Loss Q: 0.1752253919839859   Test Loss G: 0.6938260793685913
== Iteration 7500 ==
Test Loss Q: 0.1740250438451767   Test Loss G: 0.7062925696372986
== Iteration 8000 ==
Test Loss Q: 0.17190100252628326   Test Loss G: 0.7072523236274719
== Iteration 8500 ==
Test Loss Q: 0.1739846020936966   Test Loss G: 0.7038698792457581
== Iteration 9000 ==
Test Loss Q: 0.17233605682849884   Test Loss G: 0.671136736869812
== Iteration 9500 ==
Test Loss Q: 0.17022594809532166   Test Loss G: 0.6880619525909424
== Iteration 10000 ==
Test Loss Q: 0.1711665689945221   Test Loss G: 0.6859276294708252
== Iteration 10500 ==
Test Loss Q: 0.17470912635326385   Test Loss G: 0.7332396507263184
== Iteration 11000 ==
Test Loss Q: 0.17240367829799652   Test Loss G: 0.7153126001358032
== Iteration 11500 ==
Test Loss Q:

== Iteration 20500 ==
Test Loss Q: 0.18782973289489746   Test Loss G: 0.6870893836021423
== Iteration 21000 ==
Test Loss Q: 0.1838790476322174   Test Loss G: 0.6716381311416626
== Iteration 21500 ==
Test Loss Q: 0.18567998707294464   Test Loss G: 0.6854738593101501
== Iteration 22000 ==
Test Loss Q: 0.1870380938053131   Test Loss G: 0.697191059589386
== Iteration 22500 ==
Test Loss Q: 0.18699055910110474   Test Loss G: 0.7257840037345886
== Iteration 23000 ==
Test Loss Q: 0.18581131100654602   Test Loss G: 0.6819648146629333
== Iteration 23500 ==
Test Loss Q: 0.18724779784679413   Test Loss G: 0.6951056718826294
== Iteration 24000 ==
Test Loss Q: 0.18666759133338928   Test Loss G: 0.6761404871940613
== Iteration 24500 ==
Test Loss Q: 0.1819620430469513   Test Loss G: 0.7098820805549622
== Iteration 25000 ==
Test Loss Q: 0.18762798607349396   Test Loss G: 0.7225524187088013
== Iteration 25500 ==
Test Loss Q: 0.18845969438552856   Test Loss G: 0.6848759055137634
== Iteration 26000 ==
Tes

== Iteration 67000 ==
Test Loss Q: 0.1912180483341217   Test Loss G: 0.7923637628555298
== Iteration 67500 ==
Test Loss Q: 0.1908757984638214   Test Loss G: 0.801216185092926
== Iteration 68000 ==
Test Loss Q: 0.19080978631973267   Test Loss G: 0.8068796396255493
== Iteration 68500 ==
Test Loss Q: 0.19428478181362152   Test Loss G: 0.7715816497802734
== Iteration 69000 ==
Test Loss Q: 0.19443123042583466   Test Loss G: 0.8131933808326721
== Iteration 69500 ==
Test Loss Q: 0.1924627125263214   Test Loss G: 0.7988831400871277
== Iteration 70000 ==
Test Loss Q: 0.1962345540523529   Test Loss G: 0.7927321195602417
== Iteration 70500 ==
Test Loss Q: 0.19536404311656952   Test Loss G: 0.7764102816581726
== Iteration 71000 ==
Test Loss Q: 0.19648206233978271   Test Loss G: 0.7911187410354614
== Iteration 71500 ==
Test Loss Q: 0.19154632091522217   Test Loss G: 0.7939262986183167
== Iteration 72000 ==
Test Loss Q: 0.19311034679412842   Test Loss G: 0.8042961359024048
== Iteration 72500 ==
Test

== Iteration 12500 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.6982542872428894
== Iteration 13000 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.6804891228675842
== Iteration 13500 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.6640640497207642
== Iteration 14000 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.67203289270401
== Iteration 14500 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.6649522185325623
== Iteration 15000 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.6765958070755005
== Iteration 15500 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.6912574768066406
== Iteration 16000 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.6828961968421936
== Iteration 16500 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.7426938414573669
== Iteration 17000 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.6755930185317993
== Iteration 17500 ==
Test Loss Q: 0.3240000009536743   Test Loss G: 0.6835403442382812
== Iteration 18000 ==
Test Loss Q:

== Iteration 17000 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.710736870765686
== Iteration 17500 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.694762647151947
== Iteration 18000 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.6919097304344177
== Iteration 18500 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.6437207460403442
== Iteration 19000 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.7041012048721313
== Iteration 19500 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.6773781180381775
== Iteration 20000 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.6441006660461426
== Iteration 20500 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.734000563621521
== Iteration 21000 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.8477976322174072
== Iteration 21500 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.6911922693252563
== Iteration 22000 ==
Test Loss Q: 0.30720001459121704   Test Loss G: 0.6844871640205383
== Iteration 22500 ==
Te

== Iteration 12500 ==
Test Loss Q: 0.1814546138048172   Test Loss G: 0.6849516034126282
== Iteration 13000 ==
Test Loss Q: 0.18072634935379028   Test Loss G: 0.6827531456947327
== Iteration 13500 ==
Test Loss Q: 0.18229779601097107   Test Loss G: 0.6767551302909851
== Iteration 14000 ==
Test Loss Q: 0.1819417029619217   Test Loss G: 0.6467821002006531
== Iteration 14500 ==
Test Loss Q: 0.1819426268339157   Test Loss G: 0.677806556224823
== Iteration 15000 ==
Test Loss Q: 0.182791069149971   Test Loss G: 0.6722108125686646
== Iteration 15500 ==
Test Loss Q: 0.18146198987960815   Test Loss G: 0.6793279647827148
== Iteration 16000 ==
Test Loss Q: 0.1846313625574112   Test Loss G: 0.6659914255142212
== Iteration 16500 ==
Test Loss Q: 0.18104897439479828   Test Loss G: 0.6640620827674866
== Iteration 17000 ==
Test Loss Q: 0.18378004431724548   Test Loss G: 0.675186038017273
== Iteration 17500 ==
Test Loss Q: 0.18212679028511047   Test Loss G: 0.6908836960792542
== Iteration 18000 ==
Test Lo

== Iteration 21500 ==
Test Loss Q: 0.19154947996139526   Test Loss G: 0.6918519139289856
== Iteration 22000 ==
Test Loss Q: 0.19160079956054688   Test Loss G: 0.6793621778488159
== Iteration 22500 ==
Test Loss Q: 0.19053372740745544   Test Loss G: 0.6873390078544617
== Iteration 23000 ==
Test Loss Q: 0.19205904006958008   Test Loss G: 0.6900007724761963
== Iteration 23500 ==
Test Loss Q: 0.19143493473529816   Test Loss G: 0.6969752907752991
== Iteration 24000 ==
Test Loss Q: 0.19140870869159698   Test Loss G: 0.6914201378822327
== Iteration 24500 ==
Test Loss Q: 0.19130130112171173   Test Loss G: 0.6922194957733154
== Iteration 25000 ==
Test Loss Q: 0.1922672539949417   Test Loss G: 0.6902480721473694
== Iteration 25500 ==
Test Loss Q: 0.19095712900161743   Test Loss G: 0.6936924457550049
== Iteration 26000 ==
Test Loss Q: 0.19101548194885254   Test Loss G: 0.6922116279602051
== Iteration 26500 ==
Test Loss Q: 0.19069518148899078   Test Loss G: 0.6913910508155823
== Iteration 27000 ==


== Iteration 68000 ==
Test Loss Q: 0.19202712178230286   Test Loss G: 0.6927400827407837
== Iteration 68500 ==
Test Loss Q: 0.1921338438987732   Test Loss G: 0.6905883550643921
== Iteration 69000 ==
Test Loss Q: 0.19235554337501526   Test Loss G: 0.6922495365142822
== Iteration 69500 ==
Test Loss Q: 0.19221533834934235   Test Loss G: 0.6921907663345337
== Iteration 70000 ==
Test Loss Q: 0.19173435866832733   Test Loss G: 0.6865063309669495
== Iteration 70500 ==
Test Loss Q: 0.19144277274608612   Test Loss G: 0.6879963874816895
== Iteration 71000 ==
Test Loss Q: 0.19269241392612457   Test Loss G: 0.6874966025352478
== Iteration 71500 ==
Test Loss Q: 0.19136881828308105   Test Loss G: 0.7016867995262146
== Iteration 72000 ==
Test Loss Q: 0.19102045893669128   Test Loss G: 0.6932217478752136
== Iteration 72500 ==
Test Loss Q: 0.1903759390115738   Test Loss G: 0.6883541941642761
== Iteration 73000 ==
Test Loss Q: 0.19015683233737946   Test Loss G: 0.7033345699310303
== Iteration 73500 ==
T

== Iteration 13500 ==
Test Loss Q: 0.17417965829372406   Test Loss G: 0.6778079271316528
== Iteration 14000 ==
Test Loss Q: 0.17265290021896362   Test Loss G: 0.6654467582702637
== Iteration 14500 ==
Test Loss Q: 0.17464011907577515   Test Loss G: 0.6780683398246765
== Iteration 15000 ==
Test Loss Q: 0.17356112599372864   Test Loss G: 0.6727290749549866
== Iteration 15500 ==
Test Loss Q: 0.17219968140125275   Test Loss G: 0.6876581311225891
== Iteration 16000 ==
Test Loss Q: 0.173326313495636   Test Loss G: 0.6714311838150024
== Iteration 16500 ==
Test Loss Q: 0.1725558042526245   Test Loss G: 0.6883122324943542
== Iteration 17000 ==
Test Loss Q: 0.1762521117925644   Test Loss G: 0.7013314366340637
== Iteration 17500 ==
Test Loss Q: 0.17344795167446136   Test Loss G: 0.6868994832038879
== Iteration 18000 ==
Test Loss Q: 0.1725660115480423   Test Loss G: 0.7083485126495361
== Iteration 18500 ==
Test Loss Q: 0.17348523437976837   Test Loss G: 0.6891160011291504
== Iteration 19000 ==
Test

== Iteration 9000 ==
Test Loss Q: 0.18322575092315674   Test Loss G: 0.6819992065429688
== Iteration 9500 ==
Test Loss Q: 0.18331152200698853   Test Loss G: 0.6900646686553955
== Iteration 10000 ==
Test Loss Q: 0.18368475139141083   Test Loss G: 0.6722742319107056
== Iteration 10500 ==
Test Loss Q: 0.18355520069599152   Test Loss G: 0.6779845952987671
== Iteration 11000 ==
Test Loss Q: 0.18349392712116241   Test Loss G: 0.6829582452774048
== Iteration 11500 ==
Test Loss Q: 0.18363338708877563   Test Loss G: 0.6790221929550171
== Iteration 12000 ==
Test Loss Q: 0.18364489078521729   Test Loss G: 0.664513111114502
== Iteration 12500 ==
Test Loss Q: 0.1847286969423294   Test Loss G: 0.6773478388786316
== Iteration 13000 ==
Test Loss Q: 0.18459108471870422   Test Loss G: 0.6801100373268127
== Iteration 13500 ==
Test Loss Q: 0.18639688193798065   Test Loss G: 0.6984339952468872
== Iteration 14000 ==
Test Loss Q: 0.18405641615390778   Test Loss G: 0.68467777967453
== Iteration 14500 ==
Test 

== Iteration 55500 ==
Test Loss Q: 0.1945737600326538   Test Loss G: 0.6793989539146423
== Iteration 56000 ==
Test Loss Q: 0.19375263154506683   Test Loss G: 0.6806169152259827
== Iteration 56500 ==
Test Loss Q: 0.1931113302707672   Test Loss G: 0.7016875743865967
== Iteration 57000 ==
Test Loss Q: 0.19372889399528503   Test Loss G: 0.6833723187446594
== Iteration 57500 ==
Test Loss Q: 0.1944606751203537   Test Loss G: 0.6753908395767212
== Iteration 58000 ==
Test Loss Q: 0.19327077269554138   Test Loss G: 0.6971405148506165
== Iteration 58500 ==
Test Loss Q: 0.1926957070827484   Test Loss G: 0.6768116354942322
== Iteration 59000 ==
Test Loss Q: 0.19481855630874634   Test Loss G: 0.680568277835846
== Iteration 59500 ==
Test Loss Q: 0.1936831772327423   Test Loss G: 0.6802725791931152
== Iteration 60000 ==
Test Loss Q: 0.19535478949546814   Test Loss G: 0.7032316327095032
== Iteration 60500 ==
Test Loss Q: 0.19305488467216492   Test Loss G: 0.6901383399963379
== Iteration 61000 ==
Test 

== Iteration 1000 ==
Test Loss Q: 0.1884543001651764   Test Loss G: 0.6796543002128601
== Iteration 1500 ==
Test Loss Q: 0.18739201128482819   Test Loss G: 0.703567385673523
== Iteration 2000 ==
Test Loss Q: 0.1869073510169983   Test Loss G: 0.6848145723342896
== Iteration 2500 ==
Test Loss Q: 0.18602629005908966   Test Loss G: 0.6912984848022461
== Iteration 3000 ==
Test Loss Q: 0.18667583167552948   Test Loss G: 0.6900853514671326
== Iteration 3500 ==
Test Loss Q: 0.1861119121313095   Test Loss G: 0.6855696439743042
== Iteration 4000 ==
Test Loss Q: 0.186296746134758   Test Loss G: 0.6914005279541016
== Iteration 4500 ==
Test Loss Q: 0.1868424266576767   Test Loss G: 0.6893094182014465
== Iteration 5000 ==
Test Loss Q: 0.18657752871513367   Test Loss G: 0.7082334756851196
== Iteration 5500 ==
Test Loss Q: 0.1867188662290573   Test Loss G: 0.698915958404541
== Iteration 6000 ==
Test Loss Q: 0.18676632642745972   Test Loss G: 0.7020937204360962
== Iteration 6500 ==
Test Loss Q: 0.18748

== Iteration 500 ==
Test Loss Q: 0.16766205430030823   Test Loss G: 0.6649878621101379
== Iteration 1000 ==
Test Loss Q: 0.16696390509605408   Test Loss G: 0.6890805959701538
== Iteration 1500 ==
Test Loss Q: 0.16811572015285492   Test Loss G: 0.6945591568946838
== Iteration 2000 ==
Test Loss Q: 0.16857117414474487   Test Loss G: 0.7386770248413086
== Iteration 2500 ==
Test Loss Q: 0.16822326183319092   Test Loss G: 0.702465832233429
== Iteration 3000 ==
Test Loss Q: 0.16786399483680725   Test Loss G: 0.6976260542869568
== Iteration 3500 ==
Test Loss Q: 0.1713264137506485   Test Loss G: 0.7140806913375854
== Iteration 4000 ==
Test Loss Q: 0.1699521839618683   Test Loss G: 0.6863245368003845
== Iteration 4500 ==
Test Loss Q: 0.1698186695575714   Test Loss G: 0.7056494355201721
== Iteration 5000 ==
Test Loss Q: 0.16821454465389252   Test Loss G: 0.7000263929367065
== Iteration 5500 ==
Test Loss Q: 0.1683085709810257   Test Loss G: 0.6726741790771484
== Iteration 6000 ==
Test Loss Q: 0.16

== Iteration 30000 ==
Test Loss Q: 0.18778762221336365   Test Loss G: 0.7069936990737915
== Iteration 30500 ==
Test Loss Q: 0.1863366961479187   Test Loss G: 0.6674565672874451
== Iteration 31000 ==
Test Loss Q: 0.18699203431606293   Test Loss G: 0.6831660866737366
== Iteration 31500 ==
Test Loss Q: 0.1856677532196045   Test Loss G: 0.6876856684684753
== Iteration 32000 ==
Test Loss Q: 0.18710772693157196   Test Loss G: 0.6737488508224487
== Iteration 32500 ==
Test Loss Q: 0.1853792816400528   Test Loss G: 0.6995728015899658
== Iteration 33000 ==
Test Loss Q: 0.18708953261375427   Test Loss G: 0.7010156512260437
== Iteration 33500 ==
Test Loss Q: 0.1871134489774704   Test Loss G: 0.7001780867576599
== Iteration 34000 ==
Test Loss Q: 0.18778395652770996   Test Loss G: 0.6846804022789001
== Iteration 34500 ==
Test Loss Q: 0.18725799024105072   Test Loss G: 0.7134044170379639
== Iteration 35000 ==
Test Loss Q: 0.18552635610103607   Test Loss G: 0.6732379198074341
== Iteration 35500 ==
Tes

== Iteration 9500 ==
Test Loss Q: 0.18250785768032074   Test Loss G: 0.6971862316131592
== Iteration 10000 ==
Test Loss Q: 0.18158023059368134   Test Loss G: 0.7318454384803772
== Iteration 10500 ==
Test Loss Q: 0.18139204382896423   Test Loss G: 0.7116947174072266
== Iteration 11000 ==
Test Loss Q: 0.18217667937278748   Test Loss G: 0.7084564566612244
== Iteration 11500 ==
Test Loss Q: 0.18312515318393707   Test Loss G: 0.7292243242263794
== Iteration 12000 ==
Test Loss Q: 0.18302898108959198   Test Loss G: 0.7141613960266113
== Iteration 12500 ==
Test Loss Q: 0.18116098642349243   Test Loss G: 0.7276664972305298
== Iteration 13000 ==
Test Loss Q: 0.1804289072751999   Test Loss G: 0.7323824167251587
== Iteration 13500 ==
Test Loss Q: 0.17977283895015717   Test Loss G: 0.7167496681213379
== Iteration 14000 ==
Test Loss Q: 0.1805604100227356   Test Loss G: 0.7070012092590332
== Iteration 14500 ==
Test Loss Q: 0.18048591911792755   Test Loss G: 0.7110646367073059
== Iteration 15000 ==
Te

== Iteration 5000 ==
Test Loss Q: 0.1759868711233139   Test Loss G: 0.712501049041748
== Iteration 5500 ==
Test Loss Q: 0.17340461909770966   Test Loss G: 0.7184669375419617
== Iteration 6000 ==
Test Loss Q: 0.17541441321372986   Test Loss G: 0.7177584171295166
== Iteration 6500 ==
Test Loss Q: 0.17224733531475067   Test Loss G: 0.7238606214523315
== Iteration 7000 ==
Test Loss Q: 0.17716260254383087   Test Loss G: 0.7159110307693481
== Iteration 7500 ==
Test Loss Q: 0.17361316084861755   Test Loss G: 0.7095708250999451
== Iteration 8000 ==
Test Loss Q: 0.17575839161872864   Test Loss G: 0.7291364073753357
== Iteration 8500 ==
Test Loss Q: 0.17533431947231293   Test Loss G: 0.7139137983322144
== Iteration 9000 ==
Test Loss Q: 0.17390838265419006   Test Loss G: 0.7209062576293945
== Iteration 9500 ==
Test Loss Q: 0.176883727312088   Test Loss G: 0.7114713788032532
== Iteration 10000 ==
Test Loss Q: 0.17323791980743408   Test Loss G: 0.7097844481468201
== Iteration 10500 ==
Test Loss Q: 

== Iteration 13500 ==
Test Loss Q: 0.1825185865163803   Test Loss G: 0.6951906681060791
== Iteration 14000 ==
Test Loss Q: 0.18160861730575562   Test Loss G: 0.6921228766441345
== Iteration 14500 ==
Test Loss Q: 0.18170320987701416   Test Loss G: 0.6931688189506531
== Iteration 15000 ==
Test Loss Q: 0.1818031519651413   Test Loss G: 0.6926295161247253
== Iteration 15500 ==
Test Loss Q: 0.18175791203975677   Test Loss G: 0.700348436832428
== Iteration 16000 ==
Test Loss Q: 0.18178607523441315   Test Loss G: 0.6918933391571045
== Iteration 16500 ==
Test Loss Q: 0.1827702522277832   Test Loss G: 0.6922820210456848
== Iteration 17000 ==
Test Loss Q: 0.18284793198108673   Test Loss G: 0.6986324787139893
== Iteration 17500 ==
Test Loss Q: 0.18270620703697205   Test Loss G: 0.6979674696922302
== Iteration 18000 ==
Test Loss Q: 0.18283730745315552   Test Loss G: 0.6977812051773071
== Iteration 18500 ==
Test Loss Q: 0.1824936419725418   Test Loss G: 0.6902759075164795
== Iteration 19000 ==
Test

== Iteration 39000 ==
Test Loss Q: 0.182348370552063   Test Loss G: 0.6936487555503845
== Iteration 39500 ==
Test Loss Q: 0.18177171051502228   Test Loss G: 0.6862478256225586
== Iteration 40000 ==
Test Loss Q: 0.18274761736392975   Test Loss G: 0.6966877579689026
== Iteration 40500 ==
Test Loss Q: 0.18262244760990143   Test Loss G: 0.7088489532470703
== Iteration 41000 ==
Test Loss Q: 0.18199381232261658   Test Loss G: 0.6881476044654846
== Iteration 41500 ==
Test Loss Q: 0.18335986137390137   Test Loss G: 0.6759116053581238
== Iteration 42000 ==
Test Loss Q: 0.18159480392932892   Test Loss G: 0.6933313012123108
== Iteration 42500 ==
Test Loss Q: 0.18210336565971375   Test Loss G: 0.7039247751235962
== Iteration 43000 ==
Test Loss Q: 0.18420642614364624   Test Loss G: 0.6927090287208557
== Iteration 43500 ==
Test Loss Q: 0.18000511825084686   Test Loss G: 0.7040703892707825
== Iteration 44000 ==
Test Loss Q: 0.18233709037303925   Test Loss G: 0.6933858394622803
== Iteration 44500 ==
T

== Iteration 85500 ==
Test Loss Q: 0.18735165894031525   Test Loss G: 0.6873472929000854
== Iteration 86000 ==
Test Loss Q: 0.18522138893604279   Test Loss G: 0.6909536123275757
== Iteration 86500 ==
Test Loss Q: 0.18723982572555542   Test Loss G: 0.7041051983833313
== Iteration 87000 ==
Test Loss Q: 0.188147634267807   Test Loss G: 0.7024465203285217
== Iteration 87500 ==
Test Loss Q: 0.1879187971353531   Test Loss G: 0.6987071633338928
== Iteration 88000 ==
Test Loss Q: 0.1885499805212021   Test Loss G: 0.6764321327209473
== Iteration 88500 ==
Test Loss Q: 0.18495668470859528   Test Loss G: 0.6886206269264221
== Iteration 89000 ==
Test Loss Q: 0.18655480444431305   Test Loss G: 0.6863211989402771
== Iteration 89500 ==
Test Loss Q: 0.18516628444194794   Test Loss G: 0.7001267671585083
== Iteration 90000 ==
Test Loss Q: 0.1884308159351349   Test Loss G: 0.676123857498169
== Iteration 90500 ==
Test Loss Q: 0.18632206320762634   Test Loss G: 0.6921068429946899
== Iteration 91000 ==
Test 

== Iteration 4000 ==
Test Loss Q: 0.17347773909568787   Test Loss G: 0.6983563303947449
== Iteration 4500 ==
Test Loss Q: 0.17268823087215424   Test Loss G: 0.69570392370224
== Iteration 5000 ==
Test Loss Q: 0.17240816354751587   Test Loss G: 0.7012048959732056
== Iteration 5500 ==
Test Loss Q: 0.17262279987335205   Test Loss G: 0.7003812789916992
== Iteration 6000 ==
Test Loss Q: 0.1722903996706009   Test Loss G: 0.708561897277832
== Iteration 6500 ==
Test Loss Q: 0.17194265127182007   Test Loss G: 0.7029297947883606
== Iteration 7000 ==
Test Loss Q: 0.17298708856105804   Test Loss G: 0.6944859027862549
== Iteration 7500 ==
Test Loss Q: 0.17231428623199463   Test Loss G: 0.7047679424285889
== Iteration 8000 ==
Test Loss Q: 0.17187874019145966   Test Loss G: 0.6908693313598633
== Iteration 8500 ==
Test Loss Q: 0.17205938696861267   Test Loss G: 0.7059761881828308
== Iteration 9000 ==
Test Loss Q: 0.1716877818107605   Test Loss G: 0.6937819719314575
== Iteration 9500 ==
Test Loss Q: 0.1

== Iteration 50500 ==
Test Loss Q: 0.17279550433158875   Test Loss G: 0.7084691524505615
== Iteration 51000 ==
Test Loss Q: 0.17277123034000397   Test Loss G: 0.7081215381622314
== Iteration 51500 ==
Test Loss Q: 0.17245304584503174   Test Loss G: 0.7054833173751831
== Iteration 52000 ==
Test Loss Q: 0.1730843037366867   Test Loss G: 0.705377995967865
== Iteration 52500 ==
Test Loss Q: 0.17219533026218414   Test Loss G: 0.7086652517318726
== Iteration 53000 ==
Test Loss Q: 0.17288517951965332   Test Loss G: 0.7106065154075623
== Iteration 53500 ==
Test Loss Q: 0.17252278327941895   Test Loss G: 0.710138201713562
== Iteration 54000 ==
Test Loss Q: 0.17271916568279266   Test Loss G: 0.7109314203262329
== Iteration 54500 ==
Test Loss Q: 0.17240069806575775   Test Loss G: 0.7124814391136169
== Iteration 55000 ==
Test Loss Q: 0.17276833951473236   Test Loss G: 0.7064903974533081
== Iteration 55500 ==
Test Loss Q: 0.17306752502918243   Test Loss G: 0.707477331161499
== Iteration 56000 ==
Tes

== Iteration 97000 ==
Test Loss Q: 0.17352020740509033   Test Loss G: 0.7066400647163391
== Iteration 97500 ==
Test Loss Q: 0.17332495748996735   Test Loss G: 0.703781247138977
== Iteration 98000 ==
Test Loss Q: 0.17353785037994385   Test Loss G: 0.7108004093170166
== Iteration 98500 ==
Test Loss Q: 0.1731758862733841   Test Loss G: 0.7150132060050964
== Iteration 99000 ==
Test Loss Q: 0.17375607788562775   Test Loss G: 0.7095832228660583
== Iteration 99500 ==
Test Loss Q: 0.17383426427841187   Test Loss G: 0.713049590587616
== Iteration 99999 ==
Test Loss Q: 0.17334622144699097   Test Loss G: 0.7059768438339233
Batch size 60  Iters 5000  Lr 0.01  Layers 6  Dropout 0.1  Layer Size 128
== Iteration 0 ==
Test Loss Q: 0.20500971376895905   Test Loss G: 1.386708378791809
== Iteration 500 ==
Test Loss Q: 0.17753483355045319   Test Loss G: 0.6894322037696838
== Iteration 1000 ==
Test Loss Q: 0.17426200211048126   Test Loss G: 0.6927869915962219
== Iteration 1500 ==
Test Loss Q: 0.17629402875

== Iteration 36500 ==
Test Loss Q: 0.18749649822711945   Test Loss G: 0.6926975250244141
== Iteration 37000 ==
Test Loss Q: 0.19010117650032043   Test Loss G: 0.6940495371818542
== Iteration 37500 ==
Test Loss Q: 0.18757009506225586   Test Loss G: 0.6872508525848389
== Iteration 38000 ==
Test Loss Q: 0.1890677958726883   Test Loss G: 0.6977527141571045
== Iteration 38500 ==
Test Loss Q: 0.18767443299293518   Test Loss G: 0.6972540616989136
== Iteration 39000 ==
Test Loss Q: 0.19193784892559052   Test Loss G: 0.701156735420227
== Iteration 39500 ==
Test Loss Q: 0.1889665573835373   Test Loss G: 0.71675705909729
== Iteration 40000 ==
Test Loss Q: 0.18819932639598846   Test Loss G: 0.6917896270751953
== Iteration 40500 ==
Test Loss Q: 0.19127829372882843   Test Loss G: 0.6952453851699829
== Iteration 41000 ==
Test Loss Q: 0.18816129863262177   Test Loss G: 0.7099752426147461
== Iteration 41500 ==
Test Loss Q: 0.19050589203834534   Test Loss G: 0.7069671154022217
== Iteration 42000 ==
Test

== Iteration 32000 ==
Test Loss Q: 0.7160000205039978   Test Loss G: 1.197070837020874
== Iteration 32500 ==
Test Loss Q: 0.7160000205039978   Test Loss G: 0.673439085483551
== Iteration 33000 ==
Test Loss Q: 0.7160000205039978   Test Loss G: 0.7408543229103088
== Iteration 33500 ==
Test Loss Q: 0.7160000205039978   Test Loss G: 0.6900140643119812
== Iteration 34000 ==
Test Loss Q: 0.7160000205039978   Test Loss G: 0.6828205585479736
== Iteration 34500 ==
Test Loss Q: 0.7160000205039978   Test Loss G: 0.687020480632782
== Iteration 35000 ==
Test Loss Q: 0.7160000205039978   Test Loss G: 0.7708791494369507
== Iteration 35500 ==
Test Loss Q: 0.7160000205039978   Test Loss G: 0.650674045085907
== Iteration 36000 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.8148590922355652
== Iteration 36500 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.6347727179527283
== Iteration 37000 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.6603394746780396
== Iteration 37500 ==
Test Loss Q: 0

== Iteration 79000 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.6231100559234619
== Iteration 79500 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.6907343864440918
== Iteration 80000 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.7333908081054688
== Iteration 80500 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.6352111101150513
== Iteration 81000 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.68365877866745
== Iteration 81500 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.6812145113945007
== Iteration 82000 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.7203167676925659
== Iteration 82500 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.7175868153572083
== Iteration 83000 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.6729703545570374
== Iteration 83500 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.7017717361450195
== Iteration 84000 ==
Test Loss Q: 0.6984000205993652   Test Loss G: 0.8000083565711975
== Iteration 84500 ==
Test Loss Q:

== Iteration 18500 ==
Test Loss Q: 0.17182350158691406   Test Loss G: 0.6943220496177673
== Iteration 19000 ==
Test Loss Q: 0.17239366471767426   Test Loss G: 0.6808977127075195
== Iteration 19500 ==
Test Loss Q: 0.17252317070960999   Test Loss G: 0.7004799842834473
== Iteration 20000 ==
Test Loss Q: 0.17231649160385132   Test Loss G: 0.6632526516914368
== Iteration 20500 ==
Test Loss Q: 0.17368151247501373   Test Loss G: 0.6742224097251892
== Iteration 21000 ==
Test Loss Q: 0.1722625494003296   Test Loss G: 0.6906134486198425
== Iteration 21500 ==
Test Loss Q: 0.17408159375190735   Test Loss G: 0.6935139894485474
== Iteration 22000 ==
Test Loss Q: 0.1725430190563202   Test Loss G: 0.6824207305908203
== Iteration 22500 ==
Test Loss Q: 0.17315711081027985   Test Loss G: 0.6921439170837402
== Iteration 23000 ==
Test Loss Q: 0.17268496751785278   Test Loss G: 0.6934488415718079
== Iteration 23500 ==
Test Loss Q: 0.17277388274669647   Test Loss G: 0.6704308390617371
== Iteration 24000 ==
T

== Iteration 65000 ==
Test Loss Q: 0.17577116191387177   Test Loss G: 0.6922573447227478
== Iteration 65500 ==
Test Loss Q: 0.17499808967113495   Test Loss G: 0.6906676292419434
== Iteration 66000 ==
Test Loss Q: 0.17537739872932434   Test Loss G: 0.6910187602043152
== Iteration 66500 ==
Test Loss Q: 0.1743890941143036   Test Loss G: 0.6855812668800354
== Iteration 67000 ==
Test Loss Q: 0.17402124404907227   Test Loss G: 0.70308917760849
== Iteration 67500 ==
Test Loss Q: 0.17500633001327515   Test Loss G: 0.6803197264671326
== Iteration 68000 ==
Test Loss Q: 0.1756395399570465   Test Loss G: 0.6993808746337891
== Iteration 68500 ==
Test Loss Q: 0.17451278865337372   Test Loss G: 0.6921050548553467
== Iteration 69000 ==
Test Loss Q: 0.17457984387874603   Test Loss G: 0.6831175088882446
== Iteration 69500 ==
Test Loss Q: 0.1746816784143448   Test Loss G: 0.6849364042282104
== Iteration 70000 ==
Test Loss Q: 0.17384879291057587   Test Loss G: 0.6917880773544312
== Iteration 70500 ==
Test

== Iteration 10500 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.6803154945373535
== Iteration 11000 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.7595877051353455
== Iteration 11500 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.6858105659484863
== Iteration 12000 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.6702488660812378
== Iteration 12500 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.6556137204170227
== Iteration 13000 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.6817623972892761
== Iteration 13500 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.7008931040763855
== Iteration 14000 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.6747809648513794
== Iteration 14500 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.7292377948760986
== Iteration 15000 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.6774585247039795
== Iteration 15500 ==
Test Loss Q: 0.2824000120162964   Test Loss G: 0.6644779443740845
== Iteration 16000 ==
Test Loss 

== Iteration 15000 ==
Test Loss Q: 0.17040421068668365   Test Loss G: 0.6905936598777771
== Iteration 15500 ==
Test Loss Q: 0.17265550792217255   Test Loss G: 0.680040180683136
== Iteration 16000 ==
Test Loss Q: 0.1712970733642578   Test Loss G: 0.7056118845939636
== Iteration 16500 ==
Test Loss Q: 0.16998493671417236   Test Loss G: 0.7058541774749756
== Iteration 17000 ==
Test Loss Q: 0.16995252668857574   Test Loss G: 0.7032962441444397
== Iteration 17500 ==
Test Loss Q: 0.17011147737503052   Test Loss G: 0.7032678723335266
== Iteration 18000 ==
Test Loss Q: 0.17142413556575775   Test Loss G: 0.6889143586158752
== Iteration 18500 ==
Test Loss Q: 0.17003290355205536   Test Loss G: 0.7122353911399841
== Iteration 19000 ==
Test Loss Q: 0.16938070952892303   Test Loss G: 0.6971900463104248
== Iteration 19500 ==
Test Loss Q: 0.17014069855213165   Test Loss G: 0.7025092840194702
== Iteration 20000 ==
Test Loss Q: 0.17085276544094086   Test Loss G: 0.7185493111610413
== Iteration 20500 ==
T

== Iteration 10500 ==
Test Loss Q: 0.16811279952526093   Test Loss G: 0.6889650821685791
== Iteration 11000 ==
Test Loss Q: 0.165914386510849   Test Loss G: 0.6892839670181274
== Iteration 11500 ==
Test Loss Q: 0.16753976047039032   Test Loss G: 0.6836807727813721
== Iteration 12000 ==
Test Loss Q: 0.1661079078912735   Test Loss G: 0.687276303768158
== Iteration 12500 ==
Test Loss Q: 0.16676463186740875   Test Loss G: 0.6797053217887878
== Iteration 13000 ==
Test Loss Q: 0.16633324325084686   Test Loss G: 0.6727791428565979
== Iteration 13500 ==
Test Loss Q: 0.1692163646221161   Test Loss G: 0.7006909251213074
== Iteration 14000 ==
Test Loss Q: 0.16634683310985565   Test Loss G: 0.6883022785186768
== Iteration 14500 ==
Test Loss Q: 0.1675066351890564   Test Loss G: 0.6854677796363831
== Iteration 15000 ==
Test Loss Q: 0.16779233515262604   Test Loss G: 0.6814206838607788
== Iteration 15500 ==
Test Loss Q: 0.1677626222372055   Test Loss G: 0.6647778749465942
== Iteration 16000 ==
Test L

== Iteration 7000 ==
Test Loss Q: 0.16689525544643402   Test Loss G: 0.6754405498504639
== Iteration 7500 ==
Test Loss Q: 0.1694214791059494   Test Loss G: 0.6710491180419922
== Iteration 8000 ==
Test Loss Q: 0.16743633151054382   Test Loss G: 0.7055472731590271
== Iteration 8500 ==
Test Loss Q: 0.17058594524860382   Test Loss G: 0.6757087707519531
== Iteration 9000 ==
Test Loss Q: 0.16878491640090942   Test Loss G: 0.6962423324584961
== Iteration 9500 ==
Test Loss Q: 0.16744378209114075   Test Loss G: 0.6647951602935791
== Iteration 10000 ==
Test Loss Q: 0.1673804521560669   Test Loss G: 0.7097601294517517
== Iteration 10500 ==
Test Loss Q: 0.16635890305042267   Test Loss G: 0.6817145943641663
== Iteration 11000 ==
Test Loss Q: 0.1691356897354126   Test Loss G: 0.6908190250396729
== Iteration 11500 ==
Test Loss Q: 0.16849984228610992   Test Loss G: 0.6565067172050476
== Iteration 12000 ==
Test Loss Q: 0.16912652552127838   Test Loss G: 0.684526801109314
== Iteration 12500 ==
Test Loss

== Iteration 25000 ==
Test Loss Q: 0.17959216237068176   Test Loss G: 0.6892229318618774
== Iteration 25500 ==
Test Loss Q: 0.1792462319135666   Test Loss G: 0.6828190088272095
== Iteration 26000 ==
Test Loss Q: 0.1789359152317047   Test Loss G: 0.7039712071418762
== Iteration 26500 ==
Test Loss Q: 0.17810066044330597   Test Loss G: 0.6977355480194092
== Iteration 27000 ==
Test Loss Q: 0.17872926592826843   Test Loss G: 0.7013384699821472
== Iteration 27500 ==
Test Loss Q: 0.17894987761974335   Test Loss G: 0.6959096789360046
== Iteration 28000 ==
Test Loss Q: 0.17975641787052155   Test Loss G: 0.6783769726753235
== Iteration 28500 ==
Test Loss Q: 0.18007619678974152   Test Loss G: 0.6895323395729065
== Iteration 29000 ==
Test Loss Q: 0.17882344126701355   Test Loss G: 0.6923514008522034
== Iteration 29500 ==
Test Loss Q: 0.18091027438640594   Test Loss G: 0.6698578596115112
== Iteration 30000 ==
Test Loss Q: 0.17920702695846558   Test Loss G: 0.6906929612159729
== Iteration 30500 ==
T

== Iteration 71500 ==
Test Loss Q: 0.17917399108409882   Test Loss G: 0.6694291234016418
== Iteration 72000 ==
Test Loss Q: 0.18147385120391846   Test Loss G: 0.690719723701477
== Iteration 72500 ==
Test Loss Q: 0.18100903928279877   Test Loss G: 0.6876659393310547
== Iteration 73000 ==
Test Loss Q: 0.18085192143917084   Test Loss G: 0.6981260776519775
== Iteration 73500 ==
Test Loss Q: 0.18013611435890198   Test Loss G: 0.6910934448242188
== Iteration 74000 ==
Test Loss Q: 0.18075117468833923   Test Loss G: 0.6748955249786377
== Iteration 74500 ==
Test Loss Q: 0.18173035979270935   Test Loss G: 0.7005760669708252
== Iteration 75000 ==
Test Loss Q: 0.18162710964679718   Test Loss G: 0.6956323385238647
== Iteration 75500 ==
Test Loss Q: 0.18206310272216797   Test Loss G: 0.6877596974372864
== Iteration 76000 ==
Test Loss Q: 0.17984680831432343   Test Loss G: 0.688770055770874
== Iteration 76500 ==
Test Loss Q: 0.18061400949954987   Test Loss G: 0.6884720921516418
== Iteration 77000 ==
T

== Iteration 1000 ==
Test Loss Q: 0.18672378361225128   Test Loss G: 0.7080886363983154
== Iteration 1500 ==
Test Loss Q: 0.18216684460639954   Test Loss G: 0.7201002240180969
== Iteration 2000 ==
Test Loss Q: 0.18436074256896973   Test Loss G: 0.7156857252120972
== Iteration 2500 ==
Test Loss Q: 0.18222273886203766   Test Loss G: 0.7133666276931763
== Iteration 3000 ==
Test Loss Q: 0.18176887929439545   Test Loss G: 0.6931232213973999
== Iteration 3500 ==
Test Loss Q: 0.18136781454086304   Test Loss G: 0.6998529434204102
== Iteration 4000 ==
Test Loss Q: 0.1804734319448471   Test Loss G: 0.711435854434967
== Iteration 4500 ==
Test Loss Q: 0.18467742204666138   Test Loss G: 0.7128562927246094
== Iteration 5000 ==
Test Loss Q: 0.1821896731853485   Test Loss G: 0.7006216049194336
== Iteration 5500 ==
Test Loss Q: 0.18057331442832947   Test Loss G: 0.6916778087615967
== Iteration 6000 ==
Test Loss Q: 0.18170346319675446   Test Loss G: 0.6979748010635376
== Iteration 6500 ==
Test Loss Q: 0

== Iteration 47500 ==
Test Loss Q: 0.18198110163211823   Test Loss G: 0.6978926062583923
== Iteration 48000 ==
Test Loss Q: 0.18168869614601135   Test Loss G: 0.6899076700210571
== Iteration 48500 ==
Test Loss Q: 0.18116413056850433   Test Loss G: 0.6958121061325073
== Iteration 49000 ==
Test Loss Q: 0.18234708905220032   Test Loss G: 0.6884369850158691
== Iteration 49500 ==
Test Loss Q: 0.181391641497612   Test Loss G: 0.7011067271232605
== Iteration 50000 ==
Test Loss Q: 0.18070088326931   Test Loss G: 0.7027832269668579
== Iteration 50500 ==
Test Loss Q: 0.18050554394721985   Test Loss G: 0.6957954168319702
== Iteration 51000 ==
Test Loss Q: 0.18389147520065308   Test Loss G: 0.6972837448120117
== Iteration 51500 ==
Test Loss Q: 0.18190984427928925   Test Loss G: 0.7015995383262634
== Iteration 52000 ==
Test Loss Q: 0.18188180029392242   Test Loss G: 0.6945618987083435
== Iteration 52500 ==
Test Loss Q: 0.18233183026313782   Test Loss G: 0.6989120841026306
== Iteration 53000 ==
Test

== Iteration 94000 ==
Test Loss Q: 0.18247312307357788   Test Loss G: 0.6937842965126038
== Iteration 94500 ==
Test Loss Q: 0.1815388947725296   Test Loss G: 0.6964937448501587
== Iteration 95000 ==
Test Loss Q: 0.18264150619506836   Test Loss G: 0.7018594145774841
== Iteration 95500 ==
Test Loss Q: 0.18213610351085663   Test Loss G: 0.6985418796539307
== Iteration 96000 ==
Test Loss Q: 0.1819833368062973   Test Loss G: 0.6969555020332336
== Iteration 96500 ==
Test Loss Q: 0.18253909051418304   Test Loss G: 0.6907672882080078
== Iteration 97000 ==
Test Loss Q: 0.18342739343643188   Test Loss G: 0.7020024657249451
== Iteration 97500 ==
Test Loss Q: 0.1819019466638565   Test Loss G: 0.6968094706535339
== Iteration 98000 ==
Test Loss Q: 0.18159931898117065   Test Loss G: 0.6898543238639832
== Iteration 98500 ==
Test Loss Q: 0.18160520493984222   Test Loss G: 0.6999301910400391
== Iteration 99000 ==
Test Loss Q: 0.18272808194160461   Test Loss G: 0.7053169012069702
== Iteration 99500 ==
Te

== Iteration 39500 ==
Test Loss Q: 0.17671389877796173   Test Loss G: 0.6685027480125427
== Iteration 40000 ==
Test Loss Q: 0.1749652475118637   Test Loss G: 0.657773494720459
== Iteration 40500 ==
Test Loss Q: 0.17788925766944885   Test Loss G: 0.6755630970001221
== Iteration 41000 ==
Test Loss Q: 0.17644524574279785   Test Loss G: 0.6746232509613037
== Iteration 41500 ==
Test Loss Q: 0.1753307580947876   Test Loss G: 0.6959572434425354
== Iteration 42000 ==
Test Loss Q: 0.17423322796821594   Test Loss G: 0.7073897123336792
== Iteration 42500 ==
Test Loss Q: 0.1760009378194809   Test Loss G: 0.6872761845588684
== Iteration 43000 ==
Test Loss Q: 0.17414973676204681   Test Loss G: 0.6800538301467896
== Iteration 43500 ==
Test Loss Q: 0.1744261085987091   Test Loss G: 0.6872334480285645
== Iteration 44000 ==
Test Loss Q: 0.1751985400915146   Test Loss G: 0.6779305338859558
== Iteration 44500 ==
Test Loss Q: 0.1783718764781952   Test Loss G: 0.6980707049369812
== Iteration 45000 ==
Test L

== Iteration 86000 ==
Test Loss Q: 0.17768162488937378   Test Loss G: 0.7102522253990173
== Iteration 86500 ==
Test Loss Q: 0.1790330559015274   Test Loss G: 0.6735350489616394
== Iteration 87000 ==
Test Loss Q: 0.17887654900550842   Test Loss G: 0.6715177297592163
== Iteration 87500 ==
Test Loss Q: 0.17754462361335754   Test Loss G: 0.6818963885307312
== Iteration 88000 ==
Test Loss Q: 0.17836035788059235   Test Loss G: 0.6829521059989929
== Iteration 88500 ==
Test Loss Q: 0.17800429463386536   Test Loss G: 0.7066662907600403
== Iteration 89000 ==
Test Loss Q: 0.17863383889198303   Test Loss G: 0.6633716821670532
== Iteration 89500 ==
Test Loss Q: 0.18237018585205078   Test Loss G: 0.6985359191894531
== Iteration 90000 ==
Test Loss Q: 0.17798104882240295   Test Loss G: 0.6800640821456909
== Iteration 90500 ==
Test Loss Q: 0.17817434668540955   Test Loss G: 0.683019757270813
== Iteration 91000 ==
Test Loss Q: 0.17882081866264343   Test Loss G: 0.6698856949806213
== Iteration 91500 ==
T

== Iteration 31500 ==
Test Loss Q: 0.17568549513816833   Test Loss G: 0.700706422328949
== Iteration 32000 ==
Test Loss Q: 0.1793489307165146   Test Loss G: 0.7515236139297485
== Iteration 32500 ==
Test Loss Q: 0.17526963353157043   Test Loss G: 0.6793081760406494
== Iteration 33000 ==
Test Loss Q: 0.17674724757671356   Test Loss G: 0.6500585675239563
== Iteration 33500 ==
Test Loss Q: 0.17798161506652832   Test Loss G: 0.6675557494163513
== Iteration 34000 ==
Test Loss Q: 0.17655685544013977   Test Loss G: 0.6812456250190735
== Iteration 34500 ==
Test Loss Q: 0.17661991715431213   Test Loss G: 0.718497097492218
== Iteration 35000 ==
Test Loss Q: 0.17851167917251587   Test Loss G: 0.6922922730445862
== Iteration 35500 ==
Test Loss Q: 0.17589274048805237   Test Loss G: 0.6819581985473633
== Iteration 36000 ==
Test Loss Q: 0.17578347027301788   Test Loss G: 0.7039024829864502
== Iteration 36500 ==
Test Loss Q: 0.17632471024990082   Test Loss G: 0.6972925662994385
== Iteration 37000 ==
Te

== Iteration 27000 ==
Test Loss Q: 0.18730928003787994   Test Loss G: 0.6968567371368408
== Iteration 27500 ==
Test Loss Q: 0.18455693125724792   Test Loss G: 0.6687642335891724
== Iteration 28000 ==
Test Loss Q: 0.18393929302692413   Test Loss G: 0.6993321180343628
== Iteration 28500 ==
Test Loss Q: 0.1842130869626999   Test Loss G: 0.6556938886642456
== Iteration 29000 ==
Test Loss Q: 0.18361210823059082   Test Loss G: 0.6935719847679138
== Iteration 29500 ==
Test Loss Q: 0.18457624316215515   Test Loss G: 0.666802167892456
== Iteration 30000 ==
Test Loss Q: 0.18351168930530548   Test Loss G: 0.7019025683403015
== Iteration 30500 ==
Test Loss Q: 0.1840737760066986   Test Loss G: 0.6597602367401123
== Iteration 31000 ==
Test Loss Q: 0.1827949732542038   Test Loss G: 0.6873127222061157
== Iteration 31500 ==
Test Loss Q: 0.18461164832115173   Test Loss G: 0.670718789100647
== Iteration 32000 ==
Test Loss Q: 0.18408045172691345   Test Loss G: 0.690346896648407
== Iteration 32500 ==
Test 

== Iteration 73500 ==
Test Loss Q: 0.18316762149333954   Test Loss G: 0.6911557912826538
== Iteration 74000 ==
Test Loss Q: 0.21774227917194366   Test Loss G: 0.6983012557029724
== Iteration 74500 ==
Test Loss Q: 0.1966433972120285   Test Loss G: 0.7083687782287598
== Iteration 75000 ==
Test Loss Q: 0.18802496790885925   Test Loss G: 0.6684520244598389
== Iteration 75500 ==
Test Loss Q: 0.1899406611919403   Test Loss G: 0.7056620717048645
== Iteration 76000 ==
Test Loss Q: 0.1884833127260208   Test Loss G: 0.6774833798408508
== Iteration 76500 ==
Test Loss Q: 0.18508698046207428   Test Loss G: 0.7184089422225952
== Iteration 77000 ==
Test Loss Q: 0.18680933117866516   Test Loss G: 0.7061823606491089
== Iteration 77500 ==
Test Loss Q: 0.19327697157859802   Test Loss G: 0.7122042179107666
== Iteration 78000 ==
Test Loss Q: 0.1856795996427536   Test Loss G: 0.6979976892471313
== Iteration 78500 ==
Test Loss Q: 0.18679454922676086   Test Loss G: 0.6905664801597595
== Iteration 79000 ==
Tes

== Iteration 19000 ==
Test Loss Q: 0.17862462997436523   Test Loss G: 0.6927662491798401
== Iteration 19500 ==
Test Loss Q: 0.17925891280174255   Test Loss G: 0.6909160017967224
== Iteration 19999 ==
Test Loss Q: 0.1795935332775116   Test Loss G: 0.7200976610183716
Batch size 120  Iters 20000  Lr 0.0001  Layers 6  Dropout 0.1  Layer Size 128
== Iteration 0 ==
Test Loss Q: 0.24532721936702728   Test Loss G: 0.7182278037071228
== Iteration 500 ==
Test Loss Q: 0.18545083701610565   Test Loss G: 0.669610857963562
== Iteration 1000 ==
Test Loss Q: 0.18630751967430115   Test Loss G: 0.6651753187179565
== Iteration 1500 ==
Test Loss Q: 0.18520234525203705   Test Loss G: 0.6664732694625854
== Iteration 2000 ==
Test Loss Q: 0.18552018702030182   Test Loss G: 0.682130753993988
== Iteration 2500 ==
Test Loss Q: 0.18626412749290466   Test Loss G: 0.6875709295272827
== Iteration 3000 ==
Test Loss Q: 0.18654027581214905   Test Loss G: 0.6807236075401306
== Iteration 3500 ==
Test Loss Q: 0.1862826645

== Iteration 23500 ==
Test Loss Q: 0.18294477462768555   Test Loss G: 0.7061948180198669
== Iteration 24000 ==
Test Loss Q: 0.18029525876045227   Test Loss G: 0.7235721945762634
== Iteration 24500 ==
Test Loss Q: 0.1779109239578247   Test Loss G: 0.7157276272773743
== Iteration 25000 ==
Test Loss Q: 0.17946520447731018   Test Loss G: 0.6826443076133728
== Iteration 25500 ==
Test Loss Q: 0.1810818910598755   Test Loss G: 0.7086109519004822
== Iteration 26000 ==
Test Loss Q: 0.1796783059835434   Test Loss G: 0.6900830268859863
== Iteration 26500 ==
Test Loss Q: 0.18088312447071075   Test Loss G: 0.6890251636505127
== Iteration 27000 ==
Test Loss Q: 0.17952094972133636   Test Loss G: 0.690195620059967
== Iteration 27500 ==
Test Loss Q: 0.18033824861049652   Test Loss G: 0.6935831308364868
== Iteration 28000 ==
Test Loss Q: 0.17949460446834564   Test Loss G: 0.6867151260375977
== Iteration 28500 ==
Test Loss Q: 0.18024539947509766   Test Loss G: 0.693331778049469
== Iteration 29000 ==
Test

== Iteration 70000 ==
Test Loss Q: 0.1800711303949356   Test Loss G: 0.6963858008384705
== Iteration 70500 ==
Test Loss Q: 0.17970237135887146   Test Loss G: 0.67628014087677
== Iteration 71000 ==
Test Loss Q: 0.1813019961118698   Test Loss G: 0.6767477989196777
== Iteration 71500 ==
Test Loss Q: 0.1804761290550232   Test Loss G: 0.7042042016983032
== Iteration 72000 ==
Test Loss Q: 0.18000872433185577   Test Loss G: 0.7274227142333984
== Iteration 72500 ==
Test Loss Q: 0.18139415979385376   Test Loss G: 0.7684697508811951
== Iteration 73000 ==
Test Loss Q: 0.18082788586616516   Test Loss G: 0.7876018285751343
== Iteration 73500 ==
Test Loss Q: 0.1802658885717392   Test Loss G: 0.7996652126312256
== Iteration 74000 ==
Test Loss Q: 0.18023991584777832   Test Loss G: 0.7616757154464722
== Iteration 74500 ==
Test Loss Q: 0.18057823181152344   Test Loss G: 0.865592896938324
== Iteration 75000 ==
Test Loss Q: 0.17971543967723846   Test Loss G: 0.7795681357383728
== Iteration 75500 ==
Test L

== Iteration 15500 ==
Test Loss Q: 0.18209704756736755   Test Loss G: 0.6887077689170837
== Iteration 16000 ==
Test Loss Q: 0.18766731023788452   Test Loss G: 0.6888712048530579
== Iteration 16500 ==
Test Loss Q: 0.18419763445854187   Test Loss G: 0.7199059724807739
== Iteration 17000 ==
Test Loss Q: 0.18489857017993927   Test Loss G: 0.6984457969665527
== Iteration 17500 ==
Test Loss Q: 0.18586888909339905   Test Loss G: 0.7129437923431396
== Iteration 18000 ==
Test Loss Q: 0.1831354796886444   Test Loss G: 0.7070049047470093
== Iteration 18500 ==
Test Loss Q: 0.18509498238563538   Test Loss G: 0.7203938961029053
== Iteration 19000 ==
Test Loss Q: 0.1848967969417572   Test Loss G: 0.7015427350997925
== Iteration 19500 ==
Test Loss Q: 0.1859242022037506   Test Loss G: 0.721469521522522
== Iteration 20000 ==
Test Loss Q: 0.1887514889240265   Test Loss G: 0.7357023358345032
== Iteration 20500 ==
Test Loss Q: 0.18457303941249847   Test Loss G: 0.6865736842155457
== Iteration 21000 ==
Test

== Iteration 62000 ==
Test Loss Q: 0.18820138275623322   Test Loss G: 0.7322945594787598
== Iteration 62500 ==
Test Loss Q: 0.1890404224395752   Test Loss G: 0.7717601656913757
== Iteration 63000 ==
Test Loss Q: 0.18546609580516815   Test Loss G: 0.7637732028961182
== Iteration 63500 ==
Test Loss Q: 0.18345347046852112   Test Loss G: 0.7623404860496521
== Iteration 64000 ==
Test Loss Q: 0.1877463459968567   Test Loss G: 0.770463764667511
== Iteration 64500 ==
Test Loss Q: 0.19091515243053436   Test Loss G: 0.7670603394508362
== Iteration 65000 ==
Test Loss Q: 0.1887199878692627   Test Loss G: 0.7545042037963867
== Iteration 65500 ==
Test Loss Q: 0.18981485068798065   Test Loss G: 0.7445299029350281
== Iteration 66000 ==
Test Loss Q: 0.18396499752998352   Test Loss G: 0.7933180928230286
== Iteration 66500 ==
Test Loss Q: 0.18573172390460968   Test Loss G: 0.7837985157966614
== Iteration 67000 ==
Test Loss Q: 0.185205340385437   Test Loss G: 0.762081503868103
== Iteration 67500 ==
Test L

== Iteration 7500 ==
Test Loss Q: 0.17347846925258636   Test Loss G: 0.65046226978302
== Iteration 8000 ==
Test Loss Q: 0.1732139140367508   Test Loss G: 0.6759098768234253
== Iteration 8500 ==
Test Loss Q: 0.1739744395017624   Test Loss G: 0.6761988997459412
== Iteration 9000 ==
Test Loss Q: 0.17386981844902039   Test Loss G: 0.6622137427330017
== Iteration 9500 ==
Test Loss Q: 0.1771010309457779   Test Loss G: 0.6939374804496765
== Iteration 10000 ==
Test Loss Q: 0.1753312200307846   Test Loss G: 0.6760265827178955
== Iteration 10500 ==
Test Loss Q: 0.17532868683338165   Test Loss G: 0.6944417953491211
== Iteration 11000 ==
Test Loss Q: 0.17610780894756317   Test Loss G: 0.6863893866539001
== Iteration 11500 ==
Test Loss Q: 0.17442277073860168   Test Loss G: 0.7205119132995605
== Iteration 12000 ==
Test Loss Q: 0.17487433552742004   Test Loss G: 0.6874871850013733
== Iteration 12500 ==
Test Loss Q: 0.17625857889652252   Test Loss G: 0.6887350678443909
== Iteration 13000 ==
Test Loss 

== Iteration 17000 ==
Test Loss Q: 0.17895053327083588   Test Loss G: 0.672584593296051
== Iteration 17500 ==
Test Loss Q: 0.17914071679115295   Test Loss G: 0.6539864540100098
== Iteration 18000 ==
Test Loss Q: 0.179704487323761   Test Loss G: 0.6493998765945435
== Iteration 18500 ==
Test Loss Q: 0.1785597950220108   Test Loss G: 0.652703583240509
== Iteration 19000 ==
Test Loss Q: 0.17968711256980896   Test Loss G: 0.6489287614822388
== Iteration 19500 ==
Test Loss Q: 0.1787949651479721   Test Loss G: 0.6723790764808655
== Iteration 19999 ==
Test Loss Q: 0.1795770525932312   Test Loss G: 0.6464273929595947
old loss: 0.8367812931537628
new loss: 0.8260044455528259
best model updated


## 6. Run Simulation

Now we have the best hyperparameters, we will run the simulations accordingly

In [151]:
print(best_params)
N = 10000
seed = 0
num_runs = 60

output_type_Q = 'categorical'
output_size_Q = 1
output_type_G = 'categorical'
output_size_G = 1
input_size_Q = z.shape[-1] + 1  # we will concatenate the treatment var inside the qnet class
input_size_G = z.shape[-1]
layers = best_params['layers']
dropout = best_params['dropout']
layer_size = best_params['layer_size']
iters = best_params['iters']
lr = best_params['lr']
batch_size = best_params['batch_size']

estimates_naive = []
estimates_upd = []
for i in range(num_runs):
    print('=====================RUN {}==================='.format(i))
    seed += 1
    # data generation:
    z, x, y, _, _ = generate_data(N, seed=seed)
    x = torch.tensor(x).type(torch.float32)
    z = torch.tensor(z).type(torch.float32)
    y = torch.tensor(y).type(torch.float32)
    x_int1 = torch.ones_like(x)  # this is the 'intervention data'
    x_int0 = torch.zeros_like(x)    

    qnet = QNet(input_size=input_size_Q, num_layers=layers,
                          layers_size=layer_size, output_size=output_size_Q,
                         output_type=output_type_Q, dropout=dropout)

    gnet = GNet(input_size=input_size_G, num_layers=layers,
                          layers_size=layer_size, output_size=output_size_G,
                         output_type=output_type_G, dropout=dropout)


    trainer = Trainer(qnet=qnet, gnet=gnet, iterations=iters,
                      batch_size=batch_size, test_iter=500, lr=lr)

    train_loss_q_, train_loss_g_, val_loss_q_, val_loss_g_ = trainer.train(x, y, z)

    _, _, x_pred, y_pred = trainer.test(x, y, z)
    x_pred, y_pred = x_pred.detach().numpy(), y_pred.detach().numpy()

    _, _, G10, Q1 = trainer.test(x_int1, y, z)
    _, _, _, Q0 = trainer.test(x_int0, y, z)

    Q1 = Q1.detach().numpy()
    Q0 = Q0.detach().numpy()
    biased_psi = (Q1-Q0).mean()

    G10 = np.clip(G10.detach().numpy(), a_min=0.001, a_max=0.999)

    H1 = 1/(G10)
    H0 = 1 / (1 - G10)

    x_ = x.detach().numpy()
    y_ = y.detach().numpy()
    D1 = x_ * H1 * (y_ - Q1) + Q1 - Q1.mean()
    D0 = (1 - x_) * H0 * (y_ - Q0) + Q0 - Q0.mean()

    Q1_star = Q1 + D1
    Q0_star = Q0 + D0

    upd_psi = (Q1_star - Q0_star).mean()

    estimates_naive.append(biased_psi)
    estimates_upd.append(upd_psi)
 


{'batch_size': 60, 'layers': 4, 'dropout': 0.2, 'layer_size': 64, 'lr': 0.005, 'iters': 20000, 'train_loss_q': 0.16899374127388, 'train_loss_g': 0.6496564745903015, 'val_loss_q': 0.1795770525932312, 'val_loss_g': 0.6464273929595947}
== Iteration 0 ==
Test Loss Q: 0.20961526036262512   Test Loss G: 0.6375666856765747
== Iteration 500 ==
Test Loss Q: 0.17310187220573425   Test Loss G: 0.704255998134613
== Iteration 1000 ==
Test Loss Q: 0.17446810007095337   Test Loss G: 0.7013257145881653
== Iteration 1500 ==
Test Loss Q: 0.17449995875358582   Test Loss G: 0.7027065753936768
== Iteration 2000 ==
Test Loss Q: 0.17378754913806915   Test Loss G: 0.714412271976471
== Iteration 2500 ==
Test Loss Q: 0.1746959090232849   Test Loss G: 0.6836843490600586
== Iteration 3000 ==
Test Loss Q: 0.17220377922058105   Test Loss G: 0.7178129553794861
== Iteration 3500 ==
Test Loss Q: 0.17506831884384155   Test Loss G: 0.6949360966682434
== Iteration 4000 ==
Test Loss Q: 0.17506663501262665   Test Loss G: 0

== Iteration 4000 ==
Test Loss Q: 0.176979660987854   Test Loss G: 0.6616436243057251
== Iteration 4500 ==
Test Loss Q: 0.17912252247333527   Test Loss G: 0.6680048704147339
== Iteration 5000 ==
Test Loss Q: 0.17734304070472717   Test Loss G: 0.65729159116745
== Iteration 5500 ==
Test Loss Q: 0.17780260741710663   Test Loss G: 0.654613196849823
== Iteration 6000 ==
Test Loss Q: 0.17895393073558807   Test Loss G: 0.642874002456665
== Iteration 6500 ==
Test Loss Q: 0.17671532928943634   Test Loss G: 0.6663792133331299
== Iteration 7000 ==
Test Loss Q: 0.17928814888000488   Test Loss G: 0.6749085187911987
== Iteration 7500 ==
Test Loss Q: 0.17874492704868317   Test Loss G: 0.6385259032249451
== Iteration 8000 ==
Test Loss Q: 0.17718827724456787   Test Loss G: 0.6853234171867371
== Iteration 8500 ==
Test Loss Q: 0.17760993540287018   Test Loss G: 0.6710301041603088
== Iteration 9000 ==
Test Loss Q: 0.177336648106575   Test Loss G: 0.6828089952468872
== Iteration 9500 ==
Test Loss Q: 0.1771

== Iteration 9500 ==
Test Loss Q: 0.18178677558898926   Test Loss G: 0.6556441187858582
== Iteration 10000 ==
Test Loss Q: 0.18131624162197113   Test Loss G: 0.667153000831604
== Iteration 10500 ==
Test Loss Q: 0.1825079321861267   Test Loss G: 0.6684373021125793
== Iteration 11000 ==
Test Loss Q: 0.18446648120880127   Test Loss G: 0.653491735458374
== Iteration 11500 ==
Test Loss Q: 0.18187089264392853   Test Loss G: 0.6582018733024597
== Iteration 12000 ==
Test Loss Q: 0.18322548270225525   Test Loss G: 0.6912370324134827
== Iteration 12500 ==
Test Loss Q: 0.1820419728755951   Test Loss G: 0.6726122498512268
== Iteration 13000 ==
Test Loss Q: 0.18126635253429413   Test Loss G: 0.6588608622550964
== Iteration 13500 ==
Test Loss Q: 0.18243034183979034   Test Loss G: 0.7135289311408997
== Iteration 14000 ==
Test Loss Q: 0.18231719732284546   Test Loss G: 0.657837986946106
== Iteration 14500 ==
Test Loss Q: 0.18447941541671753   Test Loss G: 0.6654677391052246
== Iteration 15000 ==
Test 

== Iteration 15000 ==
Test Loss Q: 0.18620625138282776   Test Loss G: 0.7274943590164185
== Iteration 15500 ==
Test Loss Q: 0.1848737597465515   Test Loss G: 0.6832261681556702
== Iteration 16000 ==
Test Loss Q: 0.18537801504135132   Test Loss G: 0.717913806438446
== Iteration 16500 ==
Test Loss Q: 0.18502943217754364   Test Loss G: 0.689757227897644
== Iteration 17000 ==
Test Loss Q: 0.1853286325931549   Test Loss G: 0.7070826292037964
== Iteration 17500 ==
Test Loss Q: 0.18616870045661926   Test Loss G: 0.7193867564201355
== Iteration 18000 ==
Test Loss Q: 0.1862832009792328   Test Loss G: 0.6815193295478821
== Iteration 18500 ==
Test Loss Q: 0.18526101112365723   Test Loss G: 0.7080547213554382
== Iteration 19000 ==
Test Loss Q: 0.18587148189544678   Test Loss G: 0.696033775806427
== Iteration 19500 ==
Test Loss Q: 0.18674863874912262   Test Loss G: 0.6955092549324036
== Iteration 19999 ==
Test Loss Q: 0.18545366823673248   Test Loss G: 0.6724504232406616
== Iteration 0 ==
Test Loss

== Iteration 500 ==
Test Loss Q: 0.19295740127563477   Test Loss G: 0.7065145969390869
== Iteration 1000 ==
Test Loss Q: 0.18333519995212555   Test Loss G: 0.6937850713729858
== Iteration 1500 ==
Test Loss Q: 0.1839827597141266   Test Loss G: 0.72364741563797
== Iteration 2000 ==
Test Loss Q: 0.1832769215106964   Test Loss G: 0.6575930118560791
== Iteration 2500 ==
Test Loss Q: 0.18395625054836273   Test Loss G: 0.6631726026535034
== Iteration 3000 ==
Test Loss Q: 0.18346118927001953   Test Loss G: 0.6960019469261169
== Iteration 3500 ==
Test Loss Q: 0.18285730481147766   Test Loss G: 0.6739802360534668
== Iteration 4000 ==
Test Loss Q: 0.18153272569179535   Test Loss G: 0.6734603047370911
== Iteration 4500 ==
Test Loss Q: 0.1834358274936676   Test Loss G: 0.7261164784431458
== Iteration 5000 ==
Test Loss Q: 0.18321894109249115   Test Loss G: 0.6857438087463379
== Iteration 5500 ==
Test Loss Q: 0.18348002433776855   Test Loss G: 0.76627117395401
== Iteration 6000 ==
Test Loss Q: 0.1836

== Iteration 6000 ==
Test Loss Q: 0.17940956354141235   Test Loss G: 0.6783599257469177
== Iteration 6500 ==
Test Loss Q: 0.17926941812038422   Test Loss G: 0.6607931852340698
== Iteration 7000 ==
Test Loss Q: 0.17821045219898224   Test Loss G: 0.6883957982063293
== Iteration 7500 ==
Test Loss Q: 0.178594172000885   Test Loss G: 0.6593534350395203
== Iteration 8000 ==
Test Loss Q: 0.18187080323696136   Test Loss G: 0.6662944555282593
== Iteration 8500 ==
Test Loss Q: 0.17834427952766418   Test Loss G: 0.6532649397850037
== Iteration 9000 ==
Test Loss Q: 0.17853204905986786   Test Loss G: 0.648842990398407
== Iteration 9500 ==
Test Loss Q: 0.17870473861694336   Test Loss G: 0.6718217730522156
== Iteration 10000 ==
Test Loss Q: 0.1770128607749939   Test Loss G: 0.6790705323219299
== Iteration 10500 ==
Test Loss Q: 0.1785958707332611   Test Loss G: 0.6789431571960449
== Iteration 11000 ==
Test Loss Q: 0.1777798980474472   Test Loss G: 0.6680405139923096
== Iteration 11500 ==
Test Loss Q: 

== Iteration 11500 ==
Test Loss Q: 0.17295467853546143   Test Loss G: 0.7281212210655212
== Iteration 12000 ==
Test Loss Q: 0.1735040545463562   Test Loss G: 0.6556169986724854
== Iteration 12500 ==
Test Loss Q: 0.17296482622623444   Test Loss G: 0.6636937260627747
== Iteration 13000 ==
Test Loss Q: 0.173307403922081   Test Loss G: 0.7019376158714294
== Iteration 13500 ==
Test Loss Q: 0.17370285093784332   Test Loss G: 0.693516731262207
== Iteration 14000 ==
Test Loss Q: 0.173162043094635   Test Loss G: 0.6850184798240662
== Iteration 14500 ==
Test Loss Q: 0.1766209900379181   Test Loss G: 0.7138059139251709
== Iteration 15000 ==
Test Loss Q: 0.17550566792488098   Test Loss G: 0.6989105939865112
== Iteration 15500 ==
Test Loss Q: 0.17375615239143372   Test Loss G: 0.7023426294326782
== Iteration 16000 ==
Test Loss Q: 0.17475497722625732   Test Loss G: 0.7011775374412537
== Iteration 16500 ==
Test Loss Q: 0.17390751838684082   Test Loss G: 0.6937759518623352
== Iteration 17000 ==
Test L

== Iteration 17000 ==
Test Loss Q: 0.18184080719947815   Test Loss G: 0.6899331212043762
== Iteration 17500 ==
Test Loss Q: 0.18208572268486023   Test Loss G: 0.669420063495636
== Iteration 18000 ==
Test Loss Q: 0.18150481581687927   Test Loss G: 0.675213098526001
== Iteration 18500 ==
Test Loss Q: 0.18382006883621216   Test Loss G: 0.6842193007469177
== Iteration 19000 ==
Test Loss Q: 0.18350756168365479   Test Loss G: 0.6735725998878479
== Iteration 19500 ==
Test Loss Q: 0.18280883133411407   Test Loss G: 0.6801804900169373
== Iteration 19999 ==
Test Loss Q: 0.18346881866455078   Test Loss G: 0.6785328388214111
== Iteration 0 ==
Test Loss Q: 0.23963434994220734   Test Loss G: 0.6889318227767944
== Iteration 500 ==
Test Loss Q: 0.17800068855285645   Test Loss G: 0.7212652564048767
== Iteration 1000 ==
Test Loss Q: 0.17278151214122772   Test Loss G: 0.6627787351608276
== Iteration 1500 ==
Test Loss Q: 0.1768960952758789   Test Loss G: 0.714195728302002
== Iteration 2000 ==
Test Loss Q:

== Iteration 1500 ==
Test Loss Q: 0.1843692809343338   Test Loss G: 0.7267025113105774
== Iteration 2000 ==
Test Loss Q: 0.18278385698795319   Test Loss G: 0.6931779980659485
== Iteration 2500 ==
Test Loss Q: 0.18394611775875092   Test Loss G: 0.6998668313026428
== Iteration 3000 ==
Test Loss Q: 0.18209950625896454   Test Loss G: 0.7138081789016724
== Iteration 3500 ==
Test Loss Q: 0.18276391923427582   Test Loss G: 0.6953714489936829
== Iteration 4000 ==
Test Loss Q: 0.1826734095811844   Test Loss G: 0.686363697052002
== Iteration 4500 ==
Test Loss Q: 0.18276257812976837   Test Loss G: 0.6901052594184875
== Iteration 5000 ==
Test Loss Q: 0.18373852968215942   Test Loss G: 0.6900219917297363
== Iteration 5500 ==
Test Loss Q: 0.18275001645088196   Test Loss G: 0.6739869713783264
== Iteration 6000 ==
Test Loss Q: 0.1838310956954956   Test Loss G: 0.6909189224243164
== Iteration 6500 ==
Test Loss Q: 0.18423622846603394   Test Loss G: 0.692876398563385
== Iteration 7000 ==
Test Loss Q: 0.1

== Iteration 7000 ==
Test Loss Q: 0.17943915724754333   Test Loss G: 0.6730067133903503
== Iteration 7500 ==
Test Loss Q: 0.1805681437253952   Test Loss G: 0.7150858640670776
== Iteration 8000 ==
Test Loss Q: 0.18273739516735077   Test Loss G: 0.7339034080505371
== Iteration 8500 ==
Test Loss Q: 0.18259070813655853   Test Loss G: 0.7077101469039917
== Iteration 9000 ==
Test Loss Q: 0.180620014667511   Test Loss G: 0.696278989315033
== Iteration 9500 ==
Test Loss Q: 0.18114639818668365   Test Loss G: 0.7527666091918945
== Iteration 10000 ==
Test Loss Q: 0.17912766337394714   Test Loss G: 0.6950328350067139
== Iteration 10500 ==
Test Loss Q: 0.18063020706176758   Test Loss G: 0.7069690227508545
== Iteration 11000 ==
Test Loss Q: 0.18185773491859436   Test Loss G: 0.6971595287322998
== Iteration 11500 ==
Test Loss Q: 0.18094873428344727   Test Loss G: 0.6964148283004761
== Iteration 12000 ==
Test Loss Q: 0.1828429251909256   Test Loss G: 0.6896679997444153
== Iteration 12500 ==
Test Loss 

== Iteration 12500 ==
Test Loss Q: 0.1801891326904297   Test Loss G: 0.657367467880249
== Iteration 13000 ==
Test Loss Q: 0.18170665204524994   Test Loss G: 0.6875192523002625
== Iteration 13500 ==
Test Loss Q: 0.1815187782049179   Test Loss G: 0.6864371299743652
== Iteration 14000 ==
Test Loss Q: 0.18061265349388123   Test Loss G: 0.6935982704162598
== Iteration 14500 ==
Test Loss Q: 0.18160338699817657   Test Loss G: 0.6788687705993652
== Iteration 15000 ==
Test Loss Q: 0.18066363036632538   Test Loss G: 0.6903789043426514
== Iteration 15500 ==
Test Loss Q: 0.18103092908859253   Test Loss G: 0.6947528123855591
== Iteration 16000 ==
Test Loss Q: 0.18145357072353363   Test Loss G: 0.6919958591461182
== Iteration 16500 ==
Test Loss Q: 0.18160806596279144   Test Loss G: 0.6716675758361816
== Iteration 17000 ==
Test Loss Q: 0.18226853013038635   Test Loss G: 0.6650800108909607
== Iteration 17500 ==
Test Loss Q: 0.18241143226623535   Test Loss G: 0.6765407919883728
== Iteration 18000 ==
Te

== Iteration 18000 ==
Test Loss Q: 0.1755191832780838   Test Loss G: 0.6775267124176025
== Iteration 18500 ==
Test Loss Q: 0.17529414594173431   Test Loss G: 0.6831822991371155
== Iteration 19000 ==
Test Loss Q: 0.1749587059020996   Test Loss G: 0.6771113872528076
== Iteration 19500 ==
Test Loss Q: 0.17524151504039764   Test Loss G: 0.688499391078949
== Iteration 19999 ==
Test Loss Q: 0.17523762583732605   Test Loss G: 0.6761201024055481
== Iteration 0 ==
Test Loss Q: 0.1964791715145111   Test Loss G: 0.6630813479423523
== Iteration 500 ==
Test Loss Q: 0.1758769154548645   Test Loss G: 0.6535782814025879
== Iteration 1000 ==
Test Loss Q: 0.1744491159915924   Test Loss G: 0.6767114996910095
== Iteration 1500 ==
Test Loss Q: 0.17856518924236298   Test Loss G: 0.6995299458503723
== Iteration 2000 ==
Test Loss Q: 0.1773291826248169   Test Loss G: 0.66745924949646
== Iteration 2500 ==
Test Loss Q: 0.17638036608695984   Test Loss G: 0.7026110887527466
== Iteration 3000 ==
Test Loss Q: 0.1731

== Iteration 2500 ==
Test Loss Q: 0.17376551032066345   Test Loss G: 0.6669844388961792
== Iteration 3000 ==
Test Loss Q: 0.17426247894763947   Test Loss G: 0.659171462059021
== Iteration 3500 ==
Test Loss Q: 0.17363940179347992   Test Loss G: 0.6684436202049255
== Iteration 4000 ==
Test Loss Q: 0.1761481910943985   Test Loss G: 0.6626639366149902
== Iteration 4500 ==
Test Loss Q: 0.17545974254608154   Test Loss G: 0.7059547305107117
== Iteration 5000 ==
Test Loss Q: 0.17670083045959473   Test Loss G: 0.6849573850631714
== Iteration 5500 ==
Test Loss Q: 0.18038727343082428   Test Loss G: 0.7180644273757935
== Iteration 6000 ==
Test Loss Q: 0.17542515695095062   Test Loss G: 0.7047001719474792
== Iteration 6500 ==
Test Loss Q: 0.1752128154039383   Test Loss G: 0.6780794858932495
== Iteration 7000 ==
Test Loss Q: 0.1773621141910553   Test Loss G: 0.6603214144706726
== Iteration 7500 ==
Test Loss Q: 0.1735861599445343   Test Loss G: 0.6740401387214661
== Iteration 8000 ==
Test Loss Q: 0.1

== Iteration 8000 ==
Test Loss Q: 0.1699555367231369   Test Loss G: 0.7023746371269226
== Iteration 8500 ==
Test Loss Q: 0.1694471538066864   Test Loss G: 0.7111669182777405
== Iteration 9000 ==
Test Loss Q: 0.16955259442329407   Test Loss G: 0.6855321526527405
== Iteration 9500 ==
Test Loss Q: 0.16907939314842224   Test Loss G: 0.7015407681465149
== Iteration 10000 ==
Test Loss Q: 0.1687878668308258   Test Loss G: 0.6955600380897522
== Iteration 10500 ==
Test Loss Q: 0.1675090342760086   Test Loss G: 0.7080205678939819
== Iteration 11000 ==
Test Loss Q: 0.16653117537498474   Test Loss G: 0.6820518374443054
== Iteration 11500 ==
Test Loss Q: 0.16722561419010162   Test Loss G: 0.6791002750396729
== Iteration 12000 ==
Test Loss Q: 0.16742223501205444   Test Loss G: 0.7085676789283752
== Iteration 12500 ==
Test Loss Q: 0.16904319822788239   Test Loss G: 0.6900407671928406
== Iteration 13000 ==
Test Loss Q: 0.16814154386520386   Test Loss G: 0.7055850625038147
== Iteration 13500 ==
Test Lo

== Iteration 13000 ==
Test Loss Q: 0.17931048572063446   Test Loss G: 0.6793591976165771
== Iteration 13500 ==
Test Loss Q: 0.18072830140590668   Test Loss G: 0.7105761170387268
== Iteration 14000 ==
Test Loss Q: 0.18036232888698578   Test Loss G: 0.6718204617500305
== Iteration 14500 ==
Test Loss Q: 0.17985016107559204   Test Loss G: 0.684692919254303
== Iteration 15000 ==
Test Loss Q: 0.18052442371845245   Test Loss G: 0.7151145935058594
== Iteration 15500 ==
Test Loss Q: 0.18183156847953796   Test Loss G: 0.6756598353385925
== Iteration 16000 ==
Test Loss Q: 0.1803176999092102   Test Loss G: 0.7016437649726868
== Iteration 16500 ==
Test Loss Q: 0.1823655068874359   Test Loss G: 0.7193499803543091
== Iteration 17000 ==
Test Loss Q: 0.18085914850234985   Test Loss G: 0.6883047223091125
== Iteration 17500 ==
Test Loss Q: 0.17946158349514008   Test Loss G: 0.6930254101753235
== Iteration 18000 ==
Test Loss Q: 0.18181954324245453   Test Loss G: 0.6827499270439148
== Iteration 18500 ==
Te

== Iteration 18500 ==
Test Loss Q: 0.17483139038085938   Test Loss G: 0.7171825766563416
== Iteration 19000 ==
Test Loss Q: 0.1751554310321808   Test Loss G: 0.7095181345939636
== Iteration 19500 ==
Test Loss Q: 0.17598316073417664   Test Loss G: 0.7114952802658081
== Iteration 19999 ==
Test Loss Q: 0.1733643114566803   Test Loss G: 0.7025995850563049
== Iteration 0 ==
Test Loss Q: 0.22332753241062164   Test Loss G: 0.6576361060142517
== Iteration 500 ==
Test Loss Q: 0.17210350930690765   Test Loss G: 0.6450904011726379
== Iteration 1000 ==
Test Loss Q: 0.17294807732105255   Test Loss G: 0.6660062074661255
== Iteration 1500 ==
Test Loss Q: 0.17378999292850494   Test Loss G: 0.6531586050987244
== Iteration 2000 ==
Test Loss Q: 0.17397351562976837   Test Loss G: 0.6354658007621765
== Iteration 2500 ==
Test Loss Q: 0.17234845459461212   Test Loss G: 0.7002955079078674
== Iteration 3000 ==
Test Loss Q: 0.17225930094718933   Test Loss G: 0.6875597238540649
== Iteration 3500 ==
Test Loss Q: 

== Iteration 3000 ==
Test Loss Q: 0.18060512840747833   Test Loss G: 0.6787151098251343
== Iteration 3500 ==
Test Loss Q: 0.18353591859340668   Test Loss G: 0.6405331492424011
== Iteration 4000 ==
Test Loss Q: 0.18373025953769684   Test Loss G: 0.6885733008384705
== Iteration 4500 ==
Test Loss Q: 0.18087320029735565   Test Loss G: 0.6782990097999573
== Iteration 5000 ==
Test Loss Q: 0.18029093742370605   Test Loss G: 0.6758188009262085
== Iteration 5500 ==
Test Loss Q: 0.17982475459575653   Test Loss G: 0.6776731014251709
== Iteration 6000 ==
Test Loss Q: 0.1855340152978897   Test Loss G: 0.6786414980888367
== Iteration 6500 ==
Test Loss Q: 0.18275150656700134   Test Loss G: 0.6798882484436035
== Iteration 7000 ==
Test Loss Q: 0.18050667643547058   Test Loss G: 0.6559973359107971
== Iteration 7500 ==
Test Loss Q: 0.1818876713514328   Test Loss G: 0.659511387348175
== Iteration 8000 ==
Test Loss Q: 0.18168707191944122   Test Loss G: 0.6924180388450623
== Iteration 8500 ==
Test Loss Q: 0

== Iteration 8500 ==
Test Loss Q: 0.18425700068473816   Test Loss G: 0.6593115925788879
== Iteration 9000 ==
Test Loss Q: 0.1832277774810791   Test Loss G: 0.6661067008972168
== Iteration 9500 ==
Test Loss Q: 0.18366973102092743   Test Loss G: 0.6710860133171082
== Iteration 10000 ==
Test Loss Q: 0.18308700621128082   Test Loss G: 0.6850507855415344
== Iteration 10500 ==
Test Loss Q: 0.1831311583518982   Test Loss G: 0.6477508544921875
== Iteration 11000 ==
Test Loss Q: 0.18486665189266205   Test Loss G: 0.6892833709716797
== Iteration 11500 ==
Test Loss Q: 0.1835259050130844   Test Loss G: 0.673839807510376
== Iteration 12000 ==
Test Loss Q: 0.1832408308982849   Test Loss G: 0.6901683807373047
== Iteration 12500 ==
Test Loss Q: 0.18232914805412292   Test Loss G: 0.6810283660888672
== Iteration 13000 ==
Test Loss Q: 0.1856795847415924   Test Loss G: 0.6647350788116455
== Iteration 13500 ==
Test Loss Q: 0.18525096774101257   Test Loss G: 0.6828808784484863
== Iteration 14000 ==
Test Los

== Iteration 13500 ==
Test Loss Q: 0.17511266469955444   Test Loss G: 0.6900567412376404
== Iteration 14000 ==
Test Loss Q: 0.17583206295967102   Test Loss G: 0.6733956336975098
== Iteration 14500 ==
Test Loss Q: 0.1747262179851532   Test Loss G: 0.697076141834259
== Iteration 15000 ==
Test Loss Q: 0.17469501495361328   Test Loss G: 0.6587107181549072
== Iteration 15500 ==
Test Loss Q: 0.17544059455394745   Test Loss G: 0.659650981426239
== Iteration 16000 ==
Test Loss Q: 0.17657093703746796   Test Loss G: 0.6695981025695801
== Iteration 16500 ==
Test Loss Q: 0.1755305528640747   Test Loss G: 0.7067788243293762
== Iteration 17000 ==
Test Loss Q: 0.17610342800617218   Test Loss G: 0.7065867185592651
== Iteration 17500 ==
Test Loss Q: 0.17567385733127594   Test Loss G: 0.6733907461166382
== Iteration 18000 ==
Test Loss Q: 0.17628248035907745   Test Loss G: 0.6861890554428101
== Iteration 18500 ==
Test Loss Q: 0.1748637557029724   Test Loss G: 0.6796565055847168
== Iteration 19000 ==
Test

== Iteration 18500 ==
Test Loss Q: 0.1753997802734375   Test Loss G: 0.7054415941238403
== Iteration 19000 ==
Test Loss Q: 0.1765657663345337   Test Loss G: 0.6743814945220947
== Iteration 19500 ==
Test Loss Q: 0.17357079684734344   Test Loss G: 0.674368679523468
== Iteration 19999 ==
Test Loss Q: 0.17312929034233093   Test Loss G: 0.6931431889533997
== Iteration 0 ==
Test Loss Q: 0.22664491832256317   Test Loss G: 0.6307252645492554
== Iteration 500 ==
Test Loss Q: 0.1705324649810791   Test Loss G: 0.6969653367996216
== Iteration 1000 ==
Test Loss Q: 0.17560341954231262   Test Loss G: 0.7118491530418396
== Iteration 1500 ==
Test Loss Q: 0.1664411723613739   Test Loss G: 0.7234565615653992
== Iteration 2000 ==
Test Loss Q: 0.1678534746170044   Test Loss G: 0.6617574095726013
== Iteration 2500 ==
Test Loss Q: 0.16783712804317474   Test Loss G: 0.671438455581665
== Iteration 3000 ==
Test Loss Q: 0.1680363118648529   Test Loss G: 0.656317412853241
== Iteration 3500 ==
Test Loss Q: 0.16662

== Iteration 3000 ==
Test Loss Q: 0.1717139482498169   Test Loss G: 0.6632900834083557
== Iteration 3500 ==
Test Loss Q: 0.1702544242143631   Test Loss G: 0.6806473731994629
== Iteration 4000 ==
Test Loss Q: 0.17029261589050293   Test Loss G: 0.665524423122406
== Iteration 4500 ==
Test Loss Q: 0.1692996621131897   Test Loss G: 0.6888720989227295
== Iteration 5000 ==
Test Loss Q: 0.17113520205020905   Test Loss G: 0.6809288263320923
== Iteration 5500 ==
Test Loss Q: 0.17034219205379486   Test Loss G: 0.6523844003677368
== Iteration 6000 ==
Test Loss Q: 0.16903124749660492   Test Loss G: 0.7000190615653992
== Iteration 6500 ==
Test Loss Q: 0.17113511264324188   Test Loss G: 0.6585346460342407
== Iteration 7000 ==
Test Loss Q: 0.1712079793214798   Test Loss G: 0.6679783463478088
== Iteration 7500 ==
Test Loss Q: 0.16995827853679657   Test Loss G: 0.6887739896774292
== Iteration 8000 ==
Test Loss Q: 0.16942983865737915   Test Loss G: 0.6843189597129822
== Iteration 8500 ==
Test Loss Q: 0.1

== Iteration 8500 ==
Test Loss Q: 0.17658326029777527   Test Loss G: 0.6777448654174805
== Iteration 9000 ==
Test Loss Q: 0.17698794603347778   Test Loss G: 0.6822851300239563
== Iteration 9500 ==
Test Loss Q: 0.1757459193468094   Test Loss G: 0.6874244213104248
== Iteration 10000 ==
Test Loss Q: 0.17736320197582245   Test Loss G: 0.6725486516952515
== Iteration 10500 ==
Test Loss Q: 0.17689533531665802   Test Loss G: 0.6807124614715576
== Iteration 11000 ==
Test Loss Q: 0.17696042358875275   Test Loss G: 0.6832022070884705
== Iteration 11500 ==
Test Loss Q: 0.1758650690317154   Test Loss G: 0.6825625896453857
== Iteration 12000 ==
Test Loss Q: 0.17541588842868805   Test Loss G: 0.6738331913948059
== Iteration 12500 ==
Test Loss Q: 0.17804445326328278   Test Loss G: 0.6813597679138184
== Iteration 13000 ==
Test Loss Q: 0.17852579057216644   Test Loss G: 0.7096474170684814
== Iteration 13500 ==
Test Loss Q: 0.17831557989120483   Test Loss G: 0.7048301696777344
== Iteration 14000 ==
Test

== Iteration 14000 ==
Test Loss Q: 0.17976313829421997   Test Loss G: 0.6572891473770142
== Iteration 14500 ==
Test Loss Q: 0.17982560396194458   Test Loss G: 0.676898717880249
== Iteration 15000 ==
Test Loss Q: 0.18033039569854736   Test Loss G: 0.667427122592926
== Iteration 15500 ==
Test Loss Q: 0.1787419617176056   Test Loss G: 0.6660481095314026
== Iteration 16000 ==
Test Loss Q: 0.17849688231945038   Test Loss G: 0.696666419506073
== Iteration 16500 ==
Test Loss Q: 0.18008731305599213   Test Loss G: 0.6762767434120178
== Iteration 17000 ==
Test Loss Q: 0.18058675527572632   Test Loss G: 0.6814953684806824
== Iteration 17500 ==
Test Loss Q: 0.1800428032875061   Test Loss G: 0.6663069128990173
== Iteration 18000 ==
Test Loss Q: 0.17974835634231567   Test Loss G: 0.6661312580108643
== Iteration 18500 ==
Test Loss Q: 0.1811143308877945   Test Loss G: 0.6887983679771423
== Iteration 19000 ==
Test Loss Q: 0.18103919923305511   Test Loss G: 0.668516993522644
== Iteration 19500 ==
Test L

== Iteration 19000 ==
Test Loss Q: 0.1767008900642395   Test Loss G: 0.6773212552070618
== Iteration 19500 ==
Test Loss Q: 0.17618483304977417   Test Loss G: 0.6686252355575562
== Iteration 19999 ==
Test Loss Q: 0.1757931262254715   Test Loss G: 0.6612309813499451
== Iteration 0 ==
Test Loss Q: 0.2119615226984024   Test Loss G: 0.7353114485740662
== Iteration 500 ==
Test Loss Q: 0.19337384402751923   Test Loss G: 0.6783080101013184
== Iteration 1000 ==
Test Loss Q: 0.19336192309856415   Test Loss G: 0.7159327268600464
== Iteration 1500 ==
Test Loss Q: 0.19258245825767517   Test Loss G: 0.7066288590431213
== Iteration 2000 ==
Test Loss Q: 0.1904962956905365   Test Loss G: 0.6842225193977356
== Iteration 2500 ==
Test Loss Q: 0.1936902105808258   Test Loss G: 0.695909857749939
== Iteration 3000 ==
Test Loss Q: 0.19440260529518127   Test Loss G: 0.6781843304634094
== Iteration 3500 ==
Test Loss Q: 0.1915491223335266   Test Loss G: 0.6799047589302063
== Iteration 4000 ==
Test Loss Q: 0.1922

== Iteration 3500 ==
Test Loss Q: 0.18478943407535553   Test Loss G: 0.7237517237663269
== Iteration 4000 ==
Test Loss Q: 0.18377278745174408   Test Loss G: 0.687369167804718
== Iteration 4500 ==
Test Loss Q: 0.18267473578453064   Test Loss G: 0.693062961101532
== Iteration 5000 ==
Test Loss Q: 0.18384170532226562   Test Loss G: 0.6915326714515686
== Iteration 5500 ==
Test Loss Q: 0.18339982628822327   Test Loss G: 0.7080062627792358
== Iteration 6000 ==
Test Loss Q: 0.1832306832075119   Test Loss G: 0.7126919031143188
== Iteration 6500 ==
Test Loss Q: 0.18256688117980957   Test Loss G: 0.6895517706871033
== Iteration 7000 ==
Test Loss Q: 0.18266846239566803   Test Loss G: 0.7031660079956055
== Iteration 7500 ==
Test Loss Q: 0.18301111459732056   Test Loss G: 0.659832775592804
== Iteration 8000 ==
Test Loss Q: 0.1832013726234436   Test Loss G: 0.737763524055481
== Iteration 8500 ==
Test Loss Q: 0.18285924196243286   Test Loss G: 0.6800064444541931
== Iteration 9000 ==
Test Loss Q: 0.18

== Iteration 9000 ==
Test Loss Q: 0.18032705783843994   Test Loss G: 0.6748831272125244
== Iteration 9500 ==
Test Loss Q: 0.17963488399982452   Test Loss G: 0.6674371361732483
== Iteration 10000 ==
Test Loss Q: 0.1817970871925354   Test Loss G: 0.6820816397666931
== Iteration 10500 ==
Test Loss Q: 0.1816866248846054   Test Loss G: 0.6751052737236023
== Iteration 11000 ==
Test Loss Q: 0.1782212108373642   Test Loss G: 0.6874918937683105
== Iteration 11500 ==
Test Loss Q: 0.17951752245426178   Test Loss G: 0.6949000954627991
== Iteration 12000 ==
Test Loss Q: 0.17859698832035065   Test Loss G: 0.6689848899841309
== Iteration 12500 ==
Test Loss Q: 0.18097886443138123   Test Loss G: 0.6891542673110962
== Iteration 13000 ==
Test Loss Q: 0.18370085954666138   Test Loss G: 0.6796928644180298
== Iteration 13500 ==
Test Loss Q: 0.18167760968208313   Test Loss G: 0.6669120192527771
== Iteration 14000 ==
Test Loss Q: 0.17890724539756775   Test Loss G: 0.6944514513015747
== Iteration 14500 ==
Test

== Iteration 14000 ==
Test Loss Q: 0.18438486754894257   Test Loss G: 0.7014032006263733
== Iteration 14500 ==
Test Loss Q: 0.183928981423378   Test Loss G: 0.7117851972579956
== Iteration 15000 ==
Test Loss Q: 0.183023601770401   Test Loss G: 0.6945000290870667
== Iteration 15500 ==
Test Loss Q: 0.1828463077545166   Test Loss G: 0.6939460635185242
== Iteration 16000 ==
Test Loss Q: 0.18411384522914886   Test Loss G: 0.6742557883262634
== Iteration 16500 ==
Test Loss Q: 0.18337029218673706   Test Loss G: 0.701088547706604
== Iteration 17000 ==
Test Loss Q: 0.18477758765220642   Test Loss G: 0.6856302618980408
== Iteration 17500 ==
Test Loss Q: 0.1849672794342041   Test Loss G: 0.6931094527244568
== Iteration 18000 ==
Test Loss Q: 0.18218855559825897   Test Loss G: 0.7041956782341003
== Iteration 18500 ==
Test Loss Q: 0.18432675302028656   Test Loss G: 0.7015549540519714
== Iteration 19000 ==
Test Loss Q: 0.18515963852405548   Test Loss G: 0.7041631937026978
== Iteration 19500 ==
Test L

In [152]:
estimates_upd = np.asarray(estimates_upd)
estimates_naive = np.asarray(estimates_naive)

print('True psi: ', true_psi)
print('naive psi: ', estimates_naive.mean(), ' relative bias:',
      (estimates_naive.mean() - true_psi)/true_psi * 100, '%')
print('updated TMLE psi: ', estimates_upd.mean(), ' relative bias:',
      (estimates_upd.mean() - true_psi)/true_psi * 100, '%')
print('Reduction in bias:', np.abs(estimates_naive.mean() - true_psi)/true_psi * 100 - 
     np.abs(estimates_upd.mean() - true_psi)/true_psi * 100, '%')

True psi:  0.1956508
naive psi:  0.2035916  relative bias: 4.058659642824327 %
updated TMLE psi:  0.20075402  relative bias: 2.6083290422539087 %
Reduction in bias: 1.450330600570418 %
