In [1]:
import numpy as np

In [2]:
import tabulate

In [3]:
mols = ['h2','heh+']
bases = ['sto-3g','6-31g']

In [4]:
# what are the column headings?
# 0 -> number of iterations
# 1 -> execution time
# 2 -> cost(\theta^{\ast})
# 3 -> \| \nabla cost(\theta^{\ast}) \|
# 4 -> \| a(T) - \beta \|
# 5 -> cost(\theta^{\ast}) / cost(\theta^{(0)})


In [5]:
# all grad stats
ags = np.zeros((2,2,1000,6))
# all hess stats
ahs = np.zeros((2,2,1000,6))

In [8]:
# sto-3g stats are easy
for i in range(2):
    mol = mols[i]
    j = 0
    base = bases[j]
    for n in range(10):
        tmp = np.load('./compare_'+mol+'_'+base+'/compare_'+mol+'_'+base+str(n)+'.npz')
        si = n*100
        ei = (n+1)*100
        ags[i,j,si:ei,:] = tmp['gradstats']
        ahs[i,j,si:ei,:] = tmp['hessstats']

In [9]:
# 6-31g stats will take a bit of work
for i in range(2):
    mol = mols[i]
    j = 1
    base = bases[j]
    for n in range(100):
        tmp = np.load('./compare_'+mol+'_'+base+'/compare_'+mol+'_'+base+str(n)+'.npz')
        ags[i,j,n,:] = tmp['gradstats']
        ahs[i,j,n,:] = tmp['hessstats']

In [10]:
meanags = np.mean(ags,axis=2)
print(meanags.shape)
meanahs = np.mean(ahs,axis=2)
print(meanahs.shape)


(2, 2, 6)
(2, 2, 6)


In [11]:
stdags = np.std(ags,axis=2)
stdahs = np.std(ahs,axis=2)

In [12]:
print(stdags)
print(stdahs)

[[[1.90299157e+03 2.55230873e+01 2.09168113e+00 7.13183398e+00
   5.06911122e-06 4.32184804e-06]
  [1.77231216e+03 2.83004821e+01 4.64895473e+00 1.49896026e+00
   7.34014103e-06 4.92261306e-06]]

 [[2.08894218e+03 2.56563392e+01 3.96566116e+00 9.30845745e+00
   8.09487874e-06 5.45603428e-06]
  [2.27659068e+03 3.94494687e+01 4.85388579e+00 2.94055505e+00
   5.98365659e-06 5.09204392e-06]]]
[[[8.84884676e+02 1.38345801e+01 2.00747160e+00 1.14181211e-04
   2.56469269e-06 4.28585418e-06]
  [4.64959938e+02 7.52133908e+00 4.71767422e+00 3.99188766e-05
   8.21374720e-06 5.05486240e-06]]

 [[1.34598357e+03 2.10058584e+01 2.89092423e+00 1.72347577e+00
   6.83425191e-06 5.23267624e-06]
  [7.46762555e+02 1.19940279e+01 4.73595950e+00 3.17945256e-05
   5.49311358e-06 4.96833996e-06]]]


In [13]:
# sto-3g final stats
print(meanags[:,0,:])
print(meanahs[:,0,:])

[[5.74812500e+03 7.34883775e+01 1.27696851e+01 8.82973998e-01
  1.60399999e-05 1.36284932e-05]
 [7.70769500e+03 9.20180916e+01 2.10352185e+01 2.38340916e+00
  3.23744288e-05 2.15939196e-05]]
[[1.42166800e+03 2.13366210e+01 1.27794507e+01 6.26159538e-05
  1.57277226e-05 1.36417517e-05]
 [2.72995300e+03 4.22175982e+01 2.08100566e+01 5.45807678e-02
  3.22242370e-05 2.14213573e-05]]


In [14]:
# 6-31 final stats
print(meanags[:,1,:])
print(meanahs[:,1,:])

[[5.45515000e+02 8.47509160e+00 1.48681591e+00 8.83242591e-02
  2.27653288e-06 1.54151392e-06]
 [7.17624000e+02 1.20251971e+01 1.59259389e+00 2.13287214e-01
  1.85398636e-06 1.63987537e-06]]
[[1.33504000e+02 2.15018452e+00 1.50446181e+00 5.58279408e-06
  2.50227583e-06 1.56999229e-06]
 [2.14925000e+02 3.41597297e+00 1.55392875e+00 4.04306776e-06
  1.77972182e-06 1.60017440e-06]]


In [16]:
mytab = np.concatenate([meanags[:,0,:],meanags[:,1,:],meanahs[:,0,:],meanahs[:,1,:]])
print(tabulate.tabulate(mytab))

mytab = np.concatenate([stdags[:,0,:],stdags[:,1,:],stdahs[:,0,:],stdahs[:,1,:]])
print(tabulate.tabulate(mytab))

--------  --------  --------  -----------  -----------  -----------
5748.12   73.4884   12.7697   0.882974     1.604e-05    1.36285e-05
7707.69   92.0181   21.0352   2.38341      3.23744e-05  2.15939e-05
 545.515   8.47509   1.48682  0.0883243    2.27653e-06  1.54151e-06
 717.624  12.0252    1.59259  0.213287     1.85399e-06  1.63988e-06
1421.67   21.3366   12.7795   6.2616e-05   1.57277e-05  1.36418e-05
2729.95   42.2176   20.8101   0.0545808    3.22242e-05  2.14214e-05
 133.504   2.15018   1.50446  5.58279e-06  2.50228e-06  1.56999e-06
 214.925   3.41597   1.55393  4.04307e-06  1.77972e-06  1.60017e-06
--------  --------  --------  -----------  -----------  -----------
--------  --------  -------  -----------  -----------  -----------
1902.99   25.5231   2.09168  7.13183      5.06911e-06  4.32185e-06
2088.94   25.6563   3.96566  9.30846      8.09488e-06  5.45603e-06
1772.31   28.3005   4.64895  1.49896      7.34014e-06  4.92261e-06
2276.59   39.4495   4.85389  2.94056      5.98366e-0

In [12]:
headings = ['# iters','execution time','final cost','final norm of grad','terminal constr viol','cost reduction']
rows = []
for i in range(2):
    for j in range(2):
        row = [mols[i]] + [bases[j]] + [x.item() for x in meanags[i,j,:]]
        rows.append(row)

print(tabulate.tabulate(rows, headers=headings, tablefmt="latex"))
print("")

rows = []
for i in range(2):
    for j in range(2):
        row = [mols[i]] + [bases[j]] + [x.item() for x in meanahs[i,j,:]]
        rows.append(row)

print(tabulate.tabulate(rows, headers=headings, tablefmt="latex"))

\begin{tabular}{llrrrrrr}
\hline
      &        &   \# iters &   execution time &   final cost &   final norm of grad &   terminal constr viol &   cost reduction \\
\hline
 h2   & sto-3g &   5758.72 &          72.1495 &      12.6598 &              1.12363 &            1.59526e-05 &      1.34204e-05 \\
 h2   & 6-31g  &   5180.39 &          74.529  &      14.7247 &              2.87876 &            2.64478e-05 &      1.56582e-05 \\
 heh+ & sto-3g &   7638.19 &          90.3189 &      21.6517 &              2.94015 &            3.44921e-05 &      2.19726e-05 \\
 heh+ & 6-31g  &   6607.74 &          99.3308 &      16.4207 &              3.9032  &            1.9902e-05  &      1.69499e-05 \\
\hline
\end{tabular}

\begin{tabular}{llrrrrrr}
\hline
      &        &   \# iters &   execution time &   final cost &   final norm of grad &   terminal constr viol &   cost reduction \\
\hline
 h2   & sto-3g &   1442.37 &          21.582  &      13.005  &          6.73511e-05 &            1.58275e-05 &

In [13]:
ahs[0,0,:,2]

array([12.18013485, 14.7377194 , 11.55582636, 12.5691156 , 10.89457391,
       13.88929397, 13.45661222, 14.14043219, 14.95274848, 16.23125616,
       15.68175397, 14.95274848, 11.95794099, 13.43569023, 15.29882531,
       13.79400385, 10.89457391, 11.83004743, 12.18013485, 13.43569023,
       15.81793198, 16.67690717, 15.68175397, 12.92228623, 13.97506085,
       13.97506085, 13.44429556,  9.54635123, 12.92760575,  9.42990766,
       12.20218433, 13.09265863,  9.89669147, 15.49142772, 14.65313782,
       11.58335086, 14.90630857, 14.14043219, 15.28676382, 13.57434989,
       14.65313782, 10.89457391,  8.23160512, 11.95794099, 10.74915651,
       14.14043219,  9.54635123, 12.86439326, 13.39412181,  9.89850626,
       12.10674691, 13.39412181, 10.05464019, 12.40008771, 13.36273567,
       14.78362759, 11.18530969, 12.76703528, 12.79829341, 12.57471328,
       11.8638855 , 14.2096619 , 12.18013485, 11.95794099, 13.72730077,
       11.95794099, 13.39412181, 13.45661222, 12.42106042, 15.11