In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
from pycalphad import Database, calculate, equilibrium, variables as v
from pycalphad.plot.utils import phase_legend
from pycalphad import ternplot
import plotly.graph_objects as go
import sympy as sp
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

In [None]:
# Load database and choose the phases that will be plotted
db = Database(r'../TDDatabaseFiles_temp/alfe.tdb')

phases = list(db.phases.keys())
constituents = list(db.elements)
print(phases)
print(constituents)

In [None]:
# Get the colors that map phase names to colors in the legend
legend_handles, color_dict = phase_legend(phases)

fig = plt.figure(figsize=(6,6))
ax = fig.gca()

conds = {
    v.T: 298.15,
    v.P: 101325
}

# Loop over phases, calculate the Gibbs energy, and scatter plot GM vs. X(RE)z`
for phase_name in ['B2_BCC']:
    result = calculate(db, constituents, phase_name, T=298.15, P=101325, output='GM')
    # print(result.X[0][0][0])
    try:
        print(result.X.sel(component='FE').shape)
        print(result.GM.shape)
        ax.scatter(result.X.sel(component='FE'), result.GM, marker='.', s=20, color=color_dict[phase_name])
    except KeyError:
        pass

# Format the plot
ax.set_xlabel('X(FE)')
ax.set_ylabel('GM')
ax.set_xlim((0, 1))
ax.legend(handles=legend_handles, loc='center left', bbox_to_anchor=(1, 0.6))

conds = {
    v.T: 298.15,
    v.P: 101325,
    v.X('FE'): np.linspace(0, 1, 100)
}

eq = equilibrium(db, ['FE', 'AL', 'VA'], 'B2_BCC', conds, output='GM')
eq.GM.values[0, 0, 0, :]
ax.scatter(eq.X.sel(component='FE').values[0, 0, 0, :, 0], eq.GM.values[0, 0, 0, :], marker='.', s=20, color='k')

plt.show()

In [None]:
conds = {
    v.T: 298.15,
    v.P: 101325,
    v.X('FE'): np.linspace(0, 1, 10)
}

# result = calculate(db, ['FE', 'AL', 'VA'], 'B2_BCC', conds, output='GM')
eq = equilibrium(db, ['FE', 'AL', 'VA'], 'B2_BCC', conds, output='GM')
eq.GM.values[0, 0, 0, :]

ax.scatter(eq.X.sel(component='FE').values[0, 0, 0, :, 0], eq.GM.values[0, 0, 0, :], marker='.', s=20, color=color_dict[phase_name])
plt.show()
# result_x = result.X.sel(component='FE').values[0, :]
# result_G = result.GM.values[0, :]

# eq_x = eq.X.sel(component='FE').values[0, 0, 0, :, 0]
# eq_G = eq.GM.values[0, 0, 0, :]

# plt.scatter(result_x, result_G, marker='.', s=20, color='blue', label='calculation')
# # plt.scatter(eq_x, eq_G, marker='.', s=20, color='red', label='equilibrium')
# plt.legend()
# plt.title('B2_BCC')
# plt.xlabel('X(FE)')
# plt.ylabel('GM')
# plt.ylim((-100, 100))
# plt.show()

In [None]:
result = calculate(db, ['FE', 'AL', 'VA'], 'B2_BCC', conds, output='GM')

In [None]:
result.GM.values

In [None]:
result.X.sel(component='FE').values

In [None]:
eq = equilibrium(db, ['FE', 'AL', 'VA'], 'B2_BCC', conds, output='GM')

In [None]:
eq.GM.values[0, 0, 0, :]

In [None]:
# This function is to format the result
def format_calc(result):
    X = np.squeeze(result.X.sel(component='FE')).values[0]
    Y = result.T.values
    X, Y = np.meshgrid(X, Y)
    Z = np.squeeze(result.GM.values)
    return X, Y, Z

In [None]:
# fig = go.Figure()

# # Loop over phases
# for phase_name in phases:
#     result = calculate(db, constituents, phase_name, P=101325, T=np.linspace(300, 2000, 100), output='GM')
#     selected_color=color_dict[phase_name]

#     X = np.squeeze(result.X.sel(component='FE')).values[0]
#     Y = result.T.values
#     X, Y = np.meshgrid(X, Y)
#     Z = np.squeeze(result.GM.values)

#     fig.add_trace(go.Surface(x=X, y=Y, z=Z, 
#                              colorscale=[[0, selected_color], [1, selected_color]],
#                              showscale=False))
    

# fig.update_layout(
#     scene=dict(
#         xaxis_title="X(FE)",
#         yaxis_title="Temperature (K)",
#         zaxis_title="Gibbs Energy (J/mol)"
#     ),
#     title="Gibbs Energy Surfaces"
# )

# fig.show('png')

In [None]:
# # Calculate the result for this phase
# result = calculate(db, constituents, 'HCP_A3', P=101325,
#                     T=np.linspace(300, 2000, 100), output='GM')
# selected_color = color_dict[phase_name]

# X, Y, Z = format_calc(result)

In [None]:
# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')

# # Plot the surface for this phase
# ax.plot_surface(X, Y, Z, color=selected_color,
#                 rstride=1, cstride=1, alpha=1, shade=False)

# # Set the axis labels.
# ax.set_xlabel("X(FE)")
# ax.set_ylabel("Temperature (K)")
# ax.set_zlabel("Gibbs Energy (J/mol)")

# # Adjust the view:
# # - elev=0 puts the camera in the XY plane.
# # - azim=0 positions the camera along the positive X axis,
# #   meaning the X axis points directly into the screen.
# ax.view_init(elev=0, azim=0)

# # Set the plot title and display the plot.
# plt.title("Gibbs Energy Surfaces")
# plt.show()

In [None]:
# # Create a new figure and add a 3D axis.
# fig = plt.figure()
# ax = fig.add_subplot(111, projection='3d')

# # Loop over phases
# for phase_name in phases:
#     # Calculate the result for this phase
#     result = calculate(db, constituents, phase_name, P=101325,
#                        T=np.linspace(300, 2000, 100), output='GM')
#     selected_color = color_dict[phase_name]

#     X, Y, Z = format_calc(result)

#     # Plot the surface for this phase
#     ax.plot_surface(X, Y, Z, color=selected_color,
#                     rstride=1, cstride=1, alpha=1, shade=False)

# # Set the axis labels.
# ax.set_xlabel("X(FE)")
# ax.set_ylabel("Temperature (K)")
# ax.set_zlabel("Gibbs Energy (J/mol)")

# # Adjust the view:
# # - elev=0 puts the camera in the XY plane.
# # - azim=0 positions the camera along the positive X axis,
# #   meaning the X axis points directly into the screen.
# ax.view_init(elev=0, azim=0)

# # Set the plot title and display the plot.
# plt.title("Gibbs Energy Surfaces")
# plt.show()

In [None]:
# # Get the colors that map phase names to colors in the legend
# legend_handles, color_dict = phase_legend(phases)

# fig = plt.figure(figsize=(6,6))
# ax = fig.gca()

# result = calculate(db, constituents, 'HCP_A3', P=101325, T=1200, output='GM')
# ax.scatter(result.X.sel(component='FE'), result.GM, marker='.', s=20, color=color_dict[phase_name])

# # Format the plot
# ax.set_xlabel('X(FE)')
# ax.set_ylabel('GM')
# ax.set_xlim((0, 1))
# ax.legend(handles=legend_handles, loc='center left', bbox_to_anchor=(1, 0.6))
# plt.show()

In [None]:
# choose a phase
result = calculate(db, constituents, 'HCP_A3',
                    P=101325,
                    T=np.linspace(300, 2000, 100),
                    output='GM')
selected_color=color_dict[phase_name]

X, Y, Z = format_calc(result)

In [None]:
# Flatten the arrays
X_flat = X.flatten()
Y_flat = Y.flatten()
Z_flat = Z.flatten()

# Create figure and 3D axis
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
ax.scatter(X_flat, Y_flat, Z_flat, c=Z_flat, cmap='viridis', marker='o', alpha=0.8)

# Labels and title
ax.set_xlabel("X")
ax.set_ylabel("T")
ax.set_zlabel("G")
ax.set_title("3D Scatter Plot")

plt.show()

In [None]:
# fig = go.Figure()

# fig.add_trace(go.Scatter3d(
#     x=X.flatten(),  # Flatten for Scatter3d
#     y=Y.flatten(),
#     z=Z.flatten(),
#     mode='markers',
#     marker=dict(
#         color=Z.flatten(),  # Color by Z values
#         colorscale="Viridis",
#         size=5
#     )
# ))

# fig.update_layout(
#     scene=dict(
#         xaxis_title="X(FE)",
#         yaxis_title="Temperature (K)",
#         zaxis_title="Gibbs Energy (J/mol)"
#     ),
#     title="Gibbs Energy Surfaces"
# )

# fig.show()

In [None]:
def fit_poly_x2y2_sympy(x_data, y_data, z_data):
    """
    Fit a polynomial of the form:
      f(x, y) = a00 + a10*x + a01*y + a20*x^2 + a11*x*y + a02*y^2 +
                a21*x^2*y + a12*x*y^2 + a22*x^2*y^2
    to the given data using least squares, and return a Sympy expression.
    
    Parameters:
        data (np.ndarray): A (n x 3) array where each row is [x, y, z].
        
    Returns:
        expr (sympy.Expr): A Sympy expression representing the fitted polynomial.
    """
    
    # Construct the design matrix
    A = np.column_stack([
        np.ones_like(x_data),    # constant term: a00
        x_data,                  # a10 * x
        y_data,                  # a01 * y
        x_data**2,               # a20 * x^2
        x_data * y_data,         # a11 * x * y
        y_data**2,               # a02 * y^2
        x_data**2 * y_data,      # a21 * x^2 * y
        x_data * y_data**2,      # a12 * x * y^2
        x_data**2 * y_data**2    # a22 * x^2 * y^2
    ])
    
    # Solve the least squares problem to get the coefficients
    coeffs, residuals, rank, s = np.linalg.lstsq(A, z_data, rcond=None)
    
    # Create sympy symbols for x and y
    x, y = sp.symbols('x y')
    
    # Define the list of polynomial terms in the same order as in A:
    terms = [
        1,          # a00
        x,          # a10
        y,          # a01
        x**2,       # a20
        x*y,        # a11
        y**2,       # a02
        x**2*y,     # a21
        x*y**2,     # a12
        x**2*y**2   # a22
    ]
    
    # Build the polynomial expression by summing coeff * term for each term.
    expr = sum(sp.Float(coeff) * term for coeff, term in zip(coeffs, terms))
    
    return sp.simplify(expr), residuals

In [None]:
# Plot the data and the function
poly, res = fit_poly_x2y2_sympy(X_flat, Y_flat, Z_flat)
poly_numeric = sp.lambdify((sp.symbols('x'), sp.symbols('y')), poly, modules='numpy')
print(res)

In [None]:
# Create figure and 3D axis
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
ax.scatter(X_flat, Y_flat, Z_flat, c=Z_flat, cmap='viridis', marker='o', alpha=0.8)
ax.scatter(X_flat, Y_flat, poly_numeric(X_flat, Y_flat), c='r', marker='.', alpha=0.8)

# Labels and title
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
ax.set_title("3D Scatter Plot of X, Y, Z")

plt.show()

In [None]:
# Create a dictionary to store the polynomial expressions
phase_energy_expression_dict = dict()

# First we need to go through all phases and fit polynomials to their energies
for phase in phases:
    result = calculate(db, constituents, phase,
                    P=101325,
                    T=np.linspace(300, 2000, 100),
                    output='GM')

    X, Y, Z = format_calc(result)
    X_flat = X.flatten()
    Y_flat = Y.flatten()
    Z_flat = Z.flatten()

    # Fit a polynomial to the data
    poly_energy_expr, res = fit_poly_x2y2_sympy(X_flat, Y_flat, Z_flat)
    poly_energy_expr = poly_energy_expr.replace(lambda term: term.is_Number, lambda term: int(round(term, 0)))
    phase_energy_expression_dict[phase] = poly_energy_expr
    print(f'finished fitting {phase}')
    print(res)

In [None]:
x_range=[0, 1]
y_range=[300, 2000]
x_points=100
y_points=100

# Create a meshgrid over the specified ranges.
x_vals = np.linspace(x_range[0], x_range[1], x_points)
y_vals = np.linspace(y_range[0], y_range[1], y_points)
X, Y = np.meshgrid(x_vals, y_vals)

# Prepare a Plotly figure.
fig = go.Figure()

# Define symbols for lambdify conversion.
x_sym, y_sym = sp.symbols('x y')

# Loop through each phase and its polynomial.
for phase, expr in phase_energy_expression_dict.items():
    if phase in ['AL2FE', 'AL5FE2', 'AL13FE4']:      # The stoiciometric compounds are not fitting right so removing them for now
        continue
    # Convert the sympy expression to a numpy-aware function.
    func = sp.lambdify((x_sym, y_sym), expr, modules="numpy")
    # Evaluate the function on the grid.
    Z = func(X, Y)
    
    # Add the surface trace for this phase.
    selected_color = color_dict[phase]
    fig.add_trace(
        go.Surface(
            x=X,
            y=Y,
            z=Z,
            name=phase,
            showscale=False,
            colorscale=[[0, selected_color], [1, selected_color]],
        )
    )

fig.update_layout(
    scene=dict(
        xaxis_title="X(FE)",
        yaxis_title="Temperature (K)",
        zaxis_title="Gibbs Energy (J/mol)",
        zaxis = dict(range=[-140000, 1000])
    ),
    title="Gibbs Energy Surfaces"
)

# Display the figure.
fig.show()