In [None]:
import argparse
import sys

import an_cockrell
import matplotlib.pyplot as plt
import numpy as np
from an_cockrell import AnCockrellModel, epitype_one_hot_encoding
from scipy.stats import multivariate_normal
from tqdm import trange

from consts import (
    UNIFIED_STATE_SPACE_DIMENSION,
    default_params,
    init_only_params,
    state_var_indices,
    state_vars,
    variational_params,
)
from modify_epi_spatial import modify_model as modify_model_spatial
from modify_simple import modify_model as modify_model_simple
from util import model_macro_data

In [None]:
################################################################################

if hasattr(sys, "ps1"):
    # interactive mode
    args = object()
else:
    parser = argparse.ArgumentParser()

    parser.add_argument("--prefix", type=str, default="", help="output file prefix")

    parser.add_argument(
        "--measurements",
        type=str,
        choices=[
            "T1IFN",
            "TNF",
            "IFNg",
            "IL6",
            "IL1",
            "IL8",
            "IL10",
            "IL12",
            "IL18",
            "extracellular_virus",
        ],
        nargs="+",
        required=True,
        help="which things to measure (required)",
    )

    parser.add_argument(
        "--matchmaker",
        help="try to match resampled macrostates with microstate "
        "models to minimize change magnitude",
        type=str,
        choices=["yes", "no"],
        required=True,
    )

    parser.add_argument("--graphs", help="make pdf graphs", action="store_true")

    args = parser.parse_args()

VERBOSE = False


################################################################################
# constants


# layout for graphing state variables.
# Attempts to be mostly square, with possibly more rows than columns
state_var_graphs_cols: int = int(np.floor(np.sqrt(len(state_vars))))
state_var_graphs_rows: int = int(np.ceil(len(state_vars) / state_var_graphs_cols))
state_var_graphs_figsize = (1.8 * state_var_graphs_rows, 1.8 * state_var_graphs_cols)

# layout for graphing parameters.
# Attempts to be mostly square, with possibly more rows than columns
variational_params_graphs_cols: int = int(np.floor(np.sqrt(len(variational_params))))
variational_params_graphs_rows: int = int(
    np.ceil(len(variational_params) / variational_params_graphs_cols)
)
variational_params_graphs_figsize = (
    1.8 * variational_params_graphs_rows,
    1.8 * variational_params_graphs_cols,
)

assert all(param in default_params for param in variational_params)

TIME_SPAN = 2016
SAMPLE_INTERVAL = 48  # how often to make measurements
ENSEMBLE_SIZE = (
    (UNIFIED_STATE_SPACE_DIMENSION + 1) * UNIFIED_STATE_SPACE_DIMENSION // 2
)  # max(50, (UNIFIED_STATE_SPACE_DIMENSION + 1))
OBSERVABLES = (
    ["extracellular_virus"] if not hasattr(args, "measurements") else args.measurements
)
OBSERVABLE_VAR_NAMES = ["total_" + name for name in OBSERVABLES]

RESAMPLE_MODELS = False

# if we are altering the models (as opposed to resampling) try to match the
# models to minimize the changes necessary.
MODEL_MATCHMAKER = (
    True if not hasattr(args, "matchmaker") else (args.matchmaker == "yes")
)

# have the models' parameters do a random walk over time (should help
# with covariance starvation)
PARAMETER_RANDOM_WALK = True

FILE_PREFIX = "" if not hasattr(args, "prefix") else args.prefix + "-"

GRAPHS = True if not hasattr(args, "graphs") else bool(args.graphs)

################################################################################
# statistical parameters

init_mean_vec = np.array(
    [default_params[param] for param in (init_only_params + variational_params)]
)

init_cov_matrix = np.diag(
    np.array(
        [
            0.75 * np.sqrt(default_params[param])
            for param in (init_only_params + variational_params)
        ]
    )
)

In [None]:
################################################################################
# sample a virtual patient

# sampled virtual patient parameters
init_params = default_params.copy()
init_param_sample = np.abs(
    multivariate_normal(mean=init_mean_vec, cov=init_cov_matrix).rvs()
)
for sample_component, param_name in zip(
    init_param_sample,
    (init_only_params + variational_params),
):
    init_params[param_name] = (
        round(sample_component)
        if isinstance(default_params[param_name], int)
        else sample_component
    )

# create model
model: AnCockrellModel = an_cockrell.AnCockrellModel(**init_params)

for _ in trange(400):
    model.time_step()

In [None]:
# fig = plt.figure()
# ax = fig.gca()
# model.plot_agents(ax)

macro_state = model_macro_data(model)

print(
    "\n".join(
        map(
            lambda a: str(a[0]) + " " + str(a[1]),
            zip(
                ["empty", "healthy", "infected", "dead", "apoptosed"],
                macro_state[
                    [
                        state_var_indices["empty_epithelium_count"],
                        state_var_indices["healthy_epithelium_count"],
                        state_var_indices["infected_epithelium_count"],
                        state_var_indices["dead_epithelium_count"],
                        state_var_indices["apoptosed_epithelium_count"],
                    ]
                ],
            ),
        )
    )
)

In [None]:
new_macro_state = macro_state.copy()
new_macro_state[state_var_indices["healthy_epithelium_count"]] += 300
new_macro_state[state_var_indices["infected_epithelium_count"]] -= 300

In [None]:
from copy import deepcopy

orig_epithelium = np.array(model.epithelium.astype(int))

simple_mod_model = deepcopy(model)

modify_model_simple(
    simple_mod_model,
    desired_state=new_macro_state,
    verbose=VERBOSE,
    state_var_indices=state_var_indices,
    state_vars=state_vars,
    variational_params=variational_params,
)

simple_mod_epithelium = np.array(simple_mod_model.epithelium.astype(int))

spatial_mod_model = deepcopy(model)

modify_model_spatial(
    spatial_mod_model,
    desired_state=new_macro_state,
    verbose=VERBOSE,
    state_var_indices=state_var_indices,
    state_vars=state_vars,
    variational_params=variational_params,
)

spatial_mod_epithelium = np.array(spatial_mod_model.epithelium.astype(int))

# print(
#     "number of changed epis",
#     np.sum(model.epithelium.astype(int) != updated_epithelium.astype(int)),
# )

for _ in trange(20):
    model.time_step()

for _ in trange(20):
    simple_mod_model.time_step()

for _ in trange(20):
    spatial_mod_model.time_step()

In [None]:
import matplotlib.gridspec as gridspec

plt.rcParams["text.usetex"] = True

from util import cmap

fig = plt.figure(figsize=(6.5, 4), constrained_layout=True)
gs_root = gridspec.GridSpec(nrows=1, ncols=3, figure=fig, width_ratios=[1, 1, 1])


# starting point
gs_orig = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs_root[0], height_ratios=[1, 1, 1])
ax_orig = [fig.add_subplot(gs_orig[idx, 0]) for idx in range(3)]

ax_orig[2].imshow(
    orig_epithelium,
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
ax_orig[2].set_title("Original")
for idx in range(3): ax_orig[idx].set_axis_off()


# macrostate updated
gs_upd = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs_root[1], wspace=100, height_ratios=[1, 1, 1])
ax_upd = [fig.add_subplot(gs_upd[idx, 0]) for idx in range(3)]

ax_upd[0].imshow(
    simple_mod_epithelium,
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
ax_upd[0].set_title("Simple update")
ax_upd[1].imshow(
    spatial_mod_epithelium,
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
ax_upd[1].set_title("Spatial update")
for idx in range(3): ax_upd[idx].set_axis_off()



# future
gs_fut = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=gs_root[2], wspace=100, height_ratios=[1, 1, 1])
ax_fut = [fig.add_subplot(gs_fut[idx, 0]) for idx in range(3)]

ax_fut[0].imshow(
    simple_mod_model.epithelium.astype(int),
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
ax_fut[0].set_title(r"Simple update")
ax_fut[1].imshow(
    spatial_mod_model.epithelium.astype(int),
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
ax_fut[1].set_title(r"Spatial update")
ax_fut[2].imshow(
    model.epithelium.astype(int),
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
ax_fut[2].set_title(r"Original")
for idx in range(3): ax_fut[idx].set_axis_off()

In [None]:
fig.savefig("microstate-mod-comparison.pdf", bbox_inches="tight")

In [None]:
plt.rcParams["text.usetex"] = True

from util import cmap
import itertools 

fig, axs = plt.subplots(3,3, figsize=(6, 4))#, constrained_layout=True)

for r, c in itertools.product(range(3), repeat=2):
    axs[r,c].set_axis_off()

# starting point
axs[2,0].imshow(
    orig_epithelium,
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
axs[2,0].set_title("Orig Prediction")

# macrostate updated
axs[0,1].imshow(
    simple_mod_epithelium,
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
axs[0,1].set_title("Simple update")
axs[1,1].imshow(
    spatial_mod_epithelium,
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
axs[1,1].set_title("Spatial update")

# future
axs[0,2].imshow(
    simple_mod_model.epithelium.astype(int),
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
axs[0,2].set_title(r"Simple upd. $t+20$")
axs[1,2].imshow(
    spatial_mod_model.epithelium.astype(int),
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
axs[1,2].set_title(r"Spatial upd. $t+20$")
axs[2,2].imshow(
    model.epithelium.astype(int),
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
axs[2,2].set_title(r"Orig $t+20$")

fig.tight_layout()

In [None]:
fig.savefig("microstate-mod-comparison.pdf", bbox_inches="tight", pad_inches=0.0)