# Paper study - Anomaly detection

## Init Env

In [None]:
%load_ext autoreload
%autoreload 2

import os
import logging
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.stats as st

from settings import Metadata, EnvMetadata, ExperimentPhase, \
    PROJECT_FOLDER, DATA_FOLDER, TEST_EPISODES_FOLDER, setup_matplotlib_config
from bin.main import run_simulation
from src.utils import lineplot_ci
from src.data_classes import Episode
from src.environment.channel.utils import create_K_MPR_matrix
from src.view.plot import plot_episode, plot_validation_metrics, plot_per_step_metrics, plot_validation_metrics_per_n_steps_model_learning, \
    plot_p_transmit_aac, plot_critic_value_aac, plot_training_actor_critic, plot_coma_actors_critic, \
    plot_tdma_actors_intermediary_probabilities
from src.view.plot_model import plot_dirichlet_data_generation, plot_dirichlet_mpr_channel
from src.view.metrics import get_experiment_throughput, get_experiment_fairness, experiment_throughput_mean_dev, get_return, \
    get_buffer_info, get_channel_collisions, get_aac_info, get_training_info, get_state_distribution,  \
    get_coma_info, get_loss, get_critic_training_info, get_gradients_info, estimate_Q_value_COMA

setup_matplotlib_config()

root = logging.getLogger()
if root.handlers:
    for handler in root.handlers:
        root.removeHandler(handler)
logging.basicConfig(format='%(asctime)s %(message)s',level=logging.INFO)

def does_experiment_exist(experiment_name):
    return os.path.isdir(os.path.join(PROJECT_FOLDER, DATA_FOLDER, experiment_name))

def does_experiment_test_exist(experiment_name):
    return os.path.isdir(os.path.join(PROJECT_FOLDER, DATA_FOLDER, experiment_name, TEST_EPISODES_FOLDER))


## Experimental setup

In [None]:
%autoreload 2

from experimental_setup import *

# Log-likelihood
EPSILON = 10**(-10)  # To avoid log(0)

# Number of steps used to create the in-distribution and out-of-distribution datasets
N_STEPS_DATA = 8000


# Perturbed system : we collect out-of-distribution data by simulating a perturbed environment where only
# the data generation probabilities differ from the initial system.
# Perturbed data generation probabilities
PERTURBED_DATA_GEN_PROBABILITIES_MAP = None  # No time-dependent data generation
PERTURBED_DEFAULT_JOINT_DISTRIBUTION = {
    "0,0": 0.6,
    "1,0": 0.4,
    "0,1": 0,  # Device is disconnected
    "1,1": 0
}
PERTURBED_DATA_GEN_PROBABILITIES_MAPS_KWARGS = [
    {
        "name": "cluster_1",
        "n_joint_agents": 2,
        "probabilities_map": PERTURBED_DATA_GEN_PROBABILITIES_MAP,
        "default_joint_distribution": PERTURBED_DEFAULT_JOINT_DISTRIBUTION
    },
    {
        "name": "cluster_2",
        "n_joint_agents": 2,
        "probabilities_map": PERTURBED_DATA_GEN_PROBABILITIES_MAP,
        "default_joint_distribution": PERTURBED_DEFAULT_JOINT_DISTRIBUTION
    }
]

# Number of samples used for tests
N_TEST_SAMPLES = 100  # n samples per data distribution (in-dist and out-dist)
N_STEPS_PER_SAMPLE = 5


# Log
EXPERIMENT_NAME_PREFIX = "paper_policy_monitor"


# Print
_JOINT_SINGLE_PACKET_GENERATED_PROBABILITY = DEFAULT_JOINT_DISTRIBUTION["1,0"]
_MARGINAL_DATA_GEN_PROBABILITY = _JOINT_SINGLE_PACKET_GENERATED_PROBABILITY / (1 + _JOINT_SINGLE_PACKET_GENERATED_PROBABILITY)
print(f"Marginal packet generation probability per agent : {_MARGINAL_DATA_GEN_PROBABILITY}")
print(f"Joint node distribution : {DEFAULT_JOINT_DISTRIBUTION}")
print(f"MPR Matrix : {MPR_MATRIX}")



# Previous experiments
# EXPERIMENT_VERSION = "02"
# PERTURBED_DEFAULT_JOINT_DISTRIBUTION = {
#     "0,0": 0.6,
#     "1,0": 0.2,
#     "0,1": 0.2,
#     "1,1": 0
# }

# EXPERIMENT_VERSION = "03"
# PERTURBED_DEFAULT_JOINT_DISTRIBUTION = {
#     "0,0": 0.2,
#     "1,0": 0.8,
#     "0,1": 0,  # Device is disconnected
#     "1,1": 0
# }

## Learn model

In [None]:
EXPERIMENT_VERSION = "04"  # Version of runs

N_MODELS_PER_STEP = 50
N_STEPS_LIST = [10, 20, 50]
RANGE_MODEL_STEPS = list(product(range(N_MODELS_PER_STEP), N_STEPS_LIST))

def get_metadata_model_learning(n_steps):
    return Metadata.from_dict({
        "env_metadata": {
            **ENV_METADATA,
            "test_n_episodes": 0,
        },
        # Policy
        "train_metadata": {
            "digital_twin_class": "DigitalTwinSeparateModel",
            "digital_twin_kwargs": {
                "n_packets_rollouts": N_PACKETS,
                "prior_dirichlet_concentration": PRIOR_DIRICHLET_CONCETRATION,
                "model_sampling_method": "posterior_sample",
                "exploration_policy_type": "random"
            },
            "train_model_n_episodes": 1,
            "train_model_max_steps": n_steps,
            "train_policy_n_episodes": 0,
            "train_policy_max_steps": 0,
            "policy_optimizer_class": "PolicyOptimizerCOMA",
            "policy_optimizer_kwargs": POLICY_OPTIMIZER_METADATA
        }
    })


def get_model_learning_experiment_name(n_model, n_steps):
    return f"{EXPERIMENT_NAME_PREFIX}_model_learning_random_{EXPERIMENT_VERSION}_n_model_{n_model}_n_steps_{n_steps}"

In [None]:
%autoreload 2

for n_model, n_steps in RANGE_MODEL_STEPS:
    metadata_model_learning = get_metadata_model_learning(n_steps)
    experiment_name_model_learning = get_model_learning_experiment_name(n_model, n_steps)
    if does_experiment_exist(experiment_name_model_learning):
        print(f"Experiment '{experiment_name_model_learning}' already done...")
    else:
        run_simulation(
            metadata_model_learning,
            log=True,
            log_train=True,
            log_experiment_name=experiment_name_model_learning,
            suffix_log_experiment_name=False,
        )

In [None]:
episode_max_steps = Episode.load_experiment(
    get_model_learning_experiment_name(2, 50),
    experiment_phase=ExperimentPhase.TRAIN_MODEL
)[0]
plot_dirichlet_data_generation(
    episode_max_steps,
    "cluster_1",
    prior_dirichlet_concentration=PRIOR_DIRICHLET_CONCETRATION,
    prior_dirichlet_concentration_map=PRIOR_DIRICHLET_CONCETRATION_MAP,
    n_steps_per_plot=1,
    true_transition_probability={"": DEFAULT_JOINT_DISTRIBUTION},
    x_axis_range=[0, 1],
    x_axis_step=0.001
)
plot_dirichlet_data_generation(
    episode_max_steps,
    "cluster_2",
    prior_dirichlet_concentration=PRIOR_DIRICHLET_CONCETRATION,
    prior_dirichlet_concentration_map=PRIOR_DIRICHLET_CONCETRATION_MAP,
    n_steps_per_plot=1,
    true_transition_probability={"": DEFAULT_JOINT_DISTRIBUTION},
    x_axis_range=[0, 1],
    x_axis_step=0.001
)
plot_dirichlet_mpr_channel(
    episode_max_steps,
    prior_dirichlet_concentration=PRIOR_DIRICHLET_CONCETRATION,
    prior_dirichlet_concentration_map=PRIOR_DIRICHLET_CONCETRATION_MAP,
    n_steps_per_plot=1,
    max_packets_transmitted=4,
    true_mpr_matrix=MPR_MATRIX
)

## Collect data

### Train operational policy

In [None]:
def get_metadata_policy_opt_on_model(lr_actor=POLICY_OPTIMIZER_METADATA["learning_rate_actor"]):
    return Metadata.from_dict({
        "env_metadata": {
            **ENV_METADATA,
            "test_n_episodes": 0
        },
        # Policy
        "train_metadata": {
            "digital_twin_class": "DigitalTwinSeparateModel",
            "digital_twin_kwargs": {
                "n_packets_rollouts": N_PACKETS,
                "prior_dirichlet_concentration": PRIOR_DIRICHLET_CONCETRATION,
                "model_sampling_method": "posterior_sample",
                "n_steps_between_model_update": N_STEPS_BETWEEN_POSTERIOR_SAMPLE,
                "exploration_policy_type": "random"
            },
            "train_model_n_episodes": 0,
            "train_model_max_steps": 0,
            "train_policy_n_episodes": N_EPISODES_TRAINING_POLICY,
            "train_policy_max_steps": N_STEPS_TRAINING_POLICY,
            "policy_optimizer_class": "PolicyOptimizerCOMA",
            "policy_optimizer_kwargs": {
                **POLICY_OPTIMIZER_METADATA,
                "learning_rate_actor": lr_actor
            }
        }
    })

def get_experiment_name_policy_opt_on_model(n_model, n_steps):
    return f"{EXPERIMENT_NAME_PREFIX}_policy_opt_on_model_{EXPERIMENT_VERSION}_n_model_{n_model}_n_steps_{n_steps}"

In [None]:
%autoreload 2

for n_model, n_steps in RANGE_MODEL_STEPS:
    policy_opt_experiment_name = get_experiment_name_policy_opt_on_model(n_model, n_steps)
    model_experiment_name = get_model_learning_experiment_name(n_model, n_steps)
    metadata = get_metadata_policy_opt_on_model()
    if does_experiment_exist(policy_opt_experiment_name):
        print(f"Experiment '{policy_opt_experiment_name}' already done...")
    else:
        try:
            run_simulation(
                metadata,
                log=True,
                log_train=False,
                log_trained_model_or_policy=True,
                log_experiment_name=policy_opt_experiment_name,
                load_model_experiment_name=model_experiment_name,
                suffix_log_experiment_name=False
            )
        except Exception as e:
            print(e)
            print("ERROR ! Trying with half the learning rate for the actor...")
            if does_experiment_exist(policy_opt_experiment_name):
                os.rename(
                    os.path.join(PROJECT_FOLDER, DATA_FOLDER, policy_opt_experiment_name),
                    os.path.join(PROJECT_FOLDER, DATA_FOLDER, f"{policy_opt_experiment_name}_ERROR"),
                )
            new_actor_lr = POLICY_OPTIMIZER_METADATA["learning_rate_actor"] / 2
            metadata = get_metadata_policy_opt_on_model(new_actor_lr)
            run_simulation(
                metadata,
                log=True,
                log_train=False,
                log_trained_model_or_policy=True,
                log_experiment_name=policy_opt_experiment_name,
                load_model_experiment_name=model_experiment_name,
                suffix_log_experiment_name=False
            )

### Select policy with highest reward for each model learning dataset size

In [None]:
# Generate est episode
TEST_EPISODE_NAME = f"{EXPERIMENT_NAME_PREFIX}_test_episode_{EXPERIMENT_VERSION}"

def get_metadata_generate_test_episode():
    return Metadata.from_dict({
        "env_metadata": {
            **ENV_METADATA,
            "test_n_episodes": 1
        },
        # Policy
        "train_metadata": {
            "digital_twin_class": "DigitalTwinPolicyPassthrough",
            "digital_twin_kwargs": {},
            "train_model_n_episodes": 0,
            "train_model_max_steps": 0,
            "train_policy_n_episodes": 0,
            "train_policy_max_steps": 0,
            "policy_optimizer_class": "PolicyOptimizerAloha",
            "policy_optimizer_kwargs": {"p_transmit": 1 / N_AGENTS}
        }
    })

# Test policy
def get_metadata_test_policy():
    return Metadata.from_dict({
        "env_metadata": {
            **ENV_METADATA,
            "test_n_episodes": 0
        },
        # Policy
        "train_metadata": {
            "digital_twin_class": "DigitalTwinPolicyPassthrough",
            "digital_twin_kwargs": {},
            "train_model_n_episodes": 0,
            "train_model_max_steps": 0,
            "train_policy_n_episodes": 0,
            "train_policy_max_steps": 0,
            "policy_optimizer_class": "PolicyOptimizerCOMA",
            "policy_optimizer_kwargs": POLICY_OPTIMIZER_METADATA
        }
    })

In [None]:
if does_experiment_exist(TEST_EPISODE_NAME):
    print(f"Experiment '{TEST_EPISODE_NAME}' already done...")
else:
    run_simulation(
        get_metadata_generate_test_episode(),
        log=True,
        log_train=False,
        log_experiment_name=TEST_EPISODE_NAME,
        suffix_log_experiment_name=False
    )

for n_model, n_steps in RANGE_MODEL_STEPS:
    policy_experiment_name = get_experiment_name_policy_opt_on_model(n_model, n_steps)
    if does_experiment_test_exist(policy_experiment_name):
        print(f"Experiment '{policy_experiment_name}' policy already tested...")
    else:
        run_simulation(
            get_metadata_test_policy(),
            log=True,
            log_train=False,
            log_experiment_name=policy_experiment_name,
            load_policy_experiment_name=policy_experiment_name,
            load_forced_test_experiment_name=TEST_EPISODE_NAME,
            suffix_log_experiment_name=False
        )

In [None]:
SELECTED_POLICY = dict()

light_selection = {
    "rewards": True,
    "actions": False,
    "info": {},
    "state": False,
    "digital_twin_info": {},
    "train_info": {}
}
current_best = dict()
for n_model, n_steps in RANGE_MODEL_STEPS:
    policy_experiment_name = get_experiment_name_policy_opt_on_model(n_model, n_steps)
    episode = Episode.load_episode(
        policy_experiment_name,
        "ep_0",
        experiment_phase=ExperimentPhase.TEST_POLICY
    )
    total_reward = np.sum([step.rewards for step in episode.history])
    if (
        (current_best.get(n_steps, None) is None) or
        (current_best[n_steps] < total_reward)
    ):
        current_best[n_steps] = total_reward
        SELECTED_POLICY[n_steps] = policy_experiment_name

print(SELECTED_POLICY)

### Collect in-distribution data

In [None]:
def get_metadata_in_distribution_data_collection():
    return Metadata.from_dict({
        "env_metadata": {
            **ENV_METADATA,
            "test_n_episodes": 1,
            "test_max_steps": N_STEPS_DATA
        },
        # Policy
        "train_metadata": {
            "digital_twin_class": "DigitalTwinSeparateModel",
            "digital_twin_kwargs": {
                "n_packets_rollouts": N_PACKETS,
                "prior_dirichlet_concentration": PRIOR_DIRICHLET_CONCETRATION,
                "model_sampling_method": "posterior_sample",
                "n_steps_between_model_update": N_STEPS_BETWEEN_POSTERIOR_SAMPLE,
                "exploration_policy_type": "random"
            },
            "train_model_n_episodes": 0,
            "train_model_max_steps": 0,
            "train_policy_n_episodes": 0,
            "train_policy_max_steps": 0,
            "policy_optimizer_class": "PolicyOptimizerCOMA",
            "policy_optimizer_kwargs": POLICY_OPTIMIZER_METADATA
        }
    })

def get_experiment_name_in_distribution_data_collection(n_steps):
    return f"{EXPERIMENT_NAME_PREFIX}_in_distribution_data_collection_{EXPERIMENT_VERSION}_n_steps_{n_steps}"

In [None]:
for n_steps in N_STEPS_LIST:
    data_collect_experiment_name = get_experiment_name_in_distribution_data_collection(n_steps)
    policy_experiment_name = SELECTED_POLICY[n_steps]
    metadata = get_metadata_in_distribution_data_collection()
    if does_experiment_exist(data_collect_experiment_name):
        print(f"Experiment '{data_collect_experiment_name}' already done...")
    else:
        run_simulation(
            metadata,
            log=True,
            log_train=False,
            log_experiment_name=data_collect_experiment_name,
            load_policy_experiment_name=policy_experiment_name,
            suffix_log_experiment_name=False
        )

### Collect out-of-distribution data

In [None]:
def get_metadata_out_of_distribution_data_collection():
    return Metadata.from_dict({
        "env_metadata": {
            **ENV_METADATA,
            "test_n_episodes": 1,
            "test_max_steps": N_STEPS_DATA,
            "data_generator_probabilities_maps_kwargs": PERTURBED_DATA_GEN_PROBABILITIES_MAPS_KWARGS,
            "data_generator_dependencies_kwargs": DATA_GEN_DEPENDENCIES_KWARGS,  # We keep the same cluster topology
        },
        # Policy
        "train_metadata": {
            "digital_twin_class": "DigitalTwinSeparateModel",
            "digital_twin_kwargs": {
                "n_packets_rollouts": N_PACKETS,
                "prior_dirichlet_concentration": PRIOR_DIRICHLET_CONCETRATION,
                "model_sampling_method": "posterior_sample",
                "n_steps_between_model_update": N_STEPS_BETWEEN_POSTERIOR_SAMPLE,
                "exploration_policy_type": "random"
            },
            "train_model_n_episodes": 0,
            "train_model_max_steps": 0,
            "train_policy_n_episodes": 0,
            "train_policy_max_steps": 0,
            "policy_optimizer_class": "PolicyOptimizerCOMA",
            "policy_optimizer_kwargs": POLICY_OPTIMIZER_METADATA
        }
    })

def get_experiment_name_out_of_distribution_data_collection(n_steps):
    return f"{EXPERIMENT_NAME_PREFIX}_out_of_distribution_data_collection_{EXPERIMENT_VERSION}_n_steps_{n_steps}"

In [None]:
for n_steps in N_STEPS_LIST:
    data_collect_experiment_name = get_experiment_name_out_of_distribution_data_collection(n_steps)
    policy_experiment_name = SELECTED_POLICY[n_steps]
    metadata = get_metadata_out_of_distribution_data_collection()
    if does_experiment_exist(data_collect_experiment_name):
        print(f"Experiment '{data_collect_experiment_name}' already done...")
    else:
        run_simulation(
            metadata,
            log=True,
            log_train=False,
            log_experiment_name=data_collect_experiment_name,
            load_policy_experiment_name=policy_experiment_name,
            suffix_log_experiment_name=False
        )

## Statistical test

### Build dataset

In [None]:
import random
from typing import List
from src.data_classes import Observation, Transition, Step
from src.digital_twin.environment_model import EnvironmentModel
from src.policy.coma.coma_optimizer import PolicyOptimizerCOMA
from src.policy.coma.utils import format_actor_input

def format_steps(n_step_start: int, list_steps: List[Step]) -> List[Transition]:
    all_observations = [
        [
            Observation(
                n_packets_max=N_PACKETS_MAX,
                n_packets_buffer=n_packets_buffer,
                data_input=data_input,
                ack=ack,
                time_step=n_step_start + offset,
            )
            for ack, data_input, n_packets_buffer in zip(
                step.state.channel_ack, step.state.data_generated, step.state.agents_buffer
            )
        ]
        for offset, step in enumerate(list_steps)
    ]
    return [
        Transition(
            agents_observations=all_observations[i],
            agents_actions=list_steps[i].actions,
            agents_next_observations=all_observations[i+1],
            agents_rewards=None  # Not useful
        )
        for i in range(len(all_observations)-1)
    ]
    

def build_test_dataset(n_steps, n_samples=N_TEST_SAMPLES, n_steps_per_sample=N_STEPS_PER_SAMPLE):
    # Load data
    in_distribution_experiment_name = get_experiment_name_in_distribution_data_collection(n_steps)
    in_distribution_episode = Episode.load_episode(in_distribution_experiment_name, "ep_0")
    out_of_distribution_experiment_name = get_experiment_name_out_of_distribution_data_collection(n_steps)
    out_of_distribution_episode = Episode.load_episode(out_of_distribution_experiment_name, "ep_0")
    
    # Sample data per <n_steps_per_sample> batches
    max_index_sample = N_STEPS_DATA - (n_steps_per_sample + 1)
    in_distribution_sample_indexes = np.random.randint(0, max_index_sample, size=n_samples)
    out_of_distribution_sample_indexes = np.random.randint(0, max_index_sample, size=n_samples)
    
    samples = []
    for sample_indexes, episode, label in zip(
        [in_distribution_sample_indexes, out_of_distribution_sample_indexes],
        [in_distribution_episode, out_of_distribution_episode],
        [1, 0]
    ):
        samples += [
            (
                format_steps(sample_idx, episode.history[sample_idx:sample_idx+n_steps_per_sample+1]),
                label
            )
            for sample_idx in sample_indexes
        ]
    
    random.shuffle(samples)
    
    return samples


In [None]:
# DATASETS dict topology
# Dict:
# - key : n_steps
# - value : List :
#           - one entry per sample (either from in or out of distribution)
#           - value : Tuple
#                     - 1st entry : List[Transition]  (number of entries = N_STEPS_PER_SAMPLE)
#                     - 2nd entry : label (0 -> Out-of-distribution ; 1 -> in-distribution)

DATASETS=dict()
for n_steps in N_STEPS_LIST:
    print(f"Steps={n_steps}")
    DATASETS[n_steps] = build_test_dataset(
        n_steps,
        n_samples=N_TEST_SAMPLES,
        n_steps_per_sample=N_STEPS_PER_SAMPLE
    )

In [None]:
DATASETS[10][2*N_TEST_SAMPLES - 1]

### Get likelihood estimates

In [None]:
%autoreload 2

from typing import List, Tuple
import logging

logger = logging.getLogger()
# logger.setLevel("DEBUG")

def get_joint_action_prob(policies, transition: Transition):
    p_send_array = np.array([
        policy.get_p_transmit(observation)
        for policy, observation in zip(policies, transition.agents_observations)
    ])
    actions_array = np.array(transition.agents_actions)
    return np.prod(
        (p_send_array * actions_array) +  # probabilities when a packet is sent
        ((1 - p_send_array) * (1 - actions_array))  # probabilities when a packet is NOT sent
    )

def get_selected_submodels_joint_probability(env, transition, selected_transition_submodels):
    logger.debug("-------------------")
    data_gen = [obs.data_input for obs in transition.agents_observations]
    data_gen_next = [obs.data_input for obs in transition.agents_next_observations]
    time_step = transition.agents_next_observations[0].time_step
    logger.debug(f"t={time_step} / data_gen={data_gen} / data_gen_next={data_gen_next}")
    
    
    transition_probabilities = env.get_transition_probabilities(transition)
    
    logger.debug(transition_probabilities)
    
    
    return np.prod([transition_probabilities[submodel_name] for submodel_name in selected_transition_submodels])
    

def get_likelihood_samples(
    n_model: int,
    n_steps: int,
    transitions_samples: List[Tuple[List[Transition], int]],
    policy_experiment_name: str = None,  # None means action probability is not taken into account,
    selected_transition_submodels: list = None,
    use_action_probability: bool = True,
    n_posterior_samples: int = 1,
    frequentist_method: str = "maximum_a_posteriori",
    frequentist_prior_dirichlet_concentration: float = None
):
    model_experiment_name = get_model_learning_experiment_name(n_model, n_steps)
    
    # Load policy and get policies probabilities
    if (policy_experiment_name is None) or (not use_action_probability):
        if use_action_probability and (policy_experiment_name is None):
            logging.warning("'policy_experiment_name' set to 'None', ignoring action probabilities in the likelihood !")
        if (not use_action_probability) and (policy_experiment_name is not None):
            logging.warning("'use_action_probability' set to 'False', ignoring action probabilities in the likelihood !")
        actions_prob = np.ones((len(transitions_samples), len(transitions_samples[0][0])))
    else:
        policy = PolicyOptimizerCOMA.load(policy_experiment_name)
        policies = policy.get_agents_policies()
        actions_prob = np.array([
            [
                get_joint_action_prob(policies, transition)
                for transition in transitions
            ]
            for transitions, _ in transitions_samples
        ])
    
    # Load frequentist model and get likelihoods
    frequentist_model = EnvironmentModel.load(model_experiment_name)
    if frequentist_prior_dirichlet_concentration is not None:
        frequentist_model.update_prior_dirichlet_concentration(frequentist_prior_dirichlet_concentration)
    map_env = frequentist_model.init_env(None, None, frequentist_method)
    
    
    map_env_model_prob = []
    for transitions, label in transitions_samples:
        logger.debug("=========================")
        logger.debug(f"label = {label}")
        
        joint_probs = [
            get_selected_submodels_joint_probability(
                map_env,
                transition,
                selected_transition_submodels
            )
            for transition in transitions
        ]
        
        logger.debug(f"JOINT PROBS = {joint_probs}")

        map_env_model_prob.append(joint_probs)
    map_env_model_prob = np.array(map_env_model_prob)
    
#     map_env_model_prob = np.array([
#         [
#             get_selected_submodels_joint_probability(
#                 map_env,
#                 transition,
#                 selected_transition_submodels
#             )
#             for transition in transitions
#         ]
#         for transitions, _ in transitions_samples
#     ])
    
    map_env_likelihoods = np.prod(  # Multiply probabilities accross the consecutive time steps
        actions_prob * map_env_model_prob,  # Overall transition probability per step
        axis=1
    ).reshape(-1, 1)
    
    # Load Bayesian model and get likelihoods
    bayesian_model = EnvironmentModel.load(model_experiment_name)
    posterior_sample_env_model_probability = np.zeros((len(transitions_samples), n_posterior_samples))
    posterior_sample_env_likelihoods = np.zeros((len(transitions_samples), n_posterior_samples))
    for idx_model_sample in range(n_posterior_samples):
        posterior_sample_env = bayesian_model.init_env(None, None, "posterior_sample")
        posterior_sample_env_model_prob = np.array([
            [
                get_selected_submodels_joint_probability(
                    posterior_sample_env,
                    transition,
                    selected_transition_submodels
                )
                for transition in transitions
            ]
            for transitions, _ in transitions_samples
        ])
        posterior_sample_env_likelihoods[:, idx_model_sample] = np.prod(  # Multiply probabilities accross the consecutive time steps
            actions_prob * posterior_sample_env_model_prob,  # Overall transition probability per step
            axis=1
        )
    
    return map_env_likelihoods, posterior_sample_env_likelihoods


def compute_log_likelihood_estimates(likelihoods, epsilon=EPSILON):
    return np.mean(
        np.log(likelihoods + epsilon),
        axis=1
    ).reshape(-1, 1)

def compute_disagreement_score(likelihoods, epsilon=EPSILON):
    # From paper : https://arxiv.org/pdf/1912.05651.pdf
    sum_likelihoods = np.sum(likelihoods, axis=1).reshape(-1, 1)
    normalized_likelihoods = likelihoods / sum_likelihoods
    squared_score = np.power(normalized_likelihoods, 2)
    return (1 / np.sum(squared_score, axis=1)).reshape(-1, 1)

def compute_log_likelihood_opposite_std(likelihoods, epsilon=EPSILON):
    log_likelihoods = np.log(likelihoods + epsilon)
    log_likelihood_means = np.mean(log_likelihoods, axis=1).reshape(-1, 1)
    log_likelihoods_distances_from_mean = np.power(log_likelihoods - log_likelihood_means, 2)
    log_likelihood_variances = np.mean(log_likelihoods_distances_from_mean, axis=1).reshape(-1, 1)
    return -np.sqrt(log_likelihood_variances)

def compute_log_likelihood_std_v2(ml_likelihoods, posterior_likelihoods, epsilon=EPSILON):
    ml_log_likelihoods = np.log(ml_likelihoods + epsilon)
    posterior_log_likelihoods = np.log(posterior_likelihoods + epsilon)
    log_likelihoods_distances_from_mean = np.power(posterior_log_likelihoods - ml_log_likelihoods, 2)
    log_likelihood_variances = np.mean(log_likelihoods_distances_from_mean, axis=1).reshape(-1, 1)
    return np.sqrt(log_likelihood_variances)

def compute_mutual_information(likelihoods, epsilon=EPSILON):
    # TODO: THIS IS NOT CORRECT !!!
    marginal_likelihood = np.mean(likelihoods, axis=1).reshape(-1, 1)
    kl_distance_to_marginal_likelihood = np.log((likelihoods + epsilon) / (marginal_likelihood + epsilon))
    return np.mean(kl_distance_to_marginal_likelihood, axis=1).reshape(-1, 1)
    

### Get test estimators (get soft prediction i.e. likelihoods only)

In [None]:
SELECTED_TRANSITION_SUBMODELS = [
    "data_gen_cluster_1", #"data_gen_cluster_2"
]
N_POSTERIOR_SAMPLES = 1000  # Number of posterior sampled models
FREQUENTIST_MODEL_METHOD = "maximum_a_posteriori"
FREQUENTIST_PRIOR_DIRICHLET_CONCENTRATION = PRIOR_DIRICHLET_CONCETRATION_MAP
# "maximum_a_posteriori" is similar to "maximum_likelihood" if prior concentration is low

# Select which function to use for estimation of epistemic uncertainty
LOG_LIKELIHOOD_THRESHOLD_RANGE = (np.log(EPSILON), np.log(1 + EPSILON))

# Standard deviation of posterior models using ensemble predictor as mean
# Note: technically when using the ML as mean we are not using the ensemble predictor but this is quite similar for small prior concentration
# BAYESIAN_EPISTEMIC_UNCERTAINTY_ESTIMATOR_FUNC = compute_log_likelihood_std_v2
# BAYESIAN_ESTIMATOR_THRESHOLD_RANGE = (np.log(1 + EPSILON), -np.log(EPSILON))

# Opposite standard deviation of posterior models using posterior sample mean
BAYESIAN_EPISTEMIC_UNCERTAINTY_ESTIMATOR_FUNC = compute_log_likelihood_opposite_std
BAYESIAN_ESTIMATOR_THRESHOLD_RANGE = (np.log(EPSILON), 0)

# Disagreement score of posterior models
# BAYESIAN_EPISTEMIC_UNCERTAINTY_ESTIMATOR_FUNC = compute_disagreement_score
# BAYESIAN_ESTIMATOR_THRESHOLD_RANGE = (1, N_POSTERIOR_SAMPLES)

# Average log-likelihood of posterior models
# BAYESIAN_EPISTEMIC_UNCERTAINTY_ESTIMATOR_FUNC = compute_log_likelihood_estimates
# BAYESIAN_ESTIMATOR_THRESHOLD_RANGE = (np.log(EPSILON), np.log(1 + EPSILON))

# Multual information
# BAYESIAN_EPISTEMIC_UNCERTAINTY_ESTIMATOR_FUNC = compute_mutual_information
# BAYESIAN_ESTIMATOR_THRESHOLD_RANGE = (-np.log(1 + (1 / (N_POSTERIOR_SAMPLES * EPSILON))), 0)

In [None]:
%autoreload 2

# Estimators dict topology
# Dict:
# - key : n_steps
# - value : List :
#           - one entry per n_model
#           - value : Array[2 * N_STEPS_DATA ; 3]
#                     - Column 1 : map_env_likelihoods
#                     - Column 2 : posterior_env_likelihoods or scores
#                     - Column 3 : labels
ESTIMATORS = dict()

for n_model, n_steps in RANGE_MODEL_STEPS:
    print(f"Steps={n_steps} / Model={n_model}")
    map_env_likelihoods, posterior_env_likelihoods = get_likelihood_samples(
        n_model,
        n_steps,
        DATASETS[n_steps],
        policy_experiment_name=None, # DO NOT TAKE ACTION PROB INTO ACCOUNT, OTHERWISE SET TO "SELECTED_POLICY[n_steps]",
        selected_transition_submodels=SELECTED_TRANSITION_SUBMODELS,
        use_action_probability=False,
        n_posterior_samples=N_POSTERIOR_SAMPLES,
        frequentist_method=FREQUENTIST_MODEL_METHOD,
        frequentist_prior_dirichlet_concentration=FREQUENTIST_PRIOR_DIRICHLET_CONCENTRATION
    )
    map_env_estimator = compute_log_likelihood_estimates(map_env_likelihoods)
#         posterior_env_estimator = BAYESIAN_EPISTEMIC_UNCERTAINTY_ESTIMATOR_FUNC(ml_env_likelihoods, posterior_env_likelihoods)
    posterior_env_estimator = BAYESIAN_EPISTEMIC_UNCERTAINTY_ESTIMATOR_FUNC(posterior_env_likelihoods)
    labels = np.array([label for _, label in DATASETS[n_steps]]).reshape(-1, 1)
    if n_steps not in ESTIMATORS.keys():
        ESTIMATORS[n_steps] = []
    ESTIMATORS[n_steps].append(np.c_[
        map_env_estimator,
        posterior_env_estimator,
        labels
    ])

In [None]:
ESTIMATORS[20][3]

## Plot ROC curves

### Get ROC data

In [None]:
def get_confusion_matrix(log_likelihoods, labels, threshold, normalize=False):
    _label_map = {
        (0, 0): "TN",
        (0, 1): "FN",
        (1, 0): "FP",
        (1, 1): "TP",
    }
    preds = np.where(log_likelihoods >= threshold, 1, 0)
    values, counts = np.unique(np.c_[preds, labels], return_counts=True, axis=0)
    if normalize:
        counts = counts.astype(np.float32) / len(labels)
    out = dict(
        TN=0, FN=0, FP=0, TP=0
    )
    out.update({
        _label_map[tuple(pred_label_pair)]: nb_occurences
        for pred_label_pair, nb_occurences in zip(values, counts)
    })
    return out

def get_roc_data(estimator, labels, threshold_range, num_points=100):
    thresholds = np.linspace(threshold_range[0], threshold_range[1], num_points)
    false_positive_rate = []
    true_positive_rate = []
    for threshold in thresholds:
        cm = get_confusion_matrix(estimator, labels, threshold)
        false_positive_rate.append(cm["FP"] / (cm["FP"] + cm["TN"]))
        true_positive_rate.append(cm["TP"] / (cm["TP"] + cm["FN"]))
    return false_positive_rate, true_positive_rate, thresholds
            



In [None]:
# Get all roc_data estimators
NUM_POINTS = 10000

ROC_DATA_MAP = dict()
ROC_DATA_POSTERIOR = dict()

for n_steps, estimators_list in ESTIMATORS.items():
    ROC_DATA_MAP[n_steps] = {"fpr": None, "tpr": None}
    ROC_DATA_POSTERIOR[n_steps] = {"fpr": None, "tpr": None}
    for n_model, estimator in enumerate(estimators_list):
        print(f"Model={n_model} / Steps={n_steps}")
        fpr_ml, tpr_ml, _ = get_roc_data(
            estimator[:, 0],
            estimator[:, 2],
            LOG_LIKELIHOOD_THRESHOLD_RANGE,
            num_points=NUM_POINTS
        )
        # ML
        if ROC_DATA_MAP[n_steps]["fpr"] is None:
            ROC_DATA_MAP[n_steps]["fpr"] = fpr_ml
            ROC_DATA_MAP[n_steps]["tpr"] = tpr_ml
        else:
            ROC_DATA_MAP[n_steps]["fpr"] = np.c_[ROC_DATA_MAP[n_steps]["fpr"], fpr_ml]
            ROC_DATA_MAP[n_steps]["tpr"] = np.c_[ROC_DATA_MAP[n_steps]["tpr"], tpr_ml]
        
        # Posterior
        fpr_posterior, tpr_posterior, _ = get_roc_data(
            estimator[:, 1],
            estimator[:, 2],
            BAYESIAN_ESTIMATOR_THRESHOLD_RANGE,
            num_points=NUM_POINTS
        )
        if ROC_DATA_POSTERIOR[n_steps]["fpr"] is None:
            ROC_DATA_POSTERIOR[n_steps]["fpr"] = fpr_posterior
            ROC_DATA_POSTERIOR[n_steps]["tpr"] = tpr_posterior
        else:
            ROC_DATA_POSTERIOR[n_steps]["fpr"] = np.c_[
                ROC_DATA_POSTERIOR[n_steps]["fpr"],
                fpr_posterior
            ]
            ROC_DATA_POSTERIOR[n_steps]["tpr"] = np.c_[
                ROC_DATA_POSTERIOR[n_steps]["tpr"],
                tpr_posterior
            ]

### Plot ROC curves with standard deviation across models

In [None]:
def plot_roc_curve_with_errors(
    ax: plt.Axes,
    roc_data,
    lineplot_kwargs: dict = None
):
    rates_mean = []
    rates_low = []
    rates_high = []
    for rate in (roc_data["fpr"], roc_data["tpr"]):
        rate_mean = np.mean(rate, axis=1)
        rate_std = np.std(rate, axis=1)
        rates_mean.append(rate_mean)
        rates_low.append(rate_mean - (1.96 * rate_std))
        rates_high.append(rate_mean + (1.96 * rate_std))
    
    # Axis formatting
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_xlim([-0.05, 1.05])
    ax.set_ylim([-0.05, 1.05])
    
    # Plot ROC
    ax.plot(rates_mean[0], rates_mean[1], **lineplot_kwargs)  # mean val
    ax.fill_between(rates_mean[0], rates_low[1], rates_high[1], alpha=.20, **lineplot_kwargs)
    
    # Plot random pred diagonal
    ax.plot([0, 1], [0, 1], color="black", linestyle="dashed")

In [None]:
select_n_steps = 20

fig, ax = plt.subplots()

fig.set_size_inches(10, 10)

plot_roc_curve_with_errors(
    ax,
    ROC_DATA_MAP[select_n_steps],
    lineplot_kwargs={
        "color": "green", "label": "Maximum Likelihood Model"
    }
)


plot_roc_curve_with_errors(
    ax,
    ROC_DATA_POSTERIOR[select_n_steps],
    lineplot_kwargs={
        "color": "orange", "label": "Posterior Sampled Model"
    }
)


### Plot multiple mean ROC curves

In [None]:
from typing import Dict

from src.view.utils import add_annotation

def plot_multiple_mean_roc_curves(
    ax: plt.Axes,
    all_roc_data: Dict[int, dict],
    n_steps_models_to_plot: List[int],
    all_lineplot_kwargs: Dict[int, dict] = None,
    label_annotation: str = "",
    y_pos_text_annotation: float = 0.4,
    y_pos_arrow_annotation: float = 0.7
):
    all_lineplot_kwargs = all_lineplot_kwargs or dict()
    
    # Get means
    all_rates_means = []
    for n_steps in n_steps_models_to_plot:
        roc_data = all_roc_data[n_steps]
        fpr_mean = np.mean(roc_data["fpr"], axis=1)
        tpr_mean = np.mean(roc_data["tpr"], axis=1)
        all_rates_means.append((fpr_mean, tpr_mean))
    
    # Plot formatting
    fig, ax = subplot
    fig.set_size_inches(10, 10)
    min_ax_lim = -0.05
    max_ax_lim = 1.05
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_xlim([min_ax_lim, max_ax_lim])
    ax.set_ylim([min_ax_lim, max_ax_lim])
    ax.tick_params(labelsize=24)
    ax.xaxis.label.set_size(24)
    ax.yaxis.label.set_size(24)
    
    # Plot ROC curves
    for n_steps, rates_mean in zip(n_steps_models_to_plot, all_rates_means):
        ax.plot(rates_mean[0], rates_mean[1], **all_lineplot_kwargs.get(n_steps, dict()))
    
    # Print fpr at tpr=0.8
    indicator_tpr_value = 0.8
    for n_steps, rates_mean in zip(n_steps_models_to_plot, all_rates_means):
        tpr_arr = np.array(rates_mean[1])
        idx_indicator = np.argmax(tpr_arr <= indicator_tpr_value)
        indicator_tpr = rates_mean[1][idx_indicator]
        indicator_fpr = rates_mean[0][idx_indicator]
        print(f"Indicator point for n_steps={n_steps}: {(indicator_fpr, indicator_tpr)}")
    
    # Plot random pred diagonal
    ax.plot([0, 1], [0, 1], color="black", linestyle=":")
    
    # Add anotations
    x_pos_text_annotation = 0.75
    all_pos_arrow = []
    for rates_mean in all_rates_means:
        tpr_arr = np.array(rates_mean[1])
        idx_pos = np.argmax(tpr_arr <= y_pos_arrow_annotation)
        all_pos_arrow.append(
            (rates_mean[0][idx_pos], rates_mean[1][idx_pos])
        )
    
    for pos_arrow in all_pos_arrow:
        add_annotation(
            ax,
            label_annotation,
            pos_arrow,
            (x_pos_text_annotation, y_pos_text_annotation),
            arrow_kwargs=dict(
                mutation_scale=15,
                linewidth=1
            ),
            text_kwargs=dict(
                fontsize=28
            )
        )

In [None]:
n_steps_to_plot = [20, 50]

subplot = plt.subplots()

plot_multiple_mean_roc_curves(
    subplot,
    ROC_DATA_MAP,
    n_steps_to_plot,
    all_lineplot_kwargs={
        n_steps_to_plot[0]: {"color": "green", "label": "MAP Model", "linewidth": 2.5},
        n_steps_to_plot[1]: {"color": "green", "label": "MAP Model", "linewidth": 2.5, "linestyle":"-."},
    },
    label_annotation="Frequentist",
    y_pos_text_annotation=0.5,
    y_pos_arrow_annotation=0.8
)


plot_multiple_mean_roc_curves(
    subplot,
    ROC_DATA_POSTERIOR,
    n_steps_to_plot,
    all_lineplot_kwargs={
        n_steps_to_plot[0]: {"color": "orange", "label": "Posterior Sampled Model", "linewidth": 2.5},
        n_steps_to_plot[1]: {"color": "orange", "label": "Posterior Sampled Model", "linewidth": 2.5, "linestyle":"-."},
    },
    label_annotation="Bayesian",
    y_pos_text_annotation=0.3,
    y_pos_arrow_annotation=0.65
)

## Annex

### Check likelihood mean and std for MAP

In [None]:
def get_model_likelihood_mean_and_std_for_each_ground_truth_dynamics(n_model, n_step, n_steps_per_sample):
    _cluster_joint_states = [(0,0), (0,1), (1,0), (1,1)]
    _cluster_joint_states_encoded = ["0,0", "0,1", "1,0", "1,1"]
    
    # Get ground truth dynamics for cluster 1
    normal_dynamics_cluster = np.array([
        DEFAULT_JOINT_DISTRIBUTION[joint_state_data_gen]
        for joint_state_data_gen in _cluster_joint_states_encoded
    ])
    abnormal_dynamics_cluster = np.array([
        PERTURBED_DEFAULT_JOINT_DISTRIBUTION[joint_state_data_gen]
        for joint_state_data_gen in _cluster_joint_states_encoded
    ])
    
    # Get model cluster 1 dynamics
    model_experiment_name = get_model_learning_experiment_name(n_model, n_steps)
    model = EnvironmentModel.load(model_experiment_name)
    model.update_prior_dirichlet_concentration(PRIOR_DIRICHLET_CONCETRATION_MAP)
    env = model.init_env(None, None, "maximum_a_posteriori")
    p_map_cluster = env.data_gen.transition._indexed_probabilities_maps["cluster_1"]
    model_dynamics_cluster = np.array([
        p_map_cluster._dict_default_joint_distribution[joint_state_data_gen]
        for joint_state_data_gen in _cluster_joint_states
    ])
    
    # Mean
    normal_mean_likelihood = np.sum(normal_dynamics_cluster * model_dynamics_cluster)
    abnormal_mean_likelihood = np.sum(abnormal_dynamics_cluster * model_dynamics_cluster)
    
    # Var
    normal_var_likelihood = np.sum(
        normal_dynamics_cluster * np.power(
            model_dynamics_cluster - normal_mean_likelihood,
            2
        )
    )
    abnormal_var_likelihood = np.sum(
        abnormal_dynamics_cluster * np.power(
            model_dynamics_cluster - abnormal_mean_likelihood,
            2
        )
    )
    
    # Scale for consecutive time steps likelihoods (this is only true if data gen is independent across time steps)
    normal_mean_likelihood *= n_steps_per_sample
    abnormal_mean_likelihood *= n_steps_per_sample
    normal_var_likelihood *= n_steps_per_sample
    abnormal_var_likelihood *= n_steps_per_sample
    
    return {
        "normal": {"mean": normal_mean_likelihood, "std": normal_var_likelihood**(1/2)},
        "abnormal": {"mean": abnormal_mean_likelihood, "std": abnormal_var_likelihood**(1/2)},
    }

def plot_average_likelihoods_clt_gaussians(ax, n_model, n_steps, n_steps_per_sample, epsilon=0):
    import scipy.stats as stats
    
    likelihood_means_stds = get_model_likelihood_mean_and_std_for_each_ground_truth_dynamics(
        n_model, n_steps, n_steps_per_sample, epsilon=epsilon
    )
    n_points = 1000
    max_std = max(ll_means_stds["normal"]["std"], ll_means_stds["abnormal"]["std"])
    min_mean = min(ll_means_stds["normal"]["mean"], ll_means_stds["abnormal"]["mean"])
    max_mean = max(ll_means_stds["normal"]["mean"], ll_means_stds["abnormal"]["mean"])
    x_range = (-3*max_std + min_mean, 3*max_std + max_mean)
    x_space = np.linspace(*x_range, n_points)
    
    normal_clt_std = likelihood_means_stds["normal"]["std"] / 1 #np.sqrt(N_TEST_SAMPLES)
    ax.plot(
        x_space,
        stats.norm.pdf(x_space, likelihood_means_stds["normal"]["mean"], normal_clt_std),
        color="green"
    )
    
    abnormal_clt_std = likelihood_means_stds["abnormal"]["std"] / 1 #np.sqrt(N_TEST_SAMPLES)
    ax.plot(
        x_space,
        stats.norm.pdf(x_space, likelihood_means_stds["abnormal"]["mean"], abnormal_clt_std),
        color="red"
    )
    

n_steps = 10
n_models = 50

subplots = plt.subplots((n_models // 5) + 1, 5)
fig, axs = subplots
fig.set_size_inches(15, 10)
axs = axs.flatten()

n_steps_per_sample = N_STEPS_PER_SAMPLE
for n_model in range(n_models):
    plot_average_log_likelihoods_clt_gaussians(axs[n_model], n_model, n_steps, n_steps_per_sample, epsilon=EPSILON)

In [None]:
# Mutual_info
select_n_steps = 20

fig, ax = plt.subplots()

fig.set_size_inches(10, 10)

plot_roc_curve(
    ax,
    roc_data_ml[select_n_steps],
    lineplot_kwargs={
        "color": "green", "label": "Maximum Likelihood Model"
    }
)


plot_roc_curve(
    ax,
    roc_data_posterior[select_n_steps],
    lineplot_kwargs={
        "color": "orange", "label": "Posterior Sampled Model"
    }
)


In [None]:
# Disagreement
select_n_steps = 20

fig, ax = plt.subplots()

fig.set_size_inches(10, 10)

plot_roc_curve(
    ax,
    roc_data_ml[select_n_steps],
    lineplot_kwargs={
        "color": "green", "label": "Maximum Likelihood Model"
    }
)


plot_roc_curve(
    ax,
    roc_data_posterior[select_n_steps],
    lineplot_kwargs={
        "color": "orange", "label": "Posterior Sampled Model"
    }
)


In [None]:
# ML
select_n_steps = 20

fig, ax = plt.subplots()

fig.set_size_inches(10, 10)

plot_roc_curve(
    ax,
    roc_data_ml[select_n_steps],
    lineplot_kwargs={
        "color": "green", "label": "Maximum Likelihood Model"
    }
)


plot_roc_curve(
    ax,
    roc_data_posterior[select_n_steps],
    lineplot_kwargs={
        "color": "orange", "label": "Posterior Sampled Model"
    }
)


In [None]:
# Var
select_n_steps = 20

fig, ax = plt.subplots()

fig.set_size_inches(10, 10)

plot_roc_curve(
    ax,
    roc_data_ml[select_n_steps],
    lineplot_kwargs={
        "color": "green", "label": "Maximum Likelihood Model"
    }
)


plot_roc_curve(
    ax,
    roc_data_posterior[select_n_steps],
    lineplot_kwargs={
        "color": "orange", "label": "Posterior Sampled Model"
    }
)


In [None]:
# score
select_n_steps = 10

fig, ax = plt.subplots()

fig.set_size_inches(10, 10)

plot_roc_curve(
    ax,
    roc_data_ml[select_n_steps],
    lineplot_kwargs={
        "color": "green", "label": "Maximum Likelihood Model"
    }
)


plot_roc_curve(
    ax,
    roc_data_posterior[select_n_steps],
    lineplot_kwargs={
        "color": "orange", "label": "Posterior Sampled Model"
    }
)


In [None]:
# LL
select_n_steps = 10

fig, ax = plt.subplots()

fig.set_size_inches(10, 10)

plot_roc_curve(
    ax,
    roc_data_ml[select_n_steps],
    lineplot_kwargs={
        "color": "green", "label": "Maximum Likelihood Model"
    }
)


plot_roc_curve(
    ax,
    roc_data_posterior[select_n_steps],
    lineplot_kwargs={
        "color": "orange", "label": "Posterior Sampled Model"
    }
)


In [None]:
# Var
select_n_steps = 10

fig, ax = plt.subplots()

fig.set_size_inches(10, 10)

plot_roc_curve(
    ax,
    roc_data_ml[select_n_steps],
    lineplot_kwargs={
        "color": "green", "label": "Maximum Likelihood Model"
    }
)


plot_roc_curve(
    ax,
    roc_data_posterior[select_n_steps],
    lineplot_kwargs={
        "color": "orange", "label": "Posterior Sampled Model"
    }
)


In [None]:
# Mutual_info
select_n_steps = 10

fig, ax = plt.subplots()

fig.set_size_inches(10, 10)

plot_roc_curve(
    ax,
    roc_data_ml[select_n_steps],
    lineplot_kwargs={
        "color": "green", "label": "Maximum Likelihood Model"
    }
)


plot_roc_curve(
    ax,
    roc_data_posterior[select_n_steps],
    lineplot_kwargs={
        "color": "orange", "label": "Posterior Sampled Model"
    }
)
