In [1]:
%load_ext autoreload
%autoreload 2

import os
import subprocess
import sys
import warnings

warnings.simplefilter("ignore", FutureWarning)

%aimport cnv_inference_config
project_config = cnv_inference_config
os.chdir(project_config.MB_ROOT)

from collections import Counter, defaultdict, OrderedDict
import itertools
from itertools import product as cartesian
import multiprocessing as mp
import pickle

import numba
from joblib import Parallel, delayed
import numpy as np
import pandas as pd
import scipy as sp
import scipy.stats as sps
from tqdm import tqdm, tqdm_notebook

import matplotlib.pyplot as plt
import seaborn as sns

import toolkit
import util
from workspace.workspace_manager import WorkspaceManager

sns.set()

workspace = {}
for data_type in ["scDNA", "scRNA"]:
    workspace.update({ 
        data_type : WorkspaceManager(
            task_name="ase_to_cnv",
            experiment_info={"data" : data_type},
            verbose=True
        )
    })
    workspace[data_type].load_workspace()

In [2]:
data = {}
for data_type in ["scDNA", "scRNA"]:
    data[data_type] = {
        data_name :
        util.pickle_load(data_dump)
        for data_name, data_dump in tqdm_notebook(
            workspace[data_type].tmp_data.items(),
            f"{data_type}, loading datasets into RAM"
        )
    }
    
for modality in ["scDNA", "scRNA"]:
    data[modality]["block_counts"].rename(
        columns={"GENE_ID" : "BLOCK_ID"}, 
        inplace=True
    )

HBox(children=(IntProgress(value=0, description='scDNA, loading datasets into RAM', max=2, style=ProgressStyle…




HBox(children=(IntProgress(value=0, description='scRNA, loading datasets into RAM', max=2, style=ProgressStyle…




In [3]:
common_blocks = (set(data["scDNA"]["block_counts"].BLOCK_ID) 
                & set(data["scRNA"]["block_counts"].BLOCK_ID))

print("Datasets have {} blocks in common".format(len(common_blocks)))

for modality in workspace.keys():
    data[modality]["block_counts"] = util.filter_by_isin(
        data[modality]["block_counts"], 
        "BLOCK_ID", 
        common_blocks
    ).reset_index(drop=True)

common_block_ids = data["scDNA"]["block_counts"]["BLOCK_ID"].values.astype(int)
    
assert (data["scDNA"]["block_counts"].shape[0] 
        == data["scRNA"]["block_counts"].shape[0])

Datasets have 8776 blocks in common


In [None]:
M_prime = toolkit.extract_barcodes(data["scDNA"]["block_counts"]).size
M = data["scRNA"]["clustering"]["LABEL"].unique().size 
# M = toolkit.extract_barcodes(data["scRNA"]["block_counts"]).size
K = data["scDNA"]["clustering"]["LABEL"].unique().size
N_G = data["scDNA"]["block_counts"].shape[0]
T_max = 5
tau = np.concatenate([[(t - k, k) for k in range(t + 1)] 
                      for t in range(1, T_max +1)])
conf_to_num = {tuple(cnv_config) : i for i, cnv_config in enumerate(tau)}
num_to_conf = {val : key for key, val in conf_to_num.items()}

In [4]:
D_G_prime = toolkit.extract_counts(data["scDNA"]["block_counts"]).values
D_G = toolkit.extract_counts(
    toolkit.aggregate_by_barcode_groups(
        data["scRNA"]["block_counts"],
        data["scRNA"]["clustering"]
    )
).values

A_G_prime = toolkit.extract_counts(
    data["scDNA"]["block_counts"], 
    suffix="ad"
).values
A_G = toolkit.extract_counts(
    toolkit.aggregate_by_barcode_groups(
        data["scRNA"]["block_counts"], 
        data["scRNA"]["clustering"]
    ),
    suffix="ad"
).values

In [None]:
f = data["scDNA"]["clustering"]["LABEL"].value_counts().sort_index().values / M_prime
assert np.isclose(f.sum(), 1)

In [5]:
rodata_dir = "/icgc/dkfzlsdf/analysis/B260/users/v390v/cnv_inference/data/raw/first_sample"
CNV_prime = []
for i in tqdm_notebook(range(9), "reading CNV information"):
    snp_cnv = pd.read_csv(
        f"{rodata_dir}/TabHaplotypeblock_with_phasedSNPs_{i + 1}.bed", 
        usecols=[0, 1, 2, 3],
        names=["CHROM", "START", "END", "CNV"],
        sep='\t'
    )
    snp_cnv.drop_duplicates(inplace=True)
    print(snp_cnv.shape, np.median(snp_cnv["CNV"]))
    CNV_prime.append(snp_cnv["CNV"].values.astype(int))
CNV_prime = np.column_stack(CNV_prime)

HBox(children=(IntProgress(value=0, description='reading CNV information', max=9, style=ProgressStyle(descript…

(18011, 4) 2.0
(18011, 4) 2.0
(18011, 4) 2.0
(18011, 4) 2.0
(18011, 4) 4.0
(18011, 4) 2.0
(18011, 4) 2.0
(18011, 4) 2.0
(18011, 4) 2.0



In [None]:
sns.heatmap(CNV_prime, cmap="BuGn_r")

In [None]:
data["scDNA"]["clonal_block_counts"] = toolkit.aggregate_by_barcode_groups(
    data["scDNA"]["block_counts"],
    data["scDNA"]["clustering"]
)

D_C_prime = toolkit.extract_counts(
    data["scDNA"]["clonal_block_counts"]
).values
A_C_prime = toolkit.extract_counts(
    data["scDNA"]["clonal_block_counts"], "ad"
).values

In [None]:
def beta_mode(a, b):
    return (a - 1) / (a + b - 2)

In [None]:
T = np.full_like(D_C_prime, np.nan)
for block_id, clone_id in tqdm_notebook(cartesian(range(N_G), range(K))):
    t = CNV_prime[block_id, clone_id]
    ad = A_C_prime[block_id, clone_id]
    dp = D_C_prime[block_id, clone_id]
    if dp == 0 or dp is np.nan:
        continue
    ase_ratio = ad / dp
    offset = np.argmin(np.abs(np.arange(t + 1) / t - ase_ratio))
    T[block_id, clone_id] = int(conf_to_num[(t, 0)] + offset)

In [25]:
# ASE ratios are stored as (\alpha, \beta) parameter tuples
# of the underlying Beta distributions
Alpha_G = np.zeros(shape=(N_G, tau.shape[0]))
Beta_G = np.zeros(shape=(N_G, tau.shape[0]))
Theta_G = np.zeros(shape=(N_G, tau.shape[0])) 

eps = 1
for block_id, cnv_config in tqdm_notebook(itertools.product(range(N_G), tau)):
    k0, k1 = cnv_config
    t = k0 + k1
    
    if k0 == 0:
        alpha, beta = 1, 1 + eps
    elif k1 == 0:
        alpha, beta = 1 + eps, 1
    else:
        if k1 > k0:
            alpha = 1 + eps
            beta = k1 / k0 * alpha + (k0 - k1) / k0
        else:
            beta = 1 + eps
            alpha = k0 / k1 * beta + (k1 - k0) / k1
        assert np.isclose(beta_mode(alpha, beta), k0 / t)
    
    assert alpha >= 1 and beta >= 1
    config_id = conf_to_num[tuple(cnv_config)]
    Alpha_G[block_id, config_id] = alpha
    Beta_G[block_id, config_id] = beta
    Theta_G[block_id, config_id] = k0 / t

HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

In [30]:
def cell_logit(cell_id, A, D, X):
    ok_mask = ~(np.isnan(D) | (D == 0))
    probas = sps.binom(
        n=D[ok_mask], 
        p=X[ok_mask]
    ).pmf(A[ok_mask])
    return np.sum(np.log(probas[probas > 0]))


def cell_likelihood(cell_id, A, D, X):
    return np.exp(cell_logit(cell_id, A, D, X)).prod()


def cell_loglikelihood(cell_id, A, D, X):
    return cell_logit(cell_id, A, D, X).sum()

In [31]:
def update_I_G(cell_id, Theta_G):
    logprobas = np.array([
        cell_loglikelihood(
            cell_id, 
            A_G[:, cell_id], 
            D_G[:, cell_id], 
            Theta_G[:, clone_id]
        ) 
        + np.log(f[clone_id])
        for clone_id in range(K)
    ])
#     print(f"{cell_id}:\t {logprobas}")
    return sps.rv_discrete(
       a=0, b=K, 
       values=[np.arange(K), np.abs(logprobas) / np.nansum(np.abs(logprobas))]
    ).rvs()

In [9]:
I_G_prime = data["scDNA"]["clustering"]["LABEL"].astype(int).values - 1 # to zero-indexing
I_G = np.array(Parallel(16)(
    delayed(update_I_G)(cell_id, Theta_G)
    for cell_id in tqdm_notebook(range(M), "updating I_G")
))

HBox(children=(IntProgress(value=0, description='updating I_G', max=10, style=ProgressStyle(description_width=…

NameError: name 'update_I_G' is not defined

In [29]:
H_G_prime = np.full_like(D_G_prime, np.nan)
X_G_prime = np.full_like(D_G_prime, np.nan)
for cell_id in tqdm_notebook(range(M_prime), "updating H_G and X_G"):
    H_G_prime[:, cell_id] = T[:, I_G_prime[cell_id]]
    not_na_mask = ~np.isnan(H_G_prime[:, cell_id])
    X_G_prime[not_na_mask, cell_id] = Theta_G[
        not_na_mask, 
        H_G_prime[not_na_mask, cell_id].astype(int)
    ]

H_G = np.full_like(D_G, np.nan)
X_G = np.full_like(D_G, np.nan)
for cell_id in tqdm_notebook(range(M), "updating H_G and X_G"):
    H_G[:, cell_id] = T[:, I_G[cell_id]]
    not_na_mask = ~np.isnan(H_G[:, cell_id])
    X_G[not_na_mask, cell_id] = Theta_G[
        not_na_mask, 
        H_G[not_na_mask, cell_id].astype(int)
    ]

HBox(children=(IntProgress(value=0, description='updating H_G and X_G', max=268, style=ProgressStyle(descripti…

HBox(children=(IntProgress(value=0, description='updating H_G and X_G', max=10, style=ProgressStyle(descriptio…

In [128]:
def classification_report(labels, title, outfile=None, show=True):
                
        sns.set(style="whitegrid", font_scale=1.5);
        fig, ax = plt.subplots(2, 1, figsize=(20,30))
        ax[0].set_title("Cluster label assigned by XClone", fontsize=20)
        sns.countplot(
            labels, 
            palette=#sns.color_palette("muted", 
                      #                n_colors=np.unique(labels).size),
            {
#                 "1" : "#3182bd", #"C0",
#                 "2" : "#2ca25f", #"C2",
#                 "3" : "#feb24c"#"C1"
                1 : "xkcd:orange",
                2 : "xkcd:azure",
                3 : "xkcd:cyan",
                4 : "xkcd:yellow",
                5 : "xkcd:blue",
                6 : "xkcd:red",
                7 : "xkcd:pink",
                8 : "xkcd:grey",
                9 : "xkcd:black"
            },
            ax=ax[0]
        )

    
        ax[1].set_title(title)
        
        sns.scatterplot(
            x="TSNE_1", y="TSNE_2", 
            hue=labels, 
            data=data["scRNA"]["clustering"], 
            legend="full",
            palette=#sns.color_palette("muted", 
                     #                 n_colors=np.unique(labels).size),
            {
#                 1 : "#3182bd", #"C0",
#                 2 : "#2ca25f", #"C2",
#                 3 : "#feb24c"#"C1"
                1 : "xkcd:orange",
                2 : "xkcd:azure",
                3 : "xkcd:cyan",
                4 : "xkcd:yellow",
                5 : "xkcd:blue",
                6 : "xkcd:red",
                7 : "xkcd:pink",
                8 : "xkcd:grey",
                9 : "xkcd:black"
            },
            ax=ax[1]
        );
        ax[1].legend().get_frame().set_facecolor("white");
        ax[1].legend(frameon=False, bbox_to_anchor=(1,0.5), loc="center left")
        fig.subplots_adjust(right=0.75)
        if outfile is not None:
            fig.savefig(outfile, format=outfile.split('.')[-1], dpi=300)
        if show == True:
            plt.show()
        plt.close()
        plt.clf()

In [164]:
def fact(n):
    return np.prod(np.arange(n) + 1)

pd.DataFrame([[np.log(fact(n)) for n in range(1, 21)], [logfact(n) for n in range(1, 21)]]).T

Unnamed: 0,0,1
0,0.0,-0.000143
1,0.693147,0.693113
2,1.791759,1.791746
3,3.178054,3.178047
4,4.787492,4.787488
5,6.579251,6.579249
6,8.525161,8.52516
7,10.604603,10.604602
8,12.801827,12.801827
9,15.104413,15.104412


In [188]:
from time import time

def cell_logit(cell_id, A, D, X):
    logits = sps.binom(
        n=D, 
        p=X
    ).logpmf(A)
    return logits[np.isfinite(logits)]


def cell_likelihood(cell_id, A, D, X):
    return np.exp(cell_logit(cell_id, A, D, X)).prod()


def cell_loglikelihood(cell_id, A, D, X):
    return cell_logit(cell_id, A, D, X).sum()

@numba.jit(nopython=True)
def logfact(n):
    # https://math.stackexchange.com/questions/138194/approximating-log-of-factorial
    return n * np.log(n) - n + np.log(n  * ( 1 + 4 * n * (1 + 2 * n) )) / 6 + np.log(np.pi) / 2

@numba.jit(nopython=True)
def total_loglikelihood(A_G_prime, R_G_prime, X_G_prime, bincoeff_prime, 
                        A_G, R_G, X_G, bincoeff):
    
    alt_prime = A_G_prime * np.log(X_G_prime)
    ref_prime = R_G_prime * np.log(1 - X_G_prime)
    
    alt = A_G * np.log(X_G)
    ref = R_G * np.log(1 - X_G)
    
    loglik_prime = np.ravel(alt_prime + ref_prime + bincoeff_prime)
    loglik = np.ravel(alt + ref + bincoeff)

    return (
        np.sum(loglik_prime[np.isfinite(loglik_prime)])
        + np.sum(loglik[np.isfinite(loglik)])
    )
#     M_prime, M = A_G_prime.shape[1], A_G.shape[1]
#     logits_prime = sps.binom(
#         n=D_G_prime, 
#         p=X_G_prime
#     ).logpmf(A_G_prime)
#     logits = sps.binom(
#         n=D_G, 
#         p=X_G
#     ).logpmf(A_G)
#     return np.sum(logits_prime[np.isfinite(logits_prime)]) \
#             + np.sum(logits[np.isfinite(logits)])
#     return (
#         np.sum([
#             cell_loglikelihood(cell_id, A_G_prime[:, cell_id], D_G_prime[:, cell_id], X_G_prime[:, cell_id])
#             for cell_id in range(M_prime)
#         ]) + np.sum([
#             cell_loglikelihood(cell_id, A_G[:, cell_id], D_G[:, cell_id], X_G[:, cell_id])
#             for cell_id in range(M)
#         ])   
#     )
    
    
class XClone:
    def __init__(self, A_G_prime, D_G_prime, A_G, D_G, I_G_prime, CNV_prime, T_max):
        assert A_G_prime.shape == D_G_prime.shape
        assert A_G.shape == D_G.shape
        assert A_G.shape[0] == A_G_prime.shape[0]
        
        self.A_G_prime = np.nan_to_num(A_G_prime.astype(np.float64))
        self.D_G_prime = np.nan_to_num(D_G_prime.astype(np.float64))
        self.R_G_prime = D_G_prime - A_G_prime
        self.M_prime = A_G_prime.shape[1]
        self.A_G = np.nan_to_num(A_G.astype(np.float64))
        self.D_G = np.nan_to_num(D_G.astype(np.float64))
        self.R_G = D_G - A_G
        self.M = A_G.shape[1] 
        self.N_G = A_G.shape[0]
        
        self.I_G_prime = I_G_prime.astype(np.int64)
        self.clones, self.f = np.unique(I_G_prime, return_counts=True)
        self.K = self.clones.size
        self.f = self.f / self.M_prime
        assert np.isclose(self.f.sum(), 1)
        self.I_G = sps.rv_discrete(
           a=0, b=self.K, 
           values=[np.arange(self.K), self.f]
        ).rvs(size=self.M)
        
        colsum_fn = lambda mx: np.sum(mx, axis=1)
        self.D_C_prime = self._group_by_clone(self.D_G_prime, colsum_fn)
        self.A_C_prime = self._group_by_clone(self.A_G_prime, colsum_fn)
        
        self.CNV_prime = CNV_prime
        self.T_max = T_max
        self.CNV_prime[self.CNV_prime > T_max] = T_max
        self.tau = np.concatenate([[(t - k, k) for k in range(t + 1)] 
                                   for t in range(1, T_max +1)])
        self.conf_to_num = {tuple(cnv_config) : i 
                            for i, cnv_config in enumerate(self.tau)}
        self.num_to_conf = {val : key 
                            for key, val in self.conf_to_num.items()} 
        self.T = XClone._init_T(
            A_C_prime=self.A_C_prime, 
            D_C_prime=self.D_C_prime, 
            CNV_prime=self.CNV_prime, 
            N_G=self.N_G, 
            K=self.K
        )
        
        self.Alpha_G, self.Beta_G = XClone._init_alpha_beta(self.N_G, self.tau)
        self.Theta_G = sps.beta(a=self.Alpha_G, b=self.Beta_G).rvs(
            size=(self.Alpha_G.shape)
        )
        
        self.H_G_prime, self.X_G_prime = XClone._init_H_X(
            N=self.N_G, 
            M=self.M_prime, 
            I=self.I_G_prime,
            T=self.T, 
            Theta_G=self.Theta_G
        )
        
        self.H_C_prime, self.X_C_prime = XClone._init_H_X(
            N=self.N_G, 
            M=self.K, 
            I=np.arange(self.K),
            T=self.T, 
            Theta_G=self.Theta_G
        )
    
        self.H_G, self.X_G = XClone._init_H_X(
            N=self.N_G, 
            M=self.M, 
            I=self.I_G,
            T=self.T, 
            Theta_G=self.Theta_G
        )
        
        self.iter_count = 0
        
        self.bincoeff_prime = logfact(self.D_G_prime) - logfact(self.A_G_prime) - logfact(self.R_G_prime)
        self.bincoeff = logfact(self.D_G) - logfact(self.A_G) - logfact(self.R_G)
    
        
    def _group_by_clone(self, mx, agg_fn):
        return np.column_stack(
            agg_fn(mx[:, self.I_G_prime == k])
            for k in self.clones
        )
    
    @staticmethod
    @numba.jit(nopython=True)
    def _init_T(A_C_prime, D_C_prime, CNV_prime, N_G, K):
        T = np.full((N_G, K), np.nan, dtype=np.float64)
        for clone_id in range(K):
            for block_id in range(N_G):
                t = CNV_prime[block_id, clone_id]
                ad = A_C_prime[block_id, clone_id]
                dp = D_C_prime[block_id, clone_id]
                if dp == 0 or dp is np.nan:
                    continue
                ase_ratio = ad / dp
                offset = np.argmin(np.abs(np.arange(t + 1) / t - ase_ratio))
                T[block_id, clone_id] = (t * (t + 1) // 2 - 1) + offset
        return T
    
    @staticmethod
    @numba.jit(nopython=True)
    def _init_alpha_beta(N_G, tau):
        # ASE ratios are stored as (\alpha, \beta) parameter tuples
        # of the underlying Beta distributions
        Alpha_G = np.zeros(shape=(N_G, tau.shape[0]), dtype=np.float64)
        Beta_G = np.zeros(shape=(N_G, tau.shape[0]), dtype=np.float64)

        eps = 1
        for config_id in range(tau.shape[0]):
            for block_id in range(N_G):
                
                k0, k1 = tau[config_id]
                t = k0 + k1

                if k0 == 0:
                    alpha, beta = 1, 1 + eps
                elif k1 == 0:
                    alpha, beta = 1 + eps, 1
                else:
                    if k1 > k0:
                        alpha = 1 + eps
                        beta = k1 / k0 * alpha + (k0 - k1) / k0
                    else:
                        beta = 1 + eps
                        alpha = k0 / k1 * beta + (k1 - k0) / k1
    #                 assert np.isclose(beta_mode(alpha, beta), k0 / t)

    #             assert alpha >= 1 and beta >= 1
                config_id = (t * (t + 1) // 2 - 1) + k1
                Alpha_G[block_id, config_id] = alpha
                Beta_G[block_id, config_id] = beta
        return Alpha_G, Beta_G
    
    @staticmethod
    @numba.jit(nopython=True)
    def _init_H_X(N, M, T, I, Theta_G):
        H = np.full((N, M), np.nan, dtype=np.float64)
        X = np.full((N, M), np.nan, dtype=np.float64)
        
        for cell_id in range(M):
            for block_id in range(N):
                H[block_id, cell_id] = T[block_id, I[cell_id]]
                if ~np.isnan(H[block_id, cell_id]):
                    X[block_id, cell_id] = Theta_G[
                        block_id, 
                        int(H[block_id, cell_id])
                    ]
        return H, X
        
    @staticmethod
    def _predict_cell_label(cell_id, A, D, X_C_prime, f):
        K = X_C_prime.shape[1]    
        logprobas = np.array([
            cell_loglikelihood(
                cell_id, 
                A[:, cell_id], 
                D[:, cell_id], 
                X_C_prime[:, clone_id]
            ) 
            + np.log(f[clone_id])
            for clone_id in range(K)
        ])
        return sps.rv_discrete(
           a=0, b=K, 
           values=[np.arange(K), 
                   np.abs(logprobas) / np.nansum(np.abs(logprobas))]
        ).rvs()

    @staticmethod
    def _update_I_G(A, D, X_C_prime, f):
        return np.array([
            XClone._predict_cell_label(cell_id, A, D, X_C_prime, f)
            for cell_id in range(A.shape[1])
        ])


    @staticmethod
    @numba.jit(nopython=True, parallel=True)
    def _update_alpha_beta(tau, Theta_G, Alpha_G, Beta_G, A_G, D_G, H_G, changed_mask):
        N_G = Alpha_G.shape[0]
        new_Alpha_G = np.full_like(Alpha_G, np.nan)
        new_Beta_G = np.full_like(Beta_G, np.nan)
        
        A_changed = A_G[:, changed_mask]
        D_changed = D_G[:, changed_mask]
        H_changed = D_G[:, changed_mask]
        
        for cnv_code in range(tau.shape[0]):
            h_mask = H_changed == cnv_code

            alpha = Alpha_G[:, cnv_code]
            beta = Alpha_G[:, cnv_code]

            u = np.sum(A_changed * h_mask, axis=1)
            v = np.sum((D_changed - A_changed) * h_mask, axis=1)

            new_Alpha_G[:, cnv_code] = alpha + u
            new_Beta_G[:, cnv_code] = beta + v
                
        return new_Alpha_G, new_Beta_G
    
    
    def do_gibbs_sampling(self, n_iters):
        t0 = time()
        self.best_loglik = total_loglikelihood(
            self.A_G_prime, self.R_G_prime, self.X_G_prime, self.bincoeff_prime,
            self.A_G, self.R_G, self.X_G, self.bincoeff
        )
        print(f"Init loglik time: {time() - t0}")
        for iter_count in tqdm_notebook(range(n_iters)):
#             t0 = time()
            new_I_G = XClone._update_I_G(
                A=self.A_G,
                D=self.D_G,
                X_C_prime=self.X_C_prime,
                f=self.f
            )
            changed_mask = new_I_G != self.I_G
#             print(f"Label sampling time: {time() - t0}")
#             t0 = time()
            new_Alpha_G, new_Beta_G = XClone._update_alpha_beta(
                tau=self.tau,
                Theta_G=self.Theta_G,
                Alpha_G=self.Alpha_G,
                Beta_G=self.Beta_G,
                A_G=self.A_G,
                D_G=self.D_G,
                H_G=self.H_G,
                changed_mask=changed_mask
            )
            new_Theta_G = sps.beta(a=new_Alpha_G, b=new_Beta_G).rvs(
                size=self.Theta_G.shape
            )
#             print(f"Posterior update time: {time() - t0}")
#             t0 = time()
            new_H_G_prime, new_X_G_prime = XClone._init_H_X(
                N=self.N_G, 
                M=self.M_prime, 
                I=self.I_G_prime,
                T=self.T, 
                Theta_G=new_Theta_G
            )

            new_H_G, new_X_G = XClone._init_H_X(
                N=self.N_G, 
                M=self.M, 
                I=new_I_G,
                T=self.T, 
                Theta_G=new_Theta_G
            )
#             print(f"Update time: {time() - t0}")
#             t0 = time()
            new_loglik = total_loglikelihood(
                self.A_G_prime, self.R_G_prime, new_X_G_prime, self.bincoeff_prime,
                self.A_G, self.R_G, new_X_G, self.bincoeff
            )
    
            if new_loglik > self.best_loglik:
                print(f"Iteration {iter_count} — labelling update!")
                print(self.best_loglik, new_loglik)
                
                self.best_loglik = new_loglik
                self.Alpha_G = new_Alpha_G
                self.Beta_G = new_Beta_G
                self.Theta_G = new_Theta_G
                
                self.I_G = new_I_G
                self.H_G = new_H_G
                self.X_G = new_X_G
                
                self.H_G_prime = new_H_G_prime
                self.X_G_prime = new_X_G_prime
                
                self.H_C_prime, self.X_C_prime = XClone._init_H_X(
                    N=self.N_G, 
                    M=self.K, 
                    I=np.arange(self.K),
                    T=self.T, 
                    Theta_G=self.Theta_G
                )
        
                classification_report(
                    data["scRNA"]["clustering"].LABEL.apply(lambda i: self.I_G[i] + 1),#I_G + 1,
                    title=f"XClone label assignment, iteration {self.iter_count},"\
                            f" negloglikelihood {np.abs(self.best_loglik)}, "
                            f"{round(100 * np.mean(changed_mask), 2) }% cells reassigned on last iter\n"\
                            "evo_dist_9 clustering of scDNA, "\
                            "seurat clustering of scRNA, "\
                            f"{self.N_G} haplotype blocks",
                    outfile=f"{workspace['scRNA'].img_dir}/xclone/16_10_2019/"\
                            f"{np.abs(self.best_loglik)}_{changed_mask.mean()}_{self.iter_count}.png",
                    show=False
                )
            
            self.iter_count += 1


T_max = 5
I_G_prime = data["scDNA"]["clustering"]["LABEL"].astype(int).values - 1
%time xclone = XClone(A_G_prime, D_G_prime, A_G, D_G, I_G_prime, CNV_prime, T_max)
# %time xclone.do_gibbs_sampling_iter()

CPU times: user 2.34 s, sys: 0 ns, total: 2.34 s
Wall time: 2.33 s


In [None]:
%time xclone.do_gibbs_sampling(10000)

Init loglik time: 1.2457890510559082


HBox(children=(IntProgress(value=0, max=10000), HTML(value='')))

Iteration 0 — labelling update!
-6951815.771568522 -6809723.779732377
Iteration 1 — labelling update!
-6809723.779732377 -6638900.113386817
Iteration 2 — labelling update!
-6638900.113386817 -6382111.483270387
Iteration 4 — labelling update!
-6382111.483270387 -6311827.888932517
Iteration 5 — labelling update!
-6311827.888932517 -6308985.006400545
Iteration 6 — labelling update!
-6308985.006400545 -6291860.285992733
Iteration 7 — labelling update!
-6291860.285992733 -6248854.026708256
Iteration 9 — labelling update!
-6248854.026708256 -6229739.118621012
Iteration 11 — labelling update!
-6229739.118621012 -6194478.171499716
Iteration 18 — labelling update!
-6194478.171499716 -6186882.550636066
Iteration 23 — labelling update!
-6186882.550636066 -6171777.392886904
Iteration 26 — labelling update!
-6171777.392886904 -6140420.936226614
Iteration 31 — labelling update!
-6140420.936226614 -6132372.342925662
Iteration 35 — labelling update!
-6132372.342925662 -6113477.518528445
Iteration 51 —

In [31]:
def total_loglikelihood(X_G_prime, X_G):
    return (
        np.sum([
            cell_loglikelihood(cell_id, A_G_prime[:, cell_id], D_G_prime[:, cell_id], X_G_prime[:, cell_id])
            for cell_id in range(M)#tqdm_notebook(range(M_prime), "scDNA: computing log L")
        ]) + np.sum([
            cell_loglikelihood(cell_id, A_G[:, cell_id], D_G[:, cell_id], X_G[:, cell_id])
            for cell_id in range(M)#tqdm_notebook(range(M), "scRNA: computing log L")
        ])   
    )

In [33]:
prior_total_loglikelihood = total_loglikelihood(X_G_prime, X_G)
print("Initial prior total loglikelihood:\t", prior_total_loglikelihood)
for block_id, cnv_code in tqdm_notebook(cartesian(range(N_G), range(tau.shape[0])), "computing posterior"):
    a_prime = A_G_prime[block_id, :]
    d_prime = D_G_prime[block_id, :]
    h_prime_mask = H_G_prime[block_id, :] == cnv_code
    
    a = A_G[block_id, :]
    d = D_G[block_id, :]
    h_mask = H_G[block_id, :] == cnv_code
    
    u = np.nansum(np.hstack((a_prime * h_prime_mask, a * h_mask)))  
    v = np.nansum(np.hstack(((d_prime - a_prime) * h_prime_mask, (d - a) * h_mask)))

    assert u >= 0, v >= 0
    Alpha_G[block_id, cnv_code] += u
    Beta_G[block_id, cnv_code] += v

Theta_G = sps.beta(a=Alpha_G, b=Beta_G).rvs(size=Alpha_G.shape)  
    
for cell_id in tqdm_notebook(range(M_prime)):
    not_na_mask = ~np.isnan(H_G_prime[:, cell_id])
    X_G_prime[not_na_mask, cell_id] = Theta_G[not_na_mask, 
                                              H_G_prime[not_na_mask, cell_id].astype(int)]

for cell_id in tqdm_notebook(range(M)):
    not_na_mask = ~np.isnan(H_G[:, cell_id])
    X_G[not_na_mask, cell_id] = Theta_G[not_na_mask, 
                                        H_G[not_na_mask, cell_id].astype(int)]
    
posterior_total_loglikelihood = total_loglikelihood(X_G_prime, X_G)
print("Initial posterior total loglikelihood:\t", posterior_total_loglikelihood)
# assert posterior_total_loglikelihood > prior_total_loglikelihood

Initial prior total loglikelihood:	 -400198.14423451526


  import sys


HBox(children=(IntProgress(value=1, bar_style='info', description='computing posterior', max=1, style=Progress…

HBox(children=(IntProgress(value=0, max=268), HTML(value='')))

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

Initial posterior total loglikelihood:	 -374280.8431529163


In [None]:
def update_alpha_beta(block_id, cnv_code, Theta_G, Alpha_G, Beta_G, H_G, changed_mask):
    a = A_G[block_id, changed_mask]
    d = D_G[block_id, changed_mask]
    h_mask = H_G[block_id, changed_mask] == cnv_code
    
    u = np.nansum(a * h_mask)
    v = np.nansum((d - a) * h_mask)
    assert u >= 0, v >= 0
    
    new_alpha = Alpha_G[block_id, cnv_code] + u
    new_beta = Beta_G[block_id, cnv_code] + v
    return new_alpha, new_beta

Alpha_G_prior = Alpha_G.copy()
Beta_G_prior = Beta_G.copy()

print(posterior_total_loglikelihood)

for ITER in tqdm_notebook(range(10000), "sampling iteration"):
    new_I_G = np.array([update_I_G(cell_id, Theta_G) for cell_id in range(M)])#tqdm_notebook(range(M), "updating I_G")])
#     new_I_G = \
#        np.array(Parallel(16)(
#            delayed(update_I_G)(cell_id, Theta_G)
#            for cell_id in tqdm_notebook(range(M), "updating I_G")
#        ))
    
    changed_mask = new_I_G != I_G
    print("% changed:\t", 100 * np.mean(changed_mask))
    
    print(Counter(new_I_G + 1))
#     plt.hist(new_I_G, bins=np.arange(K+1))
#     plt.show()
    
    new_alpha_beta = np.array([
        update_alpha_beta(block_id, cnv_code, Theta_G, Alpha_G, Beta_G, H_G, changed_mask)
        for block_id, cnv_code in cartesian(
            range(N_G), range(tau.shape[0])
        )#, "computing posterior")
    ])
#     \
#         np.vstack(Parallel(16)(
#             delayed(update_alpha_beta)(block_id, cnv_code, Theta_G, Alpha_G, Beta_G, H_G, changed_mask)
#             for block_id, cnv_code in tqdm_notebook(cartesian(
#                 range(N_G), range(tau.shape[0])
#             ), "computing posterior")
#         ))
    
    new_Alpha_G = new_alpha_beta[:, 0].reshape(Alpha_G.shape)
    new_Beta_G = new_alpha_beta[:, 1].reshape(Alpha_G.shape)
    new_Theta_G = sps.beta(a=new_Alpha_G, b=new_Beta_G).rvs(size=Alpha_G.shape)
    
#     np.vstack(update_alpha_beta(block_id, cnv_code)
#         for block_id, cnv_code in tqdm_notebook(cartesian(
#             range(N_G), range(tau.shape[0])
#         ), "computing| posterior"))


    new_H_G = H_G.copy()
    new_X_G = X_G.copy()
    for cell_id in range(M):#tqdm_notebook(range(M), "updating H_G and X_G"):
        new_H_G[:, cell_id] = T[:, new_I_G[cell_id]]
        not_na_mask = ~np.isnan(new_H_G[:, cell_id])
        new_X_G[not_na_mask, cell_id] = new_Theta_G[
            not_na_mask, 
            new_H_G[not_na_mask, cell_id].astype(int)
        ]
        
    curr_loglikelihood = total_loglikelihood(X_G_prime, new_X_G)
 
    print(f"Iter {ITER}:\t loglikelihood of {curr_loglikelihood}")
    
    if curr_loglikelihood > posterior_total_loglikelihood:
        print("UPDATE ASSIGNMENT")
        posterior_total_loglikelihood = curr_loglikelihood
        I_G = new_I_G.copy()
        Alpha_G = new_Alpha_G.copy()
        Beta_G = new_Beta_G.copy()
        Theta_G = new_Theta_G.copy()
    #     Theta_G[Theta_G == 0] = 0.01
    #     Theta_G[Theta_G == 1] = 0.99
#         print("Plotting report")
        classification_report(
            data["scRNA"]["clustering"].LABEL.apply(lambda i: I_G[i] + 1),#I_G + 1,
            title=f"XClone label assignment, iteration {ITER},"\
                    f" loglikelihood {posterior_total_loglikelihood} \n"\
                    "evo_dist_9 clustering of scDNA,\n"\
                    "seurat clustering of scRNA,\n"\
                    f"{N_G} haplotype blocks",
            outfile=f"{workspace['scRNA'].img_dir}/xclone/"\
                    f"iter_{ITER}_loglikelihood_{posterior_total_loglikelihood}.png"
        )
#         plt.show()

-374280.8431529163


HBox(children=(IntProgress(value=0, description='sampling iteration', max=10000, style=ProgressStyle(descripti…

% changed:	 90.0
Counter({9: 2, 8: 2, 6: 2, 5: 1, 7: 1, 1: 1, 2: 1})


  import sys


Iter 0:	 loglikelihood of -382604.71953687764
% changed:	 100.0
Counter({3: 2, 1: 2, 8: 1, 2: 1, 5: 1, 9: 1, 4: 1, 6: 1})
Iter 1:	 loglikelihood of -384216.7253803791
% changed:	 90.0
Counter({1: 2, 5: 2, 6: 2, 2: 1, 7: 1, 8: 1, 9: 1})
Iter 2:	 loglikelihood of -385023.0907452116
% changed:	 80.0
Counter({6: 3, 7: 2, 5: 2, 2: 1, 3: 1, 1: 1})
Iter 3:	 loglikelihood of -374947.07764005905
% changed:	 90.0
Counter({3: 3, 5: 2, 7: 1, 1: 1, 9: 1, 6: 1, 2: 1})
Iter 4:	 loglikelihood of -373442.07391015114
UPDATE ASSIGNMENT
% changed:	 80.0
Counter({5: 3, 7: 3, 8: 2, 9: 1, 2: 1})
Iter 5:	 loglikelihood of -381789.09300347936
% changed:	 90.0
Counter({5: 4, 2: 4, 6: 1, 8: 1})
Iter 6:	 loglikelihood of -397017.06126588525
% changed:	 100.0
Counter({2: 2, 8: 2, 9: 2, 3: 1, 5: 1, 7: 1, 6: 1})
Iter 7:	 loglikelihood of -372113.6197365558
UPDATE ASSIGNMENT
% changed:	 100.0
Counter({6: 3, 5: 2, 8: 2, 4: 1, 3: 1, 1: 1})
Iter 8:	 loglikelihood of -387348.3458155921
% changed:	 100.0
Counter({2: 2, 7:

% changed:	 80.0
Counter({6: 3, 1: 2, 3: 2, 9: 1, 7: 1, 8: 1})
Iter 74:	 loglikelihood of -350785.98695306387
% changed:	 80.0
Counter({5: 2, 8: 2, 1: 1, 4: 1, 7: 1, 3: 1, 9: 1, 6: 1})
Iter 75:	 loglikelihood of -362701.4125440626
% changed:	 100.0
Counter({2: 3, 1: 3, 5: 2, 8: 1, 3: 1})
Iter 76:	 loglikelihood of -351877.7436479023
% changed:	 80.0
Counter({6: 2, 8: 2, 1: 2, 5: 1, 7: 1, 2: 1, 9: 1})
Iter 77:	 loglikelihood of -360522.3832776577
% changed:	 100.0
Counter({8: 2, 2: 2, 5: 2, 9: 1, 6: 1, 4: 1, 1: 1})
Iter 78:	 loglikelihood of -366165.3069774543
% changed:	 90.0
Counter({1: 2, 3: 2, 7: 1, 5: 1, 2: 1, 9: 1, 6: 1, 8: 1})
Iter 79:	 loglikelihood of -356649.6772357491
% changed:	 70.0
Counter({7: 6, 2: 1, 4: 1, 3: 1, 5: 1})
Iter 80:	 loglikelihood of -343752.3162429412
% changed:	 80.0
Counter({1: 4, 5: 3, 3: 1, 9: 1, 6: 1})
Iter 81:	 loglikelihood of -389668.9657983347
% changed:	 90.0
Counter({3: 3, 5: 2, 1: 2, 6: 1, 8: 1, 9: 1})
Iter 82:	 loglikelihood of -351862.236285221