In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from rice import Rice
import numpy as np

""" SIMULATIONS"""
var_name_mapping = {
    "global_temperature": "Global Temperature",
    "global_carbon_mass": "Global Carbon Mass",
    "gross_output_all_regions": "Global Gross Output",
    "capital_all_regions": "Global Capital",
    "aggregate_consumption": "Global Consumption",
    "utility_all_regions": "Total Utility",
    "damages_all_regions": "Mean Region Damage",
}

tracked_vars = list(var_name_mapping.keys())

rice_env = Rice()
rice_env.reset()
total_actions = rice_env.total_possible_actions
regions = list(range(rice_env.num_regions))

class RiceAction:
    def __init__(self, dict):
        self.actions = np.zeros(len(total_actions))

        for k, v in dict.items():
            start_idx = rice_env.get_actions_index(k)
            end_idx = start_idx + rice_env.get_actions_len(k)
            self.actions[start_idx:end_idx] = v

def run_simulations(env, dmg_function_name, mitigation_rates, savings_rate):
    data = []
    env.reset()
    for mitigation_rate in mitigation_rates:
        while True:
            ind_actions = RiceAction({
                "savings": savings_rate,
                "mitigation_rate": mitigation_rate
            }).actions

            actions = {region_id: ind_actions for region_id in regions}
            obs, rew, done, truncated, info = env.step(actions)
            if done["__all__"]:
                env.reset()
                break

            current_timestep = env.current_timestep
            for var in tracked_vars:
                val = env.get_state(var)

                if var == "global_temperature":
                    val = val[0]
                elif var == "global_carbon_mass":
                    val = np.sum(val).item()
                elif var in ["gross_output_all_regions", "capital_all_regions", "aggregate_consumption", "utility_all_regions"]:
                    val = np.sum(val).item()
                elif var == "damages_all_regions":
                    val = np.mean(val)

                data.append({
                    "Mitigation rate": mitigation_rate,
                    "Savings rate": savings_rate,
                    "Value": val,
                    "Variable": var_name_mapping[var],
                    "Damage function": dmg_function_name,
                    "Year": current_timestep * 5,
                })
    return data





In [3]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import ipywidgets as widgets
from matplotlib.lines import Line2D



palette = sns.color_palette("rocket", as_cmap=True)

def plot_df(df, selected_model=None):
    fig, axes = plt.subplots(2, 3, figsize=(16, 10))
    model_versions = df["Damage function"].unique()

    def plot_var(df, var_name, ax, alpha, linestyle="-"):
        sns.lineplot(df[(df["Variable"] == var_name)], x="Year", y="Value", hue="Mitigation rate", linestyle=linestyle, ax=ax, palette=palette, alpha=alpha)
        ax.set_ylabel(var_name)

    for model_version in model_versions:
        if selected_model == model_version or selected_model == "All":
            alpha = 1
        else:
            alpha = 0.3

        linestyle = "-" if model_version == "Base" else "--"

        curr_df = df[df["Damage function"] == model_version]
        plot_var(curr_df, "Global Temperature", axes[0, 0], alpha=alpha, linestyle=linestyle)
        plot_var(curr_df, "Global Carbon Mass", axes[0, 1], alpha=alpha, linestyle=linestyle)
        plot_var(curr_df, "Mean Region Damage", axes[0, 2], alpha=alpha, linestyle=linestyle)

        plot_var(curr_df, "Global Capital", axes[1, 0], alpha=alpha, linestyle=linestyle)
        plot_var(curr_df, "Global Gross Output", axes[1, 1], alpha=alpha, linestyle=linestyle)
        plot_var(curr_df, "Global Consumption", axes[1, 2], alpha=alpha, linestyle=linestyle)

    # Legend
    legend_handles, legend_labels = axes[0, 0].get_legend_handles_labels()
    for row in axes:
        for ax in row:
            ax.legend_.remove()

    num_mitigations = len(df["Mitigation rate"].unique())
    fig.legend([Line2D([0], [0], linestyle="")] + legend_handles[:num_mitigations], ["Mitigation rate"] + legend_labels, loc="upper center",  ncol=15, bbox_to_anchor=(0.5, 0.92))

# plot_df(df, "Base")


In [4]:
""" MAIN """
MITIGATION_RATES = [0, 2, 4, 6, 8, 9, 10]
SAVINGS_RATE = 2
MODEL_VERSIONS = ["Base", "Updated"]

# Simulations
all_data = []
for model_version in MODEL_VERSIONS:
    rice = Rice(dmg_function=model_version)
    sim_data = run_simulations(rice, model_version, MITIGATION_RATES, SAVINGS_RATE) # Returns a list
    all_data += sim_data

# Plotting
df = pd.DataFrame(all_data)
model_version_button = widgets.ToggleButtons(
    options=MODEL_VERSIONS,
    description="Damage Function",
    button_style="",
)
widgets.interact(lambda selected_model: plot_df(df, selected_model), selected_model=model_version_button)


interactive(children=(ToggleButtons(description='Damage Function', options=('Base', 'Updated'), value='Base'),…

<function __main__.<lambda>(selected_model)>