# MOA prediction

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from scip_workflows.common import *


In [None]:
import pyarrow
from sklearn.decomposition import PCA, FactorAnalysis
from sklearn.metrics import (
    ConfusionMatrixDisplay,
    accuracy_score,
    balanced_accuracy_score,
    confusion_matrix,
)
from sklearn.model_selection import LeaveOneGroupOut, cross_validate
from sklearn.neighbors import KNeighborsClassifier
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler, scale
from tqdm.notebook import tqdm
from umap import UMAP


# Data

In [None]:
try:
    path = snakemake.input.features
    moa_path = snakemake.input.moa
    image_path = snakemake.input.image
    confusion_matrix_path = snakemake.output.confusion_matrix
except NameError:
    # data_root = Path("/data/gent/vo/000/gvo00070/vsc42015/datasets/BBBC021")
    data_root = Path("/home/maximl/scratch/data/vsc/datasets/BBBC021/")
    data_dir = data_root / "results" / "images_subset_v4"
    path = data_dir / "features.parquet"
    moa_path = data_root / "BBBC021_v1_moa.csv"
    image_path = data_root / "BBBC021_v1_image.csv"
    confusion_matrix_path = data_dir / "figures" / "confusion_matrix.png"


In [None]:
moa = pandas.read_csv(moa_path)
image = pandas.read_csv(image_path)


In [None]:
moa_image = moa.merge(
    image,
    left_on=["compound", "concentration"],
    right_on=["Image_Metadata_Compound", "Image_Metadata_Concentration"],
).drop(columns=["Image_Metadata_Compound", "Image_Metadata_Concentration"])


In [None]:
moa_image["batch"] = (
    moa_image["Image_Metadata_Plate_DAPI"]
    .apply(lambda p: int(p.split("_")[0][len("Week") :]))
    .astype("category")
)


In [None]:
seaborn.scatterplot(data=moa_image, x="batch", y="moa")


To apply the Not-same-compound-or-batch approach by [Ando et al.](https://www.biorxiv.org/content/10.1101/161422v1.full.pdf), the Cholesterol-lowering and Kinase Inhibitors MOAs should be removed from the datasets as they are only present on one batch.

In [None]:
moa_image = moa_image[
    ~moa_image["moa"].isin(["Cholesterol-lowering", "Kinase inhibitors"])
]


In [None]:
seaborn.scatterplot(data=moa_image, x="batch", y="compound")


In [None]:
treatments = moa_image[~moa_image["compound"].isin(["DMSO"])]


According to the previously mentioned research 92 treatments should remain.

In [None]:
(treatments["compound"] + treatments["concentration"].astype(str)).unique().shape


In [None]:
%%time
df = pq.read_table(path).to_pandas()

In [None]:
df.columns[df.isna().all()]


In [None]:
df.shape


In [None]:
moa_image.columns = ["meta_" + c for c in moa_image.columns]


In [None]:
df = df.merge(moa_image, left_on="meta_filename", right_on="meta_Image_FileName_DAPI")


# Removing interplate variation

In [None]:
qq_dmso = (
    df[df["meta_moa"] == "DMSO"]
    .groupby("meta_Image_Metadata_Plate_DAPI")[df.filter(regex="feat").columns]
    .quantile((0.01, 0.99))
)


In [None]:
dfs = []
for idx, gdf in df.groupby("meta_Image_Metadata_Plate_DAPI"):
    print(idx)
    df_scaled = (gdf.filter(regex="feat") - qq_dmso.loc[idx, 0.01]) / (
        qq_dmso.loc[idx, 0.99] - qq_dmso.loc[idx, 0.01]
    )
    df_scaled = pandas.concat([df_scaled, gdf.filter(regex="meta")], axis=1)

    dfs.append(df_scaled)


In [None]:
df = pandas.concat(dfs)
del dfs


In [None]:
allnan = df.columns[df.isna().all()]
allnan


In [None]:
df = df.drop(columns=allnan)


In [None]:
nancols = df.columns[df.isna().any()]
nancols


In [None]:
df = df[~df.isna().any(axis=1)]
df.shape


# Feature QC

After linear scaling features should be roughly in [0,1] range.

In [None]:
df.filter(regex="feat").min().mean(), df.filter(regex="feat").max().mean()


In [None]:
df.filter(regex="feat").min().idxmin()


In [None]:
df.filter(regex="feat").max().idxmax()


In [None]:
df["feat_moments_central-0-0_DAPI"].min(), df[
    "feat_moments_central-0-1_DAPI"
].min(), df["feat_moments_central-1-1_DAPI"].min()


In [None]:
df["feat_moments_central-2-2_DAPI"].min(), df[
    "feat_moments_central-2-3_DAPI"
].min(), df["feat_moments_central-3-3_DAPI"].min()


In [None]:
df["feat_moments_hu-0_DAPI"].min(), df["feat_moments_hu-1_DAPI"].min(), df[
    "feat_moments_hu-2_DAPI"
].min(), df["feat_moments_hu-3_DAPI"].min()


In [None]:
df = df.drop(columns=df.filter(regex="feat_moments.*[1, 2, 3, 4, 5, 6].*").columns)


In [None]:
df.filter(regex="feat").min().min(), df.filter(regex="feat").max().max()


In [None]:
df.filter(regex="feat").min().idxmin(), df.filter(regex="feat").max().idxmax()


In [None]:
df.filter(regex="glcm_").max().sort_values()


In [None]:
df.filter(regex="glcm_").min().sort_values()


In [None]:
df = df.drop(columns=df.filter(regex="feat_glcm_std.*").columns)


In [None]:
df.filter(regex="feat").min().min(), df.filter(regex="feat").max().max()


In [None]:
df.filter(regex="feat").min().idxmin(), df.filter(regex="feat").max().idxmax()


# Exploration

In [None]:
mu = (
    df[df["meta_moa"] != "DMSO"]
    .groupby(["meta_compound", "meta_concentration", "meta_Replicate"])
    .agg(
        {c: "mean" for c in df.filter(regex="feat").columns}
        | {c: lambda x: numpy.unique(x)[0] for c in df.filter(regex="meta")}
    )
).reset_index(drop=True)

mu = mu.groupby(["meta_compound", "meta_concentration"]).agg(
    {c: "median" for c in mu.filter(regex="feat").columns}
    | {c: lambda x: numpy.unique(x)[0] for c in mu.filter(regex="meta")}
)


In [None]:
mu.shape


In [None]:
treatment_profiles = mu.reset_index(drop=True)


In [None]:
dimred = PCA().fit_transform(treatment_profiles.filter(regex="feat"))


In [None]:
seaborn.scatterplot(
    x=dimred[:, 0], y=dimred[:, 1], hue=treatment_profiles["meta_compound"]
)
plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0)


In [None]:
seaborn.scatterplot(x=dimred[:, 0], y=dimred[:, 1], hue=treatment_profiles["meta_moa"])
plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0)


In [None]:
dimred = UMAP(metric="cosine", min_dist=1, n_neighbors=4, random_state=0).fit_transform(
    treatment_profiles.filter(regex="feat")
)


In [None]:
seaborn.scatterplot(
    x=dimred[:, 0],
    y=dimred[:, 1],
    hue=treatment_profiles["meta_compound"],
    edgecolors="none",
    alpha=0.7,
)
plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0)


In [None]:
seaborn.scatterplot(
    x=dimred[:, 0],
    y=dimred[:, 1],
    hue=treatment_profiles["meta_moa"],
    edgecolors="none",
    alpha=0.7,
)
plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0)


# Classification

In [None]:
df["meta_row"] = df["meta_Image_Metadata_Well_DAPI"].map(lambda a: a[0])


In [None]:
df_dmso = (
    df[df["meta_compound"] == "DMSO"]
    .groupby(["meta_Image_Metadata_Plate_DAPI", "meta_row"])
    .agg(
        {c: "mean" for c in df.filter(regex="feat").columns}
        | {c: lambda x: numpy.unique(x)[0] for c in df.filter(regex="meta")}
    )
)


In [None]:
# n_comps = 486
n_comps = 50


In [None]:
%%time
fa = FactorAnalysis(random_state=0, n_components=50)
fa.fit(df[df["meta_compound"] == "DMSO"].filter(regex="feat").sample(n=50000))

In [None]:
true = []
preds = []

for idx, df_test in tqdm(
    df[~df["meta_compound"].isin(["DMSO"])].groupby("meta_compound")
):
    test_compounds = (
        df[df["meta_batch"].isin(df_test["meta_batch"].unique())]["meta_compound"]
        .unique()
        .tolist()
    )
    test_compounds.remove("taxol")
    test_index = df["meta_compound"].isin(test_compounds)

    df_train = df[~test_index]

    #     e1 = make_pipeline(
    #         StandardScaler(),
    #         PCA(n_components=n_comps, random_state=0)
    #     )
    #     e1.fit(df_train.filter(regex="feat"))
    e1 = fa

    df_train = pandas.concat(
        [
            pandas.DataFrame(
                e1.transform(df_train.filter(regex="feat")),
                columns=["feat-%d" % i for i in range(n_comps)],
                index=df_train.index,
            ),
            df_train.filter(regex="meta"),
        ],
        axis=1,
    )
    df_test = pandas.concat(
        [
            pandas.DataFrame(
                e1.transform(df_test.filter(regex="feat")),
                columns=["feat-%d" % i for i in range(n_comps)],
                index=df_test.index,
            ),
            df_test.filter(regex="meta"),
        ],
        axis=1,
    )

    treatment_profiles_train = (
        df_train.groupby(["meta_compound", "meta_concentration", "meta_Replicate"])
        .agg(
            {c: "mean" for c in df_train.filter(regex="feat").columns}
            | {c: lambda x: numpy.unique(x)[0] for c in df_train.filter(regex="meta")}
        )
        .reset_index(drop=True)
        .groupby(["meta_compound", "meta_concentration"])
        .agg(
            {c: "median" for c in df_train.filter(regex="feat").columns}
            | {c: lambda x: numpy.unique(x)[0] for c in df_train.filter(regex="meta")}
        )
        .reset_index(drop=True)
    )

    treatment_profiles_test = (
        df_test.groupby(["meta_compound", "meta_concentration", "meta_Replicate"])
        .agg(
            {c: "mean" for c in df_test.filter(regex="feat").columns}
            | {c: lambda x: numpy.unique(x)[0] for c in df_test.filter(regex="meta")}
        )
        .reset_index(drop=True)
        .groupby(["meta_compound", "meta_concentration"])
        .agg(
            {c: "median" for c in df_test.filter(regex="feat").columns}
            | {c: lambda x: numpy.unique(x)[0] for c in df_test.filter(regex="meta")}
        )
        .reset_index(drop=True)
    )

    ### DMSO mock-treatments
    #     cols = treatment_profiles_train.filter(regex="feat").columns
    #     treatment_profiles_train.loc[:, cols] = df_dmso.loc[[tuple(a) for a in treatment_profiles_train[["meta_Image_Metadata_Plate_DAPI", "meta_row"]].values.tolist()], cols].reset_index(drop=True)
    #     treatment_profiles_test.loc[:, cols] = df_dmso.loc[[tuple(a) for a in treatment_profiles_test[["meta_Image_Metadata_Plate_DAPI", "meta_row"]].values.tolist()], cols].reset_index(drop=True)

    e2 = KNeighborsClassifier(n_neighbors=1, metric="cosine")
    e2.fit(
        X=treatment_profiles_train.filter(regex="feat"),
        y=treatment_profiles_train["meta_moa"],
    )

    true.extend(treatment_profiles_test["meta_moa"])
    preds.extend(e2.predict(treatment_profiles_test.filter(regex="feat")))


In [None]:
accuracy_score(true, preds)


In [None]:
fig, ax = plt.subplots()
cm = confusion_matrix(true, preds)
ConfusionMatrixDisplay(cm, display_labels=sorted(treatments["moa"].unique())).plot(
    ax=ax, colorbar=False, cmap="Reds"
)

for child in ax.get_children():
    if isinstance(child, matplotlib.text.Text) and child._text == "0":
        child.set_visible(False)

ax.set_xticks([])
ax.xaxis.labelpad = 20

acc = numpy.diag(cm) * 100 / cm.sum(axis=1)
for i in range(len(cm)):
    ax.text(x=11, y=i, s="%.0f%%" % acc[i], va="center", ha="right")
ax.text(x=11, y=-1, s="Acc.", va="center", ha="right")
ax.text(
    x=11,
    y=10,
    s="Overall Acc.:%.0f%%" % (accuracy_score(true, preds) * 100),
    va="center",
    ha="right",
)

# plt.savefig(confusion_matrix_path, dpi=300, bbox_inches="tight")


In [None]:
sum(numpy.asarray(true) == numpy.asarray(preds))
