In [2]:
import os
from sam import chdir_to_repopath
chdir_to_repopath()
from sam.dose_reponse_fit import dose_response_fit, ModelPredictions, survival_to_stress, FitSettings, Transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sam.data_formats import read_data, load_files, load_datapoints
from sam.plotting import plot_fit_prediction
from sam.system_stress import pred_surv_without_hormesis
from sam.helpers import detect_hormesis_index, pad_c0, compute_lc_from_curve, weibull_2param, weibull_2param_inverse, compute_lc, ll5_inv
from sam.data_formats import ExperimentData
from scipy.optimize import brentq
import seaborn as sns
from sam.stress_addition_model import OLD_STANDARD, sam_prediction, get_sam_lcs, stress_to_survival, survival_to_stress
from tqdm import tqdm
import plotly.graph_objects as go

In [3]:
# ignore = [
#     "huang_Flupyradifurone_2023",
#     "imrana_copper_2024",
#     "ayesha_chlorantran_2022_reference",
#     "ayesha_chlorantran_2022_agriculture",
#     "imrana_salt_2024",
#     "naemm_Clothianidin_2024_reference",
#     "imrana_2024_food",
# ]

ignore = [
    #     'bps_esf_food_2024',
    #    'ayesha-cloth_Clothianidin_2024_agricultural.xlsx',
    #    'huang_imidachloprid_2023.xlsx',
    #    'ayesha-cloth_Clothianidin_2024_agricultural_pre-contamination.xlsx',
    #    'ayesha-cloth_Clothianidin_2024_reference_pre-contamination.xlsx',
    #    'naeem_Esfenvalerate_2019.xlsx',
]

# ignore = [
    
#     'imrana_salt_2024',
#        'ayesha_chlorantran_2022_agriculture',
#        'huang_Flupyradifurone_2023',
#        'ayesha_chlorantran_2022_reference'
#        'imrana_2024_food',
#        'imrana_copper_2024',
#        "ayesha_chlorantran_2022_reference",
    
# ]

filter_func = lambda path: not any(i in path for i in ignore)

In [4]:
def predict_cleaned_curv(data : ExperimentData):
    
    
    
    concentration = pad_c0(data.main_series.concentration).copy()
    survival_tox_observerd = np.copy(data.main_series.survival_rate / data.meta.max_survival)
    
    if data.meta.hormesis_concentration is None:
    
        hormesis_index = detect_hormesis_index(survival_tox_observerd)
        
        if hormesis_index is None:
            hormesis_index = 1
    
    else:
        hormesis_index = np.argwhere(data.meta.hormesis_concentration == data.main_series.concentration)[0,0]
    
    func, _, popt = pred_surv_without_hormesis(concentration=concentration, surv_withhormesis=survival_tox_observerd, hormesis_index=hormesis_index)
    
    return func, hormesis_index, popt

In [5]:
def gen_func(stress, cleaned_func):
    
    def func(x, stress):
        
        y = cleaned_func(x)
        
        stress = survival_to_stress(y) + stress
        
        return stress_to_survival(stress)
        
    return np.vectorize(lambda x: func(x,  stress = stress))

def find_lc_brentq(func, lc, min_v = 1e-8, max_v = 100000):
    
    
    left_val = func(min_v)
    lc = (100 - lc) / 100 * left_val
    
    brent_func = lambda x : func(x) - lc
    
    return brentq(brent_func, min_v, max_v)

In [None]:
stresses = np.linspace(0, 0.6, 100)


def compute_lc_trajectory(path: str):

    data = read_data(path)

    cfg = FitSettings(
        survival_max=data.meta.max_survival,
        param_d_norm=True,
    )

    fit = dose_response_fit(data.main_series, cfg)
    cleaned_func, hormesis_index, popt = predict_cleaned_curv(data)

    x = fit.concentration_curve

    lcs = []

    for stress in stresses:

        func = gen_func(stress, cleaned_func=fit.model)

        lcs.append(
            (
                find_lc_brentq(func, 10, max_v=x.max()),
                find_lc_brentq(func, 50, max_v=x.max()),
            )
        )

    return np.array(lcs)

results = {}

for path, _ in tqdm(load_files(filter_func)):
    
    results[path] = compute_lc_trajectory(path)

In [27]:
lc10 = np.array([i[:,0] for i in results.values()])
lc50 = np.array([i[:,1] for i in results.values()])
lc_10_frac = lc10[:,0][:,None] / lc10
lc_50_frac = lc50[:,0][:,None] / lc50

In [None]:
meta_infos = []
for _, data in load_files(filter_func):

    meta = data.meta
    meta_infos.append(
        {
            "Name":meta.title,
            "Chemical": meta.chemical,
            "Organism": meta.organism,
            "Experiment": data.meta.path.parent.name,
            "Duration": int(meta.days),
        }
    )
meta_infos = pd.DataFrame(meta_infos)
meta_infos["lc_10_frac"] = [np.array(a) for a in lc_10_frac]
meta_infos["lc_50_frac"] = [np.array(a) for a in lc_50_frac]
meta_infos.head()

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go


def gen_different_curves_fig():

    unique_experiments = meta_infos.Chemical.unique()
    palette = sns.color_palette(
        "Set2", len(unique_experiments)
    )  # Use a Seaborn color palette
    color_mapping = dict(zip(unique_experiments, palette.as_hex()))
    color_mapping["Mean"] = "black"
    color_mapping["Mean of All"] = "red"

    name_to_id = []

    fig = make_subplots(rows=1, cols=2)

    def gen_traces(y_key):
        ts = list()
        for _, row in meta_infos.iterrows():

            color = color_mapping[row.Chemical]
            ts.append(
                go.Scatter(
                    x=stresses,
                    y=row[y_key],
                    mode="lines",
                    name=row.Name,
                    line=dict(color=color),
                    hovertext=f"<br><b>Name</b>: {row.Name}<br><b>Experiment</b>: {row.Experiment}<br><b>Duration</b>: {row.Duration}<br><b>Main Stressor</b>: {row.Chemical}<br><b>Organism</b>: {row.Organism}",
                    showlegend=False,
                )
            )
            name_to_id.append(f"scatter_{row.Name}_{y_key}")
        return ts

    for i in gen_traces("lc_10_frac"):
        fig.add_trace(
            i,
            row=1,
            col=1,
        )
    for i in gen_trac # Legend entry for each chemicales("lc_50_frac"):
        fig.add_trace(
            i,
            row=1,
            col=2,
        )

    for chemical, color in color_mapping.items():
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="lines",
                line=dict(color=color),
                name=chemical,  # Legend entry for each chemical
            )
        )
        name_to_id.append(f"color_{chemical}")

    cleaner = {
        "Chemical": "Main Stressor",
        "Experiment": "Experiment",
        "Duration": "Duration",
        "Organism": "Organism",
    }

    def gen_means(y_key):
        mean_curve = np.mean(np.stack(meta_infos[y_key].values), 0)

        yield go.Scatter(
            x=stresses,
            y=mean_curve,
            mode="lines",
            name="Mean",
            line=dict(color=color_mapping["Mean"]),
            showlegend=False,
        )
        name_to_id.append(f"mean_all_{y_key}")

        yield go.Scatter(
            x=stresses,
            y=mean_curve,
            mode="lines",
            name="Mean of All",
            line=dict(color=color_mapping["Mean of All"]),
            showlegend=False,
        )
        name_to_id.append(f"mean_reference_{y_key}")

        for key, label_name in cleaner.items():

            for val, df in meta_infos.groupby(key):
                mean_curve = np.mean(np.stack(df[y_key].values), 0)

                yield go.Scatter(
                    x=stresses,
                    y=mean_curve,
                    mode="lines",
                    name="Mean",
                    line=dict(color=color_mapping["Mean"]),
                    showlegend=False,
                )
                name_to_id.append(f"mean_{label_name} = {val}_{y_key}")

    for i in gen_means("lc_10_frac"):
        fig.add_trace(
            i,
            row=1,
            col=1,
        )
    for i in gen_means("lc_50_frac"):
        fig.add_trace(
            i,
            row=1,
            col=2,
        )

    fig.update_xaxes(title_text="Environmental Stress", row=1, col=1)
    fig.update_yaxes(
        title_text="Increase of Toxicant Sensitivity", type="log", row=1, col=1
    )
    fig.update_xaxes(title_text="Environmental Stress", row=1, col=2)
    fig.update_yaxes(
        title_text="Increase of Toxicant Sensitivity", type="log", row=1, col=2
    )

    def get_visible(df, name):
        valid = set()
        for n in df.Name.values:
            valid.add(f"scatter_{n}_lc_10_frac")
            valid.add(f"scatter_{n}_lc_50_frac")

        for chem in df.Chemical.unique():
            valid.add(f"color_{chem}")

        valid.add(f"color_Mean")
        valid.add(f"scatter_{n}_lc_10_frac")
        valid.add(f"scatter_{n}_lc_50_frac")
        if len(df) > 1:
            valid.add(f"mean_{name}_lc_10_frac")
            valid.add(f"mean_{name}_lc_50_frac")

        if name != "all":
            valid.add(f"color_Mean of All")
            valid.add(f"mean_reference_lc_10_frac")
            valid.add(f"mean_reference_lc_50_frac")

        for v in valid:
            assert v in name_to_id, f"{v} wrong!"

        return [i in valid for i in name_to_id]

    buttons = [
        dict(
            label="All",
            method="update",
            args=[{"visible": get_visible(meta_infos, "all")}],
        )
    ]
    assert len(get_visible(meta_infos, "all")) == len(fig.data)

    for key, label_name in cleaner.items():

        for val, df in meta_infos.groupby(key):

            name = f"{label_name} = {val}"

            buttons.append(
                dict(
                    label=name,
                    method="update",
                    args=[{"visible": get_visible(df, name)}],
                )
            )

    fig.update_layout(
        updatemenus=[
            dict(
                buttons=buttons,
                direction="down",
            ),
        ],
            annotations=[
    dict(
        x=0.2, y=1.1, xref="paper", yref="paper",
        text="LC 10",
        showarrow=False, font=dict(size=16)
    ),
    dict(
        x=0.8, y=1.1, xref="paper", yref="paper",
        text="LC 50",
        showarrow=False, font=dict(size=16)
    )],
    )
    return fig


fig = gen_different_curves_fig()
fig.write_html("control_imgs/plotly/curves.html")
fig.show()

In [None]:
dfs = []


for path, data, stress_name, stress_series in load_datapoints():
    meta = data.meta

    main_fit, stress_fit, sam_sur, sam_stress, additional_stress = sam_prediction(
        data.main_series,
        stress_series,
        data.meta,
        settings=OLD_STANDARD,
    )

    lcs = get_sam_lcs(stress_fit=stress_fit, sam_sur=sam_sur, meta=data.meta)
    
    main_lc10 = compute_lc(optim_param=main_fit.optim_param, lc=10)
    main_lc50 = compute_lc(optim_param=main_fit.optim_param, lc=50)

    dfs.append(
        {
            "title": path[:-4],
            "days" : meta.days,
            "chemical": meta.chemical,
            "organism": meta.organism,
            "main_fit": main_fit,
            "stress_fit": stress_fit,
            "stress_name": stress_name,
            "main_lc10":main_lc10,
            "main_lc50":main_lc50,
            "stress_lc10" : lcs.stress_lc10,
            "stress_lc50" : lcs.stress_lc50,
            "sam_lc10" : lcs.sam_lc10,
            "sam_lc50" : lcs.sam_lc50,
            "experiment_name" : data.meta.path.parent.name,
            "Name": data.meta.title,
        }
    )

df = pd.DataFrame(dfs)
df.head()

In [156]:
df["true_10_frac"] = df.main_lc10 / df.stress_lc10
df["true_50_frac"] = df.main_lc50 / df.stress_lc50
df["sam_10_frac"] = df.main_lc10 / df.sam_lc10
df["sam_50_frac"] = df.main_lc50 / df.sam_lc50


log_lc10 = [np.log(ar[0] / ar) for ar in lc10]
log_lc50 = [np.log(ar[0] / ar) for ar in lc50]

# Calculate the mean and std in the log-space
log_mean_curve_10 = np.mean(log_lc10, axis=0)
log_std_curve_10 = np.std(log_lc10, axis=0)

log_mean_curve_50 = np.mean(log_lc50, axis=0)
log_std_curve_50 = np.std(log_lc50, axis=0)

# Exponentiate back to the original scale
mean_curve_10 = np.exp(log_mean_curve_10)
upper_curve_10 = np.exp(log_mean_curve_10 + log_std_curve_10)
lower_curve_10 = np.exp(log_mean_curve_10 - log_std_curve_10)

mean_curve_50 = np.exp(log_mean_curve_50)
upper_curve_50 = np.exp(log_mean_curve_50 + log_std_curve_50)
lower_curve_50 = np.exp(log_mean_curve_50 - log_std_curve_50)

df["stress_level"] = df.stress_fit.apply(lambda x: survival_to_stress(x.optim_param["d"]))

y_name = "Increase of Toxicant Sensitiviy"
x_name = "Environmental Stress"

In [None]:
plt.figure(figsize=(10,6))
ax1 = plt.subplot(1,2,1)
plt.title("LC 10")

plt.scatter(df.stress_level, df["true_10_frac"], label = "Measurements")
ax1.plot(stresses, mean_curve_10, color='orange', label='SAM')
ax1.fill_between(stresses, lower_curve_10, upper_curve_10, color='gray', alpha=0.3, label='Log Std Dev')


plt.yscale("log")
plt.xlabel(x_name)
plt.ylabel(y_name)

plt.subplot(1,2,2, sharey = ax1)
plt.title("LC 50")

plt.scatter(df.stress_level, df["true_50_frac"], label = "Measurements")

plt.plot(stresses, mean_curve_50, color='orange', label='SAM')
plt.fill_between(stresses, lower_curve_50, upper_curve_50, color='gray', alpha=0.3, label='Log Std Dev')

plt.yscale("log")
plt.xlabel(x_name)
plt.ylabel(y_name)

plt.legend()
plt.show()

In [None]:
df.head()
cleaned = df.copy()
cleaned.days = cleaned.days.astype(int)
cleaned.head()

In [None]:
def gen_dot_plotly():
    fig = make_subplots(rows=1, cols=2)

    color_mapping = {
        "Mean": "orange",
        "Log Std Dev": "gray",
        "This Mean": "black",
        "Measurements": "blue",
        "Predictions":"green",
    }

    name_to_id = []


    def add(mean_curve, upper, lower, col):

        fig.add_trace(
            go.Scatter(
                x=stresses.tolist() + stresses[::-1].tolist(),
                y=upper.tolist() + lower[::-1].tolist(),
                fill="toself",
                fillcolor=color_mapping["Log Std Dev"],
                opacity=0.3,
                line=dict(color="gray"),
                name="Log Std Dev",
                showlegend=False,
            ),
            row=1,
            col=col,
        )
        fig.add_trace(
            go.Scatter(
                x=stresses,
                y=mean_curve,
                mode="lines",
                line=dict(color=color_mapping["Mean"]),
                name="SAM",
                showlegend=False,
            ),
            row=1,
            col=col,
        )



    add(mean_curve_10, lower_curve_10, upper_curve_10, 1)
    name_to_id.append("band_10")
    name_to_id.append("mean_10")
    add(mean_curve_50, lower_curve_50, upper_curve_50, 2)
    name_to_id.append("band_50")
    name_to_id.append("mean_50")


    def add_points(y_col, col):

        for _, row in cleaned.iterrows():

            fig.add_trace(
                go.Scatter(
                    x=(row.stress_level,),
                    y=(row[y_col],),
                    mode="markers",
                    name=row.Name,
                    hovertext=f"<br><b>Name</b>: {row.Name}<br><b>Experiment</b>: {row.experiment_name} <br><b>Main Stressor</b>: {row.chemical}<br><b>Additional Stressor</b>: {row.stress_name}<br> <b>Duration</b>: {row.days}<br><b>Organism</b>: {row.organism}",
                    showlegend=False,
                    line=dict(color=color_mapping["Measurements"]),
                ),
                row=1,
                col=col,
            )
            name_to_id.append(f"{row.Name}_{y_col}")
            
            fig.add_trace(
                go.Scatter(
                    x=(row.stress_level,),
                    y=(row[y_col.replace("true", "sam")],),
                    mode="markers",
                    name=row.Name,
                    hovertext=f"<br><b>Name</b>: {row.Name}<br><b>Experiment</b>: {row.experiment_name} <br><b>Main Stressor</b>: {row.chemical}<br><b>Additional Stressor</b>: {row.stress_name}<br> <b>Duration</b>: {row.days}<br><b>Organism</b>: {row.organism}",
                    showlegend=False,
                    line=dict(color=color_mapping["Predictions"]),
                ),
                row=1,
                col=col,
            )
            name_to_id.append(f"{row.Name}_{y_col}")
            
            


    add_points("true_10_frac", 1)
    add_points("true_50_frac", 2)

    cleaner = {"days": "Duration", "organism": "Organism", "experiment_name": "Experiment"}


    def add_specific_means(y_col, col):

        for key, label_name in cleaner.items():

            for val, frame in cleaned.groupby(key):

                name = f"{label_name} = {val}_{y_col}"

                names = set(frame.Name.unique())
                spec_col = "lc_10_frac" if "10" in y_col else "lc_50_frac"
                mean_curve = np.mean(
                    np.stack(meta_infos.query("Name in @names")[spec_col].values, 1), axis=1
                )
                fig.add_trace(
                    go.Scatter(
                        x=stresses,
                        y=mean_curve,
                        mode="lines",
                        line=dict(color=color_mapping["This Mean"]),
                        name=f"{label_name} = {val} Mean",
                        hovertext=f"{label_name} = {val} Mean",
                        showlegend=False,
                    ),
                    row=1,
                    col=col,
                )
                name_to_id.append(name)

    add_specific_means("true_10_frac", 1)
    add_specific_means("true_50_frac", 2)

    # color legend
    for name, color in color_mapping.items():
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="lines" if name not in ["Predictions", "Measurements"] else "markers",
                line=dict(color=color),
                name=name,
            )
        )
        name_to_id.append(f"color_{name}")


    def get_visible(df, name):
        valid = {
        "mean_10",
        "band_10",
        "mean_50",
        "band_50",
        "color_Mean",
        "color_Measurements",
        "color_Log Std Dev",
        "color_Predictions"
        }
        
        for n in df.Name.values:
            valid.add(f"{n}_true_10_frac")
            valid.add(f"{n}_true_50_frac")

        if name != "all":
            valid.add(f"color_This Mean")
            valid.add(f"{name}_true_10_frac")
            valid.add(f"{name}_true_50_frac")
            

        for v in valid:
            assert v in name_to_id, f"{v} wrong!"

        return [i in valid for i in name_to_id]

    assert len(get_visible(cleaned, "all")) == len(fig.data)

    buttons = [
        dict(
            label="All",
            method="update",
            args=[{"visible": get_visible(cleaned, "all")}],
        )
    ]

    for key, label_name in cleaner.items():

        for val, frame in cleaned.groupby(key):

            name = f"{label_name} = {val}"

            buttons.append(
                dict(
                    label=name,
                    method="update",
                    args=[{"visible": get_visible(frame, name)}],
                )
            )

    fig.update_layout(
        updatemenus=[
            dict(
                buttons=buttons,
                direction="down",
            ),
        ],
            annotations=[
            dict(
                x=0.2, y=1.1, xref="paper", yref="paper",
                text="LC 10",
                showarrow=False, font=dict(size=16)
            ),
            dict(
                x=0.8, y=1.1, xref="paper", yref="paper",
                text="LC 50",
                showarrow=False, font=dict(size=16)
            )
    ]
    )

    fig.update_xaxes(title_text="Environmental Stress", row=1, col=1)
    fig.update_yaxes(
        title_text="Increase of Toxicant Sensitivity", type="log", row=1, col=1
    )
    fig.update_xaxes(title_text="Environmental Stress", row=1, col=2)
    fig.update_yaxes(
        title_text="Increase of Toxicant Sensitivity", type="log", row=1, col=2
    )
    return fig

dot_fig = gen_dot_plotly()
dot_fig.write_html("control_imgs/plotly/lcs.html", )
dot_fig.show()

In [179]:
from bs4 import BeautifulSoup

# Load the two HTML files
with open("control_imgs/plotly/curves.html", "r") as f:
    curves_html = f.read()

with open("control_imgs/plotly/lcs.html", "r") as f:
    lcs_html = f.read()

# Parse the HTML files with BeautifulSoup
curves_soup = BeautifulSoup(curves_html, "html.parser")
lcs_soup = BeautifulSoup(lcs_html, "html.parser")

# Find the Plotly <script> tags

# Extract the main content (div and script blocks)
curves_div = curves_soup.find("div")  # Assumes plot is in a single main <div>
curves_scripts = curves_soup.find_all("script")[1:]  # Exclude the first script, which is the Plotly JS library

lcs_div = lcs_soup.find("div")  # Same assumption for the second plot
lcs_scripts = lcs_soup.find_all("script")[1:]  # Skip Plotly JS library in the second file

# Create a new HTML structure
combined_html = BeautifulSoup("<html><head></head><body></body></html>", "html.parser")

# Add Plotly JS once in the <head>


combined_html.body.append(combined_html.new_tag("h2", style="text-align:center; font-size:24px;"))
combined_html.body.h2.string = "LCs Plot"
combined_html.body.append(lcs_div)

combined_html.body.append(combined_html.new_tag("h2", style="text-align:center; font-size:24px;"))
combined_html.body.h2.string = "Curves Plot"
combined_html.body.append(curves_div)


# Write the combined HTML to a new file
with open("control_imgs/plotly/combined_plots.html", "w") as f:
    f.write(str(combined_html))



In [None]:
combined = make_subplots()

In [None]:
control_imgs/plotly/curves.html
and control_imgs/plotly/lcs.html
are both plotly plots.
i want you to join then in one after each other.
they should also have titles big

In [None]:
interest  = [
    ("liess_copper_2001", "Food_1% + UV"),
    ("naeem_Esfenvalerate_2019", "Food_1% + Prochloraz_100"),
    ("naeem_Esfenvalerate_2019", "Food_1% + Prochloraz_32"),
    ("naeem_Esfenvalerate_2024", "Food_1% + Temp_25"),
    ("naeem_Mix13_2024", "Temp_25 + Food_1%")
]

mask = df.apply(lambda x: any(path in x.title and stress in x.stress_name for path, stress in interest), axis = 1)

In [None]:
plt.figure(figsize=(10,6))
ax1 = plt.subplot(1,2,1)
plt.title("LC 10")


for days, sub in df.groupby("days"):


    plt.scatter(sub.stress_level, sub["true_10_frac"], label = "{} - Days".format(int(days)))


ax1.plot(stresses, mean_curve_10, color='orange', label='SAM')
ax1.fill_between(stresses, lower_curve_10, upper_curve_10, color='gray', alpha=0.3, label='Log Std Dev')


plt.yscale("log")
plt.xlabel(x_name)
plt.ylabel(y_name)

plt.subplot(1,2,2, sharey = ax1)
plt.title("LC 50")

for days, sub in df.groupby("days"):


    plt.scatter(sub.stress_level, sub["true_50_frac"], label = "{} - Days".format(int(days)))

plt.plot(stresses, mean_curve_50, color='orange', label='SAM')
plt.fill_between(stresses, lower_curve_50, upper_curve_50, color='gray', alpha=0.3, label='Log Std Dev')

plt.yscale("log")
plt.xlabel(x_name)
plt.ylabel(y_name)

plt.legend()
plt.show()

In [None]:
df["true_10_frac"] = df.main_lc10 / df.stress_lc10
df["true_50_frac"] = df.main_lc50 / df.stress_lc50


log_lc10 = [np.log(ar[0] / ar) for ar in lc10]
log_lc50 = [np.log(ar[0] / ar) for ar in lc50]

# Calculate the mean and std in the log-space
log_mean_curve_10 = np.mean(log_lc10, axis=0)
log_std_curve_10 = np.std(log_lc10, axis=0)

log_mean_curve_50 = np.mean(log_lc50, axis=0)
log_std_curve_50 = np.std(log_lc50, axis=0)

# Exponentiate back to the original scale
mean_curve_10 = np.exp(log_mean_curve_10)
upper_curve_10 = np.exp(log_mean_curve_10 + log_std_curve_10)
lower_curve_10 = np.exp(log_mean_curve_10 - log_std_curve_10)

mean_curve_50 = np.exp(log_mean_curve_50)
upper_curve_50 = np.exp(log_mean_curve_50 + log_std_curve_50)
lower_curve_50 = np.exp(log_mean_curve_50 - log_std_curve_50)



df["stress_level"] = df.stress_fit.apply(lambda x: survival_to_stress(x.optim_param["d"]))


y_name = "Increase of Toxicant Sensitiviy"
x_name = "Environmental Stress"

plt.figure(figsize=(10,6))
ax1 = plt.subplot(1,2,1)
plt.title("LC 10")

color = ["blue" if not i else "red" for i in mask]

plt.scatter(df.stress_level, df["true_10_frac"], label = "Measurements", color = color)


ax1.plot(stresses, mean_curve_10, color='orange', label='SAM')
ax1.fill_between(stresses, lower_curve_10, upper_curve_10, color='gray', alpha=0.3, label='Log Std Dev')


plt.yscale("log")
plt.xlabel(x_name)
plt.ylabel(y_name)

plt.subplot(1,2,2, sharey = ax1)
plt.title("LC 50")

plt.scatter(df.stress_level, df["true_50_frac"], label = "Measurements", color = color)

plt.plot(stresses, mean_curve_50, color='orange', label='SAM')
plt.fill_between(stresses, lower_curve_50, upper_curve_50, color='gray', alpha=0.3, label='Log Std Dev')

plt.yscale("log")
plt.xlabel(x_name)
plt.ylabel(y_name)

plt.legend()
plt.show()