# Setup

In [None]:
#l = list(df_genes.columns.sort_values())
#[i for i in l if i.startswith("GD")]


clinical = ["L1CAM"]  # GD2 not in dataset
novel_targets = ["GPC2", "CD276", "ALK", "NCAM1"]

promising = pd.DataFrame()
promising["name"] = clinical + novel_targets + other_targets
promising["cat"] = (
    ["clinical"] * len(clinical)
    + ["novel"] * len(novel_targets)
)
promising = promising.set_index("name")
promising

In [None]:
def df_from(selection):
    data = df[selection]
    data = pd.concat([df_classif, data], axis=1)
    data["tpm_sum"] = data.sum(axis=1, numeric_only=True)
    return data

def score_by_group_ind(selection, key="single"):
    df_scores = pd.DataFrame(data=None, index=h_tissues)
    dfs = [score_by_group(gene, key).rename(gene) for gene in selection]
    return pd.concat(dfs, axis=1)

In [None]:
sel = list(df_genes.columns)[0:5]
scoring.score_by_group(sel).head()
scoring.score(sel)
plotting.boxplot(sel)

# Single gene analysis

In [None]:
%%time
scores_avg = score_by_group_ind(df_genes.columns, 'average')
scores_single = score_by_group_ind(df_genes.columns, 'single')
scores_avg.head()
scores_single.head()

In [None]:
promising["avg_mean"] = scores_avg.mean()[promising.index]
promising["avg_min"] = scores_avg.min()[promising.index]
promising["single_mean"] = scores_single.mean()[promising.index]
promising["single_min"] = scores_single.min()[promising.index]
promising
sns.heatmap(promising.iloc[:, 1:], cmap="vlag")

In [None]:
boxplot(["L1CAM"])

In [None]:
avg_best = scores_avg.mean().sort_values(ascending=False).head(10)

In [None]:
import numpy as np


def select_genes(scores, quantile=0.25, percentage=0.8):
    threshold = np.quantile(scores.values, quantile)
    gs = scores.apply(lambda x: x > threshold).sum()
    t = np.floor(percentage * len(tissues))
    sel = gs[gs >= t].index
    return sel

In [None]:
sel = select_genes(scores_avg, quantile=0.90, percentage=0.99)
set(promising.index) - set(sel)
set(best_sel) - set(sel)
len(sel)
1 - (len(sel) / len(scores_avg.columns))

In [None]:
import itertools
from tqdm import tqdm

sols = []
for combination in tqdm(itertools.combinations(sel, 5)):
    s = score_single(list(combination))
    if s > 3.5:
        sols.append((*combination, s))

In [None]:
sorted(sols, key=lambda x: (x[-1]), reverse=True)

In [None]:
score_by_group_ind(sel).mean().sort_values(ascending=False).head(50)

# Collect results

# Eval

## Load results

In [None]:
sols = clean_and_eval([dfs[0]], 5)
dfs[0].sort_values("single", ascending=False).head(10)

In [None]:
boxplot(
    dfs[2][dfs[2].approach == "ga"]
    .sort_values("single", ascending=False)
    .head(10)
    .iloc[0, 0:20]
    .values
)

In [None]:
sns.set_style("ticks")
boxplot(
    dfs[0].sort_values("average", ascending=False).head(5).iloc[0, 0:5].values,
    "eval/pres/avg1.svg",
)
boxplot(
    dfs[0].sort_values("average", ascending=False).head(5).iloc[1, 0:5].values,
    "eval/pres/avg2.svg",
)
boxplot(
    dfs[0].sort_values("single", ascending=False).head(5).iloc[0, 0:5].values,
    "eval/pres/single1.svg",
)
boxplot(
    dfs[0].sort_values("single", ascending=False).head(5).iloc[1, 0:5].values,
    "eval/pres/single2.svg",
)
boxplot(
    sols1.sort_values("average", ascending=False).head(5).iloc[0, 0:5].values,
    "eval/pres/mix1.svg",
)
boxplot(
    sols1.sort_values("average", ascending=False).head(5).iloc[1, 0:5].values,
    "eval/pres/mix2.svg",
)

## Melt solution dfs

In [None]:
ag_cols = list(range(20))
linkages = linkage_map.keys()


def asdf(df, approach, n, diff, sort_by):
    # df["x"] = (df["single"] + df["average"]) / 2
    subset = df[df.approach == approach].sort_values(sort_by).tail(n)
    subset = subset.reset_index(drop=True).reset_index()
    subset = pd.melt(subset, id_vars=list(linkages) + ["approach", "index"])
    subset = subset.rename({"index": "solution_id"}, axis=1)
    subset = subset[subset.variable != "x"]
    size = subset.variable.astype("int").max()
    subset["sol_size"] = size + 1
    # subset = subset.drop("variable", axis=1)
    if diff:
        subset = subset[subset["value"].isin(diff)]
    return subset


def _melt(df, n=None, diff=None, sort_by=None):
    n = len(df) if not n else n
    ss1 = asdf(df, "ga", n, diff, sort_by)
    ss2 = asdf(df, "rf", n, diff, sort_by)
    ss2.loc[:, "solution_id"] = ss2.loc[:, "solution_id"] + n

    return pd.concat([ss1, ss2], ignore_index=True)


def melt(dfs, n=None, diff=None, sort_by="single"):
    tmp = []
    for df in dfs:
        tmp.append(_melt(df, n, diff, sort_by))
    return pd.concat(tmp)


def melt_uniq(dfs, n=None, diff=None, sort_by="single"):
    m = melt(dfs, n, diff, sort_by)
    m.loc[m.sol_size == 10, "solution_id"] = m.loc[m.sol_size == 10, "solution_id"] + n
    m.loc[m.sol_size == 20, "solution_id"] = (
        m.loc[m.sol_size == 20, "solution_id"] + 2 * n
    )
    return m

In [None]:
sols1["approach"] = "ga"
m1 = sols1[sols1.single > 0]
sols2["approach"] = "ga"
m2 = sols2[sols2.single > 0]
m = melt_uniq([m1, m2], 10, sort_by="average")
m

In [None]:
eval_df_full = melt(dfs)
n_sols = 10
ds = dfs[0:2]
eval_df_best_s = melt_uniq(ds, n_sols, sort_by="single")
eval_df_best_a = melt_uniq(ds, n_sols, sort_by="average")
eval_df_best_s["average"].mean()
eval_df_best_a["average"].mean()
eval_df_best_s["single"].mean()
eval_df_best_a["single"].mean()

In [None]:
for df in dfs:
    df.sort_values("single").tail(10)

## Individual solutions

In [None]:
def pivot(data, approach=None):
    if approach:
        piv = (
            data[data.approach == approach]
            .groupby(["sol_size", "solution_id", "value"])
            .size()
            .reset_index()
            .pivot(columns=["sol_size", "solution_id"], index="value", values=0)
        )
    else:
        piv = (
            data.groupby(["approach", "sol_size", "solution_id", "value"])
            .size()
            .reset_index()
            .pivot(
                columns=["sol_size", "approach", "solution_id"], index="value", values=0
            )
        )
    s = piv.sum(axis=1)
    s = s.sort_values(ascending=False)
    piv = piv.reindex(s.index)
    piv = piv.fillna(0)
    piv = piv.head(15)
    return piv


# sns.heatmap(piv)
# df_plot.to_csv("piv_new.csv")
best_sel = piv.index
# piv1 = pivot(eval_df_best_s, "ga")
# piv1
pivot(eval_df_best_s, "ga").to_csv("piv_s_ga.csv")
pivot(eval_df_best_s, "rf").to_csv("piv_s_rf.csv")
pivot(eval_df_best_a, "ga").to_csv("piv_a_ga.csv")
pivot(eval_df_best_a, "rf").to_csv("piv_a_rf.csv")
pivot(m, "ga").to_csv("piv_a_sga.csv")
# piv1.to_csv("piv_ga.csv")
# piv2.to_csv("piv_rf.csv")
# piv1
# promising.loc[set(promising.index) - set(piv1.index)]
# promising.loc[set(promising.index).intersection(set(piv1.index))]
# piv2
# pivot(eval_df_best_s)
# pivot(eval_df_best_a)

In [None]:
import random

s = len(eval_df_best["value"].unique())
a = sns.color_palette("Paired")
# col = a + sns.color_palette("hls", s - len(a))
col = sns.color_palette("Spectral", s)
col = sns.color_palette("Set1") + sns.color_palette("Set2")

assert len(col) >= s
random.shuffle(col)

ax = df_plot.plot(kind="barh", stacked=True, figsize=(10, 10), color=col, legend=False)
# plt.axis("off")


def annotateBars(row, ax=ax):
    # print(row)
    curr_value = 0
    for col in row.index:
        value = row[col]
        if str(value) != "nan":
            _ = ax.text(
                curr_value + (value) / 2,
                row.name,
                str(col),
                ha="center",
                va="center",
            )
            curr_value += value


_ = df_plot.apply(annotateBars, ax=ax, axis=1)

In [None]:
import numpy as np

# fil = eval_df_best[eval_df_best.sol_size < 10]
# fil = dfs[0].iloc[:30]
fil = melt_uniq(dfs[0:3], 10)
piv = pd.pivot_table(
    fil,
    values="single",
    index=["sol_size", "solution_id"],
    columns=["value"],
    aggfunc=np.any,
)
piv = piv.fillna(False)
s = piv.sum()
piv = piv[s.sort_values(ascending=False).index]
piv = piv[piv.columns[piv.sum() > 4]]
piv = piv.replace({True: 1, False: 0})
piv
sns.heatmap(piv.corr(), cmap="vlag")
# piv.to_csv("piv.csv")
# sns.heatmap(piv)
piv.corr().to_csv("corr2.csv")

## Solution quality distribution

### Generate random solutions

In [None]:
dfs_rand = []
for n in [5, 10, 20]:
    rows = []
    for i in range(1000):
        rows.append(list(df_genes.sample(n, axis=1, random_state=i).columns))

    df_rand = pd.DataFrame.from_records(rows, columns=[str(x) for x in range(n)])
    dfs_rand.append(clean_and_eval([df_rand], n))

dfs_rand[0]

In [None]:
for df in dfs_rand:
    df["approach"] = "random"

eval_df_rand = melt(dfs_rand)

In [None]:
data_eval_full_with_rand = pd.concat([eval_df_full, eval_df_rand]).reset_index()
data_eval_full_with_rand

In [None]:
from collections import Counter

In [None]:
dfs[0].iloc[:20, 0:5].reset_index().pivot_table(
    index="index", columns=["0", "1", "2", "3", "4"], values=["index"], aggfunc=np.sum
)

In [None]:
topk = 40
ag_names = dfs[0].iloc[:topk, 0:5].to_numpy().reshape(-1)
ag_indices = {ag[0]: i for i, ag in enumerate(Counter(ag_names).most_common())}
count_M = np.zeros((len(ag_indices), len(ag_indices)), dtype=int)

In [None]:
for row in dfs[0].iloc[:topk, 0:5].to_numpy():
    for i in range(len(row)):
        for j in range(i + 1, len(row)):
            count_M[ag_indices[row[i]], ag_indices[row[j]]] += 1
            count_M[ag_indices[row[j]], ag_indices[row[i]]] += 1

In [None]:
count_M.shape

In [None]:
plt.imshow(np.log(count_M / count_M.max()))

In [None]:
mask = np.triu(np.ones_like(count_M, dtype=bool))

plt.figure(figsize=(6, 6), dpi=140)
sns.heatmap(
    count_M,
    mask=mask | (count_M == 0),
    vmin=0,
    cmap=sns.color_palette("Blues", as_cmap=True),
    center=0,
    square=True,
    linewidths=0.1,
    cbar_kws={"shrink": 0.5},
)

### Plots

In [None]:
x = data_eval_full_with_rand.drop_duplicates(subset=["approach", "sol_size", "index"])
x.groupby(["approach", "sol_size"]).size()

In [None]:
import numpy as np

size = 850
replace = False
fn = lambda obj: obj.loc[np.random.choice(obj.index, size, replace), :]
x2 = x.groupby(["approach", "sol_size"], as_index=False).apply(fn)
x2.groupby(["approach", "sol_size"]).size()

In [None]:
fn = lambda x: list(map(lambda y: y[-1], x))
x4 = x.iloc[fn(x.groupby(["approach", "sol_size"])["single"].nlargest(850).index)]
x4

In [None]:
name_dict = {
    "single": "Single linkage distance",
    "average": "Average linkage distance",
    "sol_size": "Number of selected antigens",
    "approach": "Approach",
}
x3 = x2.rename(name_dict, axis=1)
x3["Approach"] = x3["Approach"].map(
    {"ga": "Genetic algorithm", "rf": "Random forest", "random": "Random"}
)
x3

In [None]:
sns.set_style("whitegrid")
for linkage in ["single", "average"]:
    plot = sns.displot(
        data=x3,
        x=name_dict[linkage],
        col=name_dict["sol_size"],
        hue=name_dict["approach"],
        kind="kde",
        height=3,
    )
    plot.fig.savefig(f"dist_{linkage}.svg", format="svg", bbox_inches="tight")

In [None]:
plot = sns.scatterplot(
    data=x3,
    x=name_dict["single"],
    y=name_dict["average"],
    hue=name_dict["approach"],
    s=5,
)
plot.get_figure().savefig("single_avg_corr.svg", format="svg", bbox_inches="tight")

In [None]:
sns.heatmap(x4.corr())

In [None]:
p = sns.relplot(
    data=x3,
    x=name_dict["single"],
    y=name_dict["average"],
    col=name_dict["sol_size"],
    hue=name_dict["approach"],
    height=4,
    aspect=1,
    s=5,
    # facet_kws={"sharey": False, "sharex": False},
)

p.get_figure().savefig('single_avg_corr.svg', format="svg", bbox_inches="tight")

In [None]:
# x4 = x3[x3[name_dict["single"]] > 0]
p = sns.relplot(
    data=x4[(x4.approach == "ga") & (x4.sol_size == 5)],
    x="single",
    y="single_rel",
    col="sol_size",
    hue="approach",
    height=4,
    aspect=1,
    s=5,
    # facet_kws={"sharey": False, "sharex": False},
)

x4.groupby(["approach", "sol_size"]).corr()["single"].loc[:, :, "single_rel"]

## idk what this is

In [None]:
def hist(df, diff=None):
    subset = _melt(df, diff)
    sns.set_style("darkgrid")
    plot = sns.histplot(x=subset["value"])
    _ = plt.setp(plot.get_xticklabels(), rotation=90)
    plot.get_figure().set_size_inches(15, 5)

In [None]:
eval_df = melt(dfs, 50)
ags = eval_df["value"].unique()
len(ags)
eval_df.groupby(["approach", "sol_size"])["index"].max()
eval_df.drop_duplicates(["approach", "sol_size", "value"]).groupby(
    ["approach", "sol_size"]
).size()

In [None]:
vc = eval_df.groupby(["approach", "sol_size"])["value"].value_counts()
pd.DataFrame(vc).rename({"value": "count"}, axis=1).reset_index().rename(
    {"value": "antigen"}, axis=1
).groupby(["approach", "sol_size"]).apply(
    lambda x: x.nlargest(10, "count")
).reset_index(
    drop=True
)

In [None]:
for i in range(3):
    melt(dfs_best100[i])["index"].max()

In [None]:
a = pd.read_csv("sga_eval.csv")
a = a.drop("Unnamed: 0", axis=1)
a = a.fillna(0)
a = a.melt(id_vars=["p", "q"], value_vars=["num sa", "num diff", "dropped"])
sns.lineplot(data=a, x="p", y="value", hue="variable")

In [None]:
x["perc dropped"] = x["dropped"] / 291
x["perc selected"] = x["num sa"] / 3169

In [None]:
x

In [None]:
a = x
a = a.fillna(0)
a = a.melt(id_vars=["p", "q"], value_vars=["perc selected", "perc dropped"])
sns.lineplot(data=a, x="p", y="value", hue="variable")

In [None]:
sns.lineplot(data=a, x="q", y="value", hue="variable")

In [None]:
sns.jointplot(data=a, x="q", y="p", kind="hex", color="#4CB391")

In [None]:
g = sns.scatterplot(
    data=x,
    x="q",
    y="p",
    hue="perc dropped",
)

In [None]:
x.to_csv("sga_eval.csv")

In [None]:
hist(dfs[0], diff)

In [None]:
hist(dfs[1], diff)

In [None]:
hist(dfs[2], diff)

In [None]:
eval_df

In [None]:
srt = eval_df.groupby("value").size().sort_values(ascending=False)

In [None]:
srt

In [None]:
%%time
sns.set_style('darkgrid')
plot = sns.catplot(
    data=eval_df, x="value", y="single", hue="index", kind="strip", row="sol_size", order=srt, sharey=False, aspect=3, s=6
)
plot.set_xticklabels(rotation=90)
for ax in plot.axes.flat:
    ax.grid(True, axis='both')

In [None]:
df

In [None]:
eval_df_full = pd.concat(
    [
        melt(dfs[0]),
        melt(dfs[1]),
        melt(dfs[2]),
    ]
).reset_index()

In [None]:
dfs[0]

In [None]:
%%time
plot = sns.catplot(
    data=eval_df, x="value", y="single", hue="value", row="sol_size", legend=False, kind="violin"
)
_ = plt.setp(plot.get_xticklabels(), rotation=90)
_ = plt.tight_layout()
plot.get_figure().set_size_inches(24, 24)

In [None]:
a


subset = subset.drop("variable", axis=1)
a = subset.groupby("value").mean()
b = subset.groupby(["value"]).size()
subset = (
    pd.concat([a, b], axis=1).reset_index(level=["value"]).rename({0: "count"}, axis=1)
)
subset = subset.sort_values("count", ascending=False).reset_index(drop=True)
subset

# plot = sns.barplot(x=subset["value"], y=subset["count"])
# _ = plt.setp(plot.get_xticklabels(), rotation=90)
# plot.get_figure().set_size_inches(14, 4)
# plot
# prettify_axes(plot)

In [None]:
ag_cols = ["0", "1", "2", "3", "4"]
subset = dfs[0]
subset = subset[subset["single"] > 3.5]
# subset
subset = subset.loc[:, ["0", "1", "2", "3", "4", "approach", "single", "median"]]
# subset

subset = pd.melt(subset, id_vars=["approach", "single", "median"], value_vars=ag_cols)
subset = subset.drop("variable", axis=1)
subset
subset = subset[subset["value"].map(subset["value"].value_counts()) > 3]
# subset.value[subset["value"].value_counts() > 5]
plot = sns.violinplot(
    data=subset, x="value", y="single", hue="approach", split=True, inner=None
)
_ = plt.setp(plot.get_xticklabels(), rotation=90)
plot.get_figure().set_size_inches(10, 5)
# plot
# prettify_axes(plot)

In [None]:
subset = subset.join(subset.groupby("value").size().rename("count"), on="value")

In [None]:
subset

In [None]:
plot = sns.scatterplot(data=subset, x="value", y="single", size="count", sizes=(1, 400))
_ = plt.setp(plot.get_xticklabels(), rotation=90)
plot.get_figure().set_size_inches(20, 10)

# Trash

In [None]:
df_genes = a
df = pd.concat([df_classif, df_genes], axis=1)

In [None]:
df

In [None]:
%%time
a = df_genes.apply(lambda x: 2**x - 1)

In [None]:
import math
a.applymap(lambda x: math.log2(x+1))


In [None]:
df_genes.head()

In [None]:
l = list(df_genes.columns)
l.sort()
l