<a href="https://colab.research.google.com/github/kkk020719/DFT_Fall24/blob/main/GradDFT_TrainEx1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/XanaduAI/GradDFT.git
!pip install torch_geometric
!pip install jax==0.4.20 jaxlib==0.4.20
!pip list | grep jax

In [None]:
#import os

# Mount Google Drive
#from google.colab import drive
#drive.mount('/content/drive')

# Create a folder in the root directory
#!mkdir -p "/content/drive/My Drive/Ethanol_data_MD17"

In [None]:
from torch_geometric.datasets import md17
import jax
from jax import numpy as jnp
from jax import grad, jit, vmap
from jax import random
import math
import pdb
import numpy as np

import grad_dft as gd

from grad_dft.functional import canonicalize_inputs, dm21_coefficient_inputs, dm21_densities

from grad_dft import (
	energy_predictor,
	simple_energy_loss,
	NeuralFunctional,
	molecule_from_pyscf,
  Functional,
	constraints
)

from pyscf import gto, dft

from jax.nn import sigmoid, gelu
from jax.random import PRNGKey
from jax import value_and_grad
from flax import linen as nn
from optax import adam, apply_updates
from tqdm import tqdm
from jax.flatten_util import ravel_pytree
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

# Training using GradDFT and direct optimization

In this notebook example, we will run through the different strategies that can be used to train a Neural functional for the exchange-correlation functional in Grad DFT. Using molecules such as ethanol as training molecules, we will study how a neural functional can generalize to calculating the total energy. More generally, the workflow is:


1.   Generate molecule train/test dataset and acquire required properties.
2.   Use GradDFT library to train the neural network for the coefficients corresponding to energy densities (We use LDA for energy densities) which together forms the exchange-correlation functional under integration
3.  After training, we test the learned functional to try to predict the energy for the test conformations using SCF or direct optimization. For example, by directly minimizing
$$
min_{C} E(C)
$$
subject to the orthonormality constraint of the molecular orbitals
$$
C^TSC = 1.
$$


## Generate Training/Testing Data

The data were acuiqred from MD17 dataset. According to their documentation, the energies of molecules of MD17 dataset were calculated at the PBE/def2-SVP level of theory using very tight SCF convergence and very dense DFT integration grid. We provide functions below to acquire data and return corresponding PySCF objects for future use

In [None]:
ethanol_dataset = md17.MD17(root='.', name='ethanol')

In [5]:
def obtain_data(num_data, data_set, atomic_symbols):
  'Takes a number of molecules randomly from dataset provided and returning the PySCF molecule objects'
  molecules = []
  ground_energies = []
  #Randomly select a number of indices
  indices = np.random.choice(len(data_set), num_data, replace=False)
  selected_samples = [data_set[i] for i in indices]
  # Define atomic symbols for the corresponding atomic numbers in MD17 dataset for ethanol
  for sample in selected_samples:
    position = sample.pos
    atomic_number = sample.z

    # Start building the molecule description string
    atom_desc = ''
    for z, pos in zip(atomic_number, position):
      #print(z.item())
      #print(pos)
      if z.item() in atomic_symbols:
        symbol = atomic_symbols[z.item()]  # Get symbol using atomic number
        atom_desc += f'{symbol} {pos[0]:.4f} {pos[1]:.4f} {pos[2]:.4f}\n'
      else:
        raise ValueError(f"Unrecognized atomic number {z.item()} found, unable to proceed with molecule creation.")
    # Create a PySCF molecule object
    mol = gto.M(atom=atom_desc, basis='def2-svp', unit='Ang', spin = 0, verbose=0) #setting up according to the documation of original data set
    molecules.append(mol)
    ground_energies.append(sample.energy.item()/627.509)
  return molecules, ground_energies

#loading PySCF data from MD17
ethanol_atomic_symbols = {1: 'H', 6: 'C', 8: 'O'}
train_data, ground_energy_train  = obtain_data(8, ethanol_dataset, ethanol_atomic_symbols)
test_data, ground_energy_test = obtain_data(2, ethanol_dataset, ethanol_atomic_symbols)

[-154.88805489243978, -154.89022119603064, -154.89957116152917, -154.89427990275837, -154.88469338686775]


## Training Phase

In [None]:
def energy_densities(molecule):
    rho = molecule.density()
    lda_e = -3/2 * (3/(4*jnp.pi))**(1/3) * (rho**(4/3)).sum(axis = 1, keepdims = True) #
    return lda_e

def coefficient_inputs(molecule):
    rho = molecule.density()
    kinetic = molecule.kinetic_density()
    return jnp.concatenate((rho, kinetic), axis=1)

def coefficients(self, rhoinputs):
    x = nn.Dense(features=2)(rhoinputs) # features = 1 means it outputs a single weight
    #print(x)
    #pdb.set_trace()
    x = nn.LayerNorm()(x)
    #print("coe")
    #print(x)
    #print(gelu(x))
    return gelu(x) # using gelu as activation function

neuralfunctional = NeuralFunctional(coefficients, energy_densities, coefficient_inputs)

In [None]:
seed = 40
key = jax.random.PRNGKey(seed)
params_train = []
cinputs = [coefficient_inputs(molecule) for molecule in train_molecules]
#predicted_energy = neuralfunctional.energy(params, molecule)

In [None]:
# Define optimizer
n_epochs, learning_rate, momentum = 200, 1e-2, 0.9
optimizer = adam(learning_rate=learning_rate, b1=momentum)
opt_state = optimizer.init(params)
compute_energy = energy_predictor(neuralfunctional)
trueenergy = mf.energy_tot()
converging_energies = []

@value_and_grad
def SEloss(params, compute_energy, molecule, trueenergy):
    predictedenergy, fock_matrix = compute_energy(params, molecule) #returns both the xc and fock matrix
    converging_energies.append(predictedenergy)
    return (predictedenergy - trueenergy) ** 2 #sqaured error used

## Testing Phase
