In [None]:
try:
    from pathlib import Path
    import pandas as pd
    import matplotlib.pyplot as plt
    from lib.lib_utils import Utils
    import seaborn as sns
    from lib.lib_defect_analysis import Features
    from tqdm import tqdm
    import numpy as np
    from sklearn.model_selection import train_test_split
    from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
    import xgboost as xgb
    from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
    from tqdm import tqdm
    from matplotlib.ticker import MaxNLocator
    import cv2
    import json
    import pywt
    import xgboost as xgb
    import lightgbm as lgb
    from matplotlib.colors import LinearSegmentedColormap
    import shap
except Exception as e:
    print(f"Some module are missing: {e}\n")

In [None]:
data_path = Path().resolve().joinpath("data_234")
xyz_files_path = data_path.joinpath("xyz_files")
yolo_model_path = data_path.joinpath("models", "best.pt")
images_path = data_path.joinpath("images")
crops_path = data_path.joinpath("crops")
plot_path = data_path.joinpath("plots")
pred_path = data_path.joinpath("predictions")

plot_path.mkdir(exist_ok=True, parents=True)
pred_path.mkdir(exist_ok=True, parents=True)
plt.style.use("seaborn-v0_8-paper")
colors = ["#F0741E", "#276CB3"]

custom_cmap = LinearSegmentedColormap.from_list("custom_cmap", [colors[1], colors[0]])

In [None]:
max_dim = [39.53476932, 34.27629786]
Utils.from_xyz_to_png(xyz_files_path, images_path, max_dim=max_dim, multiplier=6)

Utils.generate_yolo_crops(
    images_path,
    crops_path,
    yolo_model_path,
    binary_mask=True,
    device="cpu",
    confidence=0.75,
)

In [None]:
def calculate_entropy(data: np.ndarray, bins: int = 256) -> float:
    data_flat = data.flatten()
    histogram, _ = np.histogram(data_flat, bins=bins, range=(0, bins), density=True)
    entropy = -np.sum(histogram * np.log2(histogram + 1e-10))
    return entropy


def clalc_shift_spect(img):
    f = np.fft.fft2(img)

    coeffs = pywt.wavedec2(img, "haar", level=2)
    return f, coeffs


def extract_frequency_features(
    image: Path | np.ndarray, wavelet: str = "db4", bins: int = 256
) -> dict:
    image = image.resolve()
    if isinstance(image, Path):
        img = cv2.imread(str(image), cv2.IMREAD_GRAYSCALE)
    else:
        img = image.copy()
        if len(img.shape) == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    # f = np.fft.fft2(img)
    f, coeffs = clalc_shift_spect(img)
    fshift = np.fft.fftshift(f)
    magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1e-10)

    mean_freq = np.mean(magnitude_spectrum)
    std_freq = np.std(magnitude_spectrum)
    max_freq = np.max(magnitude_spectrum)
    min_freq = np.min(magnitude_spectrum)
    median_freq = np.median(magnitude_spectrum)
    energy = np.sum(np.abs(fshift) ** 2)

    rows, cols = magnitude_spectrum.shape
    crow, ccol = rows // 2, cols // 2

    histogram, _ = np.histogram(
        magnitude_spectrum, bins=bins, range=(0, bins), density=True
    )
    entropy = -np.sum(histogram * np.log2(histogram + 1e-10))

    low_freq_energy = np.sum(magnitude_spectrum[:crow, :ccol])
    high_freq_energy = np.sum(magnitude_spectrum[crow:, ccol:])
    frequency_contrast = high_freq_energy - low_freq_energy

    # coeffs = pywt.wavedec2(img, wavelet, level=2)
    cA2, (cH2, cV2, cD2), (cH1, cV1, cD1) = coeffs

    wavelet_features = {
        "wavelet_energy_A2": np.sum(np.square(cA2)),
        "wavelet_energy_H2": np.sum(np.square(cH2)),
        "wavelet_energy_V2": np.sum(np.square(cV2)),
        "wavelet_energy_D2": np.sum(np.square(cD2)),
        "wavelet_std_A2": np.std(cA2),
        "wavelet_std_H2": np.std(cH2),
        "wavelet_std_V2": np.std(cV2),
        "wavelet_std_D2": np.std(cD2),
        "wavelet_energy_H1": np.sum(np.square(cH1)),
        "wavelet_energy_V1": np.sum(np.square(cV1)),
        "wavelet_energy_D1": np.sum(np.square(cD1)),
        "wavelet_std_H1": np.std(cH1),
        "wavelet_std_V1": np.std(cV1),
        "wavelet_std_D1": np.std(cD1),
        "wavelet_entropy_A2": calculate_entropy(cA2, bins),
        "wavelet_entropy_H2": calculate_entropy(cH2, bins),
        "wavelet_entropy_V2": calculate_entropy(cV2, bins),
        "wavelet_entropy_D2": calculate_entropy(cD2, bins),
        "wavelet_entropy_H1": calculate_entropy(cH1, bins),
        "wavelet_entropy_V1": calculate_entropy(cV1, bins),
        "wavelet_entropy_D1": calculate_entropy(cD1, bins),
    }

    frequency_features = {
        "mean_frequency": mean_freq,
        "std_frequency": std_freq,
        "max_frequency": max_freq,
        "min_frequency": min_freq,
        "median_frequency": median_freq,
        "energy": energy,
        "entropy": entropy,
        "low_frequency_energy": low_freq_energy,
        "high_frequency_energy": high_freq_energy,
        "frequency_contrast": frequency_contrast,
    }

    frequency_features.update(wavelet_features)

    return frequency_features

In [None]:
features_list = [
    "area",
    "perimeter",
    "circularity",
    "solidity",
    "compactness",
    "feret_diameter",
    "edge_density",
    "eccentricity",
    "number_of_edges",
    "GLCM_energy",
    "GLCM_correlation",
    "GLCM_homogeneity",
    "GLCM_energy",
    "GLCM_contrast",
    "max_frequency",
    "energy",
    "entropy",
    "std_frequency",
    "min_frequency",
    "mean_frequency",
    "median_frequency",
    "low_frequency_energy",
    "high_frequency_energy",
    "frequency_contrast",
    "wavelet_energy_A2",
    "wavelet_energy_H2",
    "wavelet_energy_V2",
    "wavelet_energy_D2",
    "wavelet_std_A2",
    "wavelet_std_H2",
    "wavelet_std_V2",
    "wavelet_std_D2",
    "wavelet_energy_H1",
    "wavelet_energy_V1",
    "wavelet_energy_D1",
    "wavelet_std_H1",
    "wavelet_std_V1",
    "wavelet_std_D1",
    "wavelet_entropy_A2",
    "wavelet_entropy_H2",
    "wavelet_entropy_V2",
    "wavelet_entropy_D2",
    "wavelet_entropy_H1",
    "wavelet_entropy_V1",
    "wavelet_entropy_D1",
]


target_list = [
    "fermi_level_ev",
    "IP_ev",
    "EA_ev",
    "band_gap_ev",
    "energy_per_atom",
    "current",
]

target_labels = {
    "fermi_level_ev": ("Fermi Level [eV] - predicted", "Fermi Level[eV] - true"),
    "EA_ev": (
        "Ionization Potential [eV] - predicted",
        "Ionization Potential [eV] - true",
    ),
    "IP_ev": ("Electron Affinity [eV] - predicted", "Electron Affinity [eV] - true"),
    "band_gap_ev": ("Band Gap [eV] - predicted", "Band Gap [eV] - true"),
    "energy_per_atom": (
        "Energy Per Atom [eV] - predicted",
        "Energy Per Atom [eV] - true",
    ),
    "current": (
        "Current [μA] - predicted",
        "Current [μA] - true",
    ),
}

In [None]:
filepath = data_path.joinpath("features.csv").exists()

In [None]:
if not filepath:
    images = [
        f for f in crops_path.iterdir() if f.suffix.lower() in Features.IMAGE_EXTENSIONS
    ]

    features = {}

    for image in tqdm(images):
        if len(features) == 0:
            keys_list = list(features.keys())
            df = pd.DataFrame(columns=keys_list)
        else:
            features.clear()

        name = image.stem.split("_crop")[0]

        with open(str(xyz_files_path.joinpath(f"{name}.xyz")), "r") as file:
            first_line = file.readline()
        n_atoms = int(first_line.strip())

        features.update({"file_name": name})
        features.update({"n_atoms": n_atoms})

        shape_features = Features.extract_shape_features(image, grayscale=True)
        if shape_features is not None:
            features.update(shape_features)

        edge_features = Features.extract_edge_features(image, grayscale=True)
        if edge_features is not None:
            features.update(edge_features)

        texture_features = Features.extract_texture_features(image)
        if texture_features is not None:
            features.update(texture_features)

        # frequency_features = Features.extract_frequency_features(image)
        # if frequency_features is not None:
        #     features.update(frequency_features)

        new_row = pd.Series(features)
        df = pd.concat([df, new_row.to_frame().T], ignore_index=True)
        df["file_name"] = df["file_name"].str.replace("_opt", "", regex=False)
    grouped_df = (
        df.groupby("file_name")
        .agg(
            {
                "n_atoms": "first",
                "area": "sum",
                "num_pixels": "sum",
                "perimeter": "sum",
                "circularity": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
                "solidity": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
                "compactness": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
                "feret_diameter": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
                "eccentricity": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
                "number_of_edges": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
                "edge_density": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
                "GLCM_contrast": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
                "GLCM_homogeneity": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
                "GLCM_energy": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
                "GLCM_correlation": lambda x: (x * df.loc[x.index, "area"]).sum()
                / df.loc[x.index, "area"].sum(),
            }
        )
        .reset_index()
    )

    #
    original_df = pd.read_csv(xyz_files_path.joinpath("target_graphene_dftb.csv"))

    energy_dict = original_df.set_index("file_name")["total_energy_eV"].to_dict()
    fermi_dict = original_df.set_index("file_name")["fermi_level_ev"].to_dict()
    ip_dict = original_df.set_index("file_name")["IP_ev"].to_dict()
    ea_dict = original_df.set_index("file_name")["EA_ev"].to_dict()
    band_gap_dict = original_df.set_index("file_name")["band_gap_ev"].to_dict()
    current_dict = original_df.set_index("file_name")["current"].to_dict()
    flake_total_area_dict = original_df.set_index("file_name")[
        "flake_total_area"
    ].to_dict()

    # Aggiunta della colonna total_energy al primo dataframe
    grouped_df["total_energy_eV"] = grouped_df["file_name"].map(energy_dict)
    grouped_df["fermi_level_ev"] = grouped_df["file_name"].map(fermi_dict)
    grouped_df["IP_ev"] = grouped_df["file_name"].map(ip_dict)
    grouped_df["EA_ev"] = grouped_df["file_name"].map(ea_dict)
    grouped_df["band_gap_ev"] = grouped_df["file_name"].map(band_gap_dict)
    grouped_df["energy_per_atom"] = (
        grouped_df["total_energy_eV"] / grouped_df["n_atoms"]
    )
    grouped_df["current"] = grouped_df["file_name"].map(current_dict)
    grouped_df["flake_total_area"] = grouped_df["file_name"].map(flake_total_area_dict)
    grouped_df = grouped_df.dropna()

    grouped_df.to_csv(data_path.joinpath("features.csv"), index=False)
else:
    grouped_df = pd.read_csv(data_path.joinpath("features.csv"))

In [None]:
frequency_dict = {}
for filename in tqdm(grouped_df["file_name"].to_list()):
    img = images_path.joinpath(f"{filename}_opt.png")
    dict_img = extract_frequency_features(img)
    frequency_dict[filename] = dict_img

In [None]:
grouped_df["current"] = grouped_df["current"] * 1e6

In [None]:
frequency_df = pd.DataFrame.from_dict(frequency_dict, orient="index")
grouped_df = grouped_df.join(frequency_df, on="file_name")

In [None]:
indices_to_remove = grouped_df[
    (grouped_df["current"] >= grouped_df["current"].min())
    & (grouped_df["current"] <= 9e-02)
].index

# Rimuovere le righe corrispondenti agli indici trovati
grouped_df = grouped_df.drop(indices_to_remove)

In [None]:
if grouped_df.isna().any().any():
    print("Avviso: Il DataFrame contiene valori NaN.")
else:
    print("Il DataFrame non contiene valori NaN.")