In [None]:
# Structure Visualiser
from pathlib import Path
from ase.visualize import view
from ase.io import read
from IPython.display import display

structure_dir = "Carbon_Structures/Crystalline/Relaxed 64test/Final Trajectory Frame"

def view_relaxed_structures(dir, struct_num):

    dir = Path(dir)

    files = []
    for file in dir.rglob('*'):
        
        if file.suffix != ".cif":
            print(f"Unrecognized file suffix: {file.suffix}")
            continue
        files.append(file)
    chosen_file = files[struct_num]
    atoms = read(chosen_file)
    print(chosen_file.name)
    display(view(atoms, viewer='x3d'))

view_relaxed_structures(structure_dir,0)

In [None]:
# ------ GRAPHICAL ANALYSIS --------

# Graphical data points are means of all repeat runs with errors given as 1 standard deviation
import pandas as pd
import json
from pathlib import Path

# ------ FIGURE FORMATTING ------
import matplotlib.pyplot as plt
plt.style.use('crystalline.mplstyle')
from cycler import cycler
colour_cycle = cycler('color', [ '#000000', '#D55E00', '#3399FF', '#009E73', '#777777']) 
marker_cycle = cycler('marker', ['o','o','s','x','^'])
style_cycle = cycler('markersize', [6,3,3,3,3]) + cycler('markerfacecolor', ['none',  '#D55E00', '#3399FF', '#009E73', '#777777'])
total_cycle = colour_cycle + marker_cycle + style_cycle
# -------------------------------

# Searches recursively through the specified directory
# Returns a dataframe from json files for the specified potentials
def import_data_files(directory, potentials):
    
    directory = Path(directory)

    imported_data_files_counter = 0

    rows=[]

    # Exclude step info
    exclude = {"steps_to_relax", "step_limit"}

    for path in directory.rglob("*"):
    
        if not path.is_file(): # Filters for files not directories
            continue
        
        if path.suffix != ".json":
            print(f"Unrecognized file formate: {path.name} .Must be .json")
            continue
        
        # Load data from each file
        with open(path, "r") as f:
            data = json.load(f)

        # Check that data files aren't empty
        if not data:
            print(f"No data found in {path}")
            continue
        
        # Skip isolated atom
        if data.get("structure") == "isolated_C.cif":
            continue
        
        #Skip non-specified potentials
        if data.get("potential") not in potentials:
            continue

        # Remove relaxation step info
        numerical_data = {k: v for k, v in data.items() if k not in exclude}
        rows.append(numerical_data)

        imported_data_files_counter += 1
    
    df = pd.DataFrame(rows)
    
    missing_rows = df[df.isna().all(axis=1)]

    if not missing_rows.empty:
        print(f"WARNING: Empty cells in dataframe\n{missing_rows}")
        
    print(f"Imported {imported_data_files_counter} files")   
    return df

df = import_data_files("Analysis/Crystalline Analysis/Raw Data", 
                       ("Carbon_GAP_20.xml",
                        "medium-0b3.pt",
                        "medium-mpa-0.pt",
                        "medium-omat-0.pt"
                        ))

def scatter_plot(df, df_column_to_plot, y_label, chart_title, save_dir, save_name, set_label):

    potentials = sorted(df["potential"].unique())
    potential_label_map = {
        "Carbon_GAP_20.xml"  : "GAP-20",
        "medium-0b3.pt"      : "MACE-0b3",
        "medium-mpa-0.pt"    : "MACE-mpa-0",
        "medium-omat-0.pt"   : "MACE-omat-0",
        "carbon.xml"         : "GAP-17"
    }
    pot_string = "_".join(potential_label_map.get(p, p) for p in potentials)

    # Structures (enumerate for plotting)
    structures = df["structure"].unique().tolist()
    x_positions = {s: i for i, s in enumerate(structures)}

    # x labels
    x_label_map = {
        "Nanotube_9_0.cif"           : "Nanotube-(9,0)",
        "Nanotube_9_9.cif"           : "Nanotube-(9,9)",
        "C60.cif"                    : "C60",
        "C100.cif"                   : "C100",
        "Diamond.cif"                : "Diamond",
        "Hexagonal_Diamond.cif"      : "Hexagonal Diamond",
        "Graphite.cif"               : "Graphite",
        "Graphene.cif"               : "Graphene"
    }

    plt.figure()
    plt.rc('axes', prop_cycle=total_cycle)

    for potential in potentials:

        # Filter df for the given potential
        sub_df = df[df["potential"] == potential]
        
        # Extract structures and corresponding values
        x_values = [x_positions[s] for s in sub_df["structure"]]
        y_values = sub_df[df_column_to_plot]
        
        # Plot with potential lables displaying graphite and isolate atom energy
        mapped_potential = potential_label_map.get(potential, potential)

        # Extract graphite and isolate atom energy for reference
        iso_energy = sub_df["isolated_atom_energy"].iloc[0]
        graphite_energy = sub_df["graphite_energy/atom"].iloc[0]

        # Default label is potentials, else include reference energies
        if set_label == "default":
            label = mapped_potential
        if set_label == "atom":
            label = (
                f"{mapped_potential}: "
                f"$E_{{(atom)}}$ = {iso_energy:+.2f} eV"
            )
        if set_label == "graphite":
            label = (
                f"{mapped_potential}: "
                f"$E_{{(graphite)}}$ = {graphite_energy:+.2f} eV"
            )

        plt.plot(x_values, y_values, linestyle='', label=label)

        
    plt.xticks(
        list(x_positions.values()),
        [x_label_map.get(s, s) for s in x_positions.keys()],
        rotation=90
        )
    
    plt.ylabel(y_label)
    plt.title(chart_title)
    if set_label == "default":
        plt.legend()
    else:
        plt.legend(
        loc='lower center',
        bbox_to_anchor=(0.5, 1.02),
        ncol=1,     # increase if legend becomes tall
        frameon=True
    )
        
    # Save paths
    save_dir = Path(save_dir)
    graph_dir = save_dir / "Graphs"
    pot_comparison_dir = graph_dir / pot_string
    png_dir = pot_comparison_dir / "png Graphs"
    pdf_dir = pot_comparison_dir / "pdf Graphs"
    png_dir.mkdir(parents=True, exist_ok=True)
    pdf_dir.mkdir(parents=True, exist_ok=True)
    png_path = png_dir/ f"{save_name}.png"
    pdf_path = pdf_dir / f"{save_name}.pdf"
    
    plt.savefig(pdf_path)
    plt.savefig(png_path)
    
    plt.close('all')
    print(f"Created {save_name} plot")

def all_analysis_wrapper():
    scatter_plot(df, "atomisation_energy/atom", "Atomisation Energy [eV]", "", 
                 "Analysis/Crystalline Analysis", "atomisation_energy","atom")
    
    scatter_plot(df, "formation_energy/atom", "Formation Energy [eV]", "", 
                 "Analysis/Crystalline Analysis", "formation_energy","graphite")
    
    scatter_plot(df, "average_bond_length", " Mean Bond Length (Å)", "", 
                 "Analysis/Crystalline Analysis", "bond_length", "default")
    
    scatter_plot(df, "average_bond_angle", "Mean Bond Angle (°)", "", 
                 "Analysis/Crystalline Analysis", "bond_angle", "default")
all_analysis_wrapper()