In [None]:
import sys
from pathlib import Path
sys.path.append(f"{Path().absolute().parent}")

In [None]:
import logging
import os
from enum import Enum
from typing import Dict, List, Optional, Tuple

from radp.digital_twin.utils.gis_tools import GISTools

logging.basicConfig(level=logging.INFO)
import gpytorch
import numpy as np
import pandas as pd
import torch
from radp.digital_twin.utils.constants import (
    ANTENNA_GAIN,
    CELL_CARRIER_FREQ_MHZ,
    CELL_EL_DEG,
    CELL_ID,
    CELL_LAT,
    CELL_LON,
    CELL_RXPWR_DBM,
    HRX,
    HTX,
    LOG_DISTANCE,
    RELATIVE_BEARING,
    RXPOWER_DBM,
    RXPOWER_STDDEV_DBM,
    SIM_IDX,
)


class NormMethod(Enum):
    MINMAX = "minmax"  # {value - min}/{max - min}
    ZSCORE = "zscore"  # {value - mean}/{std}


class ExactGPModel(gpytorch.models.ExactGP):
    # We will use the simplest form of GP model, exact inference
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([1]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([1])),
            batch_shape=torch.Size([1]),
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class BayesianDigitalTwin:
    def __init__(
        self,
        data_in: List[pd.DataFrame],
        x_columns: List[str],
        y_columns: List[str],
        norm_method: NormMethod = NormMethod.MINMAX,
        x_max: Optional[Dict[str, float]] = None,
        x_min: Optional[Dict[str, float]] = None,
    ):
        """

        `data_in` is a list of Pandas dataframes, where each one corresponds to
        training data for one cell. `stats` is a list of Pandas dataframes, of the
        same length as `data_in`, where each contains statistics of the corresponding
        cell, to be used for pre-training normalization and post-prediction de-normalization.

        `x_columns` specifies the columns to train on, and `y_columns` specifies the columns
        to predict. These must be present in `data_in`.

        `x_max` and `x_min` contains optional user-specified max and min values for columns
        in `data_in` -- if these are provided, they are used instead of observed ranges
        during input pre-training normalization.

        `data_in` and `stats` may be constructed using `get_percell_data`.
        """
        self.is_cuda = False
        if torch.cuda.is_available():
            torch.cuda.set_device(0)
            self.is_cuda = True

        self.cell_ids = [data_in_cell.cell_id.unique()[0] for data_in_cell in data_in]
        self.num_cells = len(self.cell_ids)
        self.x_columns = x_columns
        self.y_columns = y_columns
        self.num_features = len(self.x_columns)
        self.n_train = data_in[0].shape[0]

        self.cell_stats = [cell_data.describe() for cell_data in data_in]
        self.xmeans = []
        self.xstds = []
        self.xmax = []
        self.xmin = []
        self.ymeans = []
        self.ystds = []

        self.norm_method = norm_method

        # Create Gaussian Process Regression (GPR) model; independent outputs.
        # produce normalization ranges
        for m in range(self.num_cells):
            self.xmax.append(self.cell_stats[m].loc["max", self.x_columns])
            self.xmin.append(self.cell_stats[m].loc["min", self.x_columns])
            # if explicitly provided, replace and use instead of empirical ranges
            if x_max:
                for k, v in x_max.items():
                    if k in self.x_columns:
                        self.xmax[m][k] = v
            if x_min:
                for k, v in x_min.items():
                    if k in self.x_columns:
                        self.xmin[m][k] = v
            self.xmeans.append(self.cell_stats[m].loc["mean", self.x_columns])
            self.xstds.append(self.cell_stats[m].loc["std", self.x_columns])
            self.ymeans.append(self.cell_stats[m].loc["mean", self.y_columns])
            self.ystds.append(self.cell_stats[m].loc["std", self.y_columns])

        # create training tensors

        #train_X, train_Y = self._create_training_tensors(self)

        train_X = torch.zeros(
            [self.num_cells, self.n_train, self.num_features], dtype=torch.float32
        )
        train_Y = torch.zeros([self.num_cells, self.n_train], dtype=torch.float32)

        for m in range(self.num_cells):
            if self.norm_method == NormMethod.MINMAX:
                train_x_cell = (data_in[m][self.x_columns] - self.xmin[m]) / (
                    self.xmax[m] - self.xmin[m]
                )
            elif self.norm_method == NormMethod.ZSCORE:
                train_x_cell = (
                    data_in[m][self.x_columns] - self.xmeans[m]
                ) / self.xstds[m]

            train_X_cell = torch.tensor(
                train_x_cell.iloc[:, :].values, dtype=torch.float32
            )

            train_y_cell = (data_in[m][self.y_columns] - self.ymeans[m]) / self.ystds[m]
            train_Y_cell = torch.tensor(
                train_y_cell.iloc[:, :].values, dtype=torch.float32
            )

            train_X[m] = train_X_cell.reshape(shape=(1, -1, self.num_features))
            train_Y[m] = torch.transpose(train_Y_cell, 0, 1)

        # initialize likelihood and model
        likelihood = gpytorch.likelihoods.GaussianLikelihood(
            batch_shape=torch.Size([self.num_cells])
        )

        self.model = ExactGPModel(train_X, train_Y, likelihood)

    @staticmethod
    def split_training_and_test_data(
        data_in: pd.DataFrame,
        n_sim: int,
        alpha: float,
    ) -> (List[pd.DataFrame], List[pd.DataFrame]):
        """Split the simulation data groups into test and training sets.

        data_in: aggregated simulation data in the form of a dataframe (test+train)
        n_sim: number of simulations aggregated in data_in
        alpha: percent of simulation runs used for training

        Each row of `data_in` corresponds to one pixel, and the columns
        are assumed to contain :
            - settings corresponding to one or more cells, with column names
            `settng_name_<n>` for different settings of interest and where
            n refers to the index of a cell in the modeled cluster.
            - `rx_loc1` and `rx_loc2`, the geo-coordinates for the pixel
            - `rxpower_dbm_<n>`, the received powers for the cell with index n
            - `rsrp_dbm` is the max power and `cell_id` is the cell index of that cell
            - `sim_idx` is the simulation index

        Example:
        ['cell_azimuth_deg_1', 'cell_azimuth_deg_2', 'cell_azimuth_deg_3',
        'cell_elec_tilt_deg_1', 'cell_elec_tilt_deg_2', 'cell_elec_tilt_deg_3',
        'cell_mech_tilt_deg_1', 'cell_mech_tilt_deg_2', 'cell_mech_tilt_deg_3',
        'cell_txpower_dbm_1', 'cell_txpower_dbm_2', 'cell_txpower_dbm_3',
        'rxpower_dbm_1', 'rxpower_dbm_2', 'rxpower_dbm_3', 'rsrp_dbm',
        'sinr_db', 'cell_id', 'rx_loc1', 'rx_loc2', 'sim_idx']

        """
        n_training_group = np.max([int(alpha * 0.01 * n_sim), 1])
        n_test_group = n_sim - n_training_group
        logging.info(
            f"Splitting data into {n_training_group} training and {n_test_group} test groups..."
        )
        training_data = data_in[data_in[SIM_IDX] > n_test_group].reset_index(drop=True)
        test_data = data_in[data_in[SIM_IDX] <= n_test_group].reset_index(drop=True)
        stats = data_in.describe(include="all")
        return training_data, test_data, stats, n_training_group

    @staticmethod
    def create_prediction_frames(
        site_config_df: pd.DataFrame,
        prediction_frame_template: pd.DataFrame,
    ) -> Dict[str, pd.DataFrame]:
        """
        `site_config_df` : 1 unique cell per row, contains at least the columns
            [cell_lat, cell_lon, cell_el_deg, cell_az_deg, cell_id]
            Assumption : `bayesian_digital_twin` was trained with respect tp `site_config_df`
        `prediction_frame_template` : 1 prediction point per row, contains columsn [loc_x, loc_y]
            e.g. loc_x is longitude, and loc_y is latitude
        """

        prediction_dfs: Dict[str, pd.DataFrame] = {}

        for c in site_config_df.itertuples():

            prediction_df = prediction_frame_template.copy()

            prediction_df[CELL_LAT] = c.cell_lat
            prediction_df[CELL_LON] = c.cell_lon
            prediction_df[CELL_EL_DEG] = c.cell_el_deg
            prediction_df[CELL_ID] = c.cell_id
            prediction_df[CELL_CARRIER_FREQ_MHZ] = c.cell_carrier_freq_mhz
            prediction_df[HTX] = c.hTx
            prediction_df[HRX] = c.hRx

            prediction_df[LOG_DISTANCE] = [
                GISTools.get_log_distance(
                    c.cell_lat,
                    c.cell_lon,
                    lat,
                    lon,
                )
                for lat, lon in zip(
                    prediction_frame_template.loc_y, prediction_frame_template.loc_x
                )
            ]

            prediction_df[RELATIVE_BEARING] = [
                GISTools.get_relative_bearing(
                    c.cell_az_deg,
                    c.cell_lat,
                    c.cell_lon,
                    lat,
                    lon,
                )
                for lat, lon in zip(
                    prediction_frame_template.loc_y, prediction_frame_template.loc_x
                )
            ]

            prediction_df[ANTENNA_GAIN] = GISTools.get_antenna_gain(
                c.hTx, c.hRx, prediction_df[LOG_DISTANCE], c.cell_el_deg
            )

            prediction_dfs[c.cell_id] = prediction_df
        return prediction_dfs

    @staticmethod
    def get_percell_data(
        data_in: pd.DataFrame,
        all_idxs: List[int],
        desired_idxs: List[int],
        n_samples: int,
        sample_cells_independently: bool = False,
        choose_strongest_samples_percell: bool = False,
        invalid_value: float = -500.0,
        seed: int = 0,
    ) -> (List[pd.DataFrame], List[pd.DataFrame]):
        """Split training data at cell level & compute per-cell statistics.

        data_in: training data across all cells and for one or more simulations
        desired_idxs: integer list of cell indexes to get data on
        n_samples: number of random samples per cell
        sample_cells_independently: True, if lat/lon to be sampled independently for each cell
        choose_strongest_samples_percell: if `sample_cells_independently` is set to True,
            this parameter governs whether cells are sampled randomly (if set to False),
            or if cells are sampled at points where they likely attached (if set to True)

        Each row of `data_in` corresponds to one pixel, and the columns
        are assumed to contain :
            - settings corresponding to one or more cells, with column names
            `settng_name_<n>` for different settings of interest and where
            n refers to the id of a cell in the modeled cluster.
            - `rx_loc1` and `rx_loc2`, the geo-coordinates for the pixel
            - `rxpower_dbm_<n>`, the received powers for the cell with id n
            - `rsrp_dbm` is the max power and `cell_id` is the cell index of that cell
            - `sim_idx` is the simulation index
        Example:
        ['cell_azimuth_deg_1', 'cell_azimuth_deg_2', 'cell_azimuth_deg_3',
        'cell_elec_tilt_deg_1', 'cell_elec_tilt_deg_2', 'cell_elec_tilt_deg_3',
        'cell_mech_tilt_deg_1', 'cell_mech_tilt_deg_2', 'cell_mech_tilt_deg_3',
        'cell_txpower_dbm_1', 'cell_txpower_dbm_2', 'cell_txpower_dbm_3',
        'rxpower_dbm_1', 'rxpower_dbm_2', 'rxpower_dbm_3', 'rsrp_dbm',
        'sinr_db', 'cell_id', 'rx_loc1', 'rx_loc2', 'sim_idx']

        The output contains a two lists, corresponding in order
        to the indices in `desired_idxs` :
            1. per cell data
            2. per cell statistics
        Each element of the per cell data list contains the following columns:
            - settings corresponding to given cell, with column names
            `settng_name` for different settings of interest
            - `rx_loc1` and `rx_loc2`, the geo-coordinates for the pixel
            - `rxpower_dbm`, the received power for the given cell
            - `rxpower_dbm_<n>`, the received powers for the cell with id n (other than given)
            - `rsrp_dbm` is the max power and `cell_id` is the cell id of that cell
            - `sim_idx` is the simulation index
        Example:
        ['cell_azimuth_deg', 'cell_elec_tilt_deg', 'cell_mech_tilt_deg',
        'cell_txpower_dbm', 'rxpower_dbm_1', 'rxpower_dbm_2', 'rxpower_dbm',
        'rsrp_dbm', 'sinr_db', 'cell_id', 'rx_loc1', 'rx_loc2', 'group',
        'sim_idx']
        """
        data_out = []
        stats = []

        data_in_sampled = data_in

        if not sample_cells_independently:
            # find pixels where all cells have valid values
            data_in_valid = data_in[
                data_in[data_in == invalid_value].count(axis=1) == 0
            ]
            # sample before splitting
            data_in_sampled = data_in_valid.sample(
                n=min(n_samples, len(data_in_valid)), random_state=(seed)
            )

        for m in desired_idxs:
            to_strip = []
            for n in all_idxs:
                # drop data from other cells
                if n != m:
                    to_strip.extend(list(data_in_sampled.filter(regex=f"cell_.+_{n}$")))
            data_cell = data_in_sampled.drop(
                to_strip,
                axis=1,
            )
            data_cell.columns = [
                col.replace(f"_{m}", "") if col.endswith(f"_{m}") else col
                for col in data_cell.columns
            ]

            data_cell_sampled = data_cell

            if sample_cells_independently:
                # filter out invalid values
                data_cell_valid = data_cell[data_cell.cell_rxpwr_dbm != invalid_value]
                if choose_strongest_samples_percell:
                    data_cell_sampled = data_cell_valid.sort_values(
                        CELL_RXPWR_DBM, ascending=False
                    ).head(n=min(n_samples, len(data_cell_valid)))
                else:
                    # get n_samples independent random samples inside training groups
                    data_cell_sampled = data_cell_valid.sample(
                        n=min(n_samples, len(data_cell_valid)), random_state=(seed + m)
                    )

            data_out.append(data_cell_sampled.reset_index(drop=True))

            stats_cell = data_cell_sampled.describe(include="all")
            stats.append(stats_cell)

        return data_out, stats

    def train_distributed_gpmodel(
        self,
        maxiter: int = 100,
        lr: float = 0.05,
        stopping_threshold: float = 1e-4,
        load_model: bool = False,
        save_model: bool = False,
        model_path: Optional[str] = None,
        model_name: Optional[str] = None,
    ) -> List[float]:
        loss_vs_iter = np.zeros(maxiter)
        # Train model
        if load_model:
            logging.info("Now loading GP model (this should be quick...)")
            state_dict = torch.load(model_path + model_name)
            self.model.load_state_dict(state_dict)
            if self.is_cuda:
                logging.info("Cuda enabled for model.")
                model = self.model.cuda()
        else:
            self.model.train()
            # "Loss" for GPs - the marginal log likelihood
            mll = gpytorch.mlls.ExactMarginalLogLikelihood(
                self.model.likelihood, self.model
            )
            optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

            train_X, train_Y = (
                mll.model.train_inputs,
                mll.model.train_targets,
            )

            if self.is_cuda:
                logging.info(torch.__config__.show().split("\n"))
                logging.info("Cuda enabled for training data.")
                train_X = train_X[0].cuda()
                train_Y = train_Y.cuda()
                model = model.cuda()
                mll = mll.cuda()

            last_loss = float("-inf")
            for i in range(maxiter):
                optimizer.zero_grad()
                output = self.model(*train_X)
                loss = -mll(output, train_Y).sum()
                loss.backward()
                optimizer.step()
                this_loss = loss.item()
                loss_vs_iter[i] = this_loss
                delta = this_loss - last_loss
                last_loss = this_loss
                logging.info(
                    "Iter %d/%d - Loss: %.3f (delta=%.6f)"
                    % (i + 1, maxiter, this_loss, delta)
                )
                if abs(delta) < stopping_threshold:
                    logging.info("Stopping criteria met...exiting.")
                    break

        # Save model
        if save_model:
            if not os.path.exists(model_path):
                os.makedirs(model_path)
            torch.save(self.model.state_dict(), model_path + model_name)
            logging.info(f"Saved trained model to: {model_path + model_name}")
        return loss_vs_iter

    
    # def _create_training_tensors(
    #     self,
    #     n_train,
    #     ):

    #     train_X = torch.zeros(
    #         [self.num_cells, n_train, self.num_features], dtype=torch.float32
    #     )
        
    #     train_Y = torch.zeros([self.num_cells, n_train], dtype=torch.float32)
        

    #     for m in range(self.num_cells):
    #         if self.norm_method == NormMethod.MINMAX:
    #             train_x_cell = (data_in[m][self.x_columns] - self.xmin[m]) / (
    #                 self.xmax[m] - self.xmin[m]
    #             )
    #         elif self.norm_method == NormMethod.ZSCORE:
    #             train_x_cell = (
    #                 data_in[m][self.x_columns] - self.xmeans[m]
    #             ) / self.xstds[m]

    #         train_X_cell = torch.tensor(
    #             train_x_cell.iloc[:, :].values, dtype=torch.float32
    #         )

    #         train_y_cell = (data_in[m][self.y_columns] - self.ymeans[m]) / self.ystds[m]
    #         train_Y_cell = torch.tensor(
    #             train_y_cell.iloc[:, :].values, dtype=torch.float32
    #         )

    #         train_X[m] = train_X_cell.reshape(shape=(1, -1, self.num_features))
    #         train_Y[m] = torch.transpose(train_Y_cell, 0, 1)

    #         return train_X, train_Y


    def update_model(
        self,
        data_in: pd.DataFrame,
        # load_model: bool = False,
        # save_model: bool = False,
        # model_path: Optional[str] = None,
        # model_name: Optional[str] = None,
        ):
        """
        Assumes that the model is already trained and prediction is run on the model at least once
        """

        # if load_model:
        #     logging.info("Now loading GP model (this should be quick...)")
        #     state_dict = torch.load(model_path + model_name)
        #     self.model.load_state_dict(state_dict)

        n_train = data_in[0].shape[0]

        # Get train_X and train_Y, create training tensors
        #train_X, train_Y = self._create_training_tensors(self, n_train)

        train_X = torch.zeros(
            [self.num_cells, n_train, self.num_features], dtype=torch.float32
        )
        
        train_Y = torch.zeros([self.num_cells, n_train], dtype=torch.float32)
        

        for m in range(self.num_cells):
            if self.norm_method == NormMethod.MINMAX:
                train_x_cell = (data_in[m][self.x_columns] - self.xmin[m]) / (
                    self.xmax[m] - self.xmin[m]
                )
            elif self.norm_method == NormMethod.ZSCORE:
                train_x_cell = (
                    data_in[m][self.x_columns] - self.xmeans[m]
                ) / self.xstds[m]

            train_X_cell = torch.tensor(
                train_x_cell.iloc[:, :].values, dtype=torch.float32
            )

            train_y_cell = (data_in[m][self.y_columns] - self.ymeans[m]) / self.ystds[m]
            train_Y_cell = torch.tensor(
                train_y_cell.iloc[:, :].values, dtype=torch.float32
            )

            train_X[m] = train_X_cell.reshape(shape=(1, -1, self.num_features))
            train_Y[m] = torch.transpose(train_Y_cell, 0, 1)




        # updated_train_X = torch.cat((self.train_X, train_X))
        # updated_train_Y = torch.cat((self.train_Y, train_Y))

        # logging.info("Now loading GP model (this should be quick...)")
        # state_dict = torch.load(model_path + model_name)
        # self.model.load_state_dict(state_dict)
        #print("update train_x", train_X)
        #print("update train_y", train_Y)

        #print("concat train_X", updated_train_X)
        #print("concat train_Y", updated_train_Y)
        
        #self.model.set_train_data(inputs=updated_train_X, targets=updated_train_Y, strict=False)
        self.model = self.model.get_fantasy_model(inputs=train_X, targets=train_Y)

        # if save_model:
        #     if not os.path.exists(model_path):
        #         os.makedirs(model_path)
        #     torch.save(self.model.state_dict(), model_path + model_name)
        #     logging.info(f"Saved trained model to: {model_path + model_name}")


    def predict_distributed_gpmodel(
        self,
        prediction_dfs: List[pd.DataFrame],
        # load_model: bool = False,
        # save_model: bool = False,
        # model_path: Optional[str] = None,
        # model_name: Optional[str] = None,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """Predicts Rx power, RSRP and SINR.

        `prediction_dfs` : one prediction dataframe, per cell.

        It is assumed that columns loc_x and loc_y in each dataframe inside `prediction_dfs`
        are the same, and that they appear in the same order.

        Returns the prediction mean and standard deviation for Rx power
        as numpy ndarrays, for all locations in the dataframe,
        in the same order given, with one such array per cell.

        Mutates `prediction_dfs` and adds columns for predicted Rx mean and stddev.

        Returns, in order :
            prediction mean for Rx power (one numpy ndarray per cell)
            prediction std dev for Rx power (one numpy ndarray per cell)
            combined RF dataframe with RSRP and SINR
        """

        # if load_model:
        #     logging.info("Now loading trained GP model (this should be quick...)")
        #     state_dict = torch.load(model_path + model_name)
        #     self.model.load_state_dict(state_dict)

        
        self.model.eval()

        num_locations = prediction_dfs[0].shape[0]
        pred_means = torch.zeros([num_locations, self.num_cells], dtype=torch.float32)
        pred_stds = torch.zeros([num_locations, self.num_cells], dtype=torch.float32)
        predict_X = torch.zeros(
            [self.num_cells, num_locations, self.num_features], dtype=torch.float32
        )

        for m in range(self.num_cells):
            if self.norm_method == NormMethod.MINMAX:
                predict_x_cell = (prediction_dfs[m][self.x_columns] - self.xmin[m]) / (
                    self.xmax[m] - self.xmin[m]
                )
            elif self.norm_method == NormMethod.ZSCORE:
                predict_x_cell = (
                    prediction_dfs[m][self.x_columns] - self.xmeans[m]
                ) / self.xstds[m]

            predict_X_cell = torch.tensor(
                predict_x_cell.iloc[:, :].values, dtype=torch.float32
            )
            predict_X[m] = predict_X_cell.reshape(shape=(1, -1, self.num_features))

        if self.is_cuda:
            logging.info("Cuda enabled for test data.")
            predict_X = predict_X.cuda()

        with torch.no_grad(), gpytorch.settings.fast_pred_var():
            observed_pred = self.model.likelihood(self.model(predict_X))
            mean = observed_pred.mean
            var = observed_pred.variance

        pred_means = mean.detach().cpu().numpy() * self.ystds + self.ymeans
        pred_stds = np.sqrt(var.detach().cpu().numpy()) * self.ystds

        # add pred_means and pred_std to prediction_dfs
        for idx in range(len(prediction_dfs)):
            prediction_dfs[idx][RXPOWER_DBM] = pred_means[idx]
            prediction_dfs[idx][RXPOWER_STDDEV_DBM] = pred_stds[idx]

        # if save_model:
        #     if not os.path.exists(model_path):
        #         os.makedirs(model_path)
        #     torch.save(self.model.state_dict(), model_path + model_name)
        #     logging.info(f"Saved predicted model to: {model_path + model_name}")

        return pred_means, pred_stds


In [None]:
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# os.system("persistent-storage define fbc_maveric --bucket fbc_maveric")
# os.system("persistent-storage mount --auto")
# buck build //bento/kernels:bento_kernel_maveric

import copy
import logging
import os
import random

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch

# from advanced_network_planning.common.fblearner_srtm.constants import AntennaType
from radp.digital_twin.anp.anp_simulator import AnpEngine
from radp.digital_twin.anp.utils import plot_helper, rf_dict_to_anp_sites
# from radp.energy_savings.energy_savings_gym import EnergySavingsGym
# from radp.energy_savings.energy_savings_gym_dgpco import DgpcoEnergySavingsGym
from radp.utils.gis_tools import GISTools
# from IPython.display import HTML, display


def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


def anp_rf_sim(rf_dict, n_cell=1, n_sim=10):

    # trim rf_dict to n_cell cells
    rf_dict = {k: v[:n_cell] for k, v in rf_dict.items()}
    anp_engine = AnpEngine()
    data_boundary = None

    # Init
    sites = rf_dict_to_anp_sites(rf_dict)
    rf_dataframe = anp_engine.sites_to_df(
        sites,
        rf_dict,
        sim_idx=0,
        data_boundary=data_boundary,
    )

    # Init: save RF data to file
    sim_idx = 0
    sim_idx_folder = str(sim_idx).zfill(3)
    save_path = f"/{BUCKET_PATH}/{SIM_DATA_PATH}/sim_{sim_idx_folder}"
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    rf_dataframe.to_csv(f"{save_path}/full_data.csv", index=False)

    # Init: save site config to file
    site_config_df = pd.DataFrame.from_dict(rf_dict)
    site_config_df = site_config_df.rename(
        columns={
            "az_boresight_angle": "cell_az_deg",
            "el_boresight_angle": "cell_el_deg",
        }
    )
    site_config_df = site_config_df[["cell_lat", "cell_lon", "cell_az_deg", "cell_el_deg", "cell_id", "hTx", "hRx"]]
    import itertools

    site_config_df["cell_id"] = list(itertools.chain(*site_config_df["cell_id"].tolist()))
    site_config_df["nRx"] = rf_dataframe.groupby("cell_id").size().tolist()
    site_config_df.to_csv(f"{save_path}/site_config.csv", index=False)

    # Init: plot RF data
    plt = plot_helper(
        lat=rf_dataframe.loc_y,
        lon=rf_dataframe.loc_x,
        data=rf_dataframe.rsrp_dbm,
        sites=sites,
        title=f"AnpEngine() | {len(rf_dataframe.index):,} points | cell_el_deg = [{', '.join(f'{x:.1f}' for x in site_config_df.cell_el_deg.tolist())}]",
        label="RSRP (dBm)",
        colormap="twilight",
        markersize=25,
        hist_flag=True,
    )
    plt.show()

    # Loop
    rf_dict_random = copy.deepcopy(rf_dict)

    for sim_idx in range(1, n_sim + 1):
        print(f"Now simulating: {sim_idx} of {n_sim}...")
        tilt_random = random.choices(tilt_set, k=n_cell)
        rf_dict_random["el_boresight_angle"] = tilt_random
        sites = rf_dict_to_anp_sites(rf_dict_random)
        rf_dataframe = anp_engine.sites_to_df(
            sites,
            rf_dict_random,
            sim_idx=sim_idx,
            data_boundary=data_boundary,
        )

        # Save RF data to file
        sim_idx_folder = str(sim_idx).zfill(3)
        save_path = f"/{BUCKET_PATH}/{SIM_DATA_PATH}/sim_{sim_idx_folder}"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        rf_dataframe.to_csv(f"{save_path}/full_data.csv", index=False)

        # Save site config to file
        site_config_df = pd.DataFrame.from_dict(rf_dict_random)
        site_config_df = site_config_df.rename(
            columns={
                "az_boresight_angle": "cell_az_deg",
                "el_boresight_angle": "cell_el_deg",
            }
        )
        site_config_df = site_config_df[
            [
                "cell_lat",
                "cell_lon",
                "cell_az_deg",
                "cell_el_deg",
                "cell_id",
                "hTx",
                "hRx",
            ]
        ]

        site_config_df["cell_id"] = list(itertools.chain(*site_config_df["cell_id"].tolist()))
        site_config_df["nRx"] = rf_dataframe.groupby("cell_id").size().tolist()
        site_config_df.to_csv(f"{save_path}/site_config.csv", index=False)

        # Plot RF data
        plt = plot_helper(
            lat=rf_dataframe.loc_y,
            lon=rf_dataframe.loc_x,
            data=rf_dataframe.rsrp_dbm,
            sites=sites,
            title=f"AnpEngine() | {len(rf_dataframe.index):,} points | cell_el_deg = [{', '.join(f'{x:.1f}' for x in site_config_df.cell_el_deg.tolist())}]",
            label="RSRP (dBm)",
            colormap="twilight",
            markersize=25,
            hist_flag=True,
        )
        plt.show()
    return

def get_training_and_test_data(
    desired_idxs_train: List[int],
    desired_idxs_test: List[int],
    p_train=20,
    p_test=100,
    n_sim=10,
):
    site_config_path = "sim_" + str(0).zfill(3) + "/site_config.csv"
    site_config_df = pd.read_csv(f"/{BUCKET_PATH}/{SIM_DATA_PATH}/{site_config_path}")

    nRx = site_config_df["nRx"].sum()
    n_sample_train = round(p_train * 0.01 * nRx)
    n_sample_test = round(p_test * 0.01 * nRx)
    n_cell = len(site_config_df.index)
    logging.info("==========")
    logging.info(f"nRx={nRx}, n_sample_train={n_sample_train}, n_sample_test={n_sample_test}")
    logging.info("==========")

    metadata_df = pd.DataFrame({"cell_id": [1, 2, 3], "idx": [1, 2, 3]})
    idx_cell_id_mapping = dict(zip(metadata_df.idx, metadata_df.cell_id))
    #print(idx_cell_id_mapping)
    # desired_idxs = desired_idxs
    # print(desired_idxs)
    # test
    test_sim_idx = n_sim
    sim_idx_folder = "sim_" + str(test_sim_idx).zfill(3) + "/full_data.csv"
    tilt_test_df = pd.read_csv(f"/{BUCKET_PATH}/{SIM_DATA_PATH}/{sim_idx_folder}")

    (tilt_test_per_cell_df, tilt_test_per_cell_stats,) = BayesianDigitalTwin.get_percell_data(
        data_in=tilt_test_df,
        all_idxs=list(idx_cell_id_mapping.keys()),
        desired_idxs=desired_idxs_test,
        sample_cells_independently=False,
        n_samples=n_sample_test,
    )
    test_data = {}
    for i in range(len(desired_idxs_test)):
        test_data[desired_idxs_test[i]] = pd.concat(
            [
                tilt_test_per_cell_df[i],
            ]
        )

    # train
    n_sim_train = n_sim - 1
    percell_data_list = []
    for s in range(n_sim_train):
        sim_idx_folder = "sim_" + str(s + 1).zfill(3) + "/full_data.csv"
        # tilt_df = pd.read_csv(os.path.join(BUCKET_PATH, SIM_DATA_PATH, sim_idx_folder))
        tilt_df = pd.read_csv(f"/{BUCKET_PATH}/{SIM_DATA_PATH}/{sim_idx_folder}")
        tilt_per_cell_df, _ = BayesianDigitalTwin.get_percell_data(
            data_in=tilt_df,
            all_idxs=list(idx_cell_id_mapping.keys()),
            desired_idxs=desired_idxs_train,
            sample_cells_independently=False,
            n_samples=n_sample_train,
        )
        percell_data_list.append(tilt_per_cell_df)
    training_data = {}
    for i in range(len(desired_idxs_train)):
        training_data[desired_idxs_train[i]] = pd.concat(
            [tilt_per_cell_df[i] for tilt_per_cell_df in percell_data_list]
        )

    # train
    for idx, training_data_idx in training_data.items():
        train_cell_id = idx_cell_id_mapping[idx]
        training_data_idx["cell_id"] = 1

        training_data_idx["cell_lat"] = site_config_df[site_config_df["cell_id"] == train_cell_id]["cell_lat"].values[0]

        training_data_idx["cell_lon"] = site_config_df[site_config_df["cell_id"] == train_cell_id]["cell_lon"].values[0]

        training_data_idx["hTx"] = site_config_df[site_config_df["cell_id"] == train_cell_id]["hTx"].values[0]

        training_data_idx["hRx"] = site_config_df[site_config_df["cell_id"] == train_cell_id]["hRx"].values[0]

        training_data_idx["cell_carrier_freq_mhz"] = 1200

        training_data_idx["log_distance"] = [
            GISTools.get_log_distance(
                training_data_idx["cell_lat"].values[0],
                training_data_idx["cell_lon"].values[0],
                lat,
                lon,
            )
            for lat, lon in zip(training_data_idx.loc_y, training_data_idx.loc_x)
        ]

        training_data_idx["relative_bearing"] = [
            GISTools.get_relative_bearing(
                training_data_idx["cell_az_deg"].values[0],
                training_data_idx["cell_lat"].values[0],
                training_data_idx["cell_lon"].values[0],
                lat,
                lon,
            )
            for lat, lon in zip(training_data_idx.loc_y, training_data_idx.loc_x)
        ]

        training_data_idx["antenna_gain"] = GISTools.get_antenna_gain(
            training_data_idx["hTx"].values[0],
            training_data_idx["hRx"].values[0],
            training_data_idx["log_distance"],
            training_data_idx["cell_el_deg"],
        )

    # test
    for idx, test_data_idx in test_data.items():
        test_cell_id = idx_cell_id_mapping[idx]
        #print(test_cell_id)
        test_data_idx["cell_id"] = 1
        test_data_idx["cell_lat"] = site_config_df[site_config_df["cell_id"] == test_cell_id]["cell_lat"].values[0]
        test_data_idx["cell_lon"] = site_config_df[site_config_df["cell_id"] == test_cell_id]["cell_lon"].values[0]
        test_data_idx["hTx"] = site_config_df[site_config_df["cell_id"] == test_cell_id]["hTx"].values[0]
        test_data_idx["hRx"] = site_config_df[site_config_df["cell_id"] == test_cell_id]["hRx"].values[0]

        test_data_idx["cell_carrier_freq_mhz"] = 1200

        test_data_idx["log_distance"] = [
            GISTools.get_log_distance(
                test_data_idx["cell_lat"].values[0],
                test_data_idx["cell_lon"].values[0],
                lat,
                lon,
            )
            for lat, lon in zip(test_data_idx.loc_y, test_data_idx.loc_x)
        ]

        test_data_idx["relative_bearing"] = [
            GISTools.get_relative_bearing(
                test_data_idx["cell_az_deg"].values[0],
                test_data_idx["cell_lat"].values[0],
                test_data_idx["cell_lon"].values[0],
                lat,
                lon,
            )
            for lat, lon in zip(test_data_idx.loc_y, test_data_idx.loc_x)
        ]

        test_data_idx["antenna_gain"] = GISTools.get_antenna_gain(
            test_data_idx["hTx"].values[0],
            test_data_idx["hRx"].values[0],
            test_data_idx["log_distance"],
            test_data_idx["cell_el_deg"],
        )


        test_data_final = {}
        for idx, data in test_data.items():
            data.sort_values(by=['loc_x', 'loc_y'],  inplace=True)
            append_value(test_data_final, 1, data)

        training_data_final = {}
        for idx, data in training_data.items():
            data.sort_values(by=['loc_x', 'loc_y'],  inplace=True)
            append_value(training_data_final, 1, data)

        # for cell_id, data in training_data_final.items():
        #     print(data)

    return training_data_final, test_data_final

def append_value(dict_obj, key, value):
    # Check if key exist in dict or not
    if key in dict_obj:
        # Key exist in dict.
        # Check if type of value of key is list or not
        if not isinstance(dict_obj[key], list):
            # If type is not list then make it list
            dict_obj[key] = [dict_obj[key]]
        # Append the value in list
        dict_obj[key].append(value)
    else:
        # As key is not in dict,
        # so, add key-value pair
        dict_obj[key] = value


def bdt(
    p_train=20,
    p_test=100,
    maxiter=20,
    n_sim=10,
    desired_idxs_train=[1],
    desired_idxs_test=[1, 2, 3],
    desired_idxs_train_update=[1,2],
    desired_idxs_test_update=[1,2],
    load_model=False,
    save_model=False,
    model_path="",
    model_name="",
):

    ###########TRAIN AND TEST DATA##################
    training_data, test_data = get_training_and_test_data(
        desired_idxs_train,
        desired_idxs_test,
        p_train=20,
        p_test=100,
        n_sim=10,
    )

    training_data = {key: val[:4000] for key, val in training_data.items()}

    #############TRAIN MODEL########################
    bayesian_digital_twin_map = {}
    for cell_id, training_data in training_data.items():

        bayesian_digital_twin_map[cell_id] = BayesianDigitalTwin(
            data_in=[training_data],
            x_columns=["log_distance", "relative_bearing", "antenna_gain"],
            y_columns=["cell_rxpwr_dbm"],
            # norm_method=NormMethod.MINMAX,
            x_max=None,
            x_min=None,
        )

        loss_vs_iter = bayesian_digital_twin_map[cell_id].train_distributed_gpmodel(
            maxiter=maxiter,
            load_model=load_model,
            save_model=save_model,
            model_path=model_path,
            model_name=model_name,
        )

        logging.info(
            f"\nTrained {len(loss_vs_iter)} epochs of Bayesian Digital Twin (Gaussian Process Regression) "
            f"on {len(training_data)} data points"
            f" with min learning loss {min(loss_vs_iter):0.5f}, "
            f"avg learning loss {np.mean(loss_vs_iter):0.5f} and final learning loss {loss_vs_iter[-1]:0.5f}"
        )

    plt.figure(figsize=(10, 8))
    plt.style.use("bmh")
    plt.plot(np.arange(maxiter), loss_vs_iter)
    plt.xlabel("iter")
    plt.ylabel("loss")
    plt.grid(True)
    plt.show()

    # evaluate over test set:

    # Hack
    # for idx, data in test_data.values():
    #     frames = [idx, data]
    #     result = pd.concat(frames)

    # cell_initial_test_map = {k: v for k, v in result.groupby("cell_id")}
    # test_data = cell_initial_test_map
    # = {key: val[:2000] for key, val in test_data.items()}

    for cell_id, testing_data in test_data.items():
        #prediction_dfs = list(test_data.values())
        (pred_means, _) = bayesian_digital_twin_map[
                cell_id
            ].predict_distributed_gpmodel(
                prediction_dfs=[testing_data],
                # load_model=True,
                # save_model=True,
                # model_path=model_path,
                # model_name=model_name,
                )
        #print(pred_means)
        f, axs = plt.subplots(1, 2, figsize=(12, 12))
        lons = test_data[1]["loc_x"].values
        lats = test_data[1]["loc_y"].values

        # compute RSRP as maximum over predicted rx powers
        pred_rsrp = np.amax(pred_means, axis=0)
        # extract true (actual) RSRP from test set
        true_rsrp = np.maximum.reduce([test_data_cell.cell_rxpwr_dbm for test_data_cell in test_data.values()])
        # mean absolute error
        MAE = abs(true_rsrp - pred_rsrp).mean()
        # mean square error
        MSE = (abs(true_rsrp - pred_rsrp) ** 2).mean()
        # mean absolute percentage error
        MAPE = 100 * abs((true_rsrp - pred_rsrp) / true_rsrp).mean()
        logging.info("==========")
        logging.info(f"MSE = {MSE:0.5f}, MAE = {MAE:0.5f} dB, MAPE = {MAPE:0.5f} %")
        logging.info("==========")
        #print("true_rsrp", true_rsrp)
        axs[0].scatter(lons, lats, c=true_rsrp, cmap="twilight", s=10)
        axs[0].set_title(
            r"Actual RSRP",
            fontsize=14,
        )
        axs[0].set_aspect("equal", "box")
        axs[0].set_xticks([])
        axs[0].set_yticks([])
        #print(pred_rsrp)
        axs[1].scatter(lons, lats, c=pred_rsrp, cmap="twilight", s=10)
        axs[1].set_title(
            f"Predicted RSRP | MAE = {MAE:0.1f} dB",
            fontsize=14,
        )
        axs[1].set_aspect("equal", "box")
        axs[1].set_xticks([])
        axs[1].set_yticks([])

        plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.0, hspace=0.1)
        plt.show()

    ################# UPDATE #########################

    training_update_data, test_update_data = get_training_and_test_data(
        desired_idxs_train_update,
        desired_idxs_test_update,
        p_train=20,
        p_test=100,
        n_sim=10,
    )
    
    training_update_data = {key: val[4000:] for key, val in training_update_data.items()}
    for cell_id, training_data in training_update_data.items():

        bayesian_digital_twin_map[
                cell_id
            ].update_model(
            data_in=list(training_update_data.values()),
            # load_model=True,
            # save_model=True,
            # model_path=model_path,
            # model_name=model_name,
        )

    # evaluate over test set:

    # Hack
    # for idx, data in test_update_data.values():
    #     frames = [idx, data]
    #     result = pd.concat(frames)

    # cell_id_testing_data_map = {k: v for k, v in result.groupby("cell_id")}
    cell_id_testing_data_map = test_update_data

    #cell_id_testing_data_map = {key: val[2000:] for key, val in test_update_data.items()}
    for cell_id, testing_data in cell_id_testing_data_map.items():
        prediction_dfs = list(test_update_data.values())
        (pred_means1, _) = bayesian_digital_twin_map[
                cell_id
            ].predict_distributed_gpmodel(
                prediction_dfs=[testing_data],
                # load_model=True,
                # save_model=True,
                # model_path=model_path,
                # model_name=model_name,
                )
        #print(pred_means1)
        f, axs = plt.subplots(1, 2, figsize=(12, 12))
        lons = cell_id_testing_data_map[1]["loc_x"].values
        lats = cell_id_testing_data_map[1]["loc_y"].values

        # compute RSRP as maximum over predicted rx powers
        pred_rsrp1 = np.amax(pred_means1, axis=0)
        # extract true (actual) RSRP from test set
        true_rsrp1 = np.maximum.reduce([test_data_cell.cell_rxpwr_dbm for test_data_cell in cell_id_testing_data_map.values()])
        # mean absolute error
        MAE1 = abs(true_rsrp1 - pred_rsrp1).mean()
        # mean square error
        MSE1 = (abs(true_rsrp1 - pred_rsrp1) ** 2).mean()
        # mean absolute percentage error
        MAPE1 = 100 * abs((true_rsrp1 - pred_rsrp1) / true_rsrp1).mean()
        logging.info("==========")
        logging.info(f"MSE = {MSE1:0.5f}, MAE = {MAE1:0.5f} dB, MAPE = {MAPE1:0.5f} %")
        logging.info("==========")

        axs[0].scatter(lons, lats, c=true_rsrp1, cmap="twilight", s=10)
        axs[0].set_title(
            r"Actual RSRP",
            fontsize=14,
        )
        axs[0].set_aspect("equal", "box")
        axs[0].set_xticks([])
        axs[0].set_yticks([])
        #print(pred_rsrp1)
        axs[1].scatter(lons, lats, c=pred_rsrp1, cmap="twilight", s=10)
        axs[1].set_title(
            f"Predicted RSRP | MAE = {MAE1:0.1f} dB",
            fontsize=14,
        )
        axs[1].set_aspect("equal", "box")
        axs[1].set_xticks([])
        axs[1].set_yticks([])

        plt.subplots_adjust(left=0.1, bottom=0.1, right=0.9, top=0.9, wspace=0.0, hspace=0.1)
        plt.show()

    return bayesian_digital_twin_map, test_data, loss_vs_iter


# RF

In [None]:
n_cell = 3
n_sim = 10
seed = 100
seed_everything(seed=seed)

os.system("persistent-storage mount --auto")
BUCKET_PATH = "home/sandeeprajan/rf_sim/fbc_maveric/data/digital_twin"
#BUCKET_PATH = "tmp"
SIM_DATA_PATH = "sim_data/simple3cell"

tilt_min, tilt_max = 10, 15
tilt_set = list(np.arange(tilt_min, tilt_max + 1))

rf_dict = {
    "cell_name": ["11", "12", "13"],
    "cell_lat": [
        34.0000,
        34.0000,
        34.0000,
    ],
    "cell_lon": [
        -102.0000,
        -102.0000,
        -102.0000,
    ],
    "cell_id": [[1], [2], [3]],
    "enb_tx_power": [23, 23, 23],
    "enb_noise": [5.0, 5.0, 5.0],
    "enb_antenna_gain": [6, 6, 6],
    "enb_tx_diversity_gain": [0.0, 0.0, 0.0],
    "enb_rx_diversity_gain": [0.0, 0.0, 0.0],
    "enb_tx_misc_loss": [0.0, 0.0, 0.0],
    "enb_rx_misc_loss": [0.0, 0.0, 0.0],
    "az_boresight_angle": [0, 120, 240],
    "el_boresight_angle": [tilt_set[0], tilt_set[0], tilt_set[0]],
    "az_beamwidth": [60, 60, 60],
    "el_beamwidth": [10, 10, 10],
    "load_factor": [1.0, 1.0, 1.0],
    # "antenna_type": [AntennaType.triGPP, AntennaType.triGPP, AntennaType.triGPP],
    "max_range": [1.0, 1.0, 1.0],
    "hTx": [35, 35, 35],
    "hRx": [2, 2, 2],
    "freq_MHz": [1800, 1800, 1800],
    "penetration_loss_db": [0.0, 0.0, 0.0],
    "df_loss_factor": [0.0, 0.0, 0.0],
    "pl_model": [3, 3, 3],
    "use_clutter": [False, False, False],
    "urban_mode": [False, False, False],
    "use_openmp": [False, False, False],
}

# rf_dict = {
#     "cell_name": ["11"],
#     "cell_lat": [
#         34.0000,
#     ],
#     "cell_lon": [
#         -102.0000,
#     ],
#     "cell_id": [[1]],
#     "enb_tx_power": [23],
#     "enb_noise": [5.0],
#     "enb_antenna_gain": [6],
#     "enb_tx_diversity_gain": [0.0],
#     "enb_rx_diversity_gain": [0.0],
#     "enb_tx_misc_loss": [0.0],
#     "enb_rx_misc_loss": [0.0],
#     "az_boresight_angle": [0],
#     "el_boresight_angle": [tilt_set[0]],
#     "az_beamwidth": [60],
#     "el_beamwidth": [10],
#     "load_factor": [1.0],
#     "antenna_type": [AntennaType.triGPP],
#     "max_range": [1.0],
#     "hTx": [35],
#     "hRx": [2],
#     "freq_MHz": [1800],
#     "penetration_loss_db": [0.0],
#     "df_loss_factor": [0.0],
#     "pl_model": [3],
#     "use_clutter": [False],
#     "urban_mode": [False],
#     "use_openmp": [False],
# }

#anp_rf_sim(rf_dict=rf_dict, n_cell=n_cell, n_sim=n_sim)

# Bayesian digital twin training

In [None]:
training_data, test_data = get_training_and_test_data(
    desired_idxs_train = [1],
    desired_idxs_test = [1],
    p_train=20,
    p_test=100,
    n_sim=10,
)


In [None]:
# for cell_id, testing_data in test_data.items():
#     print(testing_data)


# for idx, data in test_data.values():
#     frames = [idx, data]
#     result = pd.concat(frames)

print(type(training_data))
#training_data
#cell_id_training_data_map = {k: v for k, v in result.groupby("cell_id")}
#print(cell_id_training_data_map[1])

# for cell_id, testing_data in cell_id_training_data_map.items():
#     testing_data
#     pass
    
# test_data = {key: val[3000:] for key, val in training_data.items()}
training_data[1]
# for idx, data in test_data.items():
#     print(data)
#     data.sort_values(by=['loc_x', 'loc_y'],  inplace=True)
#     print(data)
    

In [None]:
# dictionary[new_key] = dictionary[old_key]
# del dictionary[old_key]

#test_data[1] = test_data[2]
#del training_data[2]

# def append_value(dict_obj, key, value):
#     # Check if key exist in dict or not
#     if key in dict_obj:
#         # Key exist in dict.
#         # Check if type of value of key is list or not
#         if not isinstance(dict_obj[key], list):
#             # If type is not list then make it list
#             dict_obj[key] = [dict_obj[key]]
#         # Append the value in list
#         dict_obj[key].append(value)
#     else:
#         # As key is not in dict,
#         # so, add key-value pair
#         dict_obj[key] = value

# test_data_one= []
# test_data_one.append(test_data[1])

# test_data_two = []
# test_data_two.append(test_data[2])



#test_data_final[1] = test_data[1]
#test_data_final[1].append(test_data[2])

#test_data_two[0]
# joined_list = [*test_data_one, *test_data_two]
# joined_list
# cell_id_training_data_map = {k: v for k, v in test_data_final.groupby("cell_id")}
# cell_id_training_data_map

In [None]:
p_train = 20
p_test = 100
maxiter = 25
load_model = False
save_model = False
bdt_model_path = "/tmp/incremental"
bdt_model_name = "/bdt.mod"

bayesian_digital_twin, test_data, loss_vs_iter = bdt(
    p_train=p_train,
    p_test=p_test,
    maxiter=maxiter,
    n_sim=n_sim,
    desired_idxs_train=[1],
    desired_idxs_test=[1],
    desired_idxs_train_update=[1],
    desired_idxs_test_update=[1],
    load_model=load_model,
    save_model=save_model,
    model_path=bdt_model_path,
    model_name=bdt_model_name,
)
