In [13]:
import os
from sam import chdir_to_repopath

chdir_to_repopath()
from sam.dose_reponse_fit import (
    dose_response_fit,
    ModelPredictions,
    survival_to_stress,
    FitSettings,
    Transforms,
)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from sam.data_formats import read_data, load_files
from sam.plotting import plot_fit_prediction
from sam.system_stress import pred_surv_without_hormesis
from sam.helpers import (
    detect_hormesis_index,
    pad_c0,
    compute_lc_from_curve,
    weibull_2param,
    weibull_2param_inverse,
)
from sam.data_formats import ExperimentData

In [14]:
def predict_cleaned_curv(data: ExperimentData):
    concentration = pad_c0(data.main_series.concentration).copy()
    survival_tox_observerd = np.copy(
        data.main_series.survival_rate / data.meta.max_survival
    )

    if data.meta.hormesis_concentration is None:
        hormesis_index = detect_hormesis_index(survival_tox_observerd)

        if hormesis_index is None:
            hormesis_index = 1

    else:
        hormesis_index = np.argwhere(
            data.meta.hormesis_concentration == data.main_series.concentration
        )[0, 0]

    func, _, popt = pred_surv_without_hormesis(
        concentration=concentration,
        surv_withhormesis=survival_tox_observerd,
        hormesis_index=hormesis_index,
    )

    return func, hormesis_index, popt

In [15]:
for path, data in load_files():
    meta = data.meta
    res: ModelPredictions = dose_response_fit(
        data.main_series, FitSettings(param_d_norm=True, survival_max=meta.max_survival)
    )

    cleaned_func, hormesis_index, popt = predict_cleaned_curv(data)

    inverse = lambda x: weibull_2param_inverse(x, *popt)

    def find_lc(lc):
        lc = 1 - lc / 100
        return inverse(lc)

    lc1 = find_lc(1)
    lc99 = find_lc(99)

    title = os.path.split(path)[-1]

    color = [
        "blue" if i != hormesis_index else "red"
        for i in range(len(data.main_series.concentration))
    ]

    plt.scatter(
        pad_c0(data.main_series.concentration),
        data.main_series.survival_rate,
        label="orig",
        color=color,
    )
    plt.plot(res.concentration_curve, res.survival_curve, label="with hormesis")
    plt.plot(
        res.concentration_curve,
        cleaned_func(res.concentration_curve) * meta.max_survival,
        label="with out",
    )
    plt.axvline(lc1, 0, 1, color="red", ls="--")
    plt.axvline(lc99, 0, 1, color="red", ls="--")
    plt.legend()
    plt.title(title)
    plt.xscale("log")
    plt.savefig(f"control_imgs/cleaned_curves/{title.replace('.xlsx', '.png')}")
    plt.close()

In [None]:
dfs = []

for path, data in load_files():
    meta = data.meta
    res: ModelPredictions = dose_response_fit(
        data.main_series, FitSettings(param_d_norm=True, survival_max=meta.max_survival)
    )

    cleaned_func, _, popt = predict_cleaned_curv(data)

    inverse = lambda x: weibull_2param_inverse(x, *popt)

    def find_lc(lc):
        lc = 1 - lc / 100
        return inverse(lc)

    lc1 = find_lc(1)
    lc99 = find_lc(99)

    title = f"{meta.chemical} - {meta.organism}"

    dfs.append(
        {
            "title": os.path.split(path[:-5])[1],
            "chemical": meta.chemical,
            "Organism": meta.organism,
            "model": res,
            "cleaned_func": cleaned_func,
            "lc1": lc1,
            "lc99": lc99,
            "Name": meta.title,
            "Duration": int(meta.days),
            "Experiment": meta.path.parent.name,
        }
    )

df = pd.DataFrame(dfs)
df.head()

In [23]:
def compute_normalised_curve(model: ModelPredictions):
    if np.isnan(model.lc1):
        print("nan")
        model.lc1 = 0.0
    x = np.linspace(model.lc1, model.lc99, 1000)

    return model.model(x) * 100


df["normed_curves"] = df.model.apply(compute_normalised_curve)

df["stress"] = df.normed_curves.apply(lambda x: survival_to_stress(x / 100))

In [24]:
def compute_cleaned_curve(row):
    x = np.linspace(row.lc1, row.lc99, 1000)

    return row.cleaned_func(x) * 100


df["cleaned_curves"] = df.apply(compute_cleaned_curve, axis=1)

df["cleaned_stress"] = df.cleaned_curves.apply(lambda x: survival_to_stress(x / 100))

In [64]:
chemicals = df["chemical"].unique()
color_map = {
    chemical: color
    for chemical, color in zip(
        chemicals, sns.color_palette("Set2", len(chemicals)).as_hex()
    )
}
color_map["Mean"] = "black"
color_map["Selection Mean"] = "red"

In [65]:
def plot_curves(chemicals):
    global curves_df
    # Plot average survival and stress curves with bands
    fig, axs = plt.subplots(1, 2, figsize=(14, 6))

    for chemical in chemicals:
        survival_columns = [
            col for col in curves_df.columns if f"{chemical}_survival" in col
        ]
        stress_columns = [
            col for col in curves_df.columns if f"{chemical}_stress" in col
        ]

        for _, col in curves_df[survival_columns].items():
            axs[0].plot(
                curves_df["dose"], col, label=chemical, color=color_map[chemical]
            )

        for _, col in curves_df[stress_columns].items():
            axs[1].plot(
                curves_df["dose"], col, label=chemical, color=color_map[chemical]
            )

    # Set titles and labels
    axs[0].set_title("Survival Curves with Bands")
    axs[0].set_xlabel("LC")
    axs[0].set_ylabel("Survival Rate")
    axs[0].set_xscale("log")

    axs[1].set_title("Stress Curves with Bands")
    axs[1].set_xlabel("LC")
    axs[1].set_ylabel("Stress")
    axs[1].set_xscale("log")

    legend_elements = [
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor=color,
            markersize=10,
            label=chemical,
        )
        for chemical, color in color_map.items()
    ]

    axs[1].legend(handles=legend_elements, title="Chemicals")

    plt.tight_layout()
    plt.show()

In [None]:
df.head()

In [76]:
def make_fig(surv_col, stres_col):
    name_to_id = []

    x = np.linspace(1, 99, 1000)
    fig = make_subplots(rows=1, cols=2)

    def gen_traces(y_key, col):
        for _, row in df.iterrows():
            fig.add_trace(
                go.Scatter(
                    x=x,
                    y=row[y_key],
                    mode="lines",
                    name=row.Name,
                    line=dict(color=color_map[row.chemical]),
                    hovertext=f"<br><b>Name</b>: {row.Name}<br><b>Experiment</b>: {row.Experiment}<br><b>Duration</b>: {row.Duration}<br><b>Main Stressor</b>: {row.chemical}<br><b>Organism</b>: {row.Organism}",
                    showlegend=False,
                ),
                col=col,
                row=1,
            )

            name_to_id.append(f"line_{row.Name}_{y_key}")

    def add_means(df: pd.DataFrame, name):
        mean_curve = np.mean(np.stack(df[surv_col].values), axis=0)
        mean_stress = survival_to_stress(mean_curve / 100)
        key = "Mean" if name == "Mean" else "Selection Mean"

        fig.add_trace(
            go.Scatter(
                x=x,
                y=mean_curve,
                mode="lines",
                name=name,
                line=dict(color=color_map[key]),
                showlegend=False,
                opacity=0.7 if key != "Mean" else 1,
            ),
            col=1,
            row=1,
        )

        name_to_id.append(f"mean_{name}_surv")

        fig.add_trace(
            go.Scatter(
                x=x,
                y=mean_stress,
                mode="lines",
                name=name,
                line=dict(color=color_map[key]),
                showlegend=False,
                opacity=0.7 if key != "Mean" else 1,
            ),
            col=2,
            row=1,
        )

        name_to_id.append(f"mean_{name}_stress")

    add_means(df, "Mean")

    cleaner = {
        "chemical": "Main Stressor",
        "Experiment": "Experiment",
        "Duration": "Duration",
        "Organism": "Organism",
    }

    for key, label_name in cleaner.items():
        for val, frame in df.groupby(key):
            name = f"{label_name} = {val}"
            add_means(frame, name)
    gen_traces(surv_col, 1)
    gen_traces(stres_col, 2)

    for chemical, color in color_map.items():
        fig.add_trace(
            go.Scatter(
                x=[None],
                y=[None],
                mode="lines" if chemical in ["Mean", "Selection Mean"] else "markers",
                line=dict(color=color),
                name=chemical,
            )
        )
        name_to_id.append(f"color_{chemical}")

    def visible(df, name):
        valid = {"mean_Mean_stress", "mean_Mean_surv", "color_Mean"}

        for n in df.Name.values:
            valid.add(f"line_{n}_{surv_col}")
            valid.add(f"line_{n}_{stres_col}")

        for chem in df.chemical.unique():
            valid.add(f"color_{chem}")

        if len(df) > 1 and name != "All":
            valid.add(f"mean_{name}_surv")
            valid.add(f"mean_{name}_stress")
            valid.add("color_Selection Mean")

        for v in valid:
            assert v in name_to_id, f"{v} wrong!"

        return [i in valid for i in name_to_id]

    buttons = [
        dict(
            label="All",
            method="update",
            args=[{"visible": visible(df, "All")}],
        )
    ]
    assert len(visible(df, "All")) == len(fig.data)

    for key, label_name in cleaner.items():
        for val, frame in df.groupby(key):
            name = f"{label_name} = {val}"
            buttons.append(
                dict(
                    label=name,
                    method="update",
                    args=[{"visible": visible(frame, name)}],
                )
            )
    fig.update_yaxes(title_text="Survival Rate", row=1, col=1)
    fig.update_xaxes(title_text="LC", type="log", row=1, col=1)
    fig.update_yaxes(title_text="Stress", row=1, col=2)
    fig.update_xaxes(title_text="LC", type="log", row=1, col=2)

    fig.update_layout(
        updatemenus=[
            dict(
                buttons=buttons,
                direction="down",
            ),
        ],
        annotations=[
            dict(
                x=0.2,
                y=1.1,
                xref="paper",
                yref="paper",
                text="Survival",
                showarrow=False,
                font=dict(size=16),
            ),
            dict(
                x=0.8,
                y=1.1,
                xref="paper",
                yref="paper",
                text="Stress",
                showarrow=False,
                font=dict(size=16),
            ),
        ],
    )
    return fig

In [78]:
make_fig("cleaned_curves", "cleaned_stress").write_html(
    "control_imgs/plotly/cleaned_dosecurves.html"
)
make_fig("normed_curves", "stress").write_html(
    "control_imgs/plotly/raw_dosecurves.html"
)

In [8]:
def graphic(surv_col, stres_col):
    fig, axs = plt.subplots(1, 2, figsize=(14, 6))

    x = np.linspace(1, 99, 1000)

    for _, row in df.iterrows():
        axs[0].plot(x, row[surv_col], label=row.chemical, color=color_map[row.chemical])

    for _, row in df.iterrows():
        axs[1].plot(
            x, row[stres_col], label=row.chemical, color=color_map[row.chemical]
        )

    mean_curve = np.mean(np.stack(df[surv_col].values), axis=0)
    mean_stress = survival_to_stress(mean_curve / 100)

    axs[0].plot(x, mean_curve, label="Mean", color=color_map["Mean"], linewidth=3)
    axs[1].plot(x, mean_stress, label="Mean", color=color_map["Mean"], linewidth=3)

    legend_elements = [
        plt.Line2D(
            [0],
            [0],
            marker="o",
            color="w",
            markerfacecolor=color,
            markersize=10,
            label=chemical,
        )
        for chemical, color in color_map.items()
    ]

    axs[1].legend(handles=legend_elements, title="Chemicals")

    axs[0].set_title("Survival Curves with Bands")
    axs[0].set_xlabel("LC")
    axs[0].set_ylabel("Survival Rate")
    axs[0].set_xscale("log")

    axs[1].set_title("Stress Curves with Bands")
    axs[1].set_xlabel("LC")
    axs[1].set_ylabel("Stress")
    axs[1].set_xscale("log")

    plt.tight_layout()
    plt.show()

In [None]:
graphic("cleaned_curves", "cleaned_stress")

In [None]:
graphic("normed_curves", "stress")

In [34]:
def plot_graphic_text(df, with_mean=False):
    surv_col = "cleaned_curves"
    stres_col = "cleaned_stress"

    fig, axs = plt.subplots(1, 2, figsize=(14, 6))

    x = np.linspace(1, 99, 1000)
    for _, row in df.iterrows():
        axs[0].plot(x, row[surv_col], label=row.title)

    for _, row in df.iterrows():
        axs[1].plot(x, row[stres_col], label=row.title)

    if with_mean:
        mean_curve = np.mean(np.stack(df[surv_col].values), axis=0)
        mean_stress = survival_to_stress(mean_curve / 100)

        axs[0].plot(x, mean_curve, label="Mean", color=color_map["Mean"], linewidth=3)
        axs[1].plot(x, mean_stress, label="Mean", color=color_map["Mean"], linewidth=3)

    axs[1].legend()
    axs[0].set_title("Survival Curves with Bands")
    axs[0].set_xlabel("LC")
    axs[0].set_ylabel("Survival Rate")
    axs[0].set_xscale("log")

    axs[1].set_title("Stress Curves with Bands")
    axs[1].set_xlabel("LC")
    axs[1].set_ylabel("Stress")
    axs[1].set_xscale("log")

    plt.tight_layout()
    plt.show()

In [35]:
mean_curve = np.mean(np.stack(df["cleaned_curves"].values), axis=0)

ar = (
    df.cleaned_curves.apply(lambda x: (x > mean_curve).mean())
    .sort_values()
    .index.values
)

In [None]:
drop_index = ar[-6:]

print(df.iloc[drop_index].title.values.tolist())

used = df.drop(index=drop_index)
plot_graphic_text(used, with_mean=True)

mean_curve = np.mean(np.stack(used["cleaned_curves"].values), axis=0)

In [None]:
mean_curve = np.mean(np.stack(df["cleaned_curves"].values), axis=0)

mask = df.cleaned_curves.apply(lambda x: (x > mean_curve).mean() > 0.5).values

bigger = df.iloc[mask]
plot_graphic_text(bigger)