In [1]:
import logging
import os
from pprint import pprint

import dask.dataframe as dd
import numpy as np
import pandas as pd
import scipy.stats
from cloudpathlib import AnyPath as Path

In [2]:
import warnings

# with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)

In [3]:
handler = logging.StreamHandler()
formatter = logging.Formatter(
    "%(asctime)s %(process)d/%(threadName)s %(name)s %(levelname)s\n%(message)s"
)
handler.setFormatter(formatter)
logging.getLogger().handlers = [handler]

In [4]:
logger = logging.getLogger(__name__)
logger.setLevel("DEBUG")

In [5]:
logging.getLogger("dask").setLevel("INFO")
logging.getLogger("gcsfs").setLevel("INFO")
logging.getLogger("google.cloud.bigquery").setLevel("DEBUG")
logging.getLogger("pandas").setLevel("DEBUG")
logging.getLogger("pyarrow").setLevel("DEBUG")

In [6]:
!gsutil ls gs://liulab/data/pseudobulk_optimization

gs://liulab/data/pseudobulk_optimization/1_no_qc_subset/
gs://liulab/data/pseudobulk_optimization/2_no_qc/
gs://liulab/data/pseudobulk_optimization/3_with_tcga_qc/


In [11]:
uri_pseudobulks = "gs://liulab/data/pseudobulk_optimization/3_with_tcga_qc/mixtures"
# !gsutil ls -lhR {uri_pseudobulks} | grep data.parquet | head
pseudobulks = dd.read_parquet(uri_pseudobulks, engine="pyarrow")
pseudobulks = pseudobulks.replace(
    {"malignant_from_one_sample": {"True": True, "False": False}}
)
pseudobulks = pseudobulks.astype(
    {"n_cells": "uint8", "malignant_from_one_sample": "bool"}
)
logger.debug(pseudobulks.dtypes)

2022-07-04 14:49:30,718 31236/MainThread __main__ DEBUG
gene_symbol                           category
tcga_aliquot_barcode_for_fractions    category
tpm                                    float64
n_cells                                  uint8
malignant_from_one_sample                 bool
dtype: object


In [12]:
uri_real_bulks = "gs://liulab/data/pseudobulk_optimization/3_with_tcga_qc/mixtures_real_tcga_skcm/tpm.parquet"
real_bulks = dd.read_parquet(uri_real_bulks, engine="pyarrow")
logger.debug(real_bulks.dtypes)

2022-07-04 14:49:35,179 31236/MainThread __main__ DEBUG
gene_symbol        category
aliquot_barcode    category
tpm                 float64
dtype: object


In [13]:
merged = dd.merge(
    pseudobulks,
    real_bulks,
    how="inner",
    left_on=["gene_symbol", "tcga_aliquot_barcode_for_fractions"],
    right_on=["gene_symbol", "aliquot_barcode"],
    suffixes=["_pseudo", "_real"],
)

In [14]:
merged_groupby = merged.groupby(["n_cells", "malignant_from_one_sample"])

# correlations with real bulk

In [None]:
def compute_comparison_metrics(df: pd.DataFrame) -> pd.Series:
    logger.debug(f"computing metrics for {len(df)} length DataFrame")
    gene_means = df.groupby("gene_symbol")[["tpm_pseudo", "tpm_real"]].mean()
    ks_test_results = scipy.stats.ks_2samp(
        gene_means["tpm_pseudo"], gene_means["tpm_real"]
    )
    return pd.Series(
        {
            "corr_linear": np.corrcoef(df["tpm_pseudo"], df["tpm_real"])[0, 1],
            "corr_rank": scipy.stats.spearmanr(df["tpm_pseudo"], df["tpm_real"])[0],
            "ks_test_stat": ks_test_results[0],
            # "ks_test_pval_neg_log10": -np.log10(ks_test_results[1]),
        }
    )


metadata = [
    ("corr_linear", "float64"),
    ("corr_rank", "float64"),
    ("ks_test_stat", "float64"),
    # ("ks_test_pval_neg_log10", "float64"),
]
results = (
    merged_groupby.apply(compute_comparison_metrics, meta=metadata).dropna().compute()
)
results = results.sort_index()

  return getattr(__obj, self.method)(*args, **kwargs)


In [None]:
results

In [None]:
import plotly.express as px
import plotly.graph_objects as go

In [None]:
fig = go.Figure()
modes = {True: "lines+markers", False: "lines"}
for column, color in zip(results.columns[[1, 0]], ["blue", "red"]):
    for malignant_from_one_sample in results.index.unique(
        level="malignant_from_one_sample"
    ):
        malignant_label = "one sample" if malignant_from_one_sample else "all samples"
        trace_name = f"{column}, malignants ~ {malignant_label}"
        subset = (
            results.query(f"malignant_from_one_sample == {malignant_from_one_sample}")
            .sort_index()
            .reset_index()
        )
        logger.debug(f"adding trace: {trace_name}")
        fig.add_trace(
            go.Scatter(
                x=subset["n_cells"],
                y=subset[column],
                mode=modes[malignant_from_one_sample],
                name=trace_name,
                line_color=color,
            )
        )
fig.update_layout(
    title="Similarity measures of real vs pseudobulk, by generation parameter"
)
fig = fig.update_yaxes(range=[0, 1], title="similarity measure")
fig = fig.update_xaxes(title="# cells per cell type in pseudobulk samples")

fig.show(width=1000, renderer="png")
# fig

# inter-sample correlation

what's the inter-sample correlation of tcga skcm?

In [None]:
df_real_bulks = real_bulks.compute()

In [None]:
z = df_real_bulks.pivot(
    index="gene_symbol",
    columns="aliquot_barcode",
    values="tpm",
).corr()
zz = z.values.flatten()
zzz = zz[: -len(z)]
np.median(zzz), np.mean(zzz), np.std(zzz)

How does inter-sample correlation decrease with more n_cells?

In [None]:
def compute_intersample_metrics(df):
    z = df.pivot(
        index="gene_symbol",
        columns="tcga_aliquot_barcode_for_fractions",
        values="tpm_pseudo",
    ).corr()
    zz = z.values.flatten()
    zzz = zz[: -len(z)]
    return pd.Series(
        {
            "intersample_corr_median": np.median(zzz),
            "intersample_corr_mean": np.mean(zzz),
            "intersample_corr_stddev": np.std(zzz),
        }
    )


metadata = [
    ("intersample_corr_median", "float64"),
    ("intersample_corr_mean", "float64"),
    ("intersample_corr_stddev", "float64"),
]
results_intersample = (
    merged_groupby.apply(compute_intersample_metrics, meta=metadata).dropna().compute()
)
results_intersample = results_intersample.sort_index()

In [None]:
results_intersample

# ended here!

In [None]:
results_intersample.set_index(
    results_intersample.index.map(lambda x: (x[0], f"malignant_from_one_sample={x[1]}"))
).unstack(level=-1)

In [None]:
_ = results_intersample.unstack(level="malignant_from_one_sample")
# _.columns = _.columns.to_flat_index().map(lambda x: f"{x[0]}={x[1]}")
_.columns.map(lambda x: f"malignant_from_one_sample={x}")

In [None]:
_ = results_intersample.unstack(level="malignant_from_one_sample")
_.columns = _.columns.get_level_values(1)
_

px.line(_)

# appendix 1

In [None]:
merged.dtypes

In [None]:
x = merged.query("n_cells == 5 and malignant_from_one_sample")[
    ["gene_symbol", "tcga_aliquot_barcode_for_fractions", "tpm_pseudo"]
]

In [None]:
y = x.compute()

In [None]:
y.info()

In [None]:
z = y.pivot(
    index="gene_symbol",
    columns="tcga_aliquot_barcode_for_fractions",
    values="tpm_pseudo",
).corr()

In [None]:
z.shape

In [None]:
zz = z.values.flatten()

In [None]:
zz.shape

In [None]:
zz.sort()

In [None]:
zzz = zz[: -len(z)]

In [None]:
zzz.shape

In [None]:
np.median(zzz)