In [1]:
import numpy as np

In [2]:
import tabulate

In [3]:
mols = ['h2','heh+']
bases = ['sto-3g','6-31g']

In [4]:
# what are the column headings?
# 0 -> number of iterations
# 1 -> execution time
# 2 -> cost(\theta^{\ast})
# 3 -> \| \nabla cost(\theta^{\ast}) \|
# 4 -> \| a(T) - \beta \|
# 5 -> cost(\theta^{\ast}) / cost(\theta^{(0)})


In [5]:
# all grad stats
ags = np.zeros((2,2,100,6))
# all hess stats
ahs = np.zeros((2,2,100,6))

In [6]:
# sto-3g stats are easy
for i in range(2):
    mol = mols[i]
    j = 0
    base = bases[j]
    tmp = np.load('compare_'+mol+'_'+base+'.npz')
    ags[i,j,:,:] = tmp['gradstats']
    ahs[i,j,:,:] = tmp['hessstats']

In [7]:
# 6-31g stats will take a bit of work
for i in range(2):
    mol = mols[i]
    j = 1
    base = bases[j]
    for n in range(100):
        tmp = np.load('./compare_'+mol+'_'+base+'/compare_'+mol+'_'+base+str(n)+'.npz')
        ags[i,j,n,:] = tmp['gradstats']
        ahs[i,j,n,:] = tmp['hessstats']

In [8]:
meanags = np.mean(ags,axis=2)
print(meanags.shape)
meanahs = np.mean(ahs,axis=2)
print(meanahs.shape)


(2, 2, 6)
(2, 2, 6)


In [9]:
# sto-3g final stats
print(meanags[:,0,:])
print(meanahs[:,0,:])

[[5.75872000e+03 7.21495099e+01 1.26598428e+01 1.12362755e+00
  1.59525588e-05 1.34204388e-05]
 [7.63819000e+03 9.03189123e+01 2.16516772e+01 2.94014591e+00
  3.44920690e-05 2.19725799e-05]]
[[1.44237000e+03 2.15820073e+01 1.30050155e+01 6.73510813e-05
  1.58275067e-05 1.37647863e-05]
 [2.69143000e+03 4.11616403e+01 2.07509436e+01 5.07409487e-05
  3.28318480e-05 2.11690630e-05]]


In [10]:
# 6-31 final stats
print(meanags[:,1,:])
print(meanahs[:,1,:])

[[5.18039000e+03 7.45290432e+01 1.47247002e+01 2.87876371e+00
  2.64478220e-05 1.56582385e-05]
 [6.60774000e+03 9.93308288e+01 1.64207019e+01 3.90320351e+00
  1.99019869e-05 1.69499049e-05]]
[[1.34276000e+03 2.15856878e+01 1.42151089e+01 3.24519986e-05
  2.51896377e-05 1.51851816e-05]
 [2.10289000e+03 3.39749972e+01 1.56177253e+01 5.90312188e-05
  1.76828806e-05 1.61284316e-05]]


In [11]:
mytab = np.concatenate([meanags[:,0,:],meanags[:,1,:],meanahs[:,0,:],meanahs[:,1,:]])
print(tabulate.tabulate(mytab))

-------  -------  -------  -----------  -----------  -----------
5758.72  72.1495  12.6598  1.12363      1.59526e-05  1.34204e-05
7638.19  90.3189  21.6517  2.94015      3.44921e-05  2.19726e-05
5180.39  74.529   14.7247  2.87876      2.64478e-05  1.56582e-05
6607.74  99.3308  16.4207  3.9032       1.9902e-05   1.69499e-05
1442.37  21.582   13.005   6.73511e-05  1.58275e-05  1.37648e-05
2691.43  41.1616  20.7509  5.07409e-05  3.28318e-05  2.11691e-05
1342.76  21.5857  14.2151  3.2452e-05   2.51896e-05  1.51852e-05
2102.89  33.975   15.6177  5.90312e-05  1.76829e-05  1.61284e-05
-------  -------  -------  -----------  -----------  -----------


In [12]:
headings = ['# iters','execution time','final cost','final norm of grad','terminal constr viol','cost reduction']
rows = []
for i in range(2):
    for j in range(2):
        row = [mols[i]] + [bases[j]] + [x.item() for x in meanags[i,j,:]]
        rows.append(row)

print(tabulate.tabulate(rows, headers=headings))
print("")

rows = []
for i in range(2):
    for j in range(2):
        row = [mols[i]] + [bases[j]] + [x.item() for x in meanahs[i,j,:]]
        rows.append(row)

print(tabulate.tabulate(rows, headers=headings))

                # iters    execution time    final cost    final norm of grad    terminal constr viol    cost reduction
----  ------  ---------  ----------------  ------------  --------------------  ----------------------  ----------------
h2    sto-3g    5758.72           72.1495       12.6598               1.12363             1.59526e-05       1.34204e-05
h2    6-31g     5180.39           74.529        14.7247               2.87876             2.64478e-05       1.56582e-05
heh+  sto-3g    7638.19           90.3189       21.6517               2.94015             3.44921e-05       2.19726e-05
heh+  6-31g     6607.74           99.3308       16.4207               3.9032              1.9902e-05        1.69499e-05

                # iters    execution time    final cost    final norm of grad    terminal constr viol    cost reduction
----  ------  ---------  ----------------  ------------  --------------------  ----------------------  ----------------
h2    sto-3g    1442.37           21.58

In [13]:
ahs[0,0,:,2]

array([12.18013485, 14.7377194 , 11.55582636, 12.5691156 , 10.89457391,
       13.88929397, 13.45661222, 14.14043219, 14.95274848, 16.23125616,
       15.68175397, 14.95274848, 11.95794099, 13.43569023, 15.29882531,
       13.79400385, 10.89457391, 11.83004743, 12.18013485, 13.43569023,
       15.81793198, 16.67690717, 15.68175397, 12.92228623, 13.97506085,
       13.97506085, 13.44429556,  9.54635123, 12.92760575,  9.42990766,
       12.20218433, 13.09265863,  9.89669147, 15.49142772, 14.65313782,
       11.58335086, 14.90630857, 14.14043219, 15.28676382, 13.57434989,
       14.65313782, 10.89457391,  8.23160512, 11.95794099, 10.74915651,
       14.14043219,  9.54635123, 12.86439326, 13.39412181,  9.89850626,
       12.10674691, 13.39412181, 10.05464019, 12.40008771, 13.36273567,
       14.78362759, 11.18530969, 12.76703528, 12.79829341, 12.57471328,
       11.8638855 , 14.2096619 , 12.18013485, 11.95794099, 13.72730077,
       11.95794099, 13.39412181, 13.45661222, 12.42106042, 15.11