In [None]:
import numpy as np
from pandas import DataFrame


class DataHelper:
    @staticmethod
    def get_dataset_min_max(dataset: DataFrame):
        array = np.array(dataset)
        return [np.amin(array, axis=0), np.amax(array, axis=0)]

    @staticmethod
    def find_category_columns(dataset: DataFrame):
        categorical = []

        for col in dataset.columns:
            if len(dataset[col].unique()) < 10:
                categorical.append(col)

        return categorical


In [None]:
import os
import shutil


class CreatedPlotsSaver:
    @staticmethod
    def save_plots(from_path: str, to_path: str) -> None:
        files = os.listdir(from_path)
        for file in files:
            if ".png" in file:
                shutil.copy2(f"{from_path}\\{file}", f"{to_path}\\{file}")


In [None]:
from enum import Enum
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.neural_network import MLPClassifier
from typing import Union


class DefaultModelType(Enum):
    FOREST = (
        RandomForestClassifier,
        {
            "n_estimators": 40,
            "random_state": 42,
            "criterion": "gini",
            "max_depth": None,
            "min_samples_split": 2,
            "min_samples_leaf": 1,
            "min_weight_fraction_leaf": 0.0,
            "max_features": "sqrt",
        },
    )
    XGBOOST = (
        XGBClassifier,
        {
            "n_estimators": 100,
            "random_state": 42,
            "max_depth": None,
        },
    )
    MLP = (
        MLPClassifier,
        {
            "hidden_layer_sizes": (100,),
            "activation": "relu",
            "solver": "lbfgs",
            "alpha": 1e-5,
            "batch_size": "auto",
        },
    )


class DefaultModelCreator:
    def __init__(
        self,
        model_type: DefaultModelType,
    ) -> None:
        self._model_class = model_type.value[0]
        self._model_params = model_type.value[1]

        self._model_fitted = False

    def fit_model(
        self, X: pd.DataFrame, y: pd.DataFrame
    ) -> Union[RandomForestClassifier, XGBClassifier, MLPClassifier]:
        model = self._init_model()
        fitted_model = self._fit_model(model, X, y)
        self._model = fitted_model
        return fitted_model

    def get_model_accuracy(self):
        if self._model_fitted:
            return {
                "train": self._model.score(self._X_train, self._y_train),
                "test": self._model.score(self._X_test, self._y_test),
            }

    def _init_model(self):
        model = self._model_class(**self._model_params)
        return model

    def _fit_model(self, model, X, y):
        self._X_train, self._X_test, self._y_train, self._y_test = (
            self._get_splited_dataset(X, y)
        )
        fitted_model = model.fit(self._X_train, self._y_train)
        self._model_fitted = True
        return fitted_model

    @staticmethod
    def _get_splited_dataset(X, y):
        return train_test_split(X, y, test_size=0.15, random_state=42)


In [None]:
import seaborn as sns
from pandas import DataFrame


class DefaultDatasetCreator:
    def __init__(self) -> None:
        self.dataset = self._get_default_dataset()
        self.X = self.dataset.drop("survived", axis=1)
        self.y = self.dataset["survived"]

    def get_dataset(self) -> DataFrame:
        return self.dataset

    def get_X(self):
        return self.X

    def get_y(self):
        return self.y

    def _get_default_dataset(self):
        dataset = self._get_original_dataset()
        dataset = self._drop_extra_fields(dataset)
        dataset = self._fill_missing_values(dataset)
        dataset = self._map_not_numeric_feature(dataset)

        return dataset

    @staticmethod
    def _get_original_dataset():
        dataset = sns.load_dataset("titanic")

        return dataset

    @staticmethod
    def _drop_extra_fields(dataset):
        dataset = dataset.drop("embark_town", axis=1)
        dataset = dataset.drop("alive", axis=1)
        dataset = dataset.drop("who", axis=1)
        dataset = dataset.drop("pclass", axis=1)
        dataset = dataset.drop("deck", axis=1)
        dataset = dataset.drop("parch", axis=1)
        dataset = dataset.drop("adult_male", axis=1)

        return dataset

    @staticmethod
    def _fill_missing_values(dataset):
        dataset["age"] = dataset["age"].fillna(dataset["age"].mean())
        freq_port = dataset.embarked.dropna().mode()[0]
        dataset["embarked"] = dataset["embarked"].fillna(freq_port)

        return dataset

    @staticmethod
    def _map_not_numeric_feature(dataset):
        params = {
            "sex": {"female": 1, "male": 0},
            "class": {"Third": 3, "Second": 2, "First": 1},
            "alone": {True: 1, False: 0},
            # "adult_male": {True: 1, False: 0},
            "embarked": {"S": 0, "C": 1, "Q": 2},
        }
        dataset["sex"] = dataset["sex"].map(params["sex"]).astype(int)
        dataset["class"] = dataset["class"].map(params["class"]).astype(int)
        dataset["alone"] = dataset["alone"].map(params["alone"]).astype(int)
        # dataset["adult_male"] = (
        #     dataset["adult_male"].map(params["adult_male"]).astype(int)
        # )
        dataset["embarked"] = dataset["embarked"].map(params["embarked"]).astype(int)

        return dataset


In [None]:
from core.services import DefaultDatasetCreator, DefaultModelType, DefaultModelCreator


class DebugStarter:
    def __init__(self, model: str):
        if model == "forest":
            model_type = DefaultModelType.FOREST
        elif model == "neuro":
            model_type = DefaultModelType.MLP
        elif model == "boost":
            model_type = DefaultModelType.XGBOOST
        else:
            model_type = DefaultModelType.FOREST

        self._dataset_creator = DefaultDatasetCreator()
        self._model_creator = DefaultModelCreator(model_type)

    def get_dataset_and_model(self):
        dataset = self._get_dataset()
        model = self._get_model()

        return dataset, model

    def _get_dataset(self):
        dataset = self._dataset_creator.get_dataset()

        return dataset

    def _get_model(self):
        X = self._dataset_creator.get_X()
        y = self._dataset_creator.get_y()
        model = self._model_creator.fit_model(X, y)

        return model


In [None]:
from typing import Optional

from pydantic import BaseModel


class PlotSettings(BaseModel):
    min_value: float
    max_value: float
    column: str
    category_column: Optional[str] = None


In [None]:
from PyQt5 import QtGui


STYLESHEET = """
    border-bottom-width: 1px;
    border-bottom-style: solid;
    border-top-width: 1px;
    border-top-style: solid;
    border-left-width: 1px;
    border-left-style: solid;
    border-right-width: 1px;
    border-right-style: solid;
    border-radius: 0px;
"""


COLORS = ["brown", "teal", "blue", "coral", "limegreen", "pink", "olive", "navy", "red"]
FONT_SIZE = 18
FIG_COUNT = 1

PLOT_WIDTH = 12
PLOT_HEIGHT = 6

FIG_WIDTH = PLOT_WIDTH * 100
FIG_HEIGHT = PLOT_HEIGHT * 100

FIG_SIZE = (FIG_WIDTH, FIG_HEIGHT)
DPI = 2


SAVE_COMPLETE_MESSAGE = "Сохранение завершено"

DATASET_LOAD_ERROR_MESSAGE = "Загруженный объект не является выборкой"
MODEL_LOAD_ERROR = "Загруженный объект не является выборкой"

DATA_IS_AWAITED_MESSAGE = "Ожидается загрузка датасета и модели"
ONLY_MODEL_LOADED_MESSAGE = "Загрузка модели произведена. Ожидается загрузка выборки"
ONLY_DATASET_LOADED_MESSAGE = "Загрузка датасета произведена. Ожидается загрузка модели"
DATASET_AND_MODEL_LOADED_MESSAGE = "Датасет и модель загружены"

NO_PLOTS_MESSAGE = "Для сохранения постройте графики"

DEBUG_FILE_PATH = "../data/obj_v2"

r_w = 255
g_w = 255
b_w = 255

r_b = 235
g_b = 204
b_b = 153

QT_COLOR_WHITE = QtGui.QColor(r_w, g_w, b_w)
QT_COLOR_BEIGE = QtGui.QColor(r_b, g_b, b_b)

CHAT_BASE_STYLESHEET = "border: 1px solid black; background-color: rgb({}, {}, {}); "
CHAT_USER_STYLESHEET = CHAT_BASE_STYLESHEET.format(r_w, g_w, b_w)
CHAT_LLM_STYLESHEET = CHAT_BASE_STYLESHEET.format(r_b, g_b, b_b)


TOP5_CENTERED_IMPORTANCE_TITLE = "Центрированный график изменения важности переменных"

ICE_IMPORTANCE_Y_LABEL = "Важность переменной {}"
ICE_IMPORTANCE_TITLE = "с-ICE график изменения важности переменной {}"
BASE_IMPORTANCE_TITLE = "График важности переменной {} для исходных данных"

ICE_PREDICTIONS_Y_LABEL = "Вероятность удачного исхода при изменении переменной {}"
ICE_PREDICTIONS_TITLE = (
    "с-ICE график вероятности удачного исхода при изменении переменной {}"
)

PLOT_SETTINGS_ERROR_MESSAGE = (
    "Именна переменной варьирования и категориальной переменной не могут совпадать"
)


TOKEN = "y0_AgAAAAAdy8jaAATuwQAAAAEB9i4DAADeGpmlxK9DS5qcedkm1T8SG3Xh2A"
TOKEN_URL = "https://iam.api.cloud.yandex.net/iam/v1/tokens"
LLM_FOLDER = "b1gqlhmqdst9ejbptsnp"
LLM_URL = "https://llm.api.cloud.yandex.net/foundationModels/v1/completion"

LLM_SYSTEM_PROMPT = """
Ты опытный аналитик, который профессионально умеет искать внутренние закономерности в данных. 
Тебе передали набор правил, которые описывают зависимость значений нескольких переменных от параметра PROBA. 
Чем выше значение параметра PROBA, тем выше вероятность отнесения конкретной записи к категории GOOD.
Необходимо проанализировать правила и сделать следующие действия:
1. Приведи пример двух записей с высокой вероятностью попадания в категорию GOOD
2. Приведи пример двух записей со средней вероятностью попадания в категорию GOOD
3. Приведи пример двух записей с низкой вероятностью попадания в категорию GOOD
4. Приведи пример записей с самой низкой вероятностью попадания в категорию GOOD и с самой высокой
Обязательно укажи вероятность в скобках у примеров
Пиши краткое и по существу, основываясь на переданных данных, 
не пиши о предварительных данных и что нужно больше информации. 
Сравнение проводи по строгим математическим правилам
"""

LLM_SYSTEM_PROMPT_FOR_USER = """
Проанализируй правила правила и приведи примеры:
1. Приведи пример двух записей с высокой вероятностью попадания в категорию GOOD
2. Приведи пример двух записей со средней вероятностью попадания в категорию GOOD
3. Приведи пример двух записей с низкой вероятностью попадания в категорию GOOD
4. Приведи пример записей с самой низкой вероятностью попадания в категорию GOOD и с самой высокой
"""


In [None]:
import tempfile

from app.constants import (
    STYLESHEET,
    DATASET_AND_MODEL_LOADED_MESSAGE,
    DATA_IS_AWAITED_MESSAGE,
    ONLY_DATASET_LOADED_MESSAGE,
    ONLY_MODEL_LOADED_MESSAGE,
    SAVE_COMPLETE_MESSAGE,
    NO_PLOTS_MESSAGE,
    FIG_WIDTH,
    FIG_HEIGHT,
    LLM_SYSTEM_PROMPT_FOR_USER,
)
from core.schemes import PlotSettings
from core.debug_starter import DebugStarter
from app.services.created_plots_saver import CreatedPlotsSaver
from core.services.data_helper import DataHelper
from pandas import DataFrame
from xgboost import XGBClassifier
from core.services.pickle_service import PickleService
from PyQt5 import QtCore, QtWidgets, uic
from core.schemes.pickled_data import DatasetModelMonoObject
from app.services.qt_helper import QtHelper
from app.components.plot_settings_app import PlotDataDialog
from app.components.ai_settings_app import AISettingDialog
from app.components.plot_container import PlotContainer
from app.services.dataset_renderer import DatasetRendered
from app.services.plot_creator import PlotCreator
from core.services.llm_controller import LLMController
from core.services.simple_tree_model_fitter import SimpleTreeModelFitter
from core.services.model_rules_aggregator import ModelRulesAggregator
from app.components.text_chat_scroll_widget import TextChatScrollArea


class MainApp(QtWidgets.QMainWindow):
    def __init__(self, parent=None):
        super(MainApp, self).__init__()

        self.__qt_helper = QtHelper()
        self.__pickle_service = PickleService()
        self.__data_helper = DataHelper()

        uic.loadUi("../ui/main2.ui", self)

        self.open_file_action.setStatusTip("Открыть файл проекта")
        self.open_file_action.triggered.connect(self.open_saved_file)

        self.save_file_action.setStatusTip("Сохранить файл модели и выборки")
        self.save_file_action.triggered.connect(self.save_file)

        self.load_model_action.setStatusTip("Открыть файл модели")
        self.load_model_action.triggered.connect(self.open_model)

        self.load_dataset_action.setStatusTip("Открыть файл выборки")
        self.load_dataset_action.triggered.connect(self.open_dataset)

        self.save_plot_action.triggered.connect(self.save_plots)

        self.push_button_plot.clicked.connect(self.make_plots)

        self.user_input_button.clicked.connect(self.send_chat_request)
        self.llm_init_button.clicked.connect(self.init_llm_button)

        self._widget = QtWidgets.QWidget()
        self._layout = self.horizontalLayout_8
        self._scroll = TextChatScrollArea()
        self._layout.addWidget(self._scroll)

        self._widget.setLayout(self._layout)

        self.plots_combo_box.currentIndexChanged.connect(self.switch_plots)

        self.statusBar().showMessage("Ожидается загрузка модели и датасета")

        self.is_model_loaded = False
        self.is_dataset_loaded = False

        self.__model = None
        self.__dataset = None

        self.__paths: dict[str, list[str]] = {}

        self.temp_dir = tempfile.TemporaryDirectory()

        self.setWindowState(QtCore.Qt.WindowMaximized)
        self.debug_start()
        self.show()

    @property
    def dataset(self) -> DataFrame:
        return self.__dataset

    @property
    def clean_dataset(self) -> DataFrame:
        return self.__dataset.drop('survived', axis=1)

    @dataset.setter
    def dataset(self, dataset) -> None:
        self.is_dataset_loaded = True
        self.__dataset = dataset

        self.create_columns()
        self.create_rows()
        self.show_data_status()

    @property
    def model(self) -> XGBClassifier:
        return self.__model

    @model.setter
    def model(self, model) -> None:
        self.is_model_loaded = True
        self.__model = model
        self.show_data_status()

    def closeEvent(self, event):
        self.temp_dir.cleanup()

    def debug_start(self):
        debug_starter = DebugStarter("boost")
        dataset, model = debug_starter.get_dataset_and_model()
        self.dataset = dataset
        self.model = model

    def show_data_status(self) -> None:
        if self.is_dataset_loaded and self.is_model_loaded:
            self.statusBar().showMessage(DATASET_AND_MODEL_LOADED_MESSAGE)
        elif self.is_dataset_loaded:
            self.statusBar().showMessage(ONLY_DATASET_LOADED_MESSAGE)
        elif self.is_model_loaded:
            self.statusBar().showMessage(ONLY_MODEL_LOADED_MESSAGE)
        else:
            self.statusBar().showMessage(DATA_IS_AWAITED_MESSAGE)

    def open_saved_file(self, path: str = ""):
        if not path:
            file_name = self.__qt_helper.get_path_to_open_file(self)
        else:
            file_name = path

        if not file_name:
            return

        mono_object = self.__pickle_service.get_dataset_and_model(file_name)

        self.model = mono_object.model
        self.dataset = mono_object.dataset

    def save_file(self):
        if not all([self.is_dataset_loaded, self.is_model_loaded]):
            self.show_data_status()
            return

        file_name = self.__qt_helper.get_path_to_save_file(self)

        if not file_name:
            return

        mono = DatasetModelMonoObject(dataset=self.__dataset, model=self.model)
        self.__pickle_service.save_dataset_and_model(mono, file_name)

        self.statusBar().showMessage(SAVE_COMPLETE_MESSAGE)

    def open_dataset(self):
        path = self.__qt_helper.get_path_to_open_file(self)

        if not path:
            return None

        self.dataset = self.__pickle_service.get_dataset(path)

    def create_columns(self):
        layout = DatasetRendered(self.dataset).get_rendered_info_plots_layout()
        self.tabCollumns.setLayout(layout)

    def create_rows(self):
        layout = DatasetRendered(self.dataset).get_rendered_data_layout()
        self.tabData.setLayout(layout)

    def open_model(self):
        file_name = self.__qt_helper.get_path_to_open_file(self)

        if not file_name:
            return

        self.model = self.__pickle_service.get_model(file_name)

    def save_plots(self):
        if self.plots_combo_box.count() == 0:
            self.statusBar().showMessage(NO_PLOTS_MESSAGE)
            return

        path = self.__qt_helper.get_existing_dir_path(self)

        if not path:
            return

        CreatedPlotsSaver.save_plots(self.temp_dir.name, path)

    def switch_plots(self):
        if not self.is_plots_in_progress:
            text = self.plots_combo_box.currentText()
            if text:
                self.show_plots(text)

    def make_plots(self):
        if not self.is_dataset_loaded and self.is_model_loaded:
            self.show_data_status()

        dialog = QtWidgets.QDialog()
        self.ui = PlotDataDialog()
        min_max = self.__data_helper.get_dataset_min_max(self.clean_dataset)
        category_columns = self.__data_helper.find_category_columns(self.clean_dataset)
        self.ui.setupUi(
            dialog,
            list(self.clean_dataset.columns),
            category_columns,
            min_max,
            self.retrieve_data_from_child,
        )
        dialog.show()

    def retrieve_data_from_child(self, settings: PlotSettings):
        self.is_plots_in_progress = True

        plot_creator = PlotCreator(
            self.model,
            self.dataset,
            settings.column,
            settings.min_value,
            settings.max_value,
            self.temp_dir,
        )

        if settings.category_column:
            full_paths = plot_creator.plot_category_plots(settings.category_column)
            self.__save_category_paths(full_paths, settings)
        else:
            paths = plot_creator.plot_regular_plots()
            self.__save_regular_paths(paths, settings)

        self.is_plots_in_progress = False

        self.__update_plot_names_in_combo_box()
        self.switch_plots()

    def __save_category_paths(
        self, full_paths: list[[list[str]]], settings: PlotSettings
    ) -> None:
        name = self.__get_name_for_settings(settings)
        for index, path in enumerate(full_paths, start=1):
            self.__paths[f"{name} {index}"] = path

    def __save_regular_paths(self, paths: list[str], settings: PlotSettings) -> None:
        name = self.__get_name_for_settings(settings)
        self.__paths[name] = paths

    def __get_name_for_settings(self, settings: PlotSettings) -> str:
        name = f"Колонка {settings.column} от {settings.min_value} и до {settings.max_value}"

        if settings.category_column:
            name += f" по категории {settings.category_column}"

        return name

    def __update_plot_names_in_combo_box(self):
        plot_names = list(self.__paths.keys())

        self.plots_combo_box.clear()
        self.plots_combo_box.addItems(plot_names)
        self.plots_combo_box.setEnabled(True)
        self.plots_combo_box.setCurrentIndex(len(plot_names) - 1)

    def show_plots(self, plots_name: str):
        paths = self.__paths[plots_name]
        indexes = [(0, 0), (0, 1), (1, 0), (1, 1)]

        for path, position in zip(paths, indexes):
            pixmap = PlotContainer(path)
            pixmap.setMaximumSize(FIG_WIDTH, FIG_HEIGHT)
            pixmap.setStyleSheet(STYLESHEET)
            pixmap.setSizePolicy(
                QtWidgets.QSizePolicy.Maximum, QtWidgets.QSizePolicy.Maximum
            )
            widget_to_replace = self.gridLayout.itemAtPosition(*position)
            widget_to_replace.widget().setParent(None)
            self.gridLayout.addWidget(pixmap, *position)

    def send_chat_request(self):
        text = self.user_input_widget.document().toPlainText()
        if text:
            self._scroll.add_text(text)

            response = self.llm_controller.get_answer(text)
            self._scroll.add_text(response)

            scroll_bar = self._scroll.verticalScrollBar()
            scroll_bar.setValue(scroll_bar.maximum() + 500)

    def init_llm_button(self):
        if not self.is_dataset_loaded and self.is_model_loaded:
            self.show_data_status()

        dialog = QtWidgets.QDialog()
        self.ui = AISettingDialog()
        self.ui.setupUi(dialog, self._init_llm_button_callback)
        dialog.show()

    def _init_llm_button_callback(self, temperature, depth):
        fitter = SimpleTreeModelFitter(self.model, self.dataset, depth)
        tree = fitter.get_simple_tree()
        columns = fitter.get_X().columns

        aggregator = ModelRulesAggregator(tree, columns)
        rules = aggregator.get_formatted_rules()

        self.llm_controller = LLMController(temperature)
        self.proba_text.setText(rules)
        self._scroll.add_text(LLM_SYSTEM_PROMPT_FOR_USER)

        response = self.llm_controller.get_answer(rules)
        self._scroll.add_text(response)


In [None]:
import sys
from PyQt5 import QtWidgets

from app.components.main_app import MainApp


def dalex_test():
    import dalex as dx
    from core.debug_starter import DebugStarter
    starter = DebugStarter('forest')
    df, model = starter.get_dataset_and_model()
    X = df.drop('survived', axis=1)
    y = df['survived']

    titanic_rf_exp = dx.Explainer(model, X, y, label="Titanic RF Pipeline")

    test = X.iloc[0]
    bd_henry = titanic_rf_exp.predict_parts(test, type='break_down')
    print('ts')


def get_GP_plots():
    from libs.XAI.src.xai import GP
    from core.debug_starter import DebugStarter
    import pathlib
    import os

    if sys.platform == 'win32':
        path = pathlib.Path(r'C:\Program Files\Graphviz\bin')
        if path.is_dir() and str(path) not in os.environ['PATH']:
            os.environ['PATH'] += f';{path}'

    starter = DebugStarter('forest')
    df, model = starter.get_dataset_and_model()
    X = df.drop('survived', axis=1)
    # Train and predict with blackbox
    predictions = [[x] for x in model.predict(X.values)]

    # Use GP to make an approximation of the blackbox predictions
    explainer = GP(max_trees=100, num_generations=1)
    explainer.fit(X.values, predictions)

    # Save our approximations
    explainer.plot("./model.png")
    explainer.plot_pareto("./frontier.png")


def main():
    app = QtWidgets.QApplication(sys.argv)
    main_app = MainApp()
    app.exec_()
    # dalex_test()


if __name__ == "__main__":
    main()


