In [2]:
from DCC import *
from Utils import *
from Plots import *
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path
from collections import defaultdict

init_plotting()

dataset2name = {
    "Bala_classification_dataset.csv": "Bala Classification",
    "Bala_regression_dataset.csv": "Bala Regression",
    "bandgap.csv": "Bandgap",
    "BMDS_data.csv": "BMDS",
    "Crystal_structure.csv": "Crystal Structure",
    "Glass.csv": "Glass",
    "PUE.csv": "PUE",
}

In [3]:
from enum import Enum
from sklearn.datasets import make_classification


class PerturbationType(Enum):
    Deletion = "Deletion"
    AdditionLinear = "AdditionLinear"
    AdditionRand = "AdditionRand"
    ReplacementLinear = "ReplacementLinear"
    ReplacementRand = "ReplacementRand"


def generate_random_according_df(
    n: int, m: int, df: pd.DataFrame, task_type=None, target_col=None
):
    cols = df.columns
    new_data = []
    for i in range(n):
        new_row = []
        for j in range(m):
            new_row.append(random.uniform(df[cols[j]].min(), df[cols[j]].max()))
        new_data.append(new_row)
    new_df = pd.DataFrame(new_data, columns=cols)
    if task_type == "classification":
        if target_col is not None:
            new_df[target_col] = np.random.choice(
                df[target_col].unique(), size=n, replace=True
            )
    return new_df


def generate_linear_according_df(
    n: int, m: int, df: pd.DataFrame, task_type=None, target_col=None
):
    if task_type == "classification":
        # Generate a classification dataset
        X, y = make_classification(
            n_samples=n,
            n_features=m - 1,
            n_classes=len(df[target_col].unique()),
            n_informative= (m - 1) // 2,
        )[:2]
    else:
        X, y = make_regression(
            n_samples=n, n_features=m - 1, noise=0.1
        )[:2]

    X_cols = [x for x in df.columns if x != target_col]
    df_linear = pd.DataFrame(X, columns=X_cols)
    df_linear[target_col] = y
    # normalize the data to the range of df
    for col in df.columns:
        min_val = df[col].min()
        max_val = df[col].max()
        df_linear[col] = (df_linear[col] - df_linear[col].min()) / (
            df_linear[col].max() - df_linear[col].min()
        ) * (max_val - min_val) + min_val

    return df_linear


def perturbate(
    data: pd.DataFrame,
    ptb_type: PerturbationType,
    task_type,
    target_col,
    ratio=0.1,
):
    n, m = data.shape
    if ptb_type == PerturbationType.Deletion:
        perturbed_data = data.sample(frac=1 - ratio)
    elif ptb_type == PerturbationType.AdditionLinear:
        df_linear = generate_linear_according_df(
            n // 20, m, data, task_type, target_col
        )
        perturbed_data = pd.concat([data, df_linear], ignore_index=True)
    elif ptb_type == PerturbationType.AdditionRand:
        df_rand = generate_random_according_df(n // 20, m, data, task_type, target_col)
        perturbed_data = pd.concat([data, df_rand], ignore_index=True)
    elif ptb_type == PerturbationType.ReplacementLinear:
        df_linear = generate_linear_according_df(
            n // 20, m, data, task_type, target_col
        )
        perturbed_data = random_replace_rows(data, df_linear)
    elif ptb_type == PerturbationType.ReplacementRand:
        df_rand = generate_random_according_df(n // 20, m, data, task_type, target_col)
        perturbed_data = random_replace_rows(data, df_rand)
    else:
        raise ValueError(f"Unknown perturbation type: {ptb_type}")

    return perturbed_data

In [4]:
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import KFold, StratifiedKFold
from sklearn.metrics import r2_score, accuracy_score
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from joblib import Parallel, delayed
import shap

dataset_dir = Path("processed_data")
datasets = list(dataset_dir.glob("*.csv"))
dataset_config = {
    "Bala_classification_dataset.csv": {
        "target_col": "Formability",
        "type": "classification",
    },
    "Bala_regression_dataset.csv": {
        "target_col": "Ferroelectric_Tc_in_Kelvin",
        "type": "regression",
    },
    "bandgap.csv": {
        "target_col": "target",
        "type": "regression",
    },
    "BMDS_data.csv": {
        "target_col": "soc_bandgap",
        "type": "regression",
    },
    "Crystal_structure.csv": {
        "target_col": "Lowest distortion",
        "type": "classification",
    },
    "Glass.csv": {
        "target_col": "Type of glass",
        "type": "classification",
    },
    "PUE.csv": {
        "target_col": "logYM",
        "type": "regression",
    },
}
def get_shap_values(df, target_col, type):
    X = df.drop(columns=[target_col])
    y = df[target_col]
    if type == "regression":
        model = RandomForestRegressor(n_estimators=500, random_state=42)
    elif type == "classification":
        model = RandomForestClassifier(n_estimators=500, random_state=42)
    else:
        raise ValueError("type must be either 'regression' or 'classification'")
    model.fit(X, y)

    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X)
    # shap.summary_plot(shap_values, X, plot_type="bar")
    if type == "classification":
        # For classification, shap_values is a list of arrays, one for each class
        # We take the mean absolute value across all classes
        mean_abs_shap = np.mean(np.mean(np.abs(shap_values), axis=0), axis=1)
    else:
        mean_abs_shap = np.mean(np.abs(shap_values), axis=0)
    # print(f"Mean absolute SHAP values: {mean_abs_shap}")
    feats2shap = {col: mean_abs_shap[i] for i, col in enumerate(X.columns)}
    return feats2shap


def get_corr_info(df, corr_func, target_col):
    corr_matrix = corr_func(df)
    feats2corr = {}
    for col in df.columns:
        if col == target_col:
            continue
        feats2corr[col] = corr_matrix.loc[target_col, col]
    return feats2corr

In [5]:
import pickle
from joblib import Parallel, delayed
import logging


def get_logger():
    logger = logging.getLogger("worker")
    if not logger.hasHandlers():
        logger.setLevel(logging.INFO)
        handler = logging.StreamHandler()
        formatter = logging.Formatter('%(process)d - %(levelname)s - %(message)s')
        handler.setFormatter(formatter)
        logger.addHandler(handler)
    return logger

# Define correlation functions and perturbation ratios
correlation_functions = [
    pearson_matrix,
    spearman_matrix,
    kendall_matrix,
    mutual_info_matrix,
    js_corr_matrix,
    wd_corr_matrix,
    xi_matrix,
    dcor_matrix,
]

def process_single_dataset(dataset_path, ptb_ratio):
    """Process a single dataset with given perturbation ratio and correlation function."""
    logger = get_logger()
    dataset_results = {}
    try:
        df = pd.read_csv(dataset_path)
        dataset_name = dataset_path.name
        target_col = dataset_config[dataset_name]["target_col"]
        task_type = dataset_config[dataset_name]["type"]
        
        # Calculate original SHAP values
        original_shap = get_shap_values(df, target_col, task_type)
        
        dataset_results["ori_feats2shap"] = original_shap
        dataset_results["name"] = dataset2name[dataset_name]
        dataset_results["type"] = task_type
        dataset_results["target_col"] = target_col
        
        # Initialize perturbation result containers
        for ptb_type in PerturbationType:
            dataset_results[ptb_type.value] = {}
        
        # Process each perturbation type
        for ptb_type in PerturbationType:
            modified_df = perturbate(
                df, ptb_type, task_type, target_col, 
                ratio=ptb_ratio
            )
            
            perturbed_shap = get_shap_values(modified_df, target_col, task_type)
            
            # dataset_results[ptb_type.value]["dcc"] = dcc_result
            dataset_results[ptb_type.value]["feats2shap"] = perturbed_shap
            for corr_func in correlation_functions:
                correlation_info = get_corr_info(modified_df, corr_func, target_col)
                dataset_results[ptb_type.value][f"feats2corr_{corr_func.__name__}"] = correlation_info
                dataset_results[ptb_type.value][f"dcc_{corr_func.__name__}"] = dcc_diff_features(df, modified_df, target_col, eps=0.04, corr_func=corr_func)
            
        return dataset_name, dataset_results
        
    except Exception as e:
        logger.error(f"Error processing {dataset_path.name}: {str(e)}")
        return dataset_path.name, None

def process_ratio(ptb_ratio):
    """Process all datasets for a specific ratio and correlation function combination."""
    
    logger = get_logger()
    
    dataset_paths = list(Path("processed_data").glob("*.csv"))
    
    # Parallel processing of datasets
    dataset_results = Parallel(n_jobs=-1, verbose=1)(
        delayed(process_single_dataset)(dataset_path, ptb_ratio) 
        for dataset_path in dataset_paths
    )
    # Consolidate results
    consolidated_results = defaultdict(dict)
    for dataset_name, result_data in dataset_results:
        if result_data is not None:
            consolidated_results[dataset_name] = result_data
    
    # Save consolidated results
    output_dir = Path("results")
    output_dir.mkdir(exist_ok=True)
    output_file = output_dir / f"SHAP_results_{ptb_ratio}.pkl"
    
    with open(output_file, "wb") as file_handle:
        pickle.dump(dict(consolidated_results), file_handle)
    
    logger.info(f"Saved results to {output_file}")
    return output_file

In [5]:
# process_single_dataset(Path("processed_data/PUE.csv"), 0.05)

In [6]:
perturbation_ratios = [0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15]


# Parallel processing of ratio-correlation combinations
# Using fewer jobs for outer loop to avoid overwhelming the system
processed_files = Parallel(n_jobs=-1, verbose=2)(
    delayed(process_ratio)(ratio)
    for ratio in perturbation_ratios
)

logging.info(f"All processing complete. Generated {len(processed_files)} result files.")

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 32 concurrent workers.
[Parallel(n_jobs=-1)]: Done   2 out of  11 | elapsed: 79.9min remaining: 359.6min
[Parallel(n_jobs=-1)]: Done   8 out of  11 | elapsed: 81.7min remaining: 30.6min
[Parallel(n_jobs=-1)]: Done  11 out of  11 | elapsed: 82.0min finished


In [8]:
perturbation_ratios = [0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15]
rows = []
for ratio in perturbation_ratios:
    file_path = Path(f"results/SHAP_results_{ratio}.pkl")
    if file_path.exists():
        with open(file_path, "rb") as file_handle:
            results = pickle.load(file_handle)
            for dataset_name in results.keys():
                dataset_results = results[dataset_name]
                print(f"Processing dataset: {dataset_name} with ratio: {ratio}")
                for ptb_type in PerturbationType:
                    if ptb_type.value in dataset_results:

                        feats2shap = dataset_results[ptb_type.value].get("feats2shap", {})
                        feats = list(feats2shap.keys())
                        for feat in feats:
                            row = {}
                            row["dataset"] = dataset_name
                            row["perturbation_type"] = ptb_type.value
                            row["perturbation_ratio"] = ratio
                            row['SHAP_feat'] = feats2shap[feat]
                            row['SHAP_feat_ratio'] = feats2shap[feat] / max(feats2shap.values())
                            row['max_shap_feat'] = list(feats2shap.keys())[np.argmax(list(feats2shap.values()))]
                            row["feat"] = feat
                            for corr_func in correlation_functions:
                                feats2corr = dataset_results[ptb_type.value].get(f"feats2corr_{corr_func.__name__}", {})
                                dcc_result = dataset_results[ptb_type.value].get(f"dcc_{corr_func.__name__}", {})
                                row[f"DCC_{corr_func.__name__}"] = dcc_result.get(feat, None)
                                row[f"Corr_{corr_func.__name__}"] = feats2corr.get(feat, None)
                            rows.append(row)
df_results = pd.DataFrame(rows)
df_results.head()

Processing dataset: Bala_classification_dataset.csv with ratio: 0.05
Processing dataset: Bala_regression_dataset.csv with ratio: 0.05
Processing dataset: bandgap.csv with ratio: 0.05
Processing dataset: BMDS_data.csv with ratio: 0.05
Processing dataset: Crystal_structure.csv with ratio: 0.05
Processing dataset: Glass.csv with ratio: 0.05
Processing dataset: PUE.csv with ratio: 0.05
Processing dataset: Bala_classification_dataset.csv with ratio: 0.06
Processing dataset: Bala_regression_dataset.csv with ratio: 0.06
Processing dataset: bandgap.csv with ratio: 0.06
Processing dataset: BMDS_data.csv with ratio: 0.06
Processing dataset: Crystal_structure.csv with ratio: 0.06
Processing dataset: Glass.csv with ratio: 0.06
Processing dataset: PUE.csv with ratio: 0.06
Processing dataset: Bala_classification_dataset.csv with ratio: 0.07
Processing dataset: Bala_regression_dataset.csv with ratio: 0.07
Processing dataset: bandgap.csv with ratio: 0.07
Processing dataset: BMDS_data.csv with ratio: 0

Unnamed: 0,dataset,perturbation_type,perturbation_ratio,SHAP_feat,SHAP_feat_ratio,max_shap_feat,feat,DCC_pearson_matrix,Corr_pearson_matrix,DCC_spearman_matrix,...,DCC_mutual_info_matrix,Corr_mutual_info_matrix,DCC_js_corr_matrix,Corr_js_corr_matrix,DCC_wd_corr_matrix,Corr_wd_corr_matrix,DCC_xi_matrix,Corr_xi_matrix,DCC_dcor_matrix,Corr_dcor_matrix
0,Bala_classification_dataset.csv,Deletion,0.05,0.031496,0.297431,Mendeleev_Number,Compound,1.0,-0.100841,1.0,...,1.0,0.024555,1.0,0.313147,1.0,0.527515,0.357143,0.180511,1.0,0.135073
1,Bala_classification_dataset.csv,Deletion,0.05,0.027763,0.262175,Mendeleev_Number,x(BiMe1Me2)O3,0.928571,-0.395085,1.0,...,0.928571,0.077559,1.0,0.528265,1.0,0.568063,0.5,0.229435,1.0,0.377583
2,Bala_classification_dataset.csv,Deletion,0.05,0.036863,0.348111,Mendeleev_Number,Me1,0.928571,-0.094558,1.0,...,1.0,0.091084,1.0,0.320797,1.0,0.437985,0.214286,0.204973,1.0,0.175186
3,Bala_classification_dataset.csv,Deletion,0.05,0.026647,0.251633,Mendeleev_Number,Me2,0.928571,-0.246314,1.0,...,1.0,0.025397,1.0,0.390185,1.0,0.42069,0.214286,0.119355,1.0,0.253483
4,Bala_classification_dataset.csv,Deletion,0.05,0.009097,0.08591,Mendeleev_Number,frac-Me1,1.0,0.017028,1.0,...,1.0,0.012468,1.0,0.547022,1.0,0.90235,0.142857,-0.002957,1.0,0.074559


In [9]:
df_results.columns

Index(['dataset', 'perturbation_type', 'perturbation_ratio', 'SHAP_feat',
       'SHAP_feat_ratio', 'max_shap_feat', 'feat', 'DCC_pearson_matrix',
       'Corr_pearson_matrix', 'DCC_spearman_matrix', 'Corr_spearman_matrix',
       'DCC_kendall_matrix', 'Corr_kendall_matrix', 'DCC_mutual_info_matrix',
       'Corr_mutual_info_matrix', 'DCC_js_corr_matrix', 'Corr_js_corr_matrix',
       'DCC_wd_corr_matrix', 'Corr_wd_corr_matrix', 'DCC_xi_matrix',
       'Corr_xi_matrix', 'DCC_dcor_matrix', 'Corr_dcor_matrix'],
      dtype='object')

In [11]:
df_results.to_csv("results/SHAP_results.csv", index=False)