# Guess Core and use other methods to interpolate

In [None]:
import os, sys
import pickle
from scf_guess_tools import  Backend, load, calculate, guess
import numpy as np
import matplotlib.pyplot as plt
from BlockMatrix import BlockMatrix
from tensorflow.keras.callbacks import TensorBoard
import datetime


In [None]:
with open('../c5h4n2o2/models/test_data.pkl', 'rb') as f:
    test = pickle.load(f)
with open('../c5h4n2o2/models/train_data.pkl', 'rb') as f:
    train = pickle.load(f) # fock, overlap, file, density
train_X = np.array([overlap for _, overlap, _, _ in train])
train_y = np.array([fock for fock, _, _, _ in train])
test_X = np.array([overlap for _, overlap, _, _ in test])
test_y = np.array([fock for fock, _, _, _ in test])
train_files = np.array([file for _, _, file, _ in train])
test_files = np.array([file for _, _, file, _ in test])
train_densities = np.array([density for _, _, _, density in train])
test_densities = np.array([density for _, _, _, density in test])
mat_dim = train_X[0].shape[0]
mat_dim


In [None]:
mol = load(train_files[10], Backend.PY).native
dm = BlockMatrix(mol, train_X[10]) # can also override the matrix
dm.plot_blocks_by_type("homo", labels="atoms", figsize=(8, 7), title="Fock test", imshow_args={"cmap": "RdBu"})
dm.Matrix.trace()

## 1) First Trial
Just learn the Fock center Block-matrix for the general Case for all given elements! 

In [None]:
elements = ["C", "N", "O", "H"]
training_centers_X = {"C":[], "N":[], "O":[], "H":[]}
training_centers_y = {"C":[], "N":[], "O":[], "H":[]}
test_centers_X = {"C":[], "N":[], "O":[], "H":[]}
test_centers_y = {"C":[], "N":[], "O":[], "H":[]}

for fock, overlap, file, _ in train:
    mol = load(file, Backend.PY).native
    dmX = BlockMatrix(mol, overlap)
    dmY = BlockMatrix(mol, fock)
    for element in elements:
        for block in dmX.get_blocks_by_atom(element, block_type="center"):
            training_centers_X[element].append(block)
        for block in dmY.get_blocks_by_atom(element, block_type="center"):
            training_centers_y[element].append(block)

for fock, overlap, file, _ in test:
    mol = load(file, Backend.PY).native
    dmX = BlockMatrix(mol, overlap)
    dmY = BlockMatrix(mol, fock)
    for element in elements:
        for block in dmX.get_blocks_by_atom(element, block_type="center"):
            test_centers_X[element].append(block)
        for block in dmY.get_blocks_by_atom(element, block_type="center"):
            test_centers_y[element].append(block)


In [None]:
c_s = training_centers_X["C"]
for i in range(len(c_s)):
    diff = c_s[0] - c_s[i]
    if diff.sum() != 0: 
        print("Not equal")
# it seems that everything in sto-3g is actually the same
mol1 = load(train_files[0], Backend.PY).native
mol2 = load(train_files[1], Backend.PY).native
center1 = BlockMatrix(mol1)
center2 = BlockMatrix(mol2)
mean_c_ovlp1 = np.mean(center1.get_blocks_by_atom("C", block_type="center"), axis=0)
mean_c_ovlp2 = np.mean(center2.get_blocks_by_atom("C", block_type="center"), axis=0)
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(mean_c_ovlp1, cmap="RdBu", vmin=-0.5, vmax=0.5)
ax[0].set_title("Overlap")
ax[1].imshow(mean_c_ovlp2, cmap="RdBu", vmin=-0.5, vmax=0.5)
ax[1].set_title("Overlap2")
np.allclose(mean_c_ovlp1, mean_c_ovlp2)

In [None]:
# try different basis sets
ex_overlap, ex_fock, ex_density = [], [], []
mols = []
for i, file in enumerate(train_files[:5]):
    mol = load(file, Backend.PY)
    mols.append(mol)
    wf = calculate(mol, basis="6-31g", guess="minao", method="dft", functional="b3lypg")
    ex_overlap.append(wf.overlap())
    ex_fock.append(wf.fock())
    ex_density.append(wf.density())


In [None]:
ex1, ex2 = 0,3
center1 = BlockMatrix(mols[ex1].native, ex_fock[ex1].numpy)
center2 = BlockMatrix(mols[ex2].native, ex_fock[ex2].numpy)
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].imshow(center1.Matrix, cmap="RdBu", vmin=-0.5, vmax=0.5)
ax[0].set_title("Overlap")
ax[1].imshow(center2.Matrix, cmap="RdBu", vmin=-0.5, vmax=0.5)
ax[1].set_title("Overlap2")

# They are also the SAME!!! - makes sense because center block is basically coef. 
c_mean1 = np.mean(center1.get_blocks_by_atom("C", block_type="center"), axis=0)
c_mean2 = np.mean(center2.get_blocks_by_atom("C", block_type="center"), axis=0)
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
# ax[0].imshow(c_mean1, cmap="RdBu", vmin=-0.5, vmax=0.5)
# ax[0].set_title("Overlap mean")
# ax[1].imshow(c_mean2, cmap="RdBu", vmin=-0.5, vmax=0.5)
# ax[1].set_title("Overlap2 mean")
# np.allclose(c_mean1, c_mean2)
c1 = center1.get_blocks_by_atom("C", block_type="center")[0]
c2 = center1.get_blocks_by_atom("C", block_type="center")[2]
print(np.allclose(c1, c2))
ax[0].imshow(c1, cmap="RdBu", vmin=-0.5, vmax=0.5)
ax[1].imshow(c2, cmap="RdBu", vmin=-0.5, vmax=0.5)

### 1a) Only learn the diagonal!

In [None]:
train_overlap_X = np.array([overlap for _, overlap, _, _ in train])
test_overlap_X = np.array([overlap for _, overlap, _, _ in test])
train_fock_diag = np.array([np.diag(fock) for fock, _, _, _ in train])
test_fock_diag = np.array([np.diag(fock) for fock, _, _, _ in test])
test_fock_full = np.array([fock for fock, _, _, _ in test])
test_density_full = np.array([density for _, _, _, density in test])

In [None]:
import tensorflow as tf
from tensorflow import keras
import pickle, os, sys
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
sys.path.append("..")
from utils import plot_mat_comp, flatten_triang, unflatten_triang, flatten_triang_batch, unflatten_triang_batch, perform_calculation, density_from_fock, reverse_mat_permutation, benchmark_cycles

In [None]:
train_overlap_X_flat = flatten_triang_batch(train_overlap_X)
test_overlap_X_flat = flatten_triang_batch(test_overlap_X)

# scalerX = StandardScaler()
# scalerY = StandardScaler()

train_overlap_X_flat.shape

In [None]:
train_fock_diag.shape

In [None]:
tf.random.set_seed(42)
flattened_dim = train_overlap_X_flat.shape[1]
diag_dim = train_fock_diag.shape[1]
inputs = tf.keras.Input(shape=(flattened_dim,))
x = tf.keras.layers.Dense(512, activation='relu')(inputs)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Dense(256, activation='relu')(x)
x = tf.keras.layers.Dropout(0.3)(x)
outputs = tf.keras.layers.Dense(diag_dim)(x)

basic_model = tf.keras.Model(inputs=inputs, outputs=outputs)
basic_model.compile(optimizer='adam', loss="mae", metrics=["mse", 'mae'])

In [None]:
# Define a log directory for TensorBoard
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = TensorBoard(log_dir=log_dir, histogram_freq=1)

# Train the model with TensorBoard callback
basic_history = basic_model.fit(
    train_overlap_X_flat, 
    train_fock_diag, 
    epochs=250, 
    batch_size=32, 
    validation_split=0.2, 
    verbose=1, 
    callbacks=[tensorboard_callback]
)

# Plot the training and validation loss
plt.plot(basic_history.history['loss'], label='train')
plt.plot(basic_history.history['val_loss'], label='validation')
plt.yscale('log')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('mae')
plt.title('Model loss')

In [None]:
# save models
os.makedirs("models", exist_ok=True)
if os.path.exists("models/basic_model.keras") and input("Overwrite existing model? (y/n)") == "y":
    os.remove("models/basic_model.keras")
    basic_model.save("models/basic_model.keras")
    print("Model saved")
else:
    basic_model.save("models/basic_model.keras")
    print("Model saved")


In [None]:
# load model
basic_model = tf.keras.models.load_model("models/basic_model.keras")
print(basic_model.summary())

In [None]:
test_pred_fock_diag = basic_model.predict(test_overlap_X_flat)
test_pred_fock_diag.shape

In [None]:
def reconstruct_Fock(diag, ovlp, K = 1.75): 
    """Take diagonal and reconstruct the Fock matrix using GWH
    """
    mat_dim = diag.shape[0]
    out = np.zeros((mat_dim, mat_dim))
    for i in range(mat_dim):
        for j in range(mat_dim):
            if i == j:
                out[i, j] = diag[i]
            else:
                out[i, j] = K * ovlp[i, j] * (diag[i] + diag[j]) / 2
    return out

In [None]:
# example comparison: 
rand_test_sample = np.random.randint(0, len(test_pred_fock_diag))
pred_fock_example = reconstruct_Fock(test_pred_fock_diag[rand_test_sample], test_overlap_X[rand_test_sample])
ground_truth_fock_example = test_fock_full[rand_test_sample]
# compare with hückel and minao
test_mol = load(test_files[rand_test_sample], Backend.PY)
minao_guess = guess(test_mol, method="hf", basis="sto-3g", scheme="minao")
hueckel_guess = guess(test_mol, method="hf", basis="sto-3g", scheme="huckel")
plot_mat_comp(ground_truth_fock_example, pred_fock_example, title="Fock matrix prediction Basic NN")
plot_mat_comp(ground_truth_fock_example, minao_guess.fock().numpy, title="Fock matrix prediction MINAO")
plot_mat_comp(ground_truth_fock_example, hueckel_guess.fock().numpy, title="Fock matrix prediction Hückel")

Fock matrices look promising - let's check out the density matrices!

In [None]:
# Sanity check density reconstruction
ground_truth_density = 2*test_density_full[rand_test_sample]
#ground_truth_density_reconstucted = density_from_fock(test_fock_full[rand_test_sample], test_overlap_X[rand_test_sample], 32)
#plot_mat_comp(2*ground_truth_density, ground_truth_density_reconstucted, title="Sanity check density reconstruction (Ref: 2*ref density vs. density_from_fock(ref fock, ref overlap))")


pred_density_example = density_from_fock(pred_fock_example, test_overlap_X[rand_test_sample], 32)
plot_mat_comp(ground_truth_density, pred_density_example, title="Density matrix prediction Basic NN")
plot_mat_comp(ground_truth_density, minao_guess.density().numpy, title="Density matrix prediction MINAO")
plot_mat_comp(ground_truth_density, hueckel_guess.density().numpy, title="Density matrix prediction Hückel")

In [None]:
# sanity check guessing: 
rand_test_sample = 0
test_mol = load(test_files[rand_test_sample], Backend.PY)
wf = calculate(test_mol, basis="sto-3g", guess="minao", method="dft", functional="b3lypg")
print(wf.native.cycles)
# compare with perform_calculation: 
res = perform_calculation(test_files[rand_test_sample], density_guess=2*guess(mol, method="dft", basis="sto-3g", scheme="minao", functional="b3lypg").density().numpy, basis_set="sto-3g", method="dft", functional="b3lypg")
print(res)
# not exactly the same? idk why - let internal alg guess!

In [None]:
NN_density_pred = [density_from_fock(reconstruct_Fock(diag, ovlp), ovlp, 32) for diag, ovlp in zip(test_pred_fock_diag, test_overlap_X)]
test_mols = [load(file, Backend.PY) for file in test_files]
# Minao_guesses = [guess(mol, method="dft", basis="sto-3g", scheme="minao", functional="b3lypg").density().numpy for mol in test_mols]
# One_e_guesses = [guess(mol, method="dft", basis="sto-3g", scheme="1e", functional="b3lypg").density().numpy  for mol in test_mols]
# Atom_guesses = [guess(mol, method="dft", basis="sto-3g", scheme="atom", functional="b3lypg").density().numpy  for mol in test_mols]
# Huckel_guesses = [guess(mol, method="dft", basis="sto-3g", scheme="huckel", functional="b3lypg").density().numpy  for mol in test_mols]
# VSAP_guesses = [guess(mol, method="dft", basis="sto-3g", scheme="vsap", functional="b3lypg").density().numpy  for mol in test_mols]

In [None]:
from uncertainties import ufloat
import time

In [None]:
# inference speed of models: 
times_ = {}
for scheme in ["NN", "Minao", "1e", "Atom", "Huckel", "VSAP"]: 
    if scheme == "NN": 
        times_[scheme] = []
        for ovlp in test_overlap_X:
            start = time.time()
            diag = basic_model.predict(flatten_triang(ovlp).reshape(1, -1)).squeeze()
            density_from_fock(reconstruct_Fock(diag, ovlp), ovlp, 32)
            end = time.time()
            times_[scheme].append((end - start))
    else:
        times_[scheme] = []
        for mol in test_mols:
            start = time.time()
            guess(mol, method="dft", basis="sto-3g", scheme=scheme, functional="b3lypg").density().numpy
            end = time.time()
            times_[scheme].append((end - start))
for key, value in times_.items():
    times_[key] = ufloat(np.mean(value), np.std(value))
print("Inference times for different schemes: ")
for key, value in times_.items():
    print(f"{key}: {value:.2f} s")

In [None]:
benchmark_val = times_["Minao"]
for key, value in times_.items():
    print(f"{key}: {value:.4f} s - {value/benchmark_val:.3f}x comp. MINAO")

In [None]:
res = benchmark_cycles(test_files, density_guesses=[NN_density_pred, None, None, None, None, None], scheme_names=["NN", "Minao", "1e", "Atom", "Huckel", "VSAP"], basis_set="sto-3g", method="dft", functional="b3lypg", max_samples=50)

In [None]:
# save results
import os
os.makedirs("benchmark/basic_model", exist_ok=True)
with open("benchmark/basic_model/benchmark_results.pkl", "wb") as f:
    pickle.dump(res, f)

In [None]:
# load results
with open("benchmark/basic_model/benchmark_results.pkl", "rb") as f:
    res = pickle.load(f)

In [None]:
for key, val in res.items():
    cycles = np.array(val["cycles"])
    converged = val["converged"]
    cycles = cycles[converged]
    cycle_avg = ufloat(np.mean(cycles), np.std(cycles))
    print(f"{key}: {cycle_avg:.2f} cycles - converged: {len(cycles)}/{len(converged)} - {len(cycles)/len(converged):.2f}x converged")

### 1b) Try on other dataset!

In [None]:
import sys, os
from glob import glob
scripts_path = "../../scripts"
if scripts_path not in sys.path:
    sys.path.append(scripts_path)
from to_cache import density_fock_overlap


In [None]:
train_test_seed = 42
source_path = '../../datasets/QM9/xyz_c7h10o2/'
all_file_paths = glob(os.path.join(source_path, '*.xyz'))
len(all_file_paths)

In [None]:
def load_cached(file_paths, cache_path, basis, guess="minao", method="dft", functional="b3lypg", backend="pyscf"):
    error_list = []
    error_files = []
    focks = []
    used_files = []
    reference_densities = []
    for file in file_paths:
        mol_name = os.path.basename(file).strip()
        # print(mol_name)
        try: 
            ret = density_fock_overlap(filepath = file,
                                filename = mol_name,
                                method = method,
                                basis = None,
                                functional = functional,
                                guess = guess,
                                backend = backend,
                                cache = cache_path)
            # print(f"Using: file={file} - mol_name={mol_name} - basis={None} - guess={guess} - method={method} - functional={functional}")
        except Exception as e: 
            error_list.append(e)
            error_files.append(mol_name)
            print(f"File {mol_name} error - skipping")
            continue
        if any([r == None for r in ret]): 
            print(f"File {mol_name} bad - skipping")
            continue
        focks.append(ret[1].numpy)
        used_files.append(file)
        reference_densities.append(ret[0].numpy)
    print(f"Got data for: {len(focks)} - bad / no ret: {len(file_paths) - len(focks) - len(error_list)} - errors: {len(error_list)}")
    print(error_files[:5])
    return focks, reference_densities, used_files


In [None]:
ret = load_cached(all_file_paths, "../../datasets/QM9/out/c7h10o2_b3lypg_6-31G(2df,p)/pyscf", basis="6-31g_2df_p_custom_nwchem.gbs")