In [None]:
# Standard Library Imports
import time
import colorsys
from itertools import combinations

# Scientific Computing
import numpy as np
import sympy as sp
from scipy.spatial import ConvexHull

# Plotting Libraries
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
import plotly.graph_objects as go

# PyCalphad (Thermodynamics Calculations & Plotting)
from pycalphad import Database, calculate, equilibrium, variables as v
from pycalphad.plot.utils import phase_legend
from pycalphad import ternplot

# Computational Geometry
from shapely.geometry import Polygon, MultiPolygon
from shapely.ops import unary_union

# Color Processing
from skimage.color import deltaE_ciede2000, rgb2lab


In [None]:
# Load database and choose the phases that will be plotted
db = Database(r'../TDDatabaseFiles_temp/alfe.tdb')

phases = list(db.phases.keys())
constituents = list(db.elements)
legend_handles, color_dict = phase_legend(phases)

print(phases)
print(constituents)

In [None]:
# Calculate all the enthalpy as a funciton of the entropy and composition
def format_enthalpy(entropy_result, enthalpy_result):
    X = entropy_result.X.sel(component='FE').values[0, 0, :, :].flatten()
    S = entropy_result.SM.values[0, 0, :, :].flatten()
    H = enthalpy_result.HM.values[0, 0, :, :].flatten()

    sort_idx = np.argsort(X)
    X_sorted = X[sort_idx]
    H_sorted = H[sort_idx]
    S_sorted = S[sort_idx]

    sort_idx = np.argsort(S)
    X_sorted = X_sorted[sort_idx]
    S_sorted = S_sorted[sort_idx]
    H_sorted = H_sorted[sort_idx]

    return X_sorted, S_sorted, H_sorted

enthalpy_phase_dict = dict()
for phase_name in phases:
    # Only computing 10 teperature points because the plotting struggles
    temp_points_count = 40
    entropy_result = calculate(db, constituents, phase_name, P=101325, T=np.linspace(300, 2000, temp_points_count), output = "SM")
    enthalpy_result = calculate(db, constituents, phase_name, P=101325, T=np.linspace(300, 2000, temp_points_count), output = "HM")

    X, S, H = format_enthalpy(entropy_result, enthalpy_result)
    enthalpy_phase_dict[phase_name] = (X, S, H)

In [None]:
def lower_convex_hull(points):
    '''
    Calculate the lower convex hull, assuming the last dimension represents energy.

    Parameters:
        points (array): Points in N-dimensional space, with the last dimension representing energy.

    Returns:
        lower_hull (array): Array of indices describing the points that form the lower convex hull.
    '''
    processing_points = points.copy()

    # Check if the projected points are collinear
    projected_points = processing_points[:, :-1]
    transformed_points = projected_points - projected_points[0]
    if np.linalg.matrix_rank(transformed_points) == 1:
        idx = np.argsort(np.linalg.norm(transformed_points, axis=1))
        bp = np.array([idx[0], idx[-1]])
        processing_points = processing_points[:, 1:]

    else:
        bp = ConvexHull(points).simplices.flatten()
    
    fake_points = processing_points[bp].copy()
    fake_points[:, -1] += 500000  # offset to create "upper" points
    processing_points = np.vstack((processing_points, fake_points))

    hull = ConvexHull(processing_points)
    simplices = hull.simplices

    mask = np.all(simplices < len(points), axis=1)
    lower_hull = simplices[mask]

    return lower_hull

In [None]:
# Keep only the equilibrium enthalpy points for fitting
eq_enthalpy_phase_dict = dict()
for phase in phases:
    print(phase)
    X, Y, Z = enthalpy_phase_dict[phase]

    # Get the points into the lower hull
    points = np.column_stack((X, Y, Z))
    simplices = lower_convex_hull(points)

    # Keep only the points that are in the lower hull
    points = points[np.unique(simplices.ravel())]
    eq_enthalpy_phase_dict[phase] = (points[:, 0], points[:, 1], points[:, 2])

In [None]:
def fit_poly_x2y2_sympy(x_data, y_data, z_data):
    """
    Fit a polynomial of the form:
      f(x, y) = a00 + a10*x + a01*y + a20*x^2 + a11*x*y + a02*y^2 +
                a21*x^2*y + a12*x*y^2 + a22*x^2*y^2
    to the given data using least squares, and return a Sympy expression.
    
    Parameters:
        data (np.ndarray): A (n x 3) array where each row is [x, y, z].
        
    Returns:
        expr (sympy.Expr): A Sympy expression representing the fitted polynomial.
    """
    
    # Construct the design matrix
    A = np.column_stack([
        np.ones_like(x_data),    # constant term: a00
        x_data,                  # a10 * x
        y_data,                  # a01 * y
        x_data**2,               # a20 * x^2
        x_data * y_data,         # a11 * x * y
        y_data**2,               # a02 * y^2
        x_data**2 * y_data,      # a21 * x^2 * y
        x_data * y_data**2,      # a12 * x * y^2
        x_data**2 * y_data**2    # a22 * x^2 * y^2
    ])
    
    # Solve the least squares problem to get the coefficients
    coeffs, residuals, rank, s = np.linalg.lstsq(A, z_data, rcond=None)
    
    # Create sympy symbols for x and y
    x, y = sp.symbols('x y')
    
    # Define the list of polynomial terms in the same order as in A:
    terms = [
        1,          # a00
        x,          # a10
        y,          # a01
        x**2,       # a20
        x*y,        # a11
        y**2,       # a02
        x**2*y,     # a21
        x*y**2,     # a12
        x**2*y**2   # a22
    ]
    
    # Build the polynomial expression by summing coeff * term for each term.
    expr = sum(sp.Float(coeff) * term for coeff, term in zip(coeffs, terms))
    
    return sp.simplify(expr), residuals

In [None]:
# These are the equilibrium enthalpy points of the B2_BCC phase along with the fitted polynomial
phase_name = 'B2_BCC'
X, Y, Z = eq_enthalpy_phase_dict[phase_name]
selected_color = color_dict[phase_name]

fig = go.Figure()

fig.add_trace(go.Scatter3d(
        x=X, y=Y, z=Z,
        mode='markers',
        name=phase_name,
        marker=dict(color=selected_color, size=1)
    ))


# Compute the fitted polynomial
energy_polynomial, res = fit_poly_x2y2_sympy(X, Y, Z)
display(energy_polynomial)
print("residual:", res)

x_mesh, y_mesh = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 100, 100))
z_mesh = sp.lambdify((sp.symbols('x'), sp.symbols('y')), energy_polynomial, 'numpy')(x_mesh, y_mesh)

fig.add_trace(go.Surface(
    x=x_mesh,
    y=y_mesh,
    z=z_mesh,
    opacity=0.5,
    showscale=False,
    colorscale='Viridis'
))

fig.update_layout(
    scene=dict(
        xaxis_title="X(FE)",
        yaxis_title="Entropy (J/mol)",
        zaxis_title="Enthalpy (J/mol)"
    ),
    title="Equilibrium Enthalpy Points"
)

fig.show()

In [None]:
# These are the 'good' fitting phases
fig = go.Figure()
phase_poly_dict = dict()
for phase_name in phases:
    # These are the phases that don't have a good fit
    if phase_name in ['AL2FE', 'AL13FE4', 'AL5FE2']:
        continue
    
    X, Y, Z = eq_enthalpy_phase_dict[phase_name]
    selected_color = color_dict[phase_name]

    # fig.add_trace(go.Scatter3d(
    #         x=X, y=Y, z=Z,
    #         mode='markers',
    #         name=phase_name,
    #         marker=dict(color=selected_color, size=1)
    #     ))

    # Compute the fitted polynomial
    energy_polynomial, res = fit_poly_x2y2_sympy(X, Y, Z)
    print(phase_name)
    display(energy_polynomial)
    print("residual:", res)
    phase_poly_dict[phase_name] = energy_polynomial

    x_mesh, y_mesh = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 100, 100))
    z_mesh = sp.lambdify((sp.symbols('x'), sp.symbols('y')), energy_polynomial, 'numpy')(x_mesh, y_mesh)

    fig.add_trace(go.Surface(
        x=x_mesh,
        y=y_mesh,
        z=z_mesh,
        showscale=False,
        colorscale=[[0, selected_color], [1, selected_color]]
    ))

fig.update_layout(
    scene=dict(
        xaxis_title="X(FE)",
        yaxis_title="Entropy (J/mol)",
        zaxis_title="Enthalpy (J/mol)"
    ),
    title="Equilibrium Enthalpy Polynomials"
)

fig.show()

In [None]:
# These are the polynomials after rounding
fig = go.Figure()
for phase_name in phases:
    # These are the phases that don't have a good fit
    if phase_name in ['AL2FE', 'AL13FE4', 'AL5FE2']:
        continue
    
    X, Y, Z = eq_enthalpy_phase_dict[phase_name]
    selected_color = color_dict[phase_name]

    # fig.add_trace(go.Scatter3d(
    #         x=X, y=Y, z=Z,
    #         mode='markers',
    #         name=phase_name,
    #         marker=dict(color=selected_color, size=1)
    #     ))

    # Compute the fitted polynomial
    energy_polynomial = phase_poly_dict[phase_name]
    # Rounding the polynomial
    energy_polynomial = energy_polynomial.replace(lambda term: term.is_Number, lambda term: int(round(term, 0)))

    x_mesh, y_mesh = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 100, 100))
    z_mesh = sp.lambdify((sp.symbols('x'), sp.symbols('y')), energy_polynomial, 'numpy')(x_mesh, y_mesh)

    fig.add_trace(go.Surface(
        x=x_mesh,
        y=y_mesh,
        z=z_mesh,
        showscale=False,
        colorscale=[[0, selected_color], [1, selected_color]]
    ))

fig.update_layout(
    scene=dict(
        xaxis_title="X(FE)",
        yaxis_title="Entropy (J/mol)",
        zaxis_title="Enthalpy (J/mol)"
    ),
    title="Equilibrium Enthalpy Polynomials (Rounded)"
)

fig.show()

In [None]:
# # These are the phases that don't have a good fit
# fig = go.Figure()
# for phase_name in ['AL2FE', 'AL13FE4', 'AL5FE2']:
#     X, Y, Z = eq_enthalpy_phase_dict[phase_name]
#     selected_color = color_dict[phase_name]

#     fig.add_trace(go.Scatter3d(
#             x=X, y=Y, z=Z,
#             mode='markers',
#             name=phase_name,
#             marker=dict(color=selected_color, size=1)
#         ))


#     # Compute the fitted polynomial
#     energy_polynomial, res = fit_poly_x2y2_sympy(X, Y, Z)
#     display(energy_polynomial)
#     print("residual:", res)

#     x_mesh, y_mesh = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 100, 100))
#     z_mesh = sp.lambdify((sp.symbols('x'), sp.symbols('y')), energy_polynomial, 'numpy')(x_mesh, y_mesh)

#     fig.add_trace(go.Surface(
#         x=x_mesh,
#         y=y_mesh,
#         z=z_mesh,
#         showscale=False,
#         colorscale=[[0, selected_color], [1, selected_color]]
#     ))

# fig.update_layout(
#     scene=dict(
#         xaxis_title="X(FE)",
#         yaxis_title="Entropy (J/mol)",
#         zaxis_title="Enthalpy (J/mol)"
#     ),
#     title="Equilibrium Enthalpy Points"
# )

# fig.show()

In [None]:
# These functions are for looking at the phase diagram
def hex_to_rgb(hex_color):
    """Convert a hex color string to an RGB tuple (0-255)."""
    hex_color = hex_color.lstrip("#")
    return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))

def rgb_to_hex(rgb_color):
    """Convert an RGB tuple (0-255) to a hex color string."""
    return "#{:02X}{:02X}{:02X}".format(*rgb_color)

def generate_distinct_color(existing_colors, num_candidates=1000):
    """
    Generate a color that is most distinct from a given set of colors.
    
    Parameters:
    - existing_colors: List of colors in hex format (e.g., "#00538A").
    - num_candidates: Number of random color candidates to evaluate.
    
    Returns:
    - A distinct color in hex format.
    """
    
    # Convert existing colors to RGB and then to LAB space
    existing_rgb = [hex_to_rgb(color) for color in existing_colors]
    existing_colors_lab = np.array([rgb2lab(np.array(color) / 255.0) for color in existing_rgb])
    
    # Generate random candidate colors in RGB
    candidate_colors = np.random.randint(0, 256, size=(num_candidates, 3))
    
    # Convert candidates to LAB color space
    candidate_colors_lab = np.array([rgb2lab(c / 255.0) for c in candidate_colors])
    
    # Compute minimum CIEDE2000 distance from each candidate to existing colors
    max_distances = []
    for candidate_lab in candidate_colors_lab:
        distances = [deltaE_ciede2000(candidate_lab, existing_lab) for existing_lab in existing_colors_lab]
        max_distances.append(min(distances))  # Consider the closest match
    
    # Select the candidate with the highest minimum distance
    best_idx = np.argmax(max_distances)
    most_distinct_color = tuple(candidate_colors[best_idx])
    
    # Convert the result to hex format
    return rgb_to_hex(most_distinct_color)

def plot_phase_diagram(phase_dict, color_dict):
    """
    Plot the 2D projection of a convex hull phase diagram.

    Parameters
    ----------
    phase_points : dict
        Dictionary mapping phase names to numpy arrays of points in (X, T, G) space.
        Each array can be provided as either shape (n_points, 3) or (3, n_points) or even have extra dimensions;
        they will be squeezed and transposed if needed.
        
    color_dict : dict
        Dictionary mapping phase names to color strings.

    Returns
    -------
    None
        Displays a matplotlib plot.
    """
    
    # Combine all phase points into a single data array and record their phase labels
    all_points = []
    phase_labels = []

    for phase, points in phase_dict.items():
        x_points = points[0].flatten()
        y_points = points[1].flatten()
        z_points = points[2].flatten()

        points = np.vstack((x_points, y_points, z_points)).T
        all_points.append(points)
        phase_labels.extend([phase] * x_points.shape[0])

    # Create a single NumPy array of all points and an array of phase labels.
    all_points = np.vstack(all_points)
    phase_labels = np.array(phase_labels)
    
    # Compute the lower convex hull using your provided function.
    # It should return an iterable of simplices (each simplex is an array of indices into all_points).
    simplices = lower_convex_hull(all_points)
    
    # Dictionary to store the projected polygons (each simplex projected into the X-T plane) for each phase.
    phase_polygons = {}
    
    # Process each simplex from the convex hull.
    for simplex in simplices:
        # Get the 3D coordinates of the simplex vertices.
        simplex_points = all_points[simplex]
        # Project the simplex onto the X-T plane (i.e. drop the G component)
        projected_polygon = Polygon(simplex_points[:, :2])
        
        # Determine the phase label for this simplex by taking the majority label among its vertices.
        unique_phases = list(np.unique(phase_labels[simplex]))
        unique_phases.sort()
        phase = "-".join(unique_phases)
        
        if projected_polygon.is_valid:
            phase_polygons.setdefault(phase, []).append(projected_polygon)
    
    # Merge overlapping polygons for each phase.
    for phase in phase_polygons:
        phase_polygons[phase] = unary_union(phase_polygons[phase])
    
    # Plot the final merged 2D phase regions with Matplotlib.
    fig, ax = plt.subplots(figsize=(8, 6))
    
    for phase, polygon in phase_polygons.items():
        color = color_dict.get(phase, 'gray')
        
        # polygon can be a Polygon or a MultiPolygon.
        if polygon.geom_type == 'Polygon':
            x, y = polygon.exterior.xy
            ax.fill(x, y, alpha=0.5, fc=color, ec=None, label=phase)
            
        if polygon.geom_type == 'MultiPolygon':
            for subpoly in polygon.geoms:
                x, y = subpoly.exterior.xy
                ax.fill(x, y, alpha=0.5, fc=color, ec=None, label=phase)
    
    ax.set_xlabel("X (FE)")
    ax.set_ylabel("Entropy (J/mol)")
    ax.set_title("Phase Diagram")
    plt.show()

In [None]:
# Go through the existing color dict and add new colors for phase coexistences
keys = color_dict.keys()

combinations_2 = list(combinations(keys, 2))
combinations_3 = list(combinations(keys, 3))

phase_combinations = combinations_2 + combinations_3

for combin in phase_combinations:
    phase_coexistence = list(combin)
    phase_coexistence.sort()
    phase = "-".join(phase_coexistence)
    color_dict[phase] = generate_distinct_color([color_dict[phase] for phase in combin])

In [None]:
# Here Im going to plot the phase diagram using the points, and then after calculating the same points with the polynomials
# We are not going to include the phases that don't have a good fit
scrubbed_eq_enthalpy_phase_dict = eq_enthalpy_phase_dict.copy()
scrubbed_eq_enthalpy_phase_dict.pop('AL2FE')
scrubbed_eq_enthalpy_phase_dict.pop('AL13FE4')
scrubbed_eq_enthalpy_phase_dict.pop('AL5FE2')

plot_phase_diagram(scrubbed_eq_enthalpy_phase_dict, color_dict)

In [None]:
X_mesh, S_mesh = np.meshgrid(np.linspace(0, 1, 100), np.linspace(20, 100, 500))

phase_poly_energy_dict = dict()
for phase_name in scrubbed_eq_enthalpy_phase_dict:
    poly = phase_poly_dict[phase_name]
    H_mesh = sp.lambdify((sp.symbols('x'), sp.symbols('y')), poly, 'numpy')(X_mesh, S_mesh)
    phase_poly_energy_dict[phase_name] = (X_mesh.ravel(), S_mesh.ravel(), H_mesh.ravel())

plot_phase_diagram(phase_poly_energy_dict, color_dict)