# About
This notebook is to analyse the results after the simulations are done to get the plots similar to the ones the paper.

# Requirements

As it is only for analysis, it requires the following packages:
- `Pandas`
- `Numpy`
- `Matplotlib`
- `Seaborn`
- `Ase`



In [None]:
# Contains all general imports for the notebook
# %matplotlib inline
# %matplotlib notebook

from pathlib import Path
import pickle

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.markers import MarkerStyle

# sns.set_theme(style="whitegrid")
sns.set_style("whitegrid", {'axes.grid' : False})

# set matplotlib font size
# plt.rcParams.update({'font.size': 24})
plt.rcParams['svg.fonttype'] = 'none'

# for dark background
# sns.set(style="ticks", context="talk")
# plt.style.use("dark_background")

In [None]:
# loading the local code
%load_ext autoreload
%autoreload 1
%aimport utils.make_plots


Expected base_dir structure
```
base_dir
    |___ DFT
        |___ energy_volume.csv
        |___ elastic_tensor.csv
        |___ gb_surfaces.csv
        |___ gb_energies.csv
        |___ mlip_dft_energy.csv
    |___ CHG2
    |___ CHG2-
    |___ mace-medium-mp-0b3  # mace parent dir
    |___mace-exp-27-multihead # mace fine-tunedmodel
    |___mace-exp-32-freeze    # mace fine-tuned model
```

<div class="alert alert-block alert-info">
<b>Info:</b> 
The directory names are passed as list of string values in the `experiments` variable. 
</div>

In [None]:
base_dir = Path("output")

# set the below variable to True to save the plots
SAVE_PLOTS = False
# SAVE_PLOTS = True
save_plot_dir = None

if SAVE_PLOTS:
    # if saving make the directory to save the plots
    save_plot_dir = Path("plots/")
    save_plot_dir.mkdir(parents=True, exist_ok=True)

### Performance on  Fe-dataset 

In [None]:
remove_for_fe= True

plot_info = { 'forces': {"dft_key": 'dft_force',
                          "mlip_key": 'mlip_force', 
                          'xlabel': "DFT predicted force (eV/$\AA$)",
                          'ylabel':  "MLIP predicted force (eV/$\AA$)",
                          'colorscheme': "Greens",
                          },
             'energy_per_atom': {"dft_key": 'dft_energy_per_atom',
                          "mlip_key": 'mlip_energy_per_atom' ,
                          'xlabel': "$E_{DFT}$ (eV/atom)",
                          'ylabel':  "$\hat E_{MLIP}$ (eV/atom)",
                          'colorscheme': "Blues",
                          }
}

experiments = [
                "CHG2",
                "MACE",
                "Sevenn",
                "CHG2-naive",
                "MACE-naive",
                "MACE-Replay",
                "Sevenn-naive"
                ]


plot = 'energy_per_atom'

In [None]:
from utils.make_plots import plot_fe_energies

remove_for_fe= True

plot_fe_energies(plot_info, experiments, remove_for_fe, plot, base_dir, file_name="fe_mlip_dft_energy.pkl", save_dir=save_plot_dir)

In [None]:

file_name = "mptrj_w_force_mlip_dft_energy.pkl" 
remove_outlier = True # to remove very high energies from plot only

plot_info = { 'forces': {"dft_key": 'dft_force',
                          "mlip_key": 'mlip_force', 
                          'xlabel': "$f_{DFT}$ (eV/$\AA$)",
                          'ylabel':  "$\hat f_{MLIP}$ (eV/$\AA$)",
                        #   'colorscheme': "Greens",
                          'colorscheme': "viridis",
                          },
             'energy': {"dft_key": 'dft_energy_per_atom',
                          "mlip_key": 'mlip_energy_per_atom' ,
                          'xlabel': "$E_{DFT}$ (eV/atom)",
                          'ylabel':  "$\hat E_{MLIP}$ (eV/atom)",
                          'colorscheme': "Blues",
                          }
}
experiments = [
                "CHG2",
                "MACE",
                "Sevenn",
                "CHG2-naive",
                "MACE-naive",
                "MACE-Replay",
                "Sevenn-naive",
                "MACE-freeze"
                ]

plot = 'energy'

In [None]:
from utils.make_plots import plot_mptrj_energies
plot_mptrj_energies(plot_info, experiments, remove_outlier, plot, base_dir, file_name="mptrj_w_force_mlip_dft_energy.pkl",
                     save_dir=save_plot_dir,
                    #  ncols=3, # if you want to set number of columns  usually it is sets it automatically 
                    # add_fig_numbering=False # to add a) b) c) etc. to the figures default is True
                     )

In [None]:

file_name = "mptrj_mlip_dft_energy_w_force.pkl" # newer predictions which also contains forces 
remove_outlier = True # to remove very high energies from plot only

plot = 'forces'
from utils.make_plots import plot_mptrj_forces

plot_mptrj_forces(plot_info, experiments,
                   remove_outlier,
                   plot,
                   base_dir,
                   file_name="mptrj_w_force_mlip_dft_energy.pkl",
                   save_dir=save_plot_dir,
                     )

# properties of Fe

In [None]:
# below contains the dir names for the models
experiments = [
                "DFT",
                "CHG2",
                "MACE",
                "CHG2-FT",
                "MACE-FT",
                ]

In [None]:
# vacancy formation energy

file_vac = 'lattice_vac_energy.csv'

vac_data = {}
for exp in experiments:
    data_path = base_dir / exp / file_vac
    data = pd.read_csv(data_path)
    
    vac_data[exp] = [data['vac_energy'].values[0]]
    print(f"{exp = }:",vac_data[exp][-1] )

In [None]:
from utils.make_plots import plot_fe_properties

plot_fe_properties(experiments, base_dir, save_plot_dir)

In [None]:
# for figure in appendix
from utils.make_plots import plot_gb_alone

plot_gb_alone(experiments, base_dir, gb_type='twist')

# Validation

In [None]:
from utils.make_plots import plot_impurities, get_1nn_MAE

In [None]:
# for subset plot elements that have to be removed
elements_to_remove = [
                      'Mn',
                      'Si', 
                      'Mo',
                      'Cr', 
                      'V', 
                      'Ti', 
                      'Zn',
                      'vac'
                      ]

# elements_to_remove = None # to plot all elements
 
to_include = None  # this variable is to selectively add elements to the plot

experiments = ['DFT', 
               'CHG2',
               "MACE",
                 "CHG2-FT",
                 "MACE-FT",
                 "Literature1",
                 "Literature2",
                 "Literature3"
                 ]
plot_impurities(experiments, save_dir=save_plot_dir,base_dir=base_dir,
                 elements_to_remove=elements_to_remove, img_col=6)

In [None]:
# get the MAE of 1nn for the models

get_1nn_MAE(experiments=experiments,
            base_dir=base_dir,
            elements_to_remove=elements_to_remove,
            to_include=to_include)


### Plot vacancies and solute interaction

In [None]:
elements_to_remove = [
                    #   'Al',
                      'Mn',
                      'Si', 
                      'Mo',
                      'Cr', 
                      'V', 
                      'Ti', 
                      # 'Nb', 
                      'Zn',
                      ]

to_include = [ 'vac' ]

plot_impurities(experiments, save_dir=save_plot_dir,base_dir=base_dir, elements_to_remove=elements_to_remove, to_include=to_include, img_col=3)

In [None]:

get_1nn_MAE(experiments=experiments,
            base_dir=base_dir,
            elements_to_remove=elements_to_remove,
            to_include=to_include)


# interstitial atoms

In [None]:
experiments = ['DFT', 
               'CHG2',
               "MACE",
                 "CHG2-FT",
                 "MACE-FT"]

def tet_oct_diff(exp, file_name = 'interstitials.csv'):
    df_int = pd.read_csv(base_dir / exp / file_name)
    print(f"---------------------------------------")
    print(f"{exp = }")
    for idx,row in df_int.iterrows():
        ele = row['imps']
        e_oct_tet = row['e_oct'] - row['e_tet']

        print(f"{ele = } {e_oct_tet = }")

for exp in experiments:
    if exp != 'DFT':
        tet_oct_diff(exp)

# RMSE plot
The below section is to plot the RMSE plots as function of epochs for different models.

For evaluation the script expects the dir prefix for each model to be given in `base_name`, the epoch number is then appended to the base name. 

For example:
```python
{
    "exp_name": "CHG2-lr-$10^{-4}$",
    "rmse_mprtj":[0.062],
    "rmse_fe": [0.2122],
    "base_name": "CHG2-experiment-3-ep"  # base name
}
```
The code will look for the following directories in the `base_dir`:
- CHG2-experiment-3-ep10
- CHG2-experiment-3-ep20
- CHG2-experiment-3-ep30
- CHG2-experiment-3-ep40
- CHG2-experiment-3-ep50

In [None]:
# lr = 0.01
chg2_lr2 = {
    "exp_name": "CHG2-lr-$10^{-2}$",
    "rmse_mprtj":[0.062],
    "rmse_fe": [0.2122],
    "base_name": "CHG2-experiment-1-ep"
}

# lr = 0.001
chg2_lr1 = {
    "exp_name": "CHG2-lr-$10^{-3}$",
    "rmse_mprtj":[0.062],
    "rmse_fe": [0.2122],
    "base_name": "CHG2-experiment-2-ep"
}

# lr = 0.0001
chg2 = {
    "exp_name": "CHG2-lr-$10^{-4}$",
    "rmse_mprtj":[0.062],
    "rmse_fe": [0.2122],
    "base_name": "CHG2-experiment-3-ep"
}

# lr = 0.01
sevenn_lr2 = {
    "exp_name": "Sevenn-lr-$10^{-2}$",
    "rmse_mprtj":[0.040],
    "rmse_fe": [0.187],
    "base_name": "sevenn-experiment-1-ep"
}

# lr = 0.001
sevenn_lr1 = {
    "exp_name": "Sevenn-lr-$10^{-3}$",
    "rmse_mprtj":[0.040],
    "rmse_fe": [0.187],
    "base_name": "sevenn-experiment-2-ep"
}

# lr = 0.0001
sevenn = {
    "exp_name": "Sevenn-lr-$10^{-4}$",
    "rmse_mprtj":[0.040],
    "rmse_fe": [0.187],
    "base_name": "sevenn-experiment-3-ep"
}

# lr = 0.01
mace_lr2 = {
    "exp_name": "MACE-lr-$10^{-2}$",
    "rmse_mprtj":[0.04],
    "rmse_fe": [0.172],
    "base_name": "mace-experiment-1-ep"
}


# lr = 0.001
mace_lr1 = {
    "exp_name": "MACE-lr-$10^{-3}$",
    "rmse_mprtj":[0.04],
    "rmse_fe": [0.172],
    "base_name": "mace-experiment-2-ep"
}

# lr = 0.0001
mace = {
    "exp_name": "MACE-lr-$10^{-4}$",
    "rmse_mprtj":[0.04],
    "rmse_fe": [0.172],
    "base_name": "mace-experiment-3-ep"
}


In [None]:
from utils.make_plots import get_rmse

def extract_rmses(filename: str, base_dir: Path, data: dict, key: str):
    """Extract RMSE values from a pickle file.
    """
    epochs = [10, 20, 30, 40, 50] # change the epochs as per need

    for epoch in epochs:
        file_path = base_dir / f"{data['base_name']}{epoch}/{filename}"
        if file_path.exists():
            with open(file_path, 'rb') as f:
                data = pickle.load(f)
            rmse_fe = get_rmse(data['dft_energy_per_atom'], data['mlip_energy_per_atom'])
            data[key].append(rmse_fe)


In [None]:
data_list = [chg2, chg2_lr1, chg2_lr2, mace,  mace_lr1, mace_lr2, sevenn, sevenn_lr1, sevenn_lr2]

# get the rmse values
for data in data_list:
    extract_rmses("mptrj_w_force_mlip_dft_energy.pkl", base_dir, data, key="rmse_mprtj")
    extract_rmses("fe_dft_energy.pkl", base_dir, data, key="rmse_fe")



In [None]:
rows = 1
cols = 1
fig, axes = plt.subplots(rows, cols, figsize=(10 * cols, 4 * rows))

# plt.rcParams.update({'font.size': 18})
def plot_scatter_with_dotted_lines(data_list: list[dict],
                                   plot_key: str,
                                   title: str = "RMSE Comparison",
                                   ax=None,
                                   ymin=None,
                                   ymax=None,
                                   fig_label=None,
                                   ylabel: str = "RMSE MPRTJ (ev/atom)"
                                   ):
    markers = ["o", "P","d","o", "P","d","o", "P","d",]

    colors = ["tab:blue","tab:blue","tab:blue",  "tab:green","tab:green","tab:green", "tab:orange", "tab:orange", "tab:orange"]
    linestyles = [ "-", "-","-","--","--","--",  "-.",  "-.",  "-."]
    
    for idx,data in enumerate(data_list):
        exp_name = data["exp_name"]
        plot_data = data[plot_key]

        epochs = [i for i in range(0,len(plot_data)*10, 10)]
        
        # Connect points with dotted lines
        ax.plot(epochs, plot_data, alpha=0.8, color=colors[idx], label=exp_name,
                marker=markers[idx], markersize=8, linewidth=2, linestyle=linestyles[idx],
                markerfacecolor='none', markeredgewidth=1)

        if ymin is None:
            ymin = min(plot_data) - 0.01
        if ymax is None:
            ymax = max(plot_data) + 0.01

    
    # Add labels and title
    ax.set_xlabel("Epochs", fontsize=14)
    ax.set_ylabel(ylabel, fontsize=14)
    ax.set_ylim(ymin, ymax)
    if fig_label is not None:
        ax.text(-0.1, 1.01, fig_label, transform=ax.transAxes,
                fontsize=16, fontweight='bold', va='bottom', ha='right')
    
    if plot_key == 'rmse_mprtj':
        plt.annotate("CHG2 and Sevenn lr$10^{-2}$", 
                xy=(10, 0.7),  # arrow target
                xytext=(10, 0.45),     # box position
                fontsize=8, color="tab:blue",
                arrowprops=dict(facecolor="black", shrink=0.05, width=1.5, headwidth=8),
                bbox=dict(boxstyle="round,pad=0.3", edgecolor="tab:blue", facecolor="white"))

    for _, spine in ax.spines.items():
        spine.set_visible(True)
        spine.set_color('black')
        spine.set_linewidth(1)


In [None]:
# the ymax and ymin were chosen manually to have better visualization
ymax = 0.7
ymin = 0.0

plot_scatter_with_dotted_lines(data_list,
                               plot_key='rmse_mprtj',
                               title="Accuracy of MLIPs",
                                 # ax=axes[0],
                               ylabel="RMSE MPRTJ (ev/atom)",
                               ax=axes,
                                   ymax=ymax, ymin=ymin
                                #    fig_label="a)"
                                   )

fig.tight_layout()
# handles, labels = axes[0].get_legend_handles_labels()
labels = [i['exp_name'] for i in data_list]
# adjust the right margin to make space for legend
fig.subplots_adjust(right=0.49)  # Adjust right margin to make space for
# put the legend at the center right of the figure
fig.legend(labels, bbox_to_anchor=(0.5, 0.5), loc='center left', borderaxespad=0., ncol=1, fontsize=12)
fig.savefig( save_plot_dir / "accuracy_mptrj.png", dpi=300, bbox_inches='tight')


In [None]:
plot_scatter_with_dotted_lines(data_list,plot_key='rmse_fe',
                                title="Accuracy of MLIPs",
                                ax=axes,
                                ymax=0.3,
                                ymin=ymin,
                                # fig_label="b)",
                                ylabel="RMSE Fe (ev/atom)")

fig.tight_layout()
# handles, labels = axes[0].get_legend_handles_labels()
labels = [i['exp_name'] for i in data_list]
# adjust the right margin to make space for legend
fig.subplots_adjust(right=0.49)  # Adjust right margin to make space for

# put the legend at the center right of the figure
fig.legend(labels, bbox_to_anchor=(0.5, 0.5), loc='center left', borderaxespad=0., ncol=1, fontsize=12)
fig.savefig( save_plot_dir / "accuracy_fe.png", dpi=300, bbox_inches='tight')

# Supplimentary
The below plot was used to compare the freezing of layers in CHG2 model. The values were extracted manually as the number of experiments were small. 

In [None]:

################################################################################
#
# below is comparison of freezing layers
################################################################################


def plot_scatter_(data_list: list[dict], title: str = "RMSE Comparison", ax=None):
    # Create figure and axis
    plt.figure(figsize=(8, 6))
    n_groups = len(data_list)
    cmap = plt.cm.get_cmap('tab10', n_groups)  # tab10 has 10 distinct colors
    unique_colors = [cmap(i) for i in range(n_groups)]
    markers = ["o", "s","x","+","v","P","X","*"]  # Add more markers if needed

    if ax is None:
        ax = plt.gca()


    # Extract data for calculating axis limits
    rmse_fe_values = [data["rmse_fe"][0] for data in data_list]
    rmse_mprtj_values = [data["rmse_mprtj"][0] for data in data_list]

    for idx, data in enumerate(data_list):
        exp_name = data["exp_name"]
        exp_info = data['exp_info']
        exp_label = data['exp_Label']
        rmse_mprtj = data["rmse_mprtj"][0]  # Extract single value from list
        rmse_fe = data["rmse_fe"][0]        # Extract single value from list

        # Plot scatter points
        ax.scatter(rmse_fe, rmse_mprtj, alpha=0.9, s=180, c='c', marker=markers[idx], label=exp_label)

        # Add model name with arrow pointing to the point
        if exp_info != "":
            ax.annotate(exp_info,
                        xy=(rmse_fe, rmse_mprtj),  # Point to annotate
                        xytext=(rmse_fe + 0.0005, rmse_mprtj),  # Text position
                        fontsize=12,
                        ha='left', va='center',
                        clip_on=True,
                        arrowprops=dict(arrowstyle='->',
                                        connectionstyle='arc3',
                                        color='black',
                                        shrinkB=2))

    # Set axis limits with padding
    ax.set_xlim(min(rmse_fe_values) - 0.0005, max(rmse_fe_values) + 0.0015)
    ax.set_ylim(min(rmse_mprtj_values) - 0.01, max(rmse_mprtj_values) + 0.01)

    # Add labels and title
    ax.set_xlabel("RMSE Fe (ev/atom)", fontsize=14)
    ax.set_ylabel("RMSE MPRTJ (ev/atom)", fontsize=14)
    
    plt.legend()
    # Adjust layout
    # plt.tight_layout()

    if SAVE_PLOTS:
        plt.savefig(f"{save_plot_dir}/rmse_chgnet_.png", dpi=500)



# # Example usage with the provided dictionary
chg2_1 = {
    "exp_Label": "Model 1",
    "exp_info": "",
    "rmse_mprtj":[0.169],
    "rmse_fe": [0.011],
}
chg2_2 = {
    "exp_Label": "Model 2",
    "exp_info": "",
    "rmse_mprtj":[0.152],
    "rmse_fe": [0.012],
}
chg2_3 = {
    "exp_Label": "Model 3",
    "exp_info": "",
    "rmse_mprtj":[0.143],
    "rmse_fe": [0.011],
}
chg2_5 = {
    "exp_name": "CHG2-85",
    "exp_Label": "Model 4",
    "exp_info": "",
    "rmse_mprtj":[0.113],
    "rmse_fe": [0.011],
}
chg2_6 = {
    "exp_Label": "Model 5",
    "exp_info": "Only Convolution layers trainable",
    "rmse_mprtj":[0.095],
    "rmse_fe": [0.010],
}

chg2_7 = {
    "exp_Label": "Model 6",
    "exp_info": "Naive",
    "rmse_mprtj":[0.111],
    "rmse_fe": [0.013],
}
data_list = [chg2_1, chg2_2, chg2_3, chg2_5, chg2_6, chg2_7]
plot_scatter_(data_list, title="Comparison Freezing layers for CHG2")

### Elastic properties as the training progresses.

In [None]:
epochs = [1,2,3,4,5,6,7,8,9, 10, 20, 30, 40, 50]
base_names = ["mace-exp-1-naive-lr3-ep"]

In [None]:

def plot_for_model(base_dir, base_name, marker=None, colors=None):
    c11s, c12s, c44s = [255], [145], [54]
    if marker is None:
        marker = 'o'
    if colors is None:
        colors = ['blue']
    for i in epochs:
        file_path = f"{base_dir}/{base_name}{i}/elastic_tensor.csv"
        df = pd.read_csv(file_path)
        c11s.append(df['C11'].values[0])
        c12s.append(df['C12'].values[0])
        c44s.append(df['C44'].values[0])
    epochs.insert(0, 0)  # insert 0 at the beginning for the initial values

    plt.plot(epochs, c11s, marker=marker,  label='C11')
    plt.plot(epochs, c12s, marker=marker,  label='C12') 
    plt.plot(epochs, c44s, marker=marker,  label= 'C44')

markers = ['o', 's', 'D', 'x', '^', 'v', 'p', '*']
# cool color
colors = ['#1f77b4', "#6E3706", '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f']
for color, marker,base_name in zip(colors, markers,base_names):
    print(f"{base_name = }")
    plot_for_model(base_dir, base_name, marker=marker, colors=color)

plt.legend()
plt.title('Elastic Constants vs Epochs for CHGNet')
plt.show()
