# Make simple MDS layout plots of escape maps using just selected sites

The needed Python modules are in the `simple_mds` environment found in [simple_mds_environment.yml](simple_mds_environment.yml):

In [1]:
import itertools

import altair as alt

import numpy

import pandas as pd

import sklearn.manifold

In [2]:
# read data, retain just for site escape values
data = (
    pd.concat(
        [
            pd.read_csv(f"{strain}_escape_df_filt_sites.csv").assign(strain=strain)
            for strain in ["hk19", "perth09"]
        ]
    )
    [["strain", "site", "serum", "cohort", "site_escape_sum"]]
    .drop_duplicates()
)

assert data["serum"].nunique() == len(data.groupby(["serum", "strain"]))

serum_to_cohort = data.set_index("serum")["cohort"].to_dict()

for strain, strain_df in data.groupby("strain"):
    # compute dissimilarity between sera, which we simply define as 1 minus
    # the dot product of the site escape vectors after normalizing so the norm
    # of each vector is one
    sera = strain_df["serum"].unique()
    pivoted_normed_data = (
        data
        .pivot_table(index="site", columns="serum", values="site_escape_sum")
        # for normalization: https://stackoverflow.com/a/58113206
        # to get norm: https://stackoverflow.com/a/47953601
        .transform(lambda x: x / numpy.linalg.norm(x, axis=0))
    )
    dissimilarities = []
    for serum1, serum2 in itertools.product(sera, sera):
        similarity = (
            pivoted_normed_data
            .assign(similarity=lambda x: x[serum1] * x[serum2])
            ['similarity']
        )
        assert similarity.notnull().all()  # make sure no sites have null values
        dissimilarities.append(1 - similarity.sum())  # sum of similarities over sites
    dissimilarities = pd.DataFrame(
        numpy.array(dissimilarities).reshape(len(sera), len(sera)),
        columns=sera,
        index=sera,
    )

    # plot the dissimilarities
    dissimilarities_wide = (
        dissimilarities
        .reset_index(names="serum_1")
        .melt(id_vars="serum_1", var_name="serum_2", value_name="dissimilarity")
        .assign(
            cohort_1=lambda x: x["serum_1"].map(serum_to_cohort),
            cohort_2=lambda x: x["serum_2"].map(serum_to_cohort),
            cohort=lambda x: x["cohort_1"].where(x["cohort_1"] == x["cohort_2"], "mixed"),
        )
    )
    display(
        alt.Chart(dissimilarities_wide)       
        .encode(
            x=alt.X("serum_1", sort=alt.SortField("cohort_1")),
            y=alt.Y("serum_2", sort=alt.SortField("cohort_2")),
            color="dissimilarity",
            stroke="cohort",
            tooltip=dissimilarities_wide.columns.tolist(),
        )
        .mark_rect(strokeWidth=2)
        .properties(
            width=alt.Step(13),
            height=alt.Step(13),
            title=f"{strain} serum dissimilarities",
        )
    )

    # use multidimensional scaling to get locations of sera
    mds = sklearn.manifold.MDS(
        n_components=2,
        metric=True,
        max_iter=3000,
        eps=1e-6,
        random_state=1,
        dissimilarity="precomputed",
        n_jobs=1,
    )
    locs = mds.fit_transform(dissimilarities)
    mds_df = pd.DataFrame(
        {
            "serum": sera,
            "cohort": [serum_to_cohort[s] for s in sera],
            "x": locs[:, 0],
            "y": locs[:, 1],
        }
    )
    size = 350  # control chart size
    mds_chart = (
        alt.Chart(mds_df)
        .encode(
            x=alt.X("x", title=None, axis=None),
            y=alt.Y("y", title=None, axis=None),
            color="cohort",
            tooltip=["serum", "cohort"],
        )
        .mark_circle(filled=True, size=100)
        .properties(
            width=size * (mds_df["x"].max() - mds_df["x"].min()),
            height=size * (mds_df["y"].max() - mds_df["y"].min()),
            title=f"{strain} MDS plot",
        )
        .configure_axis(grid=False)
    )
    display(mds_chart)

