In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

In [3]:
S, P = np.load("../data/dataset.npy")
molecules = np.load("../data/molecules.npy")

In [4]:
from SCFInitialGuess.utilities.dataset import Dataset

dim = 26

ind_cut = 150
index = np.arange(200)
np.random.shuffle(index)

S_test = np.array(S)[index[150:]]
P_test = np.array(P)[index[150:]]
molecules_test = [molecules[index[i]] for i in range(150, 200)]

S_train = np.array(S)[index[:150]]
P_train = np.array(P)[index[:150]]
molecules_train = [molecules[index[i]] for i in range(150)]

dataset = Dataset(np.array(S_train), np.array(P_train), split_test=0.0)

dataset.testing = (Dataset.normalize(S_test, mean=dataset.x_mean, std=dataset.x_std)[0], P_test)

[-] 2018-03-26 12:02:33: Data set normalized. Mean value std: 0.008754558889917387


In [7]:
from SCFInitialGuess.nn.networks import EluTrNNN
from SCFInitialGuess.nn.training import Trainer, RegularizedMSE

trainer = Trainer(
    EluTrNNN([dim**2, 600, 400, 200, 100, dim**2]),
    cost_function=RegularizedMSE(alpha=1e-7),
    optimizer=tf.train.AdamOptimizer(learning_rate=1e-3)
)

trainer.setup()
network_orig, sess_orig = trainer.train(
    dataset,
    convergence_threshold=1e-7
)
graph_orig = trainer.graph

[-] 2018-03-26 12:05:06: No target graph specified for Trainer setup. Creating new graph ...
[-] 2018-03-26 12:05:06: Setting up the training in the target graph ...
[-] 2018-03-26 12:05:06: network ...
[-] 2018-03-26 12:05:06: error function ...
[-] 2018-03-26 12:05:06: cost function ...
[-] 2018-03-26 12:05:06: training step
[-] 2018-03-26 12:05:07: Starting network training ...
[ ] 2018-03-26 12:05:07: Val. Cost: 5.776E-02. Error: 5.776E-02. Diff: 1.0E+10
[ ] 2018-03-26 12:05:09: Val. Cost: 2.540E-04. Error: 2.497E-04. Diff: 5.8E-02
[ ] 2018-03-26 12:05:12: Val. Cost: 1.026E-04. Error: 9.814E-05. Diff: 1.5E-04
[ ] 2018-03-26 12:05:14: Val. Cost: 4.950E-05. Error: 4.488E-05. Diff: 5.3E-05
[ ] 2018-03-26 12:05:16: Val. Cost: 3.163E-05. Error: 2.697E-05. Diff: 1.8E-05
[ ] 2018-03-26 12:05:19: Val. Cost: 1.324E-05. Error: 8.560E-06. Diff: 1.8E-05
[ ] 2018-03-26 12:05:21: Val. Cost: 1.068E-05. Error: 6.082E-06. Diff: 2.5E-06
[ ] 2018-03-26 12:05:24: Val. Cost: 1.040E-05. Error: 5.896E-06

$P'= \dfrac{3}{2} PSP - \dfrac{2}{2 \cdot 2}PSPSP $

In [8]:
def mc_wheeny_purification(p,s):
    p = p.reshape(dim, dim)
    s = s.reshape(dim, dim)
    return (3 * np.dot(np.dot(p, s), p) - np.dot(np.dot(np.dot(np.dot(p, s), p), s), p)) / 2

def multi_mc_wheeny(p, s, n_max=4):
    for i in range(n_max):
        p = mc_wheeny_purification(p, s)
    return p

def idemp_error(p, s):
    p = p.reshape(dim, dim)
    s = s.reshape(dim, dim)
    return np.mean(np.abs(np.dot(np.dot(p, s), p) - 2 * p))


In [9]:
#for (s, p) in zip(*dataset.testing):
for (s, p) in zip(S_test, P_test):
    
    #s_norm = s.reshape(1, dim**2)
    s_norm = dataset.input_transformation(s.reshape(1, dim**2))
    
    print("Orip:         {:0.3E}".format(idemp_error(p, s))) 
    print("Orig prurif:  {:0.3E}".format(idemp_error(mc_wheeny_purification(p, s), s)))
    
    with graph_orig.as_default():
        p_nn = network_orig.run(sess_orig, s_norm).reshape(dim, dim)
        
    print("NN:           {:0.3E}".format(idemp_error(p_nn, s)))
    print("NN pruified:  {:0.3E}".format(idemp_error(mc_wheeny_purification(p_nn, s), s)))
    p_nn_multi = multi_mc_wheeny(p_nn, s, n_max=5)
    print("NN multified: {:0.3E}".format(idemp_error(p_nn_multi, s))) 
    print("Value before: {:0.3E}".format(np.mean(np.abs(p.reshape(dim, dim) - p_nn))))
    print("Value multi :  {:0.3E}".format(np.mean(np.abs(p.reshape(dim, dim) - p_nn_multi))))
    print("Is nan: " + str(np.sum(np.isnan(p_nn_multi))))
    print("Is inf: " + str(np.sum(np.isinf(p_nn_multi))))
    print("Is fin: " + str(np.sum(np.isfinite(p_nn_multi))))
    print("--------------------")

Orip:         3.068E-16
Orig prurif:  8.853E-17
NN:           1.989E-03
NN pruified:  4.080E-05
NN multified: 6.369E-17
Value before: 1.357E-03
Value after:  8.945E-04
Is nan: 0
Is inf: 0
Is fin: 676
--------------------
Orip:         3.191E-16
Orig prurif:  8.043E-17
NN:           1.571E-03
NN pruified:  2.384E-05
NN multified: 6.539E-17
Value before: 1.205E-03
Value after:  9.513E-04
Is nan: 0
Is inf: 0
Is fin: 676
--------------------
Orip:         3.424E-16
Orig prurif:  8.940E-17
NN:           2.157E-03
NN pruified:  5.321E-05
NN multified: 6.865E-17
Value before: 1.415E-03
Value after:  9.595E-04
Is nan: 0
Is inf: 0
Is fin: 676
--------------------
Orip:         3.416E-16
Orig prurif:  9.039E-17
NN:           2.103E-03
NN pruified:  4.238E-05
NN multified: 7.526E-17
Value before: 1.521E-03
Value after:  1.059E-03
Is nan: 0
Is inf: 0
Is fin: 676
--------------------
Orip:         2.625E-16
Orig prurif:  8.785E-17
NN:           1.561E-03
NN pruified:  1.988E-05
NN multified: 6.593E

Value after:  9.890E-04
Is nan: 0
Is inf: 0
Is fin: 676
--------------------
Orip:         3.625E-16
Orig prurif:  9.193E-17
NN:           1.787E-03
NN pruified:  3.318E-05
NN multified: 7.202E-17
Value before: 1.212E-03
Value after:  8.647E-04
Is nan: 0
Is inf: 0
Is fin: 676
--------------------
Orip:         2.924E-16
Orig prurif:  8.689E-17
NN:           2.041E-03
NN pruified:  4.826E-05
NN multified: 7.906E-17
Value before: 1.363E-03
Value after:  8.723E-04
Is nan: 0
Is inf: 0
Is fin: 676
--------------------
Orip:         3.175E-16
Orig prurif:  8.878E-17
NN:           1.969E-03
NN pruified:  4.420E-05
NN multified: 7.622E-17
Value before: 1.287E-03
Value after:  8.444E-04
Is nan: 0
Is inf: 0
Is fin: 676
--------------------
Orip:         3.151E-16
Orig prurif:  7.743E-17
NN:           2.287E-03
NN pruified:  7.093E-05
NN multified: 6.673E-17
Value before: 1.711E-03
Value after:  1.233E-03
Is nan: 0
Is inf: 0
Is fin: 676
--------------------


In [10]:
from pyscf.scf import hf
from SCFInitialGuess.utilities.plotutils import prediction_scatter

dim = 26



iterations = []
for i, (molecule, p) in enumerate(zip(molecules_test, P_test)):
    
    mol = molecule.get_pyscf_molecule()
    
    
    print("Calculating: " + str(i + 1) + "/" + str(len(molecules_test)))
    

    s_raw = hf.get_ovlp(mol)
    s_norm = dataset.input_transformation(s_raw.reshape(1, dim**2))
    
    with graph_orig.as_default():
        P_orig = network_orig.run(sess_orig, s_norm).reshape(dim, dim).astype('float64')
        
        P_orig_sym = (P_orig + P_orig.T) / 2
        P_orig_idem = multi_mc_wheeny(P_orig, s_raw, n_max=5)
        
        # check errors
        print("Accuracy (MSE):")
        print(" -Orig: {:0.3E}".format(np.mean(np.abs(p.reshape(dim, dim) - P_orig)**2)))
        print(" -Sym:  {:0.3E}".format(np.mean(np.abs(p.reshape(dim, dim) - P_orig_sym)**2)))
        print(" -Idem: {:0.3E}".format(np.mean(np.abs(p.reshape(dim, dim) - P_orig_idem)**2)))
        
        print("Idempotency:")
        print(" -Orig: {:0.3E}".format(idemp_error(P_orig, s_raw)))
        print(" -Sym:  {:0.3E}".format(idemp_error(P_orig_sym, s_raw)))
        print(" -Idem: {:0.3E}".format(idemp_error(P_orig_idem, s_raw)))

    
    # P_actual wi noise
    p_noise = p.reshape(dim, dim) + np.random.randn(dim, dim) * 1e-4
    
    
    iterations_molecule = []

    for guess in [p_noise, P_orig, P_orig_sym, P_orig_idem]:
        mf = hf.RHF(mol)
        mf.diis = None
        mf.verbose = 1
        mf.kernel(dm0=guess)
        iterations_molecule.append(mf.iterations)
    
    iterations.append(iterations_molecule)

iterations = np.array(iterations)

Warn: Ipython shell catchs sys.args


Calculating: 1/50
Accuracy (MSE):
Orig: 3.325E-06
Sym:  2.881E-06
Idem: 1.421E-06
Idempotency:
Orig: 1.989E-03
Sym:  1.818E-03
Idem: 6.369E-17


Warn: Ipython shell catchs sys.args


Calculating: 2/50
Accuracy (MSE):
Orig: 2.904E-06
Sym:  2.595E-06
Idem: 1.839E-06
Idempotency:
Orig: 1.571E-03
Sym:  1.436E-03
Idem: 6.539E-17


Warn: Ipython shell catchs sys.args


Calculating: 3/50
Accuracy (MSE):
Orig: 4.289E-06
Sym:  3.878E-06
Idem: 1.755E-06
Idempotency:
Orig: 2.157E-03
Sym:  1.948E-03
Idem: 6.865E-17


Warn: Ipython shell catchs sys.args


Calculating: 4/50
Accuracy (MSE):
Orig: 3.933E-06
Sym:  3.437E-06
Idem: 2.023E-06
Idempotency:
Orig: 2.103E-03
Sym:  1.939E-03
Idem: 7.526E-17


Warn: Ipython shell catchs sys.args


Calculating: 5/50
Accuracy (MSE):
Orig: 2.145E-06
Sym:  1.805E-06
Idem: 1.131E-06
Idempotency:
Orig: 1.561E-03
Sym:  1.433E-03
Idem: 6.593E-17


Warn: Ipython shell catchs sys.args


Calculating: 6/50
Accuracy (MSE):
Orig: 3.111E-06
Sym:  2.815E-06
Idem: 1.651E-06
Idempotency:
Orig: 1.811E-03
Sym:  1.661E-03
Idem: 6.881E-17


Warn: Ipython shell catchs sys.args


Calculating: 7/50
Accuracy (MSE):
Orig: 3.494E-06
Sym:  3.085E-06
Idem: 2.234E-06
Idempotency:
Orig: 1.667E-03
Sym:  1.469E-03
Idem: 6.512E-17


Warn: Ipython shell catchs sys.args


Calculating: 8/50
Accuracy (MSE):
Orig: 3.764E-06
Sym:  3.324E-06
Idem: 2.269E-06
Idempotency:
Orig: 1.889E-03
Sym:  1.661E-03
Idem: 6.994E-17


Warn: Ipython shell catchs sys.args


Calculating: 9/50
Accuracy (MSE):
Orig: 2.311E-06
Sym:  1.934E-06
Idem: 1.143E-06
Idempotency:
Orig: 1.739E-03
Sym:  1.536E-03
Idem: 7.275E-17


Warn: Ipython shell catchs sys.args


Calculating: 10/50
Accuracy (MSE):
Orig: 1.896E-06
Sym:  1.573E-06
Idem: 8.936E-07
Idempotency:
Orig: 1.552E-03
Sym:  1.400E-03
Idem: 6.806E-17


Warn: Ipython shell catchs sys.args


Calculating: 11/50
Accuracy (MSE):
Orig: 2.861E-06
Sym:  2.468E-06
Idem: 1.458E-06
Idempotency:
Orig: 1.699E-03
Sym:  1.496E-03
Idem: 6.268E-17


Warn: Ipython shell catchs sys.args


Calculating: 12/50
Accuracy (MSE):
Orig: 4.563E-06
Sym:  3.888E-06
Idem: 2.015E-06
Idempotency:
Orig: 2.409E-03
Sym:  2.199E-03
Idem: 6.866E-17


Warn: Ipython shell catchs sys.args


Calculating: 13/50
Accuracy (MSE):
Orig: 4.690E-06
Sym:  4.217E-06
Idem: 2.195E-06
Idempotency:
Orig: 2.223E-03
Sym:  2.041E-03
Idem: 7.023E-17


Warn: Ipython shell catchs sys.args


Calculating: 14/50
Accuracy (MSE):
Orig: 3.474E-06
Sym:  3.088E-06
Idem: 2.157E-06
Idempotency:
Orig: 1.713E-03
Sym:  1.525E-03
Idem: 7.713E-17


Warn: Ipython shell catchs sys.args


Calculating: 15/50
Accuracy (MSE):
Orig: 3.706E-06
Sym:  3.322E-06
Idem: 1.998E-06
Idempotency:
Orig: 1.935E-03
Sym:  1.802E-03
Idem: 6.892E-17


Warn: Ipython shell catchs sys.args


Calculating: 16/50
Accuracy (MSE):
Orig: 3.400E-06
Sym:  2.950E-06
Idem: 1.603E-06
Idempotency:
Orig: 1.981E-03
Sym:  1.804E-03
Idem: 7.513E-17


Warn: Ipython shell catchs sys.args


Calculating: 17/50
Accuracy (MSE):
Orig: 3.884E-06
Sym:  3.419E-06
Idem: 1.993E-06
Idempotency:
Orig: 2.054E-03
Sym:  1.829E-03
Idem: 7.089E-17


Warn: Ipython shell catchs sys.args


Calculating: 18/50
Accuracy (MSE):
Orig: 4.182E-06
Sym:  3.544E-06
Idem: 2.309E-06
Idempotency:
Orig: 2.042E-03
Sym:  1.809E-03
Idem: 7.116E-17


Warn: Ipython shell catchs sys.args


Calculating: 19/50
Accuracy (MSE):
Orig: 4.006E-06
Sym:  3.515E-06
Idem: 2.246E-06
Idempotency:
Orig: 1.964E-03
Sym:  1.719E-03
Idem: 7.486E-17


Warn: Ipython shell catchs sys.args


Calculating: 20/50
Accuracy (MSE):
Orig: 2.861E-06
Sym:  2.553E-06
Idem: 1.694E-06
Idempotency:
Orig: 1.635E-03
Sym:  1.491E-03
Idem: 7.039E-17


Warn: Ipython shell catchs sys.args


Calculating: 21/50
Accuracy (MSE):
Orig: 3.068E-06
Sym:  2.765E-06
Idem: 1.611E-06
Idempotency:
Orig: 1.829E-03
Sym:  1.676E-03
Idem: 6.792E-17


Warn: Ipython shell catchs sys.args


Calculating: 22/50
Accuracy (MSE):
Orig: 3.966E-06
Sym:  3.494E-06
Idem: 1.978E-06
Idempotency:
Orig: 2.024E-03
Sym:  1.863E-03
Idem: 6.853E-17


Warn: Ipython shell catchs sys.args


Calculating: 23/50
Accuracy (MSE):
Orig: 4.959E-06
Sym:  4.600E-06
Idem: 2.522E-06
Idempotency:
Orig: 2.118E-03
Sym:  1.970E-03
Idem: 8.356E-17


Warn: Ipython shell catchs sys.args


Calculating: 24/50
Accuracy (MSE):
Orig: 2.816E-06
Sym:  2.489E-06
Idem: 1.754E-06
Idempotency:
Orig: 1.578E-03
Sym:  1.398E-03
Idem: 7.501E-17


Warn: Ipython shell catchs sys.args


Calculating: 25/50
Accuracy (MSE):
Orig: 3.648E-06
Sym:  3.268E-06
Idem: 2.056E-06
Idempotency:
Orig: 1.830E-03
Sym:  1.657E-03
Idem: 6.735E-17


Warn: Ipython shell catchs sys.args


Calculating: 26/50
Accuracy (MSE):
Orig: 5.545E-06
Sym:  5.145E-06
Idem: 3.547E-06
Idempotency:
Orig: 2.094E-03
Sym:  1.964E-03
Idem: 7.591E-17


Warn: Ipython shell catchs sys.args


Calculating: 27/50
Accuracy (MSE):
Orig: 4.294E-06
Sym:  3.878E-06
Idem: 1.515E-06
Idempotency:
Orig: 2.233E-03
Sym:  2.059E-03
Idem: 6.681E-17


Warn: Ipython shell catchs sys.args


Calculating: 28/50
Accuracy (MSE):
Orig: 3.161E-06
Sym:  2.860E-06
Idem: 1.418E-06
Idempotency:
Orig: 1.792E-03
Sym:  1.648E-03
Idem: 7.336E-17


Warn: Ipython shell catchs sys.args


Calculating: 29/50
Accuracy (MSE):
Orig: 3.364E-06
Sym:  3.021E-06
Idem: 1.989E-06
Idempotency:
Orig: 1.733E-03
Sym:  1.567E-03
Idem: 7.219E-17


Warn: Ipython shell catchs sys.args


Calculating: 30/50
Accuracy (MSE):
Orig: 3.207E-06
Sym:  2.871E-06
Idem: 2.173E-06
Idempotency:
Orig: 1.456E-03
Sym:  1.308E-03
Idem: 7.408E-17


Warn: Ipython shell catchs sys.args


Calculating: 31/50
Accuracy (MSE):
Orig: 3.644E-06
Sym:  3.260E-06
Idem: 1.780E-06
Idempotency:
Orig: 2.001E-03
Sym:  1.853E-03
Idem: 7.875E-17


Warn: Ipython shell catchs sys.args


Calculating: 32/50
Accuracy (MSE):
Orig: 3.208E-06
Sym:  2.712E-06
Idem: 1.892E-06
Idempotency:
Orig: 1.781E-03
Sym:  1.497E-03
Idem: 6.881E-17


Warn: Ipython shell catchs sys.args


Calculating: 33/50
Accuracy (MSE):
Orig: 3.986E-06
Sym:  3.480E-06
Idem: 1.901E-06
Idempotency:
Orig: 2.160E-03
Sym:  1.977E-03
Idem: 7.586E-17


Warn: Ipython shell catchs sys.args


Calculating: 34/50
Accuracy (MSE):
Orig: 3.231E-06
Sym:  2.853E-06
Idem: 1.394E-06
Idempotency:
Orig: 1.908E-03
Sym:  1.735E-03
Idem: 7.691E-17


Warn: Ipython shell catchs sys.args


Calculating: 35/50
Accuracy (MSE):
Orig: 4.609E-06
Sym:  4.147E-06
Idem: 2.324E-06
Idempotency:
Orig: 2.154E-03
Sym:  1.972E-03
Idem: 7.115E-17


Warn: Ipython shell catchs sys.args


Calculating: 36/50
Accuracy (MSE):
Orig: 1.226E-05
Sym:  1.159E-05
Idem: 7.696E-06
Idempotency:
Orig: 3.002E-03
Sym:  2.756E-03
Idem: 6.560E-17


Warn: Ipython shell catchs sys.args


Calculating: 37/50
Accuracy (MSE):
Orig: 4.740E-06
Sym:  4.346E-06
Idem: 1.551E-06
Idempotency:
Orig: 2.245E-03
Sym:  2.077E-03
Idem: 7.389E-17


Warn: Ipython shell catchs sys.args


Calculating: 38/50
Accuracy (MSE):
Orig: 2.447E-06
Sym:  2.070E-06
Idem: 1.342E-06
Idempotency:
Orig: 1.593E-03
Sym:  1.395E-03
Idem: 6.455E-17


Warn: Ipython shell catchs sys.args


Calculating: 39/50
Accuracy (MSE):
Orig: 2.644E-06
Sym:  2.335E-06
Idem: 1.577E-06
Idempotency:
Orig: 1.558E-03
Sym:  1.402E-03
Idem: 7.890E-17


Warn: Ipython shell catchs sys.args


Calculating: 40/50
Accuracy (MSE):
Orig: 2.211E-05
Sym:  2.139E-05
Idem: 1.544E-05
Idempotency:
Orig: 3.672E-03
Sym:  3.484E-03
Idem: 7.201E-17


Warn: Ipython shell catchs sys.args


Calculating: 41/50
Accuracy (MSE):
Orig: 4.461E-06
Sym:  4.085E-06
Idem: 2.523E-06
Idempotency:
Orig: 2.052E-03
Sym:  1.898E-03
Idem: 6.566E-17


Warn: Ipython shell catchs sys.args


Calculating: 42/50
Accuracy (MSE):
Orig: 4.493E-06
Sym:  4.083E-06
Idem: 2.079E-06
Idempotency:
Orig: 2.251E-03
Sym:  2.095E-03
Idem: 6.800E-17


Warn: Ipython shell catchs sys.args


Calculating: 43/50
Accuracy (MSE):
Orig: 5.723E-06
Sym:  5.226E-06
Idem: 3.003E-06
Idempotency:
Orig: 2.319E-03
Sym:  2.135E-03
Idem: 7.025E-17


Warn: Ipython shell catchs sys.args


Calculating: 44/50
Accuracy (MSE):
Orig: 5.516E-06
Sym:  4.819E-06
Idem: 2.455E-06
Idempotency:
Orig: 2.580E-03
Sym:  2.356E-03
Idem: 7.554E-17


Warn: Ipython shell catchs sys.args


Calculating: 45/50
Accuracy (MSE):
Orig: 7.035E-06
Sym:  6.601E-06
Idem: 3.707E-06
Idempotency:
Orig: 2.468E-03
Sym:  2.309E-03
Idem: 7.232E-17


Warn: Ipython shell catchs sys.args


Calculating: 46/50
Accuracy (MSE):
Orig: 3.257E-06
Sym:  2.912E-06
Idem: 2.183E-06
Idempotency:
Orig: 1.513E-03
Sym:  1.348E-03
Idem: 6.382E-17


Warn: Ipython shell catchs sys.args


Calculating: 47/50
Accuracy (MSE):
Orig: 2.569E-06
Sym:  2.177E-06
Idem: 1.302E-06
Idempotency:
Orig: 1.787E-03
Sym:  1.584E-03
Idem: 7.202E-17


Warn: Ipython shell catchs sys.args


Calculating: 48/50
Accuracy (MSE):
Orig: 3.379E-06
Sym:  2.918E-06
Idem: 1.332E-06
Idempotency:
Orig: 2.041E-03
Sym:  1.863E-03
Idem: 7.906E-17


Warn: Ipython shell catchs sys.args


Calculating: 49/50
Accuracy (MSE):
Orig: 3.126E-06
Sym:  2.660E-06
Idem: 1.240E-06
Idempotency:
Orig: 1.969E-03
Sym:  1.791E-03
Idem: 7.622E-17


Warn: Ipython shell catchs sys.args


Calculating: 50/50
Accuracy (MSE):
Orig: 6.043E-06
Sym:  5.652E-06
Idem: 3.129E-06
Idempotency:
Orig: 2.287E-03
Sym:  2.133E-03
Idem: 6.673E-17


In [11]:
print(np.mean(iterations,0))

[  5.62  12.06  12.04   9.12]
