In [1]:
%load_ext autoreload
%autoreload 2

import os
import subprocess
import sys
import warnings

warnings.simplefilter("ignore", FutureWarning)

%aimport cnv_inference_config
project_config = cnv_inference_config
os.chdir(project_config.MB_ROOT)

from collections import Counter, defaultdict, OrderedDict
import itertools
from itertools import product as cartesian
import multiprocessing as mp
import pickle

import numba
from joblib import Parallel, delayed
import numpy as np
import pandas as pd
import scipy as sp
import scipy.stats as sps
from tqdm import tqdm, tqdm_notebook

import matplotlib.pyplot as plt
import seaborn as sns

import toolkit
import util
from workspace.workspace_manager import WorkspaceManager

sns.set()

workspace = {}
for data_type in ["scDNA", "scRNA"]:
    workspace.update({ 
        data_type : WorkspaceManager(
            task_name="ase_to_cnv",
            experiment_info={"data" : data_type},
            verbose=True
        )
    })
    workspace[data_type].load_workspace()

In [2]:
data = {}
for data_type in ["scDNA", "scRNA"]:
    data[data_type] = {
        data_name :
        util.pickle_load(data_dump)
        for data_name, data_dump in tqdm_notebook(
            workspace[data_type].tmp_data.items(),
            f"{data_type}, loading datasets into RAM"
        )
    }
    
for modality in ["scDNA", "scRNA"]:
    data[modality]["block_counts"].rename(
        columns={"GENE_ID" : "BLOCK_ID"}, 
        inplace=True
    )

HBox(children=(IntProgress(value=0, description='scDNA, loading datasets into RAM', max=2, style=ProgressStyle…




HBox(children=(IntProgress(value=0, description='scRNA, loading datasets into RAM', max=2, style=ProgressStyle…




In [3]:
common_blocks = (set(data["scDNA"]["block_counts"].BLOCK_ID) 
                & set(data["scRNA"]["block_counts"].BLOCK_ID))

print("Datasets have {} blocks in common".format(len(common_blocks)))

for modality in workspace.keys():
    data[modality]["block_counts"] = util.filter_by_isin(
        data[modality]["block_counts"], 
        "BLOCK_ID", 
        common_blocks
    ).reset_index(drop=True)

common_block_ids = data["scDNA"]["block_counts"]["BLOCK_ID"].values.astype(int)
    
assert (data["scDNA"]["block_counts"].shape[0] 
        == data["scRNA"]["block_counts"].shape[0])

Datasets have 8776 blocks in common


In [4]:
D_G_prime = toolkit.extract_counts(data["scDNA"]["block_counts"]).values
D_G = toolkit.extract_counts(
    toolkit.aggregate_by_barcode_groups(
        data["scRNA"]["block_counts"],
        data["scRNA"]["clustering"]
    )
).values

A_G_prime = toolkit.extract_counts(
    data["scDNA"]["block_counts"], 
    suffix="ad"
).values
A_G = toolkit.extract_counts(
    toolkit.aggregate_by_barcode_groups(
        data["scRNA"]["block_counts"], 
        data["scRNA"]["clustering"]
    ),
    suffix="ad"
).values

In [5]:
rodata_dir = "/icgc/dkfzlsdf/analysis/B260/users/v390v/cnv_inference/data/raw/first_sample"
CNV_prime = []
for i in tqdm_notebook(range(9), "reading CNV information"):
    snp_cnv = pd.read_csv(
        f"{rodata_dir}/TabHaplotypeblock_with_phasedSNPs_{i + 1}.bed", 
        usecols=[0, 1, 2, 3],
        names=["CHROM", "START", "END", "CNV"],
        sep='\t'
    )
    snp_cnv.drop_duplicates(inplace=True)
    print(snp_cnv.shape, np.median(snp_cnv["CNV"]))
    CNV_prime.append(snp_cnv["CNV"].values.astype(int))
CNV_prime = np.column_stack(CNV_prime)

HBox(children=(IntProgress(value=0, description='reading CNV information', max=9, style=ProgressStyle(descript…

(18011, 4) 2.0
(18011, 4) 2.0
(18011, 4) 2.0
(18011, 4) 2.0
(18011, 4) 4.0
(18011, 4) 2.0
(18011, 4) 2.0
(18011, 4) 2.0
(18011, 4) 2.0



In [17]:
def classification_report(labels, title, outfile=None, show=True):
                
        sns.set(style="whitegrid", font_scale=1.5);
        fig, ax = plt.subplots(2, 1, figsize=(20,30))
        ax[0].set_title("Cluster label assigned by XClone", fontsize=20)
        
        palette =  {
            1 : "xkcd:orange",
            2 : "xkcd:azure",
            3 : "xkcd:cyan",
            4 : "xkcd:yellow",
            5 : "xkcd:blue",
            6 : "xkcd:red",
            7 : "xkcd:pink",
            8 : "xkcd:grey",
            9 : "xkcd:black"
        }
        
        sns.countplot(
            labels, 
            palette=palette,
            ax=ax[0]
        )

        ax[1].set_title(title)
        
        sns.scatterplot(
            x="TSNE_1", y="TSNE_2", 
            hue=labels, 
            data=data["scRNA"]["clustering"], 
            legend="full",
            palette=palette,
            ax=ax[1]
        );
        ax[1].legend().get_frame().set_facecolor("white");
        ax[1].legend(frameon=False, bbox_to_anchor=(1,0.5), loc="center left")
        fig.subplots_adjust(right=0.75)
        if outfile is not None:
            fig.savefig(outfile, format=outfile.split('.')[-1], dpi=300)
        if show == True:
            plt.show()
        plt.close()
        plt.clf()

In [20]:
%%time
from classification.models.xclone import XClone
import classification.models.xclone_routines as xclone_routines

T_max = 5
I_G_prime = data["scDNA"]["clustering"]["LABEL"].astype(int).values - 1
xclone = XClone(
    A_G_prime=A_G_prime, 
    D_G_prime=D_G_prime, 
    I_G_prime=I_G_prime,
    CNV_prime=CNV_prime, T_max=T_max,
    A_G=A_G, 
    D_G=D_G, 
    report_dir="/icgc/dkfzlsdf/analysis/B260/users/v390v/"\
                "cnv_inference/data/tmp/xclone/22_10_2019/evening_run",
    verbose=True
)

CPU times: user 7.01 s, sys: 743 ms, total: 7.76 s
Wall time: 4.19 s


In [22]:
%%time
callback = lambda xclone_instance: classification_report(
        data["scRNA"]["clustering"]["LABEL"].apply(
            lambda i: xclone_instance._params.I_G[i] + 1
        ),
        title=f"Refactored XClone, iteration {xclone_instance._iter_count},"\
                f" {round(100 * xclone_instance._changed_mask.mean())}",
        outfile="/icgc/dkfzlsdf/analysis/B260/users/v390v/"\
                "cnv_inference/data/tmp/xclone/22_10_2019/evening_run/img/"\
                f"{xclone_instance._iter_count}.png",
        show=False
    )
xclone.fit(10000, callback)

Iteration 217 — labelling update!-4481883.73 --> -4450712.03, 70% labels reassigned
Iteration 2914 — labelling update!-4450712.03 --> -4448228.10, 80% labels reassigned
Iteration 3049 — labelling update!-4448228.10 --> -4442557.80, 90% labels reassigned
Iteration 8576 — labelling update!-4442557.80 --> -4433074.46, 70% labels reassigned
Iteration 9746 — labelling update!-4433074.46 --> -4426099.07, 80% labels reassigned
CPU times: user 2h 40min 12s, sys: 16min 15s, total: 2h 56min 27s
Wall time: 2h 28min 27s


<Figure size 432x288 with 0 Axes>