In [1]:
import tempfile
from typing import Any

import matplotlib.pyplot as plt
from pandas import DataFrame
from xgboost import XGBClassifier
from PyQt5 import QtCore, QtWidgets, uic



class MainApp(QtWidgets.QMainWindow):
    def __init__(self, parent=None):
        super(MainApp, self).__init__()

        self.__qt_helper = QtHelper()
        self.__pickle_service = PickleService()

        uic.loadUi("../ui/main.ui", self)

        self.open_file_action.setStatusTip("Открыть файл проекта")
        self.open_file_action.triggered.connect(self.open_saved_file)

        self.save_file_action.setStatusTip("Сохранить файл модели и выборки")
        self.save_file_action.triggered.connect(self.save_file)

        self.load_model_action.setStatusTip("Открыть файл модели")
        self.load_model_action.triggered.connect(self.open_model)

        self.load_dataset_action.setStatusTip("Открыть файл выборки")
        self.load_dataset_action.triggered.connect(self.open_dataset)

        self.save_plot_action.triggered.connect(self.save_plots)

        self.push_button_plot.clicked.connect(self.make_plots)

        self.plots_combo_box.currentIndexChanged.connect(self.switch_plots)

        self.statusBar().showMessage("Ожидается загрузка модели и датасета")

        self.is_model_loaded = False
        self.is_dataset_loaded = False

        self.__model = None
        self.__dataset = None

        self.__paths: dict[str, list[str]] = {}

        self.temp_dir = tempfile.TemporaryDirectory()

        self.setWindowState(QtCore.Qt.WindowMaximized)
        self.debug_start()
        self.show()

    @property
    def dataset(self) -> DataFrame:
        return self.__dataset

    @dataset.setter
    def dataset(self, dataset) -> None:
        self.is_dataset_loaded = True
        self.__dataset = dataset

        self.create_columns()
        self.create_rows()
        self.show_data_status()

    @property
    def model(self) -> XGBClassifier:
        return self.__model

    @model.setter
    def model(self, model) -> None:
        self.is_model_loaded = True
        self.__model = model
        self.show_data_status()

    def closeEvent(self, event):
        self.temp_dir.cleanup()

    def debug_start(self):
        self.open_saved_file(debug_file_path)

    def show_data_status(self) -> None:
        if self.is_dataset_loaded and self.is_model_loaded:
            self.statusBar().showMessage(dataset_and_model_loaded_message)
        elif self.is_dataset_loaded:
            self.statusBar().showMessage(only_dataset_loaded_message)
        elif self.is_model_loaded:
            self.statusBar().showMessage(only_model_loaded_message)
        else:
            self.statusBar().showMessage(data_is_awaited_message)

    def open_saved_file(self, path: str = ""):
        if not path:
            file_name = self.__qt_helper.get_path_to_open_file(self)
        else:
            file_name = path

        if not file_name:
            return

        mono_object = self.__pickle_service.get_dataset_and_model(file_name)

        self.model = mono_object.model
        self.dataset = mono_object.dataset

    def save_file(self):
        if not all([self.is_dataset_loaded, self.is_model_loaded]):
            self.show_data_status()
            return

        file_name = self.__qt_helper.get_path_to_save_file(self)

        if not file_name:
            return

        mono = DatasetModelMonoObject(dataset=self.__dataset, model=self.model)
        self.__pickle_service.save_dataset_and_model(mono, file_name)

        self.statusBar().showMessage(save_complete_message)

    def open_dataset(self):
        path = self.__qt_helper.get_path_to_open_file(self)

        if not path:
            return None

        self.dataset = self.__pickle_service.get_dataset(path)

    def create_columns(self):
        layout = DatasetRendered(self.dataset).get_rendered_info_plots_layout()
        self.tabCollumns.setLayout(layout)

    def create_rows(self):
        layout = DatasetRendered(self.dataset).get_rendered_data_layout()
        self.tabData.setLayout(layout)

    def open_model(self):
        file_name = self.__qt_helper.get_path_to_open_file(self)

        if not file_name:
            return

        self.model = self.__pickle_service.get_model(file_name)

    def save_plots(self):
        if self.plots_combo_box.count() == 0:
            self.statusBar().showMessage(no_plots_message)
            return

        path = self.__qt_helper.get_existing_dir_path(self)

        if not path:
            return

        CreatedPlotsSaver.save_plots(self.temp_dir.name, path)

    def switch_plots(self):
        if not self.is_plots_in_progress:
            text = self.plots_combo_box.currentText()
            self.show_plots(text)

    def make_plots(self):
        if not self.is_dataset_loaded and self.is_model_loaded:
            self.show_data_status()

        dialog = QtWidgets.QDialog()
        self.ui = PlotDataDialog()
        min_max = getMinMax(self.dataset)
        category_columns = find_category_columns(self.dataset)
        self.ui.setupUi(
            dialog,
            list(self.dataset.columns),
            category_columns,
            min_max,
            self,
            self.retrieve_data_from_child,
        )
        dialog.show()

    def retrieve_data_from_child(self, settings: PlotSettings):
        # self.createPlots(
        #     settings.column,
        #     settings.min_value,
        #     settings.max_value,
        #     settings.category_column,
        # )
        self.is_plots_in_progress = True

        plot_creator = PlotCreator(
            self.model,
            self.dataset,
            settings.column,
            settings.min_value,
            settings.max_value,
        )

        if settings.category_column:
            full_paths = plot_creator.plot_category_plots(settings.category_column)
            self.__save_category_paths(full_paths, settings)
        else:
            paths = plot_creator.plot_regular_plots()
            self.__save_regular_paths(paths, settings)

        self.is_plots_in_progress = False

        self.__update_plot_names_in_combo_box()
        self.switch_plots()

    def createPlots(self, colName, minVal, maxVal, categorcal_col=""):
        categorical = bool(categorcal_col)

        cond1 = self.dataset[colName] >= minVal
        cond2 = self.dataset[colName] <= maxVal

        cur_data = self.dataset[cond1 & cond2].reset_index(drop=True)

        h = 6
        w = 12
        plt.figure(1, figsize=(1000, 1000), dpi=1)

        temp = self.temp_dir

        self.is_plots_in_progress = True

        if categorical:
            categories = sorted(list(self.dataset[categorcal_col].unique()))
            categories_items = [f"{colName}:{categorcal_col}:{x}" for x in categories]
            self.plots_combo_box.addItems(categories_items)
            self.plots_combo_box.setEnabled(True)
            self.plots_combo_box.setCurrentIndex(
                self.plots_combo_box.count() - len(categories)
            )

            for i in range(len(categories)):
                plt.clf()

                cond3 = self.dataset[categorcal_col] == categories[i]

                cur_data = self.dataset[cond1 & cond2 & cond3].reset_index(drop=True)

                name = categories_items[i].replace(":", "")

                plot_lt = plot_top5_centered_importance(
                    self.model, cur_data, colName, True
                )
                fig_lt = plot_lt.get_figure()
                fig_lt.set_size_inches(w, h)
                fig_lt.savefig(temp.name + f"\\img1{name}.svg", bbox_inches="tight")
                fig_lt.savefig(
                    temp.name + f"\\img1{name}.png", bbox_inches="tight", format="png"
                )
                plt.clf()
                #
                if colName == categorcal_col:
                    plot_lb = plot_ice_plot(
                        self.model,
                        self.dataset[cond1 & cond2].reset_index(drop=True),
                        colName,
                        True,
                    )
                else:
                    plot_lb = plot_ice_plot(self.model, cur_data, colName, True)
                fig_lb = plot_lb.get_figure()
                fig_lb.set_size_inches(w, h)
                fig_lb.savefig(temp.name + f"\\img2{name}.svg", bbox_inches="tight")
                fig_lb.savefig(
                    temp.name + f"\\img2{name}.png", bbox_inches="tight", format="png"
                )
                plt.clf()

                plot_rt = plot_top5_centered_importance(self.model, cur_data, colName)
                fig_rt = plot_rt.get_figure().figure
                fig_rt.set_size_inches(w, h)
                fig_rt.savefig(temp.name + f"\\img0{name}.svg", bbox_inches="tight")
                fig_rt.savefig(
                    temp.name + f"\\img0{name}.png", bbox_inches="tight", format="png"
                )
                plt.clf()
                #
                if colName == categorcal_col:
                    plot_rb = plot_ice_plot(
                        self.model,
                        self.dataset[cond1 & cond2].reset_index(drop=True),
                        colName,
                    )
                else:
                    plot_rb = plot_ice_plot(self.model, cur_data, colName)
                fig_rb = plot_rb.get_figure().figure
                fig_rb.set_size_inches(w, h)
                fig_rb.savefig(temp.name + f"\\img3{name}.svg", bbox_inches="tight")
                fig_rb.savefig(
                    temp.name + f"\\img3{name}.png", bbox_inches="tight", format="png"
                )
                plt.clf()

            self.show_grath(
                self.plots_combo_box.itemText(
                    self.plots_combo_box.count() - len(categories)
                ).replace(":", "")
            )

        else:
            self.plots_combo_box.addItems([colName])
            self.plots_combo_box.setEnabled(True)
            self.plots_combo_box.setCurrentIndex(self.plots_combo_box.count() - 1)

            plt.clf()
            plot_lt = plot_top5_centered_importance(self.model, cur_data, colName, True)
            fig_lt = plot_lt.get_figure()
            fig_lt.set_size_inches(w, h)
            fig_lt.savefig(temp.name + f"\\img1{colName}.svg", bbox_inches="tight")
            fig_lt.savefig(
                temp.name + f"\\img1{colName}.png", bbox_inches="tight", format="png"
            )
            plt.clf()

            self.statusBar().showMessage("1")
            print("1")

            plot_lb = plot_ice_plot(self.model, cur_data, colName, True)
            fig_lb = plot_lb.get_figure()
            fig_lb.set_size_inches(w, h)
            fig_lb.savefig(temp.name + f"\\img2{colName}.svg", bbox_inches="tight")
            fig_lb.savefig(
                temp.name + f"\\img2{colName}.png", bbox_inches="tight", format="png"
            )
            plt.clf()
            self.statusBar().showMessage("2")
            print("2")

            plot_rt = plot_top5_centered_importance(self.model, cur_data, colName)
            fig_rt = plot_rt.get_figure().figure
            fig_rt.set_size_inches(w, h)
            fig_rt.savefig(temp.name + f"\\img0{colName}.svg", bbox_inches="tight")
            fig_rt.savefig(
                temp.name + f"\\img0{colName}.png", bbox_inches="tight", format="png"
            )
            plt.clf()
            self.statusBar().showMessage("3")
            print("3")

            plot_rb = plot_ice_plot(self.model, cur_data, colName)
            fig_rb = plot_rb.get_figure().figure
            fig_rb.set_size_inches(w, h)
            fig_rb.savefig(temp.name + f"\\img3{colName}.svg", bbox_inches="tight")
            fig_rb.savefig(
                temp.name + f"\\img3{colName}.png", bbox_inches="tight", format="png"
            )
            plt.clf()
            self.statusBar().showMessage("5")
            print("4")

            self.show_grath(colName)

        self.is_plots_in_progress = False

    def __save_category_paths(
        self, full_paths: list[[list[str]]], settings: PlotSettings
    ) -> None:
        name = self.__get_name_for_settings(settings)
        for index, path in enumerate(full_paths, start=1):
            self.__paths[f"{name} {index}"] = path

    def __save_regular_paths(self, paths: list[str], settings: PlotSettings) -> None:
        name = self.__get_name_for_settings(settings)
        self.__paths[name] = paths

    def __get_name_for_settings(self, settings: PlotSettings) -> str:
        name = f"Колонка {settings.column} от {settings.min_value} и до {settings.max_value}"

        if settings.category_column:
            name += f" по категории {settings.category_column}"

        return name

    def __update_plot_names_in_combo_box(self):
        plot_names = list(self.__paths.keys())

        self.plots_combo_box.addItems(plot_names)
        self.plots_combo_box.setEnabled(True)
        self.plots_combo_box.setCurrentIndex(len(plot_names) - 1)

    def show_plots(self, plots_name: str):
        paths = self.__paths[plots_name]
        indexes = [(0, 0), (0, 1), (1, 0), (1, 1)]

        for path, position in zip(paths, indexes):
            pixmap = PlotContainer(path)
            pixmap.setMaximumSize(800, 600)
            pixmap.setStyleSheet(stylesheet)
            pixmap.setSizePolicy(
                QtWidgets.QSizePolicy.Preferred, QtWidgets.QSizePolicy.Preferred
            )
            widget_to_replace = self.gridLayout.itemAtPosition(*position)
            widget_to_replace.widget().setParent(None)
            self.gridLayout.addWidget(pixmap, *position)


NameError: name 'PlotSettings' is not defined

In [None]:
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap



def find_category_columns(data):
    data_copy = data.copy()
    categorical = []

    for col in data_copy.columns:
        if len(data_copy[col].unique()) < 10:
            categorical.append(col)
    return categorical


def getMinMax(data):
    data_copy = data.copy()
    ar = np.array(data_copy)
    minMax = [np.amin(ar, axis=0), np.amax(ar, axis=0)]
    return minMax


def top5_centered_importance(explainer, data, col_name):
    data_copy = data.copy()

    col_vals = create_variable_list(data[col_name])
    res_vals = []

    for val in col_vals:
        new_col = len(data_copy[col_name]) * [val]
        data_copy[col_name] = new_col
        shap_values = explainer.shap_values(
            data_copy.drop("Survived", axis=1), y=data_copy["Survived"]
        )
        res_vals.append(shap_values)

    return res_vals, col_vals


def top5_find_by_importance(explainer, data):
    shap_values = explainer.shap_values(
        data.drop("Survived", axis=1), y=data["Survived"]
    )

    col_names = []
    indexes = []
    mean_importance = list(np.mean(np.absolute(shap_values), axis=0))

    max_list = sorted(mean_importance)[-5:]
    for max_val in max_list:
        indexes.append(mean_importance.index(max_val))
        col_names.append(data.columns[mean_importance.index(max_val)])

    return col_names, indexes


def plot_top5_centered_importance(model, data, col_name, absolute=False):
    plot = plt.axes()
    plot.figure.set_size_inches(16, 8)

    if absolute:
        title = "Центрированный график изменения абсолютной важности переменных"
    else:
        title = "Центрированный график изменения важности переменных"
    plot.set_title(title, fontsize=18)

    explainer = shap.TreeExplainer(model)

    data_copy = data.copy()

    cols, indexes = top5_find_by_importance(explainer, data_copy)

    res_vals, col_vals = top5_centered_importance(explainer, data_copy, col_name)

    res_vals = np.array(res_vals)

    for i in range(0, len(indexes)):
        res = []

        for j in range(0, len(res_vals)):
            val = res_vals[j, :, indexes[i]]
            if absolute:
                val = np.absolute(val)
            res.append(val.mean())

        plot.plot(col_vals, res, color=COLORS[i], linewidth=4, label=cols[i])

    plot.grid()
    plot.set_xlabel(col_name, fontsize=16)
    plot.set_ylabel("Важность переменных", fontsize=16)
    plot.legend()

    return plot


def ice_plot_data_y(model, data, col_name):
    data_copy = data.copy()

    col_vals = create_variable_list(data[col_name])
    res_vals = []

    for val in col_vals:
        new_col = len(data_copy[col_name]) * [val]
        data_copy[col_name] = new_col
        predict = model.predict_proba(data_copy.drop("Survived", axis=1))[:, 1]
        res_vals.append(predict)

    return res_vals, col_vals


def create_variable_list(col):
    min_val = col.min()
    max_val = col.max()

    unique = col.unique()

    if len(unique) < 50:
        col_vals = sorted(list(unique))
    else:
        delta = (max_val - min_val) / 100
        col_vals = []
        while min_val <= max_val:
            col_vals.append(min_val)
            min_val += delta
    return col_vals


def ice_plot_data_importance(explainer, data, col_name):
    data_copy = data.copy()

    col_vals = create_variable_list(data[col_name])
    res_vals = []

    for val in col_vals:
        new_col = len(data_copy[col_name]) * [val]
        data_copy[col_name] = new_col

        shap_values = explainer.shap_values(
            data_copy.drop("Survived", axis=1), y=data_copy["Survived"]
        )
        shap_values = np.array(shap_values)

        res_vals.append(shap_values[:, list(data.columns).index(col_name)])

    return res_vals, col_vals


def plot_ice_plot(model, data, col_name, importance=False):
    plot = plt.axes()

    if importance:
        explainer = shap.TreeExplainer(model)
        res_vals, col_vals = ice_plot_data_importance(explainer, data, col_name)
        y_label = f"Важность переменной {col_name}"
        title = f"с-ICE график изменения важности переменной {col_name}"
    else:
        res_vals, col_vals = ice_plot_data_y(model, data, col_name)
        y_label = "Вероятность удачного исхода"
        title = f"с-ICE график вероятности удачного исхода при изменении переменной {col_name}"

    df = pd.DataFrame(np.array(res_vals))
    df = df.T
    mean = df.mean()

    plot.figure.set_size_inches(16, 8)
    plot.set_title(title, fontsize=18)

    for i in df.index:
        plot.plot(col_vals, df.loc[i], color="black", linewidth=0.1)

    plot.plot(col_vals, mean, color="lime", linewidth=6)

    plot.grid()
    plot.set_xlabel(col_name, fontsize=16)
    plot.set_ylabel(y_label, fontsize=16)

    return plot


In [2]:
from PyQt5 import QtGui


stylesheet = """
    border-bottom-width: 1px;
    border-bottom-style: solid;
    border-top-width: 1px;
    border-top-style: solid;
    border-left-width: 1px;
    border-left-style: solid;
    border-right-width: 1px;
    border-right-style: solid;
    border-radius: 0px;
"""

COLORS = ["brown", "teal", "blue", "coral", "limegreen", "pink", "olive", "navy", "red"]
FONT_SIZE = 18
FIG_COUNT = 1
FIG_SIZE = (1000, 1000)
DPI = 1

save_complete_message = "Сохранение завершено"

dataset_load_error = "Загруженный объект не является выборкой"
model_load_error = "Загруженный объект не является выборкой"

data_is_awaited_message = "Ожидается загрузка датасета и модели"
only_model_loaded_message = "Загрузка модели произведена. Ожидается загрузка выборки"
only_dataset_loaded_message = "Загрузка датасета произведена. Ожидается загрузка модели"
dataset_and_model_loaded_message = "Датасет и модель загружены"

no_plots_message = "Для сохранения постройте графики"

debug_file_path = "../data/obj_v2"


qt_color_white = QtGui.QColor(255, 255, 255)
qt_color_beige = QtGui.QColor(235, 204, 153)


PLOT_WIDTH = 16
PLOT_HEIGHT = 8

TOP5_CENTERED_IMPORTANCE_TITLE = "Центрированный график изменения важности переменных"

ICE_IMPORTANCE_Y_LABEL = "Важность переменной {}"
ICE_IMPORTANCE_TITLE = "с-ICE график изменения важности переменной {}"

ICE_PREDICTIONS_Y_LABEL = "Вероятность удачного исхода при изменении переменной {}"
ICE_PREDICTIONS_TITLE = (
    "с-ICE график вероятности удачного исхода при изменении переменной {}"
)


In [3]:
from xgboost import XGBClassifier
import shap
from pandas import DataFrame, Series
import numpy as np
from app.schemes.model_explainer import MostImportantColumns


class ModelExplainer:
    def __init__(self, model: XGBClassifier) -> None:
        self.__model = model
        self.__explainer = shap.TreeExplainer(model)
        self.__most_important_columns = []
        self.__importance = {}
        self.__predicts = {}

    def get_n_most_important_columns(self, n: int) -> list[MostImportantColumns]:
        return self.__most_important_columns[:-n]

    def get_centered_importance(self):
        return self.__importance

    def get_ice_importance(self):
        column_index = list(self.__dataset.columns).index(self.__column_to_vary)

        return {
            key: shap_values[:, column_index]
            for key, shap_values in self.__importance.items()
        }

    def get_ice_predictions(self):
        return self.__predicts

    def calculate_for_dataset(self, dataset: DataFrame, column_to_vary: str) -> None:
        self.__dataset = dataset
        self.__column_to_vary = column_to_vary

        self.__calculate_most_important_columns()
        self.__calculate_importance()

    def __calculate_most_important_columns(self) -> None:
        shap_values = self.__get_shap_values(self.__dataset)

        columns = []
        mean_importance = list(np.mean(np.absolute(shap_values), axis=0))

        for importance in sorted(mean_importance, reverse=True):
            index = mean_importance.index(importance)
            columns.append(
                MostImportantColumns(index=index, name=self.__dataset.columns[index])
            )

        self.__most_important_columns = columns

    def __calculate_importance(self) -> None:
        dataset_copy = self.__dataset.copy()
        column_to_vary = self.__column_to_vary

        importance = {}
        predicts = {}

        dataset_len = dataset_copy[column_to_vary].count()

        for value in self.__get_variables_to_vary(dataset_copy[column_to_vary]):
            dataset_copy[column_to_vary] = dataset_len * [value]

            shap_values = self.__get_shap_values(dataset_copy)
            predict_values = self.__get_predict_values(dataset_copy)

            importance[value] = shap_values
            predicts[value] = predict_values

        self.__importance = importance
        self.__predicts = predicts

    def __get_variables_to_vary(self, column: Series) -> list[float]:
        min_val = column.min()
        max_val = column.max()

        unique = column.unique()

        if len(unique) < 50:
            col_vals = sorted(list(unique))
        else:
            delta = (max_val - min_val) / 100
            col_vals = []

            while min_val <= max_val:
                col_vals.append(min_val)
                min_val += delta

        return list(col_vals)

    def __get_shap_values(self, dataset: DataFrame):
        return np.array(
            self.__explainer.shap_values(
                dataset.drop("Survived", axis=1), y=dataset["Survived"]
            )
        )

    def __get_predict_values(self, dataset: DataFrame):
        return np.array(
            self.__model.predict_proba(dataset.drop("Survived", axis=1))[:, 1]
        )


Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)


ModuleNotFoundError: No module named 'app'

In [None]:
from ..