In [149]:
import json
import numpy as np
from scipy.spatial import ConvexHull
from sklearn.linear_model import RidgeCV
import matplotlib.pyplot as plt




In [150]:
#necessary functions
def regroup_query_by_config_property(casm_query_json_data: list) -> dict:
    """Groups CASM query data by property instead of by configuration.

    Parameters
    ----------
    casm_query_json_data: list
        List of dictionaries read from casm query json file.

    Returns
    -------
    results: dict
        Dictionary of all data grouped by keys (not grouped by configuraton)

    Notes
    ------
    Casm query jsons are lists of dictionaries; each dictionary corresponds to a configuration.
    This function assumes that all dictionaries have the same keys.
    It sorts all properties by those keys instead of by configuration.
    Properties that are a single value or string are passed as a list of those properties.
    Properties that are arrays are passed as a list of lists (2D matrices) even if the
    property only has one value (a matrix of one column).
    """
    data = casm_query_json_data
    keys = data[0].keys()
    data_collect = []
    for i in range(len(keys)):
        data_collect.append([])

    for element_dict in data:
        for index, key in enumerate(keys):
            data_collect[index].append(element_dict[key])

    results = dict(zip(keys, data_collect))

    if "comp" in results.keys():
        # Enforce that composition is always rank 2.
        results["comp"] = np.array(results["comp"])
        if len(results["comp"].shape) > 2:
            results["comp"] = np.squeeze(results["comp"])
        if len(results["comp"].shape) == 1:
            results["comp"] = np.reshape(results["comp"], (-1, 1))
        results["comp"] = results["comp"].tolist()

    if "corr" in results.keys():
        # Remove redundant dimensions in correlation matrix.
        results["corr"] = np.squeeze(results["corr"]).tolist()
    return results

def calculate_slopes(x_coords: np.ndarray, y_coords: np.ndarray):
    """Calculates the slope for each line segment in a series of connected points.
    
    Parameters:
    -----------
    x_coords: np.ndarray
        Array of x coordinates.
    y_coords: np.ndarray
        Array of y coordinates.
    
    Returns:
    --------
    slopes: np.ndarray
        Array of slopes.
    """
    
    #sort x_coords and y_coords by x_coords
    x_coords, y_coords = zip(*sorted(zip(x_coords, y_coords)))
 
    slopes = np.zeros(len(x_coords) - 1)
    for i in range(len(x_coords) - 1):
        slopes[i] = (y_coords[i + 1] - y_coords[i]) / (x_coords[i + 1] - x_coords[i])
    return slopes

def full_hull(
    compositions: np.ndarray, energies: np.ndarray, qhull_options=None
) -> ConvexHull:
    """Returns the full convex hull of the points specified by appending `energies` to `compositions`.

    Parameters
    ----------
    compositions: np.ndarray of floats, shape (n_points, n_composition_axes)
        Compositions of points.
    energies: np.ndarray of floats, shape (n_points,)
        Energies of points.
    qhull_options: str
        Additional optionals that can be passed to Qhull. See details on the scipy.spatial.ConvexHull documentation. Default=None
    Returns
    -------
    ConvexHull
        Convex hull of points.
    """
    return ConvexHull(
        np.hstack((compositions, energies[:, np.newaxis])), qhull_options=qhull_options
    )

def lower_hull(
    convex_hull: ConvexHull, tolerance: float = 1e-14
):
    """Returns the vertices and simplices of the lower convex hull (with respect to the last coordinate) of `convex_hull`.

    Parameters
    ----------
    convex_hull : ConvexHull
        Complete convex hull object.
    tolerance : float, optional
        Tolerance for identifying lower hull simplices (default is 1e-14).

    Returns
    -------
    lower_hull_vertex_indices : np.ndarray of ints, shape (n_vertices,)
        Indices of points forming the vertices of the lower convex hull.
    lower_hull_simplex_indices : np.ndarray of ints, shape (n_simplices,)
        Indices of simplices (within `convex_hull.simplices`) forming the facets of the lower convex hull.
    """
    # Find lower hull simplices
    lower_hull_simplex_indices = (-convex_hull.equations[:, -2] > tolerance).nonzero()[
        0
    ]
    if lower_hull_simplex_indices.size == 0:
        raise RuntimeError("No lower hull simplices found.")

    # Gather lower hull vertices from simplices
    lower_hull_vertex_indices = np.unique(
        np.ravel(convex_hull.simplices[lower_hull_simplex_indices])
    )
    return lower_hull_vertex_indices, lower_hull_simplex_indices

def ground_state_accuracy_metric(
    composition_predicted, energy_predicted, true_ground_state_indices
) -> float:
    """Computes a scalar ground state accuracy metric. The metric varies between [0,1], where 1 is perfect accuracy. The metric is a fraction. 
        The denominator is the sum across the stable chemical potential windows (slopes) for each configuration predicted on the convex hull.
        The numerator is the sum across the stable chemical potential windows (slopes) for each configuration predicted on the convex hull, which are ALSO ground states in DFT data.

    Parameters
    ----------
    composition_predicted : np.ndarray
        nxm matrix of compositions, where n is the number of configurations and m is the number of composition axes.
    energy_predicted : np.ndarray
        nx1 matrix of predicted formation energies.
    true_ground_state_indices : np.ndarray
        nx1 matrix of true ground state indices.

    Returns
    -------
    float
        Ground state accuracy metric.
    """
    hull = full_hull(
        compositions=composition_predicted, energies=energy_predicted
    )
    vertices, _ = lower_hull(hull)
    
    slopes = calculate_slopes(
        composition_predicted[vertices], energy_predicted[vertices]
    )
    stable_chem_pot_windows = [
        slopes[i + 1] - slopes[i] for i in range(len(slopes) - 1)
    ]

    # End states will always be on the convex hull and have an infinite stable chemical potential window. Exclude these from the
    vertices = np.sort(vertices)[2:]

    vertex_indices_ordered_by_comp = np.argsort(np.ravel(composition_predicted[vertices]))

    numerator = 0
    for vertex_index in vertex_indices_ordered_by_comp:
        if vertices[vertex_index] in true_ground_state_indices:
            numerator += stable_chem_pot_windows[vertex_index]

    return numerator / np.sum(stable_chem_pot_windows)




In [151]:
#Load data
with open('ZrN_FCC_1.2.0_8_body_10-5-2022.json') as f:
    query = json.load(f)
    data = regroup_query_by_config_property(query)
corr = np.array(data['corr'])
comp = np.array(data['comp'])
formation_energy = np.array(data['formation_energy'])
name = np.array(data['name'])

In [152]:
#Find the "true" ground states as predicted by the DFT data
dft_hull = full_hull(comp, formation_energy)
dft_vertices, dft_simplices = lower_hull(dft_hull)
print(dft_vertices)

[  0   1  27  28 696 809 815 829]


In [153]:
#Run a ridgeCV fit, get optimal regularizer
coarse_fit_object = RidgeCV(alphas = np.logspace(-5, 0, 100), fit_intercept=False)
coarse_fit = coarse_fit_object.fit(corr, formation_energy)
print(coarse_fit.alpha_)
fine_fit_object = RidgeCV(alphas = np.linspace(0.0001, 0.01, 10000), fit_intercept=False)
fine_fit = fine_fit_object.fit(corr, formation_energy)
print(fine_fit.alpha_)
predicted_energies = fine_fit.predict(corr)

0.001047615752789665
0.000991089108910891


In [166]:
#Find the convex hull of the predicted energies
hull = full_hull(
    compositions=comp, energies=predicted_energies
)

#Find the vertex indices of the lower convex hull
vertices, _ = lower_hull(hull)

#Find the slopes of the lower hull simplices
slopes = calculate_slopes(
    comp[vertices], predicted_energies[vertices]
)

#Find the stable chemical potential windows for each point on the lower hull, excluding the end states
stable_chem_pot_windows = [
    slopes[i + 1] - slopes[i] for i in range(len(slopes) - 1)
]

#Compute the ground state accuracy metric
gsa = ground_state_accuracy_metric(composition_predicted=comp, energy_predicted=predicted_energies, true_ground_state_indices=dft_vertices)
print(gsa)



0.6961673783430177


In [168]:
print(dft_vertices)
print(np.setdiff1d(dft_vertices, vertices))

[  0   1  27  28 696 809 815 829]
[ 28 696 815 829]
[ 89 197 802]


In [None]:
#See above: The fit only predicted 4 of the 8 DFT ground states. This would usually lead to an accuracy metric of 1/2. However, including the stable chemical potential window weights some structures more than others. 
