<a href="https://colab.research.google.com/github/kkk020719/DFT_Fall24/blob/main/GradDFT_TrainEx1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/XanaduAI/GradDFT.git
!pip install jax==0.4.20 jaxlib==0.4.20
!pip list | grep jax

# Training using GradDFT and direct optimization

In this notebook, we will run through the different strategies that can be used to train a Neural functional for the exchange-correlation functional in Grad DFT. Using molecules such as ethanol as training molecules, we will study how a neural functional can generalize to calculating the total energy.



In [None]:
import jax
from jax import numpy as jnp
from jax import grad, jit, vmap
from jax import random
import math
import pdb

from grad_dft import (
	energy_predictor,
	simple_energy_loss,
	NeuralFunctional,
	molecule_from_pyscf,
  Functional
)

from pyscf import gto, dft

from jax.nn import sigmoid, gelu
from jax.random import PRNGKey
from jax import value_and_grad
from flax import linen as nn
from optax import adam, apply_updates
from tqdm import tqdm
from jax.flatten_util import ravel_pytree
import matplotlib.pyplot as plt

from grad_dft import constraints

from sklearn.model_selection import train_test_split

## Generating training/validation data and initial guesses



In [None]:
#all_molecules = []
train_molecules = []
test_molecules = []
ground_energies = []

def generate_ethanol_variants(n_molecules):
    # Base coordinates for ethanol
    seed = 40
    molecules = []
    base_coords = [
        "C 0.0000 0.0000 0.0000", "C 0.0000 0.0000 1.5400",
        "O 0.0000 0.0000 -1.4300", "H 0.0000 0.9300 -1.9300",
        "H 0.0000 -0.9300 -1.9300", "H 0.0000 0.9300 2.0700",
        "H 0.0000 -0.9300 2.0700", "H 0.9200 0.0000 1.9300",
        "H -0.9200 0.0000 1.9300"
    ]

    return molecules

# Generate molecules
n_total = 5  # Total number of molecules
ethanol_molecules = generate_ethanol_variants(n_total)

In [None]:
# Splitting the data into training and testing sets
train_molecules, test_molecules = train_test_split(ethanol_molecules, test_size=0.20, random_state=42)
train_ground_energies, test_ground_energies = train_test_split(ground_energies, test_size=0.20, random_state=42)

# Print the sizes of the resulting sets
print(f"Training set size: {len(train_molecules)}")
print(f"Testing set size: {len(test_molecules)}")

# Train

In [None]:
def energy_densities(molecule):
    rho = molecule.density()
    lda_e = -3/2 * (3/(4*jnp.pi))**(1/3) * (rho**(4/3)).sum(axis = 1, keepdims = True) #
    #lda_e = jnp.concatenate([lda_e, lda_e], axis=0)
    return lda_e

def coefficient_inputs(molecule):
    rho = molecule.density()
    kinetic = molecule.kinetic_density()
    return jnp.concatenate((rho, kinetic), axis=1)

def coefficients(self, rhoinputs):
    x = nn.Dense(features=2)(rhoinputs) # features = 1 means it outputs a single weight
    #print(x)
    #pdb.set_trace()
    x = nn.LayerNorm()(x)
    #print("coe")
    #print(x)
    #print(gelu(x))
    return gelu(x) # using gelu as activation function

neuralfunctional = NeuralFunctional(coefficients, energy_densities, coefficient_inputs)

In [None]:
seed = 40
key = jax.random.PRNGKey(seed)
params_train = []
cinputs = [coefficient_inputs(molecule) for molecule in train_molecules]
for i in cinputs:
  params_train.append(neuralfunctional.init(key, i))
#predicted_energy = neuralfunctional.energy(params, molecule)

# Training

In [None]:
# Define optimizer
n_epochs, learning_rate, momentum = 200, 1e-2, 0.9
optimizer = adam(learning_rate=learning_rate, b1=momentum)
opt_state = optimizer.init(params)
compute_energy = energy_predictor(neuralfunctional)
trueenergy = mf.energy_tot()
converging_energies = []

@value_and_grad
def SEloss(params, compute_energy, molecule, trueenergy):
    predictedenergy, fock_matrix = compute_energy(params, molecule) #returns both the xc and fock matrix
    converging_energies.append(predictedenergy)
    return (predictedenergy - trueenergy) ** 2 #sqaured error used

# Testing
