In [None]:
import argparse
import sys

import an_cockrell
import matplotlib.pyplot as plt
import numpy as np
from an_cockrell import AnCockrellModel, epitype_one_hot_encoding
from scipy.stats import multivariate_normal
from tqdm import trange

from consts import (
    UNIFIED_STATE_SPACE_DIMENSION,
    default_params,
    init_only_params,
    state_var_indices,
    state_vars,
    variational_params,
)
from modify_epi_spatial import modify_model
from util import model_macro_data

In [None]:
################################################################################

if hasattr(sys, "ps1"):
    # interactive mode
    args = object()
else:
    parser = argparse.ArgumentParser()

    parser.add_argument("--prefix", type=str, default="", help="output file prefix")

    parser.add_argument(
        "--measurements",
        type=str,
        choices=[
            "T1IFN",
            "TNF",
            "IFNg",
            "IL6",
            "IL1",
            "IL8",
            "IL10",
            "IL12",
            "IL18",
            "extracellular_virus",
        ],
        nargs="+",
        required=True,
        help="which things to measure (required)",
    )

    parser.add_argument(
        "--matchmaker",
        help="try to match resampled macrostates with microstate "
        "models to minimize change magnitude",
        type=str,
        choices=["yes", "no"],
        required=True,
    )

    parser.add_argument("--graphs", help="make pdf graphs", action="store_true")

    args = parser.parse_args()

VERBOSE = False


################################################################################
# constants


# layout for graphing state variables.
# Attempts to be mostly square, with possibly more rows than columns
state_var_graphs_cols: int = int(np.floor(np.sqrt(len(state_vars))))
state_var_graphs_rows: int = int(np.ceil(len(state_vars) / state_var_graphs_cols))
state_var_graphs_figsize = (1.8 * state_var_graphs_rows, 1.8 * state_var_graphs_cols)

# layout for graphing parameters.
# Attempts to be mostly square, with possibly more rows than columns
variational_params_graphs_cols: int = int(np.floor(np.sqrt(len(variational_params))))
variational_params_graphs_rows: int = int(
    np.ceil(len(variational_params) / variational_params_graphs_cols)
)
variational_params_graphs_figsize = (
    1.8 * variational_params_graphs_rows,
    1.8 * variational_params_graphs_cols,
)

assert all(param in default_params for param in variational_params)

TIME_SPAN = 2016
SAMPLE_INTERVAL = 48  # how often to make measurements
ENSEMBLE_SIZE = (
    (UNIFIED_STATE_SPACE_DIMENSION + 1) * UNIFIED_STATE_SPACE_DIMENSION // 2
)  # max(50, (UNIFIED_STATE_SPACE_DIMENSION + 1))
OBSERVABLES = (
    ["extracellular_virus"] if not hasattr(args, "measurements") else args.measurements
)
OBSERVABLE_VAR_NAMES = ["total_" + name for name in OBSERVABLES]

RESAMPLE_MODELS = False

# if we are altering the models (as opposed to resampling) try to match the
# models to minimize the changes necessary.
MODEL_MATCHMAKER = (
    True if not hasattr(args, "matchmaker") else (args.matchmaker == "yes")
)

# have the models' parameters do a random walk over time (should help
# with covariance starvation)
PARAMETER_RANDOM_WALK = True

FILE_PREFIX = "" if not hasattr(args, "prefix") else args.prefix + "-"

GRAPHS = True if not hasattr(args, "graphs") else bool(args.graphs)

################################################################################
# statistical parameters

init_mean_vec = np.array(
    [default_params[param] for param in (init_only_params + variational_params)]
)

init_cov_matrix = np.diag(
    np.array(
        [
            0.75 * np.sqrt(default_params[param])
            for param in (init_only_params + variational_params)
        ]
    )
)

In [None]:
################################################################################
# sample a virtual patient

# sampled virtual patient parameters
init_params = default_params.copy()
init_param_sample = np.abs(
    multivariate_normal(mean=init_mean_vec, cov=init_cov_matrix).rvs()
)
for sample_component, param_name in zip(
    init_param_sample,
    (init_only_params + variational_params),
):
    init_params[param_name] = (
        round(sample_component)
        if isinstance(default_params[param_name], int)
        else sample_component
    )

# create model
rng = np.random.default_rng(seed=0)
model: AnCockrellModel = an_cockrell.AnCockrellModel(rng=rng, **init_params)

for _ in trange(400):
    model.time_step()

In [None]:
# fig = plt.figure()
# ax = fig.gca()
# model.plot_agents(ax)

macro_state = model_macro_data(model)

print(
    "\n".join(
        map(
            lambda a: str(a[0]) + " " + str(a[1]),
            zip(
                ["empty", "healthy", "infected", "dead", "apoptosed"],
                macro_state[
                    [
                        state_var_indices["empty_epithelium_count"],
                        state_var_indices["healthy_epithelium_count"],
                        state_var_indices["infected_epithelium_count"],
                        state_var_indices["dead_epithelium_count"],
                        state_var_indices["apoptosed_epithelium_count"],
                    ]
                ],
            ),
        )
    )
)

In [None]:
new_macro_state = macro_state.copy()
new_macro_state[state_var_indices["healthy_epithelium_count"]] += 300
new_macro_state[state_var_indices["infected_epithelium_count"]] -= 300

In [None]:
from modify_epi_spatial import dither, quantizer
from util import cmap

updated_epithelium, rescaled_state_vecs_copy = dither(
    model,
    new_macro_state[
        [
            state_var_indices["empty_epithelium_count"],
            state_var_indices["healthy_epithelium_count"],
            state_var_indices["infected_epithelium_count"],
            state_var_indices["dead_epithelium_count"],
            state_var_indices["apoptosed_epithelium_count"],
        ]
    ],
    rescaled_state_vecs=True,
)

fig, axs = plt.subplots(1, 3)
axs[0].imshow(
    model.epithelium.astype(int),
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
axs[1].imshow(
    updated_epithelium.astype(int),
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)
axs[2].imshow(
    model.epithelium.astype(int) - updated_epithelium.astype(int), vmin=-4, vmax=4
)

print(
    "number of changed epis",
    np.sum(model.epithelium.astype(int) != updated_epithelium.astype(int)),
)

In [None]:
import matplotlib.gridspec as gridspec

plt.rcParams["text.usetex"] = True

epi_names = ["Empty", "Healthy", "Infect.", "Nec.", "Apop."]

fig = plt.figure(figsize=(6.5, 4), constrained_layout=True)
gs_root = gridspec.GridSpec(nrows=1, ncols=4, figure=fig, width_ratios=[1.5, 2, 1, 5])


gs_delta = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec=gs_root[0])
ax_delta = fig.add_subplot(gs_delta[0, 0])
ax_delta.set_title(r"$\Delta$")
ax_delta.imshow(
    model.epithelium.astype(int) - updated_epithelium.astype(int), vmin=-4, vmax=4
)
ax_delta.set_axis_off()


gs_catgy = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs_root[1], wspace=100)
ax_catgy = [fig.add_subplot(gs_catgy[idx, 0]) for idx in range(2)]

ax_catgy[0].set_title("Original\nEpithelial State")
ax_catgy[0].imshow(
    model.epithelium.astype(int),
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)

ax_catgy[1].set_title("Updated\nEpithelial State")
ax_catgy[1].imshow(
    updated_epithelium.astype(int),
    vmin=0,
    vmax=4,
    interpolation="nearest",
    cmap=cmap,
)

for idx in range(2):
    ax_catgy[idx].set_axis_off()


gs_cmpnt = gridspec.GridSpecFromSubplotSpec(3, 5, subplot_spec=gs_root[3])
ax_cmpnt = np.array(
    [
        [fig.add_subplot(gs_cmpnt[r_idx, c_idx]) for c_idx in range(5)]
        for r_idx in range(3)
    ]
)

orig_epithelium = epitype_one_hot_encoding(model.epithelium)
for idx in range(5):
    ax_cmpnt[0, idx].set_title(epi_names[idx])
    ax_cmpnt[0, idx].imshow(orig_epithelium[:, :, idx], vmin=0.0, vmax=2.0)
    ax_cmpnt[0, idx].set_axis_off()


for idx in range(5):
    ax_cmpnt[1, idx].set_title(epi_names[idx])
    ax_cmpnt[1, idx].imshow(rescaled_state_vecs_copy[:, :, idx], vmin=0.0, vmax=2.0)
    ax_cmpnt[1, idx].set_axis_off()


new_epithelium = epitype_one_hot_encoding(updated_epithelium)
for idx in range(5):
    ax_cmpnt[2, idx].set_title(epi_names[idx])
    ax_cmpnt[2, idx].imshow(new_epithelium[:, :, idx], vmin=0.0, vmax=2.0)
    ax_cmpnt[2, idx].set_axis_off()

In [None]:
fig.savefig("microstate-mod-viz.pdf", bbox_inches="tight")

In [None]:
plt.imshow(model.epithelium.astype(int) != updated_epithelium.astype(int))