In [1]:
import sys,os
sys.path.insert(0, os.path.abspath(".."))

In [2]:
import  json
import pandas as pd
from sklearn.metrics import mean_absolute_error
from src import loading_data as ld

# -------------------------------------------------------------------
# Load host-density index mapping from JSON
# -------------------------------------------------------------------
with open("../dataset/labels/2dmd_host_density.json", "r") as file:
    index_dict = json.load(file)

# List of host窶電ensity pairs to evaluate separately
pair_host_density = [
    ("all", "combine"),
    ("P", "high"),
    ("GaSe", "high"),
    ("InSe", "high"),
    ("MoS2", "high"),
    ("WSe2", "high"),
    ("BN", "high"),
    ("MoS2", "low"),
    ("WSe2", "low"),
]

# -------------------------------------------------------------------
# Helper function: filter DataFrame rows by host窶電ensity subset
# -------------------------------------------------------------------
def filt_host(df, host, density, index_dict=index_dict):
    """Filter DataFrame rows by given host and density using index_dict."""
    host_density_index = index_dict[host][density]
    host_df_index = df.index.intersection(host_density_index)
    return df.loc[host_df_index, :]


# -------------------------------------------------------------------
# Load global configs
# -------------------------------------------------------------------
configs = ld.load_config_file("../configs/r4_config_allv2.yaml")
database = configs["database"][0]
dataset = configs["dataset"][0]
targets = configs["target_column"]
feature_sets = configs["feature_set"]

# For now, just take the first target
target = targets[0]


# -------------------------------------------------------------------
# Class for host窶電ensity specific predictions
# -------------------------------------------------------------------
class SepHostPrediction:
    def __init__(self, target, database="2dmd", dataset="all_density", feature_sets=None):
        """
        Manage host-density specific evaluation for a given target property.
        
        Args:
            target (str): Target property (e.g., 'formation_energy_per_site').
            database (str): Database name.
            dataset (str): Dataset name.
            feature_sets (list): List of feature sets to evaluate.
        """
        self.target = target
        self.database = database
        self.dataset = dataset
        self.feature_sets = feature_sets if feature_sets is not None else configs["feature_set"]

        # Reference results from Kazeev et al. (for comparison in LaTeX tables)
        self.kazeev_results = {
            "formation_energy_per_site": [
                "164$\\pm$5",
                "382$\\pm$30",
                "103$\\pm$4",
                "137$\\pm$5",
                "136$\\pm$5",
                "162$\\pm$6",
                "363$\\pm$17",
                "12.6$\\pm$0.4",
                "16.3$\\pm$0.8",
            ],
            "homo_lumo_gap_min": [
                "117$\\pm$1",
                "174$\\pm$2",
                "173$\\pm$4",
                "155$\\pm$1",
                "71$\\pm$4",
                "106$\\pm$6",
                "208$\\pm$3",
                "26.7$\\pm$0.8",
                "18.3$\\pm$0.6",
            ],
        }

        self.selected_feature_results = None

    # ---------------------------------------------------------------
    # Compute host-density specific results for all feature sets
    # ---------------------------------------------------------------
    def get_sep_results(self):
        """
        Load trained models and evaluate MAE for each host窶電ensity subset.
        
        Returns:
            pd.DataFrame: MAE results indexed by feature_set.
        """
        sep_results = []

        for feature_set in self.feature_sets:
            # Load trained CatBoost model and test data
            model = ld.load_results(
                "model.pkl",
                self.database,
                feature_set,
                self.dataset,
                self.target,
                "CatBoostRegressor",
                optimize="selected_best_random_100",
                result_dirname="results_2",
            )

            X_test = ld.load_results(
                "X_test.csv",
                self.database,
                feature_set,
                self.dataset,
                self.target,
                "CatBoostRegressor",
                optimize=None,
                result_dirname="results_2",
            )

            y_test = ld.load_results(
                "y_test.csv",
                self.database,
                feature_set,
                self.dataset,
                self.target,
                "CatBoostRegressor",
                optimize=None,
                result_dirname="results_2",
            ).iloc[:, 0]

            sample_weight = ld.load_results(
                "sample_weight.csv",
                self.database,
                feature_set,
                self.dataset,
                self.target,
                "CatBoostRegressor",
                optimize=None,
                result_dirname="results_2",
            ).iloc[:, 0]

            # Dictionary to store results for this feature set
            host_density_mae_dict = {"feature_set": feature_set}

            # Loop over all host-density pairs
            for host, density in pair_host_density:
                if (host == "all") and (density == "combine"):
                    X_test_filt, y_test_filt = X_test, y_test
                else:
                    X_test_filt = filt_host(X_test, host, density, index_dict=index_dict)
                    y_test_filt = y_test.loc[X_test_filt.index]

                # Predict and compute weighted MAE
                y_pred_test = model.predict(X_test_filt)
                test_weight = sample_weight.loc[X_test_filt.index]

                test_mae = mean_absolute_error(
                    y_test_filt, y_pred_test, sample_weight=test_weight
                )

                host_density_mae_dict[f"{host}_{density}"] = test_mae

            sep_results.append(host_density_mae_dict)

        sep_results_df = pd.DataFrame(sep_results).set_index("feature_set")
        return sep_results_df

    # ---------------------------------------------------------------
    # Filter and format results for selected feature sets
    # ---------------------------------------------------------------
    def filter_sep_results(self, sep_results_df):
        """
        Select a subset of feature sets, rescale MAE to meV, 
        and prepare DataFrame for LaTeX table.
        """
        feature_sets = [
            "cfid",
            "vpa_divi_chem_hellinger_l1_pristine_alldist_cfid",
            "chem_dist0_cfid",
            "vpa_divi_chem_hellinger_l1_pristine_cfid",
        ]

        # Select relevant rows and scale errors to meV
        selected_feature_results = (
            sep_results_df.loc[feature_sets, :] * 1000
        ).round(1).T.reset_index()

        # Extract host and density columns from index
        selected_feature_results["host"] = selected_feature_results["index"].apply(
            lambda x: x.split("_")[0]
        )
        selected_feature_results["density"] = selected_feature_results["index"].apply(
            lambda x: x.split("_")[1]
        )

        # Keep clean column order
        selected_feature_results = selected_feature_results.loc[
            :,
            [
                "host",
                "density",
                "cfid",
                "vpa_divi_chem_hellinger_l1_pristine_alldist_cfid",
                "chem_dist0_cfid",
                "vpa_divi_chem_hellinger_l1_pristine_cfid",
            ],
        ]

        # Rename feature set columns for LaTeX
        rename_dict = {
            "cfid": "Original\\\\ CFID",
            "vpa_divi_chem_hellinger_l1_pristine_alldist_cfid": "PF-Division\\\\ + Hellinger Distances",
            "chem_dist0_cfid": "Original\\\\ CFID$^*$",
            "vpa_divi_chem_hellinger_l1_pristine_cfid": "PF-Division\\\\ + Hellinger Distances$^*$",
        }
        selected_feature_results.rename(columns=rename_dict, inplace=True)

        # Replace host names with LaTeX-friendly formatting
        rename_host_dict = {
            "MoS2": r"\mos2",
            "WSe2": r"\wse2",
        }
        selected_feature_results["host"] = selected_feature_results["host"].replace(
            rename_host_dict
        )

        self.selected_feature_results = selected_feature_results
        return selected_feature_results

    # ---------------------------------------------------------------
    # Add Kazeev et al. results for comparison
    # ---------------------------------------------------------------
    def add_kazeev_results(self, df):
        """
        Append reference results from Kazeev et al. to the DataFrame.
        """
        kz_result = self.kazeev_results[self.target]
        df["previous \\\\ work \cite{kazeev2023sparse}"] = kz_result
        return df


In [3]:
def main(target):
    sephost_obj = SepHostPrediction(target)
    sep_results = sephost_obj.get_sep_results()
    filt_sep_results = sephost_obj.filter_sep_results(sep_results)
    return sephost_obj.add_kazeev_results(filt_sep_results)


In [4]:

eform_prediction_df=main('formation_energy_per_site')
eform_prediction_df

feature_set,host,density,Original\\ CFID,PF-Division\\ + Hellinger Distances,Original\\ CFID$^*$,PF-Division\\ + Hellinger Distances$^*$,previous \\ work \cite{kazeev2023sparse}
0,all,combine,154.6,141.0,126.9,135.7,164$\pm$5
1,P,high,82.3,76.3,81.8,76.2,382$\pm$30
2,GaSe,high,140.5,135.0,108.7,113.6,103$\pm$4
3,InSe,high,131.8,106.1,88.1,88.1,137$\pm$5
4,\mos2,high,170.1,159.1,121.7,177.8,136$\pm$5
5,\wse2,high,260.3,225.6,192.4,223.7,162$\pm$6
6,BN,high,359.1,341.6,356.5,331.4,363$\pm$17
7,\mos2,low,50.9,44.9,33.7,40.0,12.6$\pm$0.4
8,\wse2,low,41.9,39.3,32.4,34.9,16.3$\pm$0.8


In [5]:
gap_prediction_df=main('homo_lumo_gap_min')
gap_prediction_df

feature_set,host,density,Original\\ CFID,PF-Division\\ + Hellinger Distances,Original\\ CFID$^*$,PF-Division\\ + Hellinger Distances$^*$,previous \\ work \cite{kazeev2023sparse}
0,all,combine,112.7,112.0,116.9,117.8,117$\pm$1
1,P,high,152.1,148.1,152.1,145.7,174$\pm$2
2,GaSe,high,173.3,181.7,194.5,195.8,173$\pm$4
3,InSe,high,149.6,146.9,152.7,147.8,155$\pm$1
4,\mos2,high,64.2,60.4,81.0,88.2,71$\pm$4
5,\wse2,high,89.7,84.9,75.6,80.9,106$\pm$6
6,BN,high,193.0,199.0,210.5,213.9,208$\pm$3
7,\mos2,low,43.3,39.9,36.5,37.8,26.7$\pm$0.8
8,\wse2,low,36.5,35.2,32.2,32.7,18.3$\pm$0.6
