# Train model from dominiguez immune atlas in myeloid compartment

# LR multi-tissue cross-comparison

##### Ver:: A1_V5
##### Author(s) : Issac Goh
##### Date : 220823;YYMMDD
### Author notes
    - Current defaults scrpae data from web, so leave as default and run
    - slices model and anndata to same feature shape, scales anndata object
    - added some simple benchmarking
    - creates dynamic cutoffs for probability score (x*sd of mean) in place of more memory intensive confidence scoring
    - Does not have majority voting set on as default, but module does exist
    - Multinomial logistic relies on the (not always realistic) assumption of independence of irrelevant alternatives whereas a series of binary logistic predictions does not. collinearity is assumed to be relatively low, as it becomes difficult to differentiate between the impact of several variables if this is not the case
    
### Features to add
    - Add ability to consume anndata zar format for sequential learning
### Modes to run in
    - Run in training mode
    - Run in projection mode

In [None]:
import sys
import subprocess

# import pkg_resources
# required = {'harmonypy','sklearn','scanpy','pandas', 'numpy', 'scipy', 'matplotlib', 'seaborn' ,'scipy'}
# installed = {pkg.key for pkg in pkg_resources.working_set}
# missing = required - installed
# if missing:
#    print("Installing missing packages:" )
#    print(missing)
#    python = sys.executable
#    subprocess.check_call([python, '-m', 'pip', 'install', *missing], stdout=subprocess.DEVNULL)

from collections import Counter
from collections import defaultdict
import scanpy as sc
import pandas as pd
import pickle as pkl
import numpy as np
import scipy
import matplotlib.pyplot as plt
import re
import glob
import os
import sys
#from geosketch import gs
from numpy import cov
import scipy.cluster.hierarchy as spc
import seaborn as sns; sns.set(color_codes=True)
from sklearn.linear_model import LogisticRegression
import sklearn
from pathlib import Path
import requests
import psutil
import random
import threading
import tracemalloc
import itertools
import math
import warnings
import sklearn.metrics as metrics

# Train ABM model & validate, project onto adult Mye

In [None]:
models = {
'pan_fetal':'/nfs/team205/ig7/resources/scripts_dont_modify/logit_regression_models/adifa_lr/celltypist_model.Pan_Fetal_Human.pkl',
'pan_fetal_wget':'https://celltypist.cog.sanger.ac.uk/models/Pan_Fetal_Suo/v2/Pan_Fetal_Human.pkl',
'adata_scvi':'/nfs/team205/ig7/mount/gdrive/g_cloud/projects/amniontic_fluid/scvi_low_dim_model.sav',
'adata_ldvae':'/nfs/team205/ig7/mount/gdrive/g_cloud/projects/amniontic_fluid/ldvae_low_dim_model.sav',
'adata_harmony':'/nfs/team205/ig7/work_backups/backup_210306/projects/amiotic_fluid/train_low_dim_model/organ_low_dim_model.sav',
'test_low_dim_ipsc_ys':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_030522_notebooks/Integrating_HM_data_030522/YS_logit/lr_model.sav',
'YS_X':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/resources/YS_X_model_080922.sav',
'YS_X_V3':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/train_YS_full_X_model/YS_X_A2_V12_lvl3_ELASTICNET_YS.sav',
'SK_model':'/nfs/team205/ig7/resources/scripts_dont_modify/logit_regression_models/LR_app_format/hudaa_skin/for_hudaa_A1_V2',
'Hudaa_model_trained':'/nfs/team298/hg6/Fetal_skin/LR_15012023/train-all_model.pkl',
'A1_V1_LUNG_MYE_model_IG':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/LR_transfer_Lung_brain/A1_V1_LUNG_MYE_model_IG',
'A1_V1_immuneatlas_model_IG':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/LR_transfer_Lung_brain/A1_V1_immuneatlas_model_IG'
}

adatas_dict = {
'human_lung_mye':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_data/Theiss_lung_atlas/Myeloid_subset_HLCA_lung_processed.h5ad',
'Fetal_skin_raw': '/nfs/team298/hg6/Fetal_skin/data/FS_raw_sub.h5ad',
'vascular_organoid': '/nfs/team298/hg6/Fetal_skin/data/vasc_org_raw.h5ad',
'YS':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_data/Submission_2_data/A2_V5_scvi_YS_integrated/A2_V5_scvi_YS_integrated_raw_qc_scr_umap.h5ad',
'YS_test':'/nfs/team205/ig7/resources/scripts_dont_modify/logit_regression_models/LR_app_format/ys_test_data.h5ad',
'YS_A2_V10_X_raw':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_data/Submission_2_data/A2_V10_scvi_YS_integrated/A2_V10_raw_counts_full_no_obs.h5ad',
'YS_A2_V10_X':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_data/Submission_2_data/A2_V10_scvi_YS_integrated/A2_V10_qc_raw.h5ad',
'ABM':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/LR_transfer_Lung_brain/ABM_re_constructed_IG.h5ad',
'mye_immune_atlas':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_data/tissue_immune_atlas_dominiguez/CountAdded_PIP_myeloid_object_for_cellxgene.h5ad',
'global_immune_atlas':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_data/tissue_immune_atlas_dominiguez/CountAdded_PIP_global_object_for_cellxgene.h5ad',
'pan_organ_mye':'/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/A1_V1_LING_ADULT_IG_annot.h5ad',
}

# Variable assignment
train_model = True
feat_use = 'Manually_curated_celltype'
adata_key = 'global_immune_atlas'#'fliv_wget_test' # key for dictionary entry containing local or web path to adata/s can be either url or local 
data_merge = False # read and merge multiple adata (useful, but keep false for now)
model_key = 'A1_V1_immuneatlas_model_IG'#'test_low_dim_ipsc_ys'# key for model of choice can be either url or local 
train_x_partition = 'X' # what partition was the data trained on? To keep simple, for now only accepts 'X'
dyn_std = 1.96 # Dynamic cutoffs using std of the mean for each celltype probability, gives a column notifying user of uncertain labels 1 == 68Ci, 1.96 = 95CI
freq_redist = 'Manually_curated_celltype'#'cell.labels'#'False#'cell.labels'#False # False or key of column in anndata object which contains labels/clusters // not currently implemented
partial_scale = False # should data be scaled in batches?
QC_normalise = True # should data be normalised?

# training variables
penalty='elasticnet' # can be ["l1","l2","elasticnet"]
sparcity=0.5 # C penalty for degree of regularisation
thread_num = -1
l1_ratio = 0.5 # ratio between L1 and L2 regulrisatiuon depending on penatly method

#

# If low dim & not in keys
batch_key = 'Donor'
batch_correction = False#'Harmony' #or bbknn
theta = 3 #harmony specifc

#Sketch training?
sketch_obsm = None

In [None]:
import h5py
from anndata._io.specs import read_elem

with h5py.File("/nfs/team205/ig7/work_backups/backup_210306/projects/YS/YS_data/tissue_immune_atlas_dominiguez/CountAdded_PIP_global_object_for_cellxgene.h5ad") as f:
    cell_types = read_elem(f["obs"])

In [None]:
list(cell_types['Manually_curated_celltype'].unique())

In [None]:
list(cell_types['Manually_curated_celltype'].unique())

# Partial scaling ver
- scale across 10 mini bulks/every 100,000 cells
- sequential learning for scaling
- sequential application of scaling

In [None]:
from collections import Counter
from collections import defaultdict
import scanpy as sc
import pandas as pd
import pickle as pkl
import numpy as np
import scipy
import matplotlib.pyplot as plt
import re
import glob
import os
import sys
#from geosketch import gs
from numpy import cov
import scipy.cluster.hierarchy as spc
import seaborn as sns; sns.set(color_codes=True)
from sklearn.linear_model import LogisticRegression
import sklearn
from pathlib import Path
import requests
import psutil
import random
import threading
import tracemalloc
import itertools
import math
import warnings
import sklearn.metrics as metrics

def load_models(model_dict,model_run):
    if (Path(model_dict[model_run])).is_file():
        # Load data (deserialize)
        model = pkl.load(open(model_dict[model_run], "rb"))
        return model
    elif 'http' in model_dict[model_run]:
        print('Loading model from web source')
        r_get = requests.get(model_dict[model_run])
        fpath = './model_temp.sav'
        open(fpath , 'wb').write(r_get.content)
        model = pkl.load(open(fpath, "rb"))
        return model

def load_adatas(adatas_dict,data_merge, data_key_use,QC_normalise):
    if data_merge == True:
        # Read
        gene_intersect = {} # unused here
        adatas = {}
        for dataset in adatas_dict.keys():
            if 'https' in adatas_dict[dataset]:
                print('Loading anndata from web source')
                adatas[dataset] = sc.read('./temp_adata.h5ad',backup_url=adatas_dict[dataset])
            adatas[dataset] = sc.read(data[dataset])
            adatas[dataset].var_names_make_unique()
            adatas[dataset].obs['dataset_merge'] = dataset
            adatas[dataset].obs['dataset_merge'] = dataset
            gene_intersect[dataset] = list(adatas[dataset].var.index)
        adata = list(adatas.values())[0].concatenate(list(adatas.values())[1:],join='inner')
        return adatas, adata
    elif data_merge == False:
        if 'https' in adatas_dict[data_key_use]:
            print('Loading anndata from web source')
            adata = sc.read('./temp_adata.h5ad',backup_url=adatas_dict[data_key_use])
        else: 
            adata = sc.read(adatas_dict[data_key_use])
    if QC_normalise == True:
        print('option to apply standardisation to data detected, performing basic QC filtering')
        sc.pp.filter_cells(adata, min_genes=200)
        sc.pp.filter_genes(adata, min_cells=3)
        sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
        sc.pp.log1p(adata)
        
    return adata

# resource usage logger
class DisplayCPU(threading.Thread):
    def run(self):
        tracemalloc.start()
        starting, starting_peak = tracemalloc.get_traced_memory()
        self.running = True
        self.starting = starting
        currentProcess = psutil.Process()
        cpu_pct = []
        peak_cpu = 0
        while self.running:
            peak_cpu = 0
#           time.sleep(3)
#             print('CPU % usage = '+''+ str(currentProcess.cpu_percent(interval=1)))
#             cpu_pct.append(str(currentProcess.cpu_percent(interval=1)))
            cpu = currentProcess.cpu_percent()
        # track the peak utilization of the process
            if cpu > peak_cpu:
                peak_cpu = cpu
                peak_cpu_per_core = peak_cpu/psutil.cpu_count()
        self.peak_cpu = peak_cpu
        self.peak_cpu_per_core = peak_cpu_per_core
        
    def stop(self):
#        cpu_pct = DisplayCPU.run(self)
        self.running = False
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()
        return current, peak
    
# projection module
def reference_projection(adata, model, dyn_std,partial_scale):
    
    class adata_temp:
        pass
    from sklearn.preprocessing import StandardScaler
    print('Determining model flavour')
    try:
        model_lr =  model['Model']
        print('Consuming celltypist model')
    except:# hasattr(model, 'coef_'):
        print('Consuming non-celltypist model')
        model_lr =  model
    print(model_lr)
    
#     model_lr =  model['Model']

    if train_x_partition == 'X':
        print('Matching reference genes in the model')
        k_x = np.isin(list(adata.var.index), list(model_lr.features))
        if k_x.sum() == 0:
            raise ValueError(f"🛑 No features overlap with the model. Please provide gene symbols")
        print(f"🧬 {k_x.sum()} features used for prediction")
        #slicing adata
        k_x_idx = np.where(k_x)[0]
        # adata_temp = adata[:,k_x_idx]
        adata_temp.var = adata[:,k_x_idx].var
        adata_temp.X = adata[:,k_x_idx].X
        adata_temp.obs = adata[:,k_x_idx].obs
        lr_idx = pd.DataFrame(model_lr.features, columns=['features']).reset_index().set_index('features').loc[list(adata_temp.var.index)].values
        # adata_arr = adata_temp.X[:,list(lr_idexes['index'])]

        # slice and reorder model
        ni, fs, cf = model_lr.n_features_in_, model_lr.features, model_lr.coef_
        model_lr.n_features_in_ = lr_idx.size
        model_lr.features = np.array(model_lr.features)[lr_idx]
        model_lr.coef_ = np.squeeze(model_lr.coef_[:,lr_idx]) #model_lr.coef_[:, lr_idx]
        
        if partial_scale == True:
            print('scaling input data, default option is to use incremental learning and fit in mini bulks!')
            # Partial scaling alg
            scaler = StandardScaler(with_mean=False)
            n = adata_temp.X.shape[0]  # number of rows
            # set dyn scale packet size
            x_len = len(adata_temp.var)
            y_len = len(adata.obs)
            if y_len < 100000:
                dyn_pack = int(x_len/10)
                pack_size = dyn_pack
            else:
                # 10 pack for every 100,000
                dyn_pack = int((y_len/100000)*10)
                pack_size = int(x_len/dyn_pack)

            batch_size =  1000#pack_size#500  # number of rows in each call to partial_fit
            index = 0  # helper-var
            while index < n:
                partial_size = min(batch_size, n - index)  # needed because last loop is possibly incomplete
                partial_x = adata_temp.X[index:index+partial_size]
                scaler.partial_fit(partial_x)
                index += partial_size
            adata_temp.X = scaler.transform(adata_temp.X)
    
    # model projections
    print('Starting reference projection!')
    if train_x_partition == 'X':
        train_x = adata_temp.X
        pred_out = pd.DataFrame(model_lr.predict(train_x),columns = ['predicted'],index = list(adata.obs.index))
        proba =  pd.DataFrame(model_lr.predict_proba(train_x),columns = model_lr.classes_,index = list(adata.obs.index))
        pred_out = pred_out.join(proba)
        
    elif train_x_partition in list(adata.obsm.keys()): 
        print('{low_dim: this partition modality is still under development!}')
        train_x = adata.obsm[train_x_partition]
        pred_out = pd.DataFrame(model_lr.predict(train_x),columns = ['predicted'],index = list(adata.obs.index))
        proba =  pd.DataFrame(model_lr.predict_proba(train_x),columns = model_lr.classes_,index = list(adata.obs.index))
        pred_out = pred_out.join(proba)
    
    else:
        print('{this partition modality is still under development!}')
    ## insert modules for low dim below

    # Simple dynamic confidence calling
    pred_out['confident_calls'] = pred_out['predicted']
    pred_out.loc[pred_out.max(axis=1)<(pred_out.mean(axis=1) + (1*pred_out.std(axis=1))),'confident_calls'] = pred_out.loc[pred_out.max(axis=1)<(pred_out.mean(axis=1) + (1*pred_out.std(axis=1))),'confident_calls'].astype(str) + '_uncertain'
    # means_ = self.model.scaler.mean_[lr_idx] if self.model.scaler.with_mean else 0
    return(pred_out,train_x,model_lr,adata_temp)

def freq_redist_68CI(adata,clusters_reassign):
    if freq_redist != False:
        print('Frequency redistribution commencing')
        cluster_prediction = "consensus_clus_prediction"
        lr_predicted_col = 'predicted'
        pred_out[clusters_reassign] = adata.obs[clusters_reassign].astype(str)
        reassign_classes = list(pred_out[clusters_reassign].unique())
        lm = 1 # lambda value
        pred_out[cluster_prediction] = pred_out[clusters_reassign]
        for z in pred_out[clusters_reassign][pred_out[clusters_reassign].isin(reassign_classes)].unique():
            df = pred_out
            df = df[(df[clusters_reassign].isin([z]))]
            df_count = pd.DataFrame(df[lr_predicted_col].value_counts())
            # Look for classificationds > 68CI
            if len(df_count) > 1:
                df_count_temp = df_count[df_count[lr_predicted_col]>int(int(df_count.mean()) + (df_count.std()*lm))]
                if len(df_count_temp >= 1):
                    df_count = df_count_temp
            #print(df_count)     
            freq_arranged = df_count.index
            cat = freq_arranged[0]
        #Make the cluster assignment first
            pred_out[cluster_prediction] = pred_out[cluster_prediction].astype(str)
            pred_out.loc[pred_out[clusters_reassign] == z, [cluster_prediction]] = cat
        # Create assignments for any classification >68CI
            for cats in freq_arranged:
                #print(cats)
                cats_assignment = cats#.replace(data1,'') + '_clus_prediction'
                pred_out.loc[(pred_out[clusters_reassign] == z) & (pred_out[lr_predicted_col] == cats),[cluster_prediction]] = cats_assignment
        min_counts = pd.DataFrame((pred_out[cluster_prediction].value_counts()))
        reassign = list(min_counts.index[min_counts[cluster_prediction]<=2])
        pred_out[cluster_prediction] = pred_out[cluster_prediction].str.replace(str(''.join(reassign)),str(''.join(pred_out.loc[pred_out[clusters_reassign].isin(list(pred_out.loc[(pred_out[cluster_prediction].isin(reassign)),clusters_reassign])),lr_predicted_col].value_counts().head(1).index.values)))
        return pred_out

### Feature importance notes
#- If we increase the x feature one unit, then the prediction will change e to the power of its weight. We can apply this rule to the all weights to find the feature importance.
#- We will calculate the Euler number to the power of its coefficient to find the importance.
#- To sum up an increase of x feature by one unit increases the odds of being versicolor class by a factor of x[importance] when all other features remain the same.

#- For low-dim, we look at the distribution of e^coef per class, we extract the 


# class coef_extract:
#     def __init__(self, model,features, pos):
# #         self.w = list(itertools.chain(*(model.coef_[pos]).tolist())) #model.coef_[pos]
#         self.w = model.coef_[class_pred_pos]
#         self.features = features 

def long_format_features(top_loadings):
    p = top_loadings.loc[:, top_loadings.columns.str.endswith("_e^coef")]
    p = pd.melt(p)
    n = top_loadings.loc[:, top_loadings.columns.str.endswith("_feature")]
    n = pd.melt(n)
    l = top_loadings.loc[:, top_loadings.columns.str.endswith("_coef")]
    l = pd.melt(l)
    n = n.replace(regex=r'_feature', value='')
    n = n.rename(columns={"variable": "class", "value": "feature"})
    p = (p.drop(["variable"],axis = 1)).rename(columns={ "value": "e^coef"})
    l = (l.drop(["variable"],axis = 1)).rename(columns={ "value": "coef"})
    concat = pd.concat([n,p,l],axis=1)
    return concat

def model_feature_sf(long_format_feature_importance, coef_use):
        long_format_feature_importance[str(coef_use) + '_pval'] = 'NaN'
        for class_lw in long_format_feature_importance['class'].unique():
            df_loadings = long_format_feature_importance[long_format_feature_importance['class'].isin([class_lw])]
            comps = coef_use #'e^coef'
            U = np.mean(df_loadings[comps])
            std = np.std(df_loadings[comps])
            med =  np.median(df_loadings[comps])
            mad = np.median(np.absolute(df_loadings[comps] - np.median(df_loadings[comps])))
            # Survival function scaled by 1.4826 of MAD (approx norm)
            pvals = scipy.stats.norm.sf(df_loadings[comps], loc=med, scale=1.4826*mad) # 95% CI of MAD <10,000 samples
            #pvals = scipy.stats.norm.sf(df_loadings[comps], loc=U, scale=1*std)
            df_loadings[str(comps) +'_pval'] = pvals
            long_format_feature_importance.loc[long_format_feature_importance.index.isin(df_loadings.index)] = df_loadings
        long_format_feature_importance['is_significant_sf'] = False
        long_format_feature_importance.loc[long_format_feature_importance[coef_use+ '_pval']<0.05,'is_significant_sf'] = True
        return long_format_feature_importance
# Apply SF to e^coeff mat data
#         pval_mat = pd.DataFrame(columns = mat.columns)
#         for class_lw in mat.index:
#             df_loadings = mat.loc[class_lw]
#             U = np.mean(df_loadings)
#             std = np.std(df_loadings)
#             med =  np.median(df_loadings)
#             mad = np.median(np.absolute(df_loadings - np.median(df_loadings)))
#             pvals = scipy.stats.norm.sf(df_loadings, loc=med, scale=1.96*U)

class estimate_important_features: # This calculates feature effect sizes of the model
    def __init__(self, model, top_n):
        print('Estimating feature importance')
        classes =  list(model.classes_)
         # get feature names
        try:
            model_features = list(itertools.chain(*list(model.features)))
        except:
            warnings.warn('no features recorded in data, naming features by position')
            print('if low-dim lr was submitted, run linear decoding function to obtain true feature set')
            model_features = list(range(0,model.coef_.shape[1]))
            model.features = model_features
        print('Calculating the Euler number to the power of coefficients')
        impt_ = pow(math.e,model.coef_)
        try:
            self.euler_pow_mat = pd.DataFrame(impt_,columns = list(itertools.chain(*list(model.features))),index = list(model.classes_))
        except:
            self.euler_pow_mat = pd.DataFrame(impt_,columns = list(model.features),index = list(model.classes_))
        self.top_n_features = pd.DataFrame(index = list(range(0,top_n)))
        # estimate per class feature importance
        
        print('Estimating feature importance for each class')
        mat = self.euler_pow_mat
        for class_pred_pos in list(range(0,len(mat.T.columns))):
            class_pred = list(mat.T.columns)[class_pred_pos]
            #     print(class_pred)
            temp_mat =  pd.DataFrame(mat.T[class_pred])
            temp_mat['coef'] = model.coef_[class_pred_pos]
            temp_mat = temp_mat.sort_values(by = [class_pred], ascending=False)
            temp_mat = temp_mat.reset_index()
            temp_mat.columns = ['feature','e^coef','coef']
            temp_mat = temp_mat[['feature','e^coef','coef']]
            temp_mat.columns =str(class_pred)+ "_" + temp_mat.columns
            self.top_n_features = pd.concat([self.top_n_features,temp_mat.head(top_n)], join="inner",ignore_index = False, axis=1)
            self.to_n_features_long = model_feature_sf(long_format_features(self.top_n_features),'e^coef')
            
    
    # plot class-wise features
def model_class_feature_plots(top_loadings, classes, comps):
    import matplotlib.pyplot as plt
    for class_temp in classes:
        class_lw = class_temp
        long_format = top_loadings
        df_loadings = long_format[long_format['class'].isin([class_lw])]
        plt.hist(df_loadings[comps])
        for i in ((df_loadings[comps][df_loadings[str(comps) +'_pval']<0.05]).unique()):
            plt.axvline(x=i,color='red')
        med = np.median(df_loadings[comps])
        plt.axvline(x=med,color='blue')
        plt.xlabel('feature_importance', fontsize=12)
        plt.title(class_lw)
        #plt.axvline(x=med,color='pink')
        df_loadings[comps][df_loadings[str(comps) +'_pval']<0.05]
        print(len(df_loadings[comps][df_loadings[str(comps) +'_pval']<0.05]))
        #Plot feature ranking
        plot_loading = pd.DataFrame(pd.DataFrame(df_loadings[comps][df_loadings[str(comps) +'_pval']<0.05]).iloc[:,0].sort_values(ascending=False))
        table = plt.table(cellText=plot_loading.values,colWidths = [1]*len(plot_loading.columns),
        rowLabels= list(df_loadings['feature'][df_loadings.index.isin(plot_loading.index)].reindex(plot_loading.index)), #plot_loading.index,
        colLabels=plot_loading.columns,
        cellLoc = 'center', rowLoc = 'center',
        loc='right', bbox=[1.4, -0.05, 0.5,1])
        table.scale(1, 2)
        table.set_fontsize(10)
        
def report_f1(model,train_x, train_label):
    ## Report accuracy score
    from sklearn.model_selection import cross_val_score
    from sklearn.model_selection import RepeatedStratifiedKFold
    from sklearn import metrics
    import seaborn as sn
    import pandas as pd
    from sklearn.metrics import confusion_matrix
    import matplotlib.pyplot as plt
    
    # cv = RepeatedStratifiedKFold(n_splits=2, n_repeats=2, random_state=1)
    # # evaluate the model and collect the scores
    # n_scores = cross_val_score(lr, train_x, train_label, scoring='accuracy', cv=cv, n_jobs=-1)
    # # report the model performance
    # print('Mean Accuracy: %.3f (%.3f)' % (np.mean(n_scores), np.std(n_scores)))

    # Report Precision score
    metric = pd.DataFrame((metrics.classification_report(train_label, model.predict(train_x), digits=2,output_dict=True))).T
    cm = confusion_matrix(train_label, model.predict(train_x))
    #cm = confusion_matrix(train_label, model.predict_proba(train_x))
    df_cm = pd.DataFrame(cm, index = model.classes_,columns = model.classes_)
    df_cm = (df_cm / df_cm.sum(axis=0))*100
    plt.figure(figsize = (20,15))
    sn.set(font_scale=1) # for label size
    pal = sns.diverging_palette(240, 10, n=10)
    #plt.suptitle(('Mean Accuracy 5 fold: %.3f std: %.3f' % (np.mean(n_scores),  np.std(n_scores))), y=1.05, fontsize=18)
    #Plot precision recall and recall
    table = plt.table(cellText=metric.values,colWidths = [1]*len(metric.columns),
    rowLabels=metric.index,
    colLabels=metric.columns,
    cellLoc = 'center', rowLoc = 'center',
    loc='bottom', bbox=[0.25, -0.6, 0.5, 0.3])
    table.scale(1, 2)
    table.set_fontsize(10)

    sn.heatmap(df_cm, annot=True, annot_kws={"size": 16},cmap=pal) # font size
    print(metrics.classification_report(train_label, model.predict(train_x), digits=2))

def subset_top_hvgs(adata_lognorm, n_top_genes):
    dispersion_norm = adata_lognorm.var['dispersions_norm'].values.astype('float32')

    dispersion_norm = dispersion_norm[~np.isnan(dispersion_norm)]
    dispersion_norm[
                ::-1
            ].sort()  # interestingly, np.argpartition is slightly slower

    disp_cut_off = dispersion_norm[n_top_genes - 1]
    gene_subset = adata_lognorm.var['dispersions_norm'].values >= disp_cut_off
    return(adata_lognorm[:,gene_subset])

def prep_scVI(adata, 
              n_hvgs = 5000,
              remove_cc_genes = True,
              remove_tcr_bcr_genes = False
             ):
    ## Remove cell cycle genes
    if remove_cc_genes:
        adata = panfetal_utils.remove_geneset(adata,genes.cc_genes)

    ## Remove TCR/BCR genes
    if remove_tcr_bcr_genes:
        adata = panfetal_utils.remove_geneset(adata, genes.IG_genes)
        adata = panfetal_utils.remove_geneset(adata, genes.TCR_genes)
        
    ## HVG selection
    adata = subset_top_hvgs(adata, n_top_genes=n_hvgs)
    return(adata)

# Modified LR train module, does not work with low-dim by default anymore, please use low-dim adapter
def LR_train(adata, train_x, train_label, penalty='elasticnet', sparcity=0.2,max_iter=200,l1_ratio =0.2,tune_hyper_params =False,n_splits=5, n_repeats=3,l1_grid = [0.1,0.2,0.5,0.8], c_grid = [0.1,0.2,0.4,0.6],sketch_obsm =None):
    if tune_hyper_params == True:
        train_labels = train_label
        results,adata_tuned = tune_lr_model(adata, train_x_partition = train_x, random_state = 42,  train_labels = train_labels, n_splits=n_splits, n_repeats=n_repeats,l1_grid = l1_grid, c_grid = c_grid,sketch_obsm = sketch_obsm)
        print('hyper_params tuned')
        sparcity = results.best_params_['C']
        l1_ratio = results.best_params_['l1_ratio']
        
    if not sketch_obsm == None:
        #sketch data
        try:
            adata = sketch_data(adata, train_x_partition = train_x, random_state = 42,  train_labels = train_label,sketch_obsm = sketch_obsm)
        except:
            print()

    lr = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  max_iter, n_jobs=thread_num)
    if (penalty == "l1"):
        lr = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  max_iter, dual = True, solver = 'liblinear',multi_class = 'ovr', n_jobs=thread_num ) # one-vs-rest
    if (penalty == "elasticnet"):
        lr = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  max_iter, dual=False,solver = 'saga',l1_ratio=l1_ratio,multi_class = 'ovr', n_jobs=thread_num)
    if train_x == 'X':
        subset_train = adata.obs.index
        # Define training parameters
        train_label = adata.obs[train_label].values
#        predict_label = train_label[subset_predict]
#        train_label = train_label[subset_train]
        train_x = adata.X#[adata.obs.index.isin(list(adata.obs[subset_train].index))]
#        predict_x = adata.X[adata.obs.index.isin(list(adata.obs[subset_predict].index))]
    elif train_x in adata.obsm.keys():
        # Define training parameters
        train_label = adata.obs[train_label].values
#        predict_label = train_label[subset_predict]
#         train_label = train_label[subset_train]
        train_x = adata.obsm[train_x]
#        predict_x = train_x
#        train_x = train_x[subset_train, :]
        # Define prediction parameters
#        predict_x = predict_x[subset_predict]
#        predict_x = pd.DataFrame(predict_x)
#        predict_x.index = adata.obs[subset_predict].index
    # Train predictive model using user defined partition labels (train_x ,train_label, predict_x)
    model = lr.fit(train_x, train_label)
    model.features = np.array(adata.var.index)
    return model

def tune_lr_model(adata, train_x_partition = 'X', random_state = 42,  train_labels = None, n_splits=5, n_repeats=3,l1_grid = [0.05,0.2,0.5,0.8], c_grid = [0.05,0.2,0.4,0.6],sketch_obsm = None):
    import bless as bless
    from sklearn.gaussian_process.kernels import RBF
    from numpy import arange
    from sklearn.model_selection import RepeatedKFold
    from sklearn.datasets import make_classification
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import f1_score
    from sklearn.model_selection import GridSearchCV

    # If latent rep is provided, randomly sample data in spatially aware manner for initialisation
    r = np.random.RandomState(random_state)
    if train_x_partition in adata.obsm.keys():
        tune_train_x = adata.obsm[train_x_partition][:]
        lvg = bless.bless(tune_train_x, RBF(length_scale=20), lam_final = 2, qbar = 2, random_state = r, H = 10, force_cpu=True)
        adata_tuning = adata[lvg.idx]
        print('Sketched data is {} long'.format(len(adata_tuning.obs)))

    elif sketch_obsm in adata.obsm.keys():
        sketch_obsm_id = adata.obsm[sketch_obsm][:]
        lvg = bless.bless(sketch_obsm_id, RBF(length_scale=20), lam_final = 2, qbar = 2, random_state = r, H = 10, force_cpu=True)
        adata_tuning = adata[lvg.idx]
        print('Sketched data is {} long'.format(len(adata_tuning.obs)))
        tune_train_x = adata_tuning.X
    #     try:
    #         import cupy
    #         lvg_2 = bless(adata.obsm[train_x_partition], RBF(length_scale=10), 10, 10, r, 10, force_cpu=False)
    #     except ImportError:
    #         print("cupy not found, defaulting to numpy")
    else:
        print('no latent representation provided, random sampling instead')
        prop = 0.1
        random_vertices = []
        n_ixs = int(len(adata.obs) * prop)
        random_vertices = random.sample(list(range(len(adata.obs))), k=n_ixs)
        adata_tuning = adata[random_vertices]
        tune_train_x = adata_tuning.X
        
    if not train_labels == None:
        tune_train_label = adata_tuning.obs[train_labels]
    elif train_labels == None:
        try:
            print('no training labels provided, defaulting to unsuperived leiden clustering, updates will change this to voronoi greedy sampling')
            sc.tl.leiden(adata_tuning)
        except:
            print('no training labels provided, no neighbors, defaulting to unsuperived leiden clustering, updates will change this to voronoi greedy sampling')
            sc.pp.neighbors(adata_hm, n_neighbors=15, n_pcs=50)
            sc.tl.leiden(adata_tuning)
        tune_train_label = adata_tuning.obs['leiden']
    ## tune regularization for multinomial logistic regression
    print('starting tuning loops')
    X = tune_train_x
    y = tune_train_label
    grid = dict()
    # define model
    cv = RepeatedKFold(n_splits=n_splits, n_repeats=n_repeats, random_state=random_state)
    #model = LogisticRegression(penalty = penalty, max_iter =  200, dual=False,solver = 'saga', multi_class = 'multinomial',)
    model = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  100, n_jobs=4)
    if (penalty == "l1"):
        model = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  100, dual = True, solver = 'liblinear',multi_class = 'ovr', n_jobs=4 ) # one-vs-rest
    if (penalty == "elasticnet"):
        model = LogisticRegression(penalty = penalty, C = sparcity, max_iter =  100, dual=False,solver = 'saga',l1_ratio=l1_ratio,multi_class = 'ovr', n_jobs=4) # use multinomial class if probabilities are descrete
        grid['l1_ratio'] = l1_grid
    grid['C'] = c_grid
    # define search
    search = GridSearchCV(model, grid, scoring='neg_mean_absolute_error', cv=cv, n_jobs=-1)
    # perform the search
    results = search.fit(X, y)
    # summarize
    print('MAE: %.3f' % results.best_score_)
    print('Config: %s' % results.best_params_)
    return results , adata_tuning

def prep_training_data(adata_temp,feat_use,batch_key, model_key, batch_correction=False, var_length = 7500,penalty='elasticnet',sparcity=0.2,max_iter = 200,l1_ratio = 0.1,partial_scale=True,train_x_partition ='X',theta = 3,tune_hyper_params=False,sketch_obsm = None ):
    model_name = model_key + '_lr_model'
    print('performing highly variable gene selection')
    sc.pp.highly_variable_genes(adata_temp, batch_key = batch_key, subset=False)
    #temp inclusion
    sc.pp.pca(adata_temp, n_comps=100, use_highly_variable=True, svd_solver='arpack')
    sc.pl.pca_variance_ratio(adata_temp, log=True,n_pcs=100)
        
    adata_temp = subset_top_hvgs(adata_temp,var_length)
    #scale the input data
    if partial_scale == True:
        print('scaling input data, default option is to use incremental learning and fit in mini bulks!')
        # Partial scaling alg
        #adata_temp.X = (adata_temp.X)
        scaler = StandardScaler(with_mean=False)
        n = adata_temp.X.shape[0]  # number of rows
        # set dyn scale packet size
        x_len = len(adata_temp.var)
        y_len = len(adata_temp.obs)
        if y_len < 100000:
            dyn_pack = int(x_len/10)
            pack_size = dyn_pack
        else:
            # 10 pack for every 100,000
            dyn_pack = int((y_len/100000)*10)
            pack_size = int(x_len/dyn_pack)
        batch_size =  1000#pack_size#500  # number of rows in each call to partial_fit
        index = 0  # helper-var
        while index < n:
            partial_size = min(batch_size, n - index)  # needed because last loop is possibly incomplete
            partial_x = adata_temp.X[index:index+partial_size]
            scaler.partial_fit(partial_x)
            index += partial_size
        adata_temp.X = scaler.transform(adata_temp.X)
#     else:
#         sc.pp.scale(adata_temp, zero_center=True, max_value=None, copy=False, layer=None, obsm=None)

    if (train_x_partition != 'X') & (train_x_partition in adata_temp.obsm.keys()):
        print('train partition is not in OBSM, defaulting to PCA')
        # Now compute PCA
        sc.pp.pca(adata_temp, n_comps=100, use_highly_variable=True, svd_solver='arpack')
        sc.pl.pca_variance_ratio(adata_temp, log=True,n_pcs=100)
        
        # Batch correction options
        # The script will test later which Harmony values we should use 
        if(batch_correction == "Harmony"):
            print("Commencing harmony")
            if len(batch_key) == 1:
                adata_temp.obs['lr_batch'] = adata_temp.obs[batch_key]
                batch_var = "lr_batch"
            else:
                batch_var = batch_key
            # Create hm subset
            adata_hm = adata_temp[:]
            # Set harmony variables
            data_mat = np.array(adata_hm.obsm["X_pca"])
            meta_data = adata_hm.obs
            vars_use = [batch_var]
            # Run Harmony
            ho = hm.run_harmony(data_mat, meta_data, vars_use,theta=theta)
            res = (pd.DataFrame(ho.Z_corr)).T
            res.columns = ['X{}'.format(i + 1) for i in range(res.shape[1])]
            # Insert coordinates back into object
            adata_hm.obsm["X_pca_back"]= adata_hm.obsm["X_pca"][:]
            adata_hm.obsm["X_pca"] = np.array(res)
            # Run neighbours
            #sc.pp.neighbors(adata_hm, n_neighbors=15, n_pcs=50)
            adata_temp = adata_hm[:]
            del adata_hm
        elif(batch_correction == "BBKNN"):
            print("Commencing BBKNN")
            sc.external.pp.bbknn(adata_temp, batch_key=batch_var, approx=True, metric='angular', copy=False, n_pcs=50, trim=None, n_trees=10, use_faiss=True, set_op_mix_ratio=1.0, local_connectivity=15) 
        print("adata1 and adata2 are now combined and preprocessed in 'adata' obj - success!")


    # train model
#    train_x = adata_temp.X
    #train_label = adata_temp.obs[feat_use]
    print('proceeding to train model')
    model = LR_train(adata_temp, train_x = train_x_partition, train_label=feat_use, penalty=penalty, sparcity=sparcity,max_iter=max_iter,l1_ratio = l1_ratio,tune_hyper_params = tune_hyper_params,sketch_obsm = sketch_obsm)
    model.features = list(adata_temp.var.index)
    return model

def regression_results(y_true, y_pred):
    # Regression metrics
    explained_variance=metrics.explained_variance_score(y_true, y_pred)
    mean_absolute_error=metrics.mean_absolute_error(y_true, y_pred) 
    mse=metrics.mean_squared_error(y_true, y_pred) 
    mean_squared_log_error=metrics.mean_squared_log_error(y_true, y_pred)
    median_absolute_error=metrics.median_absolute_error(y_true, y_pred)
    r2=metrics.r2_score(y_true, y_pred)
    print('explained_variance: ', round(explained_variance,4))    
    print('mean_squared_log_error: ', round(mean_squared_log_error,4))
    print('r2: ', round(r2,4))
    print('MAE: ', round(mean_absolute_error,4))
    print('MSE: ', round(mse,4))
    print('RMSE: ', round(np.sqrt(mse),4))
    
def sketch_data(adata, train_x_partition = 'X', sketch_obsm = None, random_state = 42,  train_labels = None):
    import bless as bless
    from sklearn.gaussian_process.kernels import RBF
    from numpy import arange
    from sklearn.model_selection import RepeatedKFold
    from sklearn.datasets import make_classification
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import f1_score
    from sklearn.model_selection import GridSearchCV

    # If latent rep is provided, randomly sample data in spatially aware manner for initialisation
    r = np.random.RandomState(random_state)
    if train_x_partition in adata.obsm.keys():
        tune_train_x = adata.obsm[train_x_partition][:]
    elif sketch_obsm in adata.obsm.keys():
        tune_train_x = adata.obsm[sketch_obsm][:]
    else:
        print('No obsm partition detected! defaulting to PCA')
        if not 'X_pca' in adata.obsm.keys():
            print('performing highly variable gene selection')
            sc.pp.highly_variable_genes(adata_temp, batch_key = batch_key, subset=False)
            sc.pp.pca(adata_temp, n_comps=100, use_highly_variable=True, svd_solver='arpack')
            sc.pl.pca_variance_ratio(adata_temp, log=True,n_pcs=100)
        tune_train_x = adata.obsm['X_pca'][:]
    lvg = bless.bless(tune_train_x, RBF(length_scale=20), lam_final = 2, qbar = 2, random_state = r, H = 10, force_cpu=True)
    adata_tuning = adata[lvg.idx]
    print('sketched partition is {}, original is {}'.format(len(lvg.idx)),len(adata.obs))
    return adata_tuning

# Read in query data for projection

In [None]:
train_model = False

In [None]:
if train_model == True:
    from sklearn.preprocessing import StandardScaler
    adata =  load_adatas(adatas_dict, data_merge, adata_key,QC_normalise)
    print('adata_loaded')
    import time
    t0 = time.time()
    display_cpu = DisplayCPU()
    display_cpu.start()
    try:
        model_trained = prep_training_data(feat_use = feat_use,
        adata_temp = adata,
        train_x_partition = train_x_partition,
        model_key = model_key + '_lr_model',
        batch_correction = False,
        var_length = 7500,
        batch_key = None, #batch_key,
        penalty='elasticnet', # can be ["l1","l2","elasticnet"],
        sparcity=sparcity, #If using LR without optimisation, this controls the sparsity in model
        max_iter = 1000, #Increase if experiencing max iter issues
        l1_ratio = l1_ratio, #If using elasticnet without optimisation, this controls the ratio between l1 and l2)
        partial_scale = False, #partial_scale,
        tune_hyper_params = True, # Current implementation is very expensive, intentionally made rigid for now
        sketch_obsm = 'X_pca'
        )
        filename = model_key
        pkl.dump(model_trained, open(filename, 'wb'))
        models[model_key] = model_key
    finally: #
        current, peak = display_cpu.stop()
        t1 = time.time()
        time_s = t1-t0
        print('training complete!')
        time.sleep(3)
        print('projection time was ' + str(time_s) + ' seconds')
        print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
        print(f"starting memory usage is" +'' + str(display_cpu.starting))
        print('peak CPU % usage = '+''+ str(display_cpu.peak_cpu))
        print('peak CPU % usage/core = '+''+ str(display_cpu.peak_cpu_per_core))
    model_lr= model_trained
    adata =  load_adatas(adatas_dict, data_merge, adata_key,QC_normalise)
else:
    adata =  load_adatas(adatas_dict, data_merge, adata_key,QC_normalise)
    model = load_models(models,model_key)
    model_lr =  model
# run with usage logger
import time
t0 = time.time()
display_cpu = DisplayCPU()
display_cpu.start()
try: #code here ##
    pred_out,train_x,model_lr,adata_temp = reference_projection(adata, model_lr, dyn_std,partial_scale)
    if freq_redist != False:
        pred_out = freq_redist_68CI(adata,freq_redist)
        pred_out['orig_labels'] = adata.obs[freq_redist]
        adata.obs['consensus_clus_prediction'] = pred_out['consensus_clus_prediction']
    adata.obs['predicted'] = pred_out['predicted']
    adata_temp.obs = adata.obs
    
    # Estimate top model features for class descrimination
    feature_importance = estimate_important_features(model_lr, 100)
    mat = feature_importance.euler_pow_mat
    top_loadings = feature_importance.to_n_features_long
    
    # Estimate dataset specific feature impact
#     for classes in ['pDC precursor_ys_HL','AEC_ys_HL']:
#         model_class_feature_plots(top_loadings, [str(classes)], 'e^coef')
#         plt.show()
finally: #
    current, peak = display_cpu.stop()
t1 = time.time()
time_s = t1-t0
print('projection complete!')
time.sleep(3)
print('projection time was ' + str(time_s) + ' seconds')
print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
print(f"starting memory usage is" +'' + str(display_cpu.starting))
print('peak CPU % usage = '+''+ str(display_cpu.peak_cpu))
print('peak CPU % usage/core = '+''+ str(display_cpu.peak_cpu_per_core))

# regression summary
idx_map = dict(zip(  list(adata.obs[feat_use].unique()),list(range(0,len(list(adata.obs[feat_use].unique()))))))
regression_results(adata.obs[feat_use].map(idx_map), adata.obs['predicted'].map(idx_map))

In [None]:
model_mean_probs = pred_out.loc[:, pred_out.columns != 'predicted'].groupby('orig_labels').median()
model_mean_probs = model_mean_probs*100
model_mean_probs = model_mean_probs.dropna(axis=0, how='any', thresh=None, subset=None, inplace=False)
crs_tbl = model_mean_probs.copy()
# Sort df columns by rows
index_order = list(crs_tbl.max(axis=1).sort_values(ascending=False).index)
col_order = list(crs_tbl.max(axis=0).sort_values(ascending=False).index)
crs_tbl = crs_tbl.loc[index_order]
crs_tbl = crs_tbl[col_order]
# Plot_df_heatmap(crs_tbl, cmap='coolwarm', rotation=90, vmin=20, vmax=70)
pal = sns.diverging_palette(240, 10, n=10)
plt.figure(figsize=(20,15))
sns.set(font_scale=0.5)
g = sns.heatmap(crs_tbl, cmap='viridis_r',  annot=False,vmin=0, vmax=max(np.max(crs_tbl)), linewidths=1, center=max(np.max(crs_tbl))/2, square=True, cbar_kws={"shrink": 0.5})
    
plt.ylabel("Original labels")
plt.xlabel("Training labels")
plt.savefig((model_key+'_X_lr_model_means_subclusters.pdf'),dpi=300)
plt.show()

In [None]:
# Report F1 metrics
k_x = np.isin(list(adata.var.index), list(model_lr.features))
k_x_idx = np.where(k_x)[0]
X = adata[:,k_x_idx].X
report_f1(model_lr,X, list(adata.obs[feat_use]))

# Project onto Mye cells in adult mye atlas

In [None]:
# Variable assignment
train_model = False
feat_use = 'IG_annot'
adata_key = 'pan_organ_mye'#'fliv_wget_test' # key for dictionary entry containing local or web path to adata/s can be either url or local 
data_merge = False # read and merge multiple adata (useful, but keep false for now)
model_key = model_key#'test_low_dim_ipsc_ys'# key for model of choice can be either url or local 
train_x_partition = 'X' # what partition was the data trained on? To keep simple, for now only accepts 'X'
dyn_std = 1.96 # Dynamic cutoffs using std of the mean for each celltype probability, gives a column notifying user of uncertain labels 1 == 68Ci, 1.96 = 95CI
freq_redist = 'leiden_res_3_IG'#'cell.labels'#'False#'cell.labels'#False # False or key of column in anndata object which contains labels/clusters // not currently implemented
partial_scale = False # should data be scaled in batches?
QC_normalise = False # should data be normalised?

# training variables
penalty='elasticnet' # can be ["l1","l2","elasticnet"]
sparcity=0.5 # C penalty for degree of regularisation
thread_num = -1
l1_ratio = 0.5 # ratio between L1 and L2 regulrisatiuon depending on penatly method

#

# If low dim & not in keys
batch_key = ['kit','donor'] #organ
batch_correction = 'Harmony'#'Harmony' #or bbknn
theta = 3 #harmony specifc


create_sketch_data = False

In [None]:
adata = sc.read('/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/A1_V1_LING_ADULT_IG_annot.h5ad')
sc.pp.highly_variable_genes(adata, batch_key = batch_key[0], subset=False)
sc.pp.pca(adata, n_comps=100, use_highly_variable=True, svd_solver='arpack')
sc.pl.pca_variance_ratio(adata, log=True,n_pcs=100)

import harmonypy as hm
# Batch correction options
# The script will test later which Harmony values we should use 
if(batch_correction == "Harmony"):
    print("Commencing harmony")
    if len(batch_key) == 1:
        adata.obs['lr_batch'] = adata.obs[batch_key]
        batch_var = "lr_batch"
    else:
        batch_var = (batch_key)
    # Create hm subset
    adata_hm = adata[:]
    # Set harmony variables
    data_mat = np.array(adata_hm.obsm["X_pca"])
    meta_data = adata_hm.obs
#     vars_use = [batch_var]
    # Run Harmony
    ho = hm.run_harmony(data_mat, meta_data, batch_var,theta=theta)
    res = (pd.DataFrame(ho.Z_corr)).T
    res.columns = ['X{}'.format(i + 1) for i in range(res.shape[1])]
    # Insert coordinates back into object
    adata_hm.obsm["X_pca_back"]= adata_hm.obsm["X_pca"][:]
    adata_hm.obsm["X_pca"] = np.array(res)
    # Run neighbours
    sc.pp.neighbors(adata_hm, n_neighbors=15, n_pcs=100)
    sc.tl.umap(adata_hm)
    sc.tl.leiden(adata_hm,resolution = 5,key_added = 'leiden_res_3_IG')
    adata = adata_hm[:]
    del adata_hm

In [None]:
adata.write('/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/A1_V2_LING_ADULT_IG_annot.h5ad')

In [None]:
adatas_dict[adata_key] = '/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/A1_V2_LING_ADULT_IG_annot.h5ad'

In [None]:
if train_model == True:
    from sklearn.preprocessing import StandardScaler
    adata =  load_adatas(adatas_dict, data_merge, adata_key,QC_normalise)
    print('adata_loaded')
    import time
    t0 = time.time()
    display_cpu = DisplayCPU()
    display_cpu.start()
    
    if create_sketch_data == True:
        #sketch data
        adata = sketch_data(adata, train_x_partition = 'X_scanvi_emb', random_state = 42,  train_labels = None)
        
    try:
        model_trained = prep_training_data(feat_use = feat_use,
        adata_temp = adata,
        train_x_partition = train_x_partition,
        model_key = model_key + '_lr_model',
        batch_correction = False,
        var_length = 7500,
        batch_key = batch_key,
        penalty='elasticnet', # can be ["l1","l2","elasticnet"],
        sparcity=sparcity, #If using LR without optimisation, this controls the sparsity in model
        max_iter = 1000, #Increase if experiencing max iter issues
        l1_ratio = l1_ratio, #If using elasticnet without optimisation, this controls the ratio between l1 and l2)
        partial_scale = False, #partial_scale,
        tune_hyper_params = True # Current implementation is very expensive, intentionally made rigid for now
        )
        filename = model_key
        pkl.dump(model_trained, open(filename, 'wb'))
        models[model_key] = model_key
    finally: #
        current, peak = display_cpu.stop()
        t1 = time.time()
        time_s = t1-t0
        print('training complete!')
        time.sleep(3)
        print('projection time was ' + str(time_s) + ' seconds')
        print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
        print(f"starting memory usage is" +'' + str(display_cpu.starting))
        print('peak CPU % usage = '+''+ str(display_cpu.peak_cpu))
        print('peak CPU % usage/core = '+''+ str(display_cpu.peak_cpu_per_core))
    model_lr= model_trained
    adata =  load_adatas(adatas_dict, data_merge, adata_key,QC_normalise)
else:
    adata =  load_adatas(adatas_dict, data_merge, adata_key,QC_normalise)
    #sc.tl.umap(adata)
    #sc.tl.leiden(adata,resolution = 3,key_added = 'leiden_res_3_IG')
    model = load_models(models,model_key)
    model_lr =  model
# run with usage logger
import time
t0 = time.time()
display_cpu = DisplayCPU()
display_cpu.start()
try: #code here ##
    pred_out,train_x,model_lr,adata_temp = reference_projection(adata, model_lr, dyn_std,partial_scale)
    if freq_redist != False:
        pred_out = freq_redist_68CI(adata,freq_redist)
        pred_out['orig_labels'] = adata.obs[freq_redist]
        adata.obs['consensus_clus_prediction'] = pred_out['consensus_clus_prediction']
    adata.obs['predicted'] = pred_out['predicted']
    adata_temp.obs = adata.obs
    
    # Estimate top model features for class descrimination
    feature_importance = estimate_important_features(model_lr, 100)
    mat = feature_importance.euler_pow_mat
    top_loadings = feature_importance.to_n_features_long
    
    # Estimate dataset specific feature impact
#     for classes in ['pDC precursor_ys_HL','AEC_ys_HL']:
#         model_class_feature_plots(top_loadings, [str(classes)], 'e^coef')
#         plt.show()
finally: #
    current, peak = display_cpu.stop()
t1 = time.time()
time_s = t1-t0
print('projection complete!')
time.sleep(3)
print('projection time was ' + str(time_s) + ' seconds')
print(f"Current memory usage is {current / 10**6}MB; Peak was {peak / 10**6}MB")
print(f"starting memory usage is" +'' + str(display_cpu.starting))
print('peak CPU % usage = '+''+ str(display_cpu.peak_cpu))
print('peak CPU % usage/core = '+''+ str(display_cpu.peak_cpu_per_core))


In [None]:
cluster_prediction = "clus_prediction_confident"
clusters_reassign = "leiden_res_3_IG"
lr_predicted_col = 'predicted'

In [None]:
adata.obs['confident_calls'] = pred_out['confident_calls']
adata.obs[cluster_prediction] = adata.obs.index
for z in adata.obs[clusters_reassign].unique():
    df = adata.obs
    df = df[(df[clusters_reassign].isin([z]))]
    df_count = pd.DataFrame(df[lr_predicted_col].value_counts())
    freq_arranged = df_count.index
    cat = freq_arranged[0]
    df.loc[:,cluster_prediction] = cat
    adata.obs.loc[adata.obs[clusters_reassign] == z, [cluster_prediction]] = cat

In [None]:
model_mean_probs = pred_out.loc[:, pred_out.columns != 'predicted'].groupby('orig_labels').median()
model_mean_probs = model_mean_probs*100
model_mean_probs = model_mean_probs.dropna(axis=0, how='any', thresh=None, subset=None, inplace=False)
crs_tbl = model_mean_probs.copy()
# Sort df columns by rows
index_order = list(crs_tbl.max(axis=1).sort_values(ascending=False).index)
col_order = list(crs_tbl.max(axis=0).sort_values(ascending=False).index)
crs_tbl = crs_tbl.loc[index_order]
crs_tbl = crs_tbl[col_order]
# Plot_df_heatmap(crs_tbl, cmap='coolwarm', rotation=90, vmin=20, vmax=70)
pal = sns.diverging_palette(240, 10, n=10)
plt.figure(figsize=(20,15))
sns.set(font_scale=0.5)
g = sns.heatmap(crs_tbl, cmap='viridis_r',  annot=False,vmin=0, vmax=max(np.max(crs_tbl)), linewidths=1, center=max(np.max(crs_tbl))/2, square=True, cbar_kws={"shrink": 0.5})
    
plt.ylabel("Original labels")
plt.xlabel("Training labels")
plt.savefig((model_key+'_X_lr_model_means_subclusters.pdf'),dpi=300)
plt.show()

In [None]:
sc.set_figure_params(dpi=150, dpi_save=150,figsize=[15,15],fontsize=10)
sc.pl.umap(adata,color = ['confident_calls','clus_prediction_confident'],wspace = 0.5,size = 10)
sc.pl.umap(adata,color = ['IG_annot','leiden_res_3_IG'],wspace = 0.5,size = 10)


In [None]:
sc.pl.umap(adata,color = ['clus_prediction_confident'],legend_loc = 'on data',wspace = 0.5,size = 10)

In [None]:
adata.obs['organ'] = adata.obs['organ'].str.replace('lung','Lung')

In [None]:
pd.set_option('display.max_rows', 100)

df = pd.DataFrame(adata.obs[adata.obs['clus_prediction_confident'].str.contains('mac')].groupby(['organ','clus_prediction_confident']).apply(len)).reset_index()

df



In [None]:
df

In [None]:
adata.obs['kit'].unique()

In [None]:
sc.pl.umap(adata[adata.obs['organ'].isin(['Lung','lung'])], color = 'kit')

In [None]:
sc.pl.umap(adata,color = ['organ'],groups = ['Lung','lung'],legend_loc = 'on data',wspace = 0.5,size = 10)

In [None]:
adata.obs[adata.obs['IG_annot'].str.contains('OSTEO')].groupby(['organ']).apply(len)

In [None]:
adata.obs.to_csv('./v2_hm_projected_immune_atlas_ling_meta.csv')
pred_out.to_csv('./v2_hm_pred_out_projected_immune_atlas_ling.csv')

# load original umap and check

In [None]:
adata_umap = sc.read('/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/A1_V1_LING_ADULT_IG_annot.h5ad')

In [None]:
adata_umap.obs['clus_prediction_confident'] = adata.obs['clus_prediction_confident']

In [None]:
sc.pl.umap(adata_umap,color = ['kit','leiden_scVI','clus_prediction_confident'],legend_loc = 'on data',wspace = 0.5,size = 10)

### Remove clusters 3 and 9 from data due to batch-effect

In [None]:
adata.obsm = adata_umap.obsm
adata = adata[~adata.obs['leiden_scVI'].isin(['3','9'])]

In [None]:
adata.write('/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/A1_V3_LING_ADULT_IG_annot.h5ad')

### Add brain and Lung annots

In [None]:
plt.rcdefaults()
sc.pl.umap(adata,color = ['clus_prediction_confident'],wspace = 0.1,save = 'Lung_atlas_projection')

# View by median probabilities per classification

In [None]:
model_mean_probs = pred_out.loc[:, pred_out.columns != 'predicted'].groupby('orig_labels').median()
model_mean_probs = model_mean_probs*100
model_mean_probs = model_mean_probs.dropna(axis=0, how='any', thresh=None, subset=None, inplace=False)
crs_tbl = model_mean_probs.copy()
# Sort df columns by rows
index_order = list(crs_tbl.max(axis=1).sort_values(ascending=False).index)
col_order = list(crs_tbl.max(axis=0).sort_values(ascending=False).index)
crs_tbl = crs_tbl.loc[index_order]
crs_tbl = crs_tbl[col_order]
# Plot_df_heatmap(crs_tbl, cmap='coolwarm', rotation=90, vmin=20, vmax=70)
pal = sns.diverging_palette(240, 10, n=10)
plt.figure(figsize=(20,15))
sns.set(font_scale=0.5)
g = sns.heatmap(crs_tbl, cmap='viridis_r',  annot=False,vmin=0, vmax=max(np.max(crs_tbl)), linewidths=1, center=max(np.max(crs_tbl))/2, square=True, cbar_kws={"shrink": 0.5})

plt.ylabel("Original labels")
plt.xlabel("Training labels")
plt.savefig((model_key+'_X_lr_model_means_subclusters.pdf'),dpi=300)
plt.show()

# View by label assignment

In [None]:
x=feat_use
y = 'predicted'

y_attr = adata_temp.obs[y]
x_attr = adata_temp.obs[x]
crs = pd.crosstab(x_attr, y_attr)
crs_tbl = crs
for col in crs_tbl :
    crs_tbl[col] = crs_tbl[col].div(crs_tbl[col].sum(axis=0)).multiply(100).round(2)
    
index_order = list(crs_tbl.max(axis=1).sort_values(ascending=False).index)
col_order = list(crs_tbl.max(axis=0).sort_values(ascending=False).index)
crs_tbl = crs_tbl.loc[index_order]
crs_tbl = crs_tbl[col_order]

#plot_df_heatmap(crs_tbl, cmap='coolwarm', rotation=90, vmin=20, vmax=70)
pal = sns.diverging_palette(240, 10, n=10)
plt.figure(figsize=(20,15))
sns.set(font_scale=0.8)
g = sns.heatmap(crs_tbl, cmap='viridis_r', vmin=0, vmax=100, linewidths=1, center=50, square=True, cbar_kws={"shrink": 0.3})
plt.xlabel("Original labels")
plt.ylabel("Predicted labels")
# plt.savefig(save_path + "/LR_predictions_consensus.pdf")
# crs_tbl.to_csv(save_path + "/post-freq_LR_predictions_consensus_supp_table.csv")
plt.show()

In [None]:
x='consensus_clus_prediction'
y = 'predicted'

y_attr = adata_temp.obs[y]
x_attr = adata_temp.obs[x]
crs = pd.crosstab(x_attr, y_attr)
crs_tbl = crs
for col in crs_tbl :
    crs_tbl[col] = crs_tbl[col].div(crs_tbl[col].sum(axis=0)).multiply(100).round(2)
    
index_order = list(crs_tbl.max(axis=1).sort_values(ascending=False).index)
col_order = list(crs_tbl.max(axis=0).sort_values(ascending=False).index)
crs_tbl = crs_tbl.loc[index_order]
crs_tbl = crs_tbl[col_order]

#plot_df_heatmap(crs_tbl, cmap='coolwarm', rotation=90, vmin=20, vmax=70)
pal = sns.diverging_palette(240, 10, n=10)
plt.figure(figsize=(20,15))
sns.set(font_scale=0.8)
g = sns.heatmap(crs_tbl, cmap='viridis_r', vmin=0, vmax=100, linewidths=1, center=50, square=True, cbar_kws={"shrink": 0.3})
plt.xlabel("Original labels")
plt.ylabel("Predicted labels")
plt.savefig((model_key+'_X_lr_model_means_subclusters.pdf'),dpi=300)
# crs_tbl.to_csv(save_path + "/post-freq_LR_predictions_consensus_supp_table.csv")
plt.show()

# Save predicted output

In [None]:
pred_out.to_csv('./A1_V1_LUNG_LUNG_adult_pred_outs.csv')

# Update the integrated macs data

In [None]:
adata = sc.read('/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/A2_V1_130323_integrated_macs_adult_pan_organ_scored.h5ad')

In [None]:
preds = pred_out[['predicted','consensus_clus_prediction']]
preds.groupby(['consensus_clus_prediction']).apply(len)

In [None]:
preds.loc[~preds['consensus_clus_prediction'].isin(['Interstitial macrophages']),'consensus_clus_prediction'] = 'Alveolar macrophages'
preds.groupby(['consensus_clus_prediction']).apply(len)

In [None]:
preds['LVL3'] = preds['consensus_clus_prediction']
preds['LVL3'] = preds['LVL3'].str.replace('Alveolar macrophages','MACROPHAGES_ALVEOLAR')
preds['LVL3'] = preds['LVL3'].str.replace('Interstitial macrophages','MACROPHAGES_INTERSTITIAL')


In [None]:
# push into main
tlf_indx = (adata.obs.loc[(~adata.obs['LVL4'].isin(['MACROPHAGE_pre_agm_hi'])) & (adata.obs.index.isin(preds.index))].index)
preds['LVL4'] = preds['LVL3'][:]
preds.loc[preds.index.isin(tlf_indx),'LVL4'] = 'MACROPHAGE_pre_agm_hi'

adata.obs['LVL3'] = adata.obs['LVL3'].astype(str)
adata.obs['LVL4'] = adata.obs['LVL4'].astype(str)
adata.obs.loc[(adata.obs.index.isin(preds.index)),'LVL3'] = preds['LVL3'].astype(str)
adata.obs.loc[(adata.obs.index.isin(preds.index)),'LVL4'] = preds['LVL4'].astype(str)

In [None]:
sc.pl.umap(adata,color = ['LVL3','LVL4'],wspace = 1)

In [None]:
cells = [
'MACROPHAGE_ERY',
'MACROPHAGE_KUPFFER_LIKE',
'MACROPHAGE_LYVE1_HIGH',
'MACROPHAGE_MHCII_HIGH',
'MACROPHAGES_ALVEOLAR',
'MACROPHAGES_INTERSTITIAL',
'MACROPHAGE_BAMS',
'MACROPHAGE_MICROGLIA',
'MACROPHAGE_PERI',
'MACROPHAGE_PROLIFERATING',
'OSTEOCLAST']
cols = [
'#DB3432',
'#FFFFFF',
'#CEDE78',
'#E38C4C',
'#BD9D93',
'#9C8897',
'#EDCD52',
'#4C7397',
'#9EAFB7',
'#EB8F50',
'#DF5251']
col_pal = dict(zip(cells,cols))
adata.obs['LVL3'] = adata.obs['LVL3'].astype('category').cat.reorder_categories(cells) 

In [None]:
plt.rcdefaults()
var = "LVL3"
#Create color dictionary_cell
adata.obs[var] = adata.obs[var].astype('category')
cells = list(adata.obs[var].cat.categories)
col = list(range(0, len(adata.obs[var].cat.categories)))
#col = adata_mac.uns['cell.labels_colors']
dic = dict(zip(cells,col))

#Create a mappable field
adata.obs['num'] = adata.obs[var].astype(str)
#map to adata_mac.obs.col to create a caterorical column
adata.obs['num'] = adata.obs['num'].map(dic)

##Map to a pallete to use with umap
#cells_list = pd.DataFrame(adata_mac.obs["cell.labels"].cat.categories)
#cells_list['col'] = cells_list[0].map(dic)
#col_pal = list(cells_list['col'])
adata.obs['num'] = adata.obs['num'].astype(str)
adata.obs[var+'_num'] = adata.obs['num'].astype(str) + " : " + adata.obs[var].astype(str) #FF4A46

# col_pal = ['#94BFB1',     '#B49EC8',    '#E0EE70',    '#EE943E',    '#4C7BAB',    '#E78AB8',    '#AFBFCC',   "#FF4A46", '#FF993F',    "#FFFF00", "#1CE6FF", "#FF34FF", "#FF4A46", "#008941", "#006FA6", "#A30059",    "#FFDBE5", "#7A4900", "#0000A6", "#63FFAC", "#B79762", "#004D43", "#8FB0FF", "#997D87",    "#5A0007", "#809693", "#6A3A4C", "#1B4400", "#4FC601", "#3B5DFF"]

import matplotlib
matplotlib.rcdefaults() #Reset matplot lb deafults as seaborne tends to mess with this
fig, (ax1, ax2,) = plt.subplots(1,2, figsize=(10,10), gridspec_kw={'wspace':0,'width_ratios': [1,0]})
p2 = sc.pl.umap(adata, color = (var+'_num') ,ax=ax2,show=False,title="", palette= cols) #title=i
p3 = sc.pl.umap(adata, color = "num",legend_loc="on data",size=10,legend_fontsize='small',ax=ax1,show=False,title="Macs_adult", palette= cols) #title=i

fig.savefig('./'+var+"_mac_num.pdf",bbox_inches='tight')
plt.show()

In [None]:
adata.write('/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/A2_V2_130323_integrated_macs_adult_pan_organ_scored.h5ad')

# Derive a new space

In [None]:
batch_key = 'organ'
theta = 3

sc.pp.highly_variable_genes(adata, batch_key = batch_key, subset=False)
sc.pp.pca(adata, n_comps=100, use_highly_variable=True, svd_solver='arpack')
sc.pl.pca_variance_ratio(adata, log=True,n_pcs=100)

import harmonypy as hm
# Batch correction options
# The script will test later which Harmony values we should use 
print("Commencing harmony")
adata.obs['lr_batch'] = adata.obs[batch_key]
batch_var = "lr_batch"
# Create hm subset
adata_hm = adata[:]
# Set harmony variables
data_mat = np.array(adata_hm.obsm["X_pca"])
meta_data = adata_hm.obs
vars_use = [batch_var]
# Run Harmony
ho = hm.run_harmony(data_mat, meta_data, vars_use,theta=theta)
res = (pd.DataFrame(ho.Z_corr)).T
res.columns = ['X{}'.format(i + 1) for i in range(res.shape[1])]
# Insert coordinates back into object
adata_hm.obsm["X_pca_back"]= adata_hm.obsm["X_pca"][:]
adata_hm.obsm["X_pca"] = np.array(res)
adata = adata_hm[:]
del adata_hm

In [None]:
sc.pp.neighbors(adata,use_rep = 'X_pca')
sc.tl.umap(adata)
sc.pl.umap(adata,color = ['LVL3','LVL4'],wspace = 1)

In [None]:
adata.write('/nfs/team205/ig7/work_backups/backup_210306/projects/YS/rebuttal_figs_010922/ling_adult_macs/A2_V3_HM_130323_integrated_macs_adult_pan_organ_scored.h5ad')

In [None]:
# # filter unlikely predictions
# filtered = pred_out[np.max(pred_out.loc[:,~pred_out.columns.isin(['predicted','confident_calls','annot_celltype', 'consensus_clus_prediction', 'orig_labels','clus_prediction_confident'])],axis = 1) > 0.3]
# adata_temp = adata[adata.obs.index.isin(filtered.index)]
# filtered['clus_prediction_confident'] = adata_temp.obs['clus_prediction_confident']

 # Significant contributors to feature effect size per class of model
     - Bear in mind these are only top features..
    - assess the positive descriminators (markers) of the model
    - “…provide information about the magnitude and direction of the difference between two groups or the relationship between two variables.”

In [None]:
list(top_loadings['class'].unique())

In [None]:
top_loadings

In [None]:
for classes in list(top_loadings['class'].unique()):
    try:
        model_class_feature_plots(top_loadings, [str(classes)], 'e^coef')
    except:
        skip = classes
        print('No significant features detected, skipped {}'.format(skip))
    plt.show()

In [None]:
plt.rcdefaults()
# plot_states = ['Tip cell (arterial)','HSC','SPP1+ proliferating neuron proneitors']
markers = top_loadings[top_loadings['class'].isin(adata_temp.obs['consensus_clus_prediction'])].groupby(['class']).head(5).groupby(['class'])['feature'].agg(lambda grp: list(grp)).to_dict()
sc.pl.dotplot(adata_temp, groupby = 'consensus_clus_prediction', var_names = markers,standard_scale='var')

In [None]:
top_loadings[top_loadings['class'].isin(['Lymphoid progenitor','Early erythroid (embryonic)','Pre-dermal condensate'])].groupby(['class']).head(15)

# Label confidence scoring
- Let's study label stability given K-neighborhood assignments

**Author notes:** 
-  Hey! If you're reading this, I've probably messed up somewhere and you're looking for an explanation why :) 
- Code blocks marked **Prototype** are usually incomplete or a irresponsible lift from another pipeline, if the source pipeline is already distributed/published, I will leave git links associated with the module.
- If there are no links, there should be some run notes

**Run mode 2 of prototype $alpha$ $beta$ sampling via leverage-score**
- Mode 2 was chosen as we want to define a sampling space which satisfies same KNN distribution and density instead of prioritising variability
- Neighborhood assignment is done via majority voting
- Posterior probability computed and sampling rate for X is determined

# Running mode 2 of prototype alpha-beta sampling via leverage-score instead of NUTs

In [None]:
# define prior probability
# define liklihood of type 1 error (FP)
# define liklihood o type 2 error (FN)
# For given geosketch, what is the error rate and posterior prob

# Vertice calling and association array

In [None]:
# vertice association matrix
adata.obsm['nhoods']
# Let's now count number of cell states given a sampled neighbourhood
nn_membership_mt = pd.DataFrame(adata.uns["nhood_adata"].X.todense(),index =adata.uns["nhood_adata"].obs.index,columns = adata.uns["nhood_adata"].var.index )
# Are there annotation fields with less than a single neighborhood rep?
adata.uns["nhood_adata"].obs['membership'] = nn_membership_mt.idxmax(axis = 1)
# let's get the binarised relationships of all cells to neighborhoods
knn_membership_mt = pd.DataFrame(adata.obsm['nhoods'].todense(),index = adata.obs.index, columns = adata.uns["nhood_adata"].obs.index)

In [None]:
nn_membership_mt

In [None]:
# areas which are under represented require re-sampling
np.max(nn_membership_mt['ASDC'])

In [None]:
nn_membership_mt['ASDC'][nn_membership_mt['ASDC']>0]

In [None]:
adata.uns["nhood_adata"].obs['membership']

In [None]:
# is there a minimum of three neighborhoods per cell state?
undersampled = adata.uns["nhood_adata"].obs.groupby(['membership']).apply(len)[adata.uns["nhood_adata"].obs.groupby(['membership']).apply(len)<3]

# which neighborhoods have these labels been lost to?
lost = list(set(list(adata.uns["nhood_adata"].obs['membership'].unique())) ^ set(list(adata.obs[feat_use].unique())))

# How many cells are in these neighborhoods?
adata.uns["nhood_adata"].obs['count'] = knn_membership_mt.T.sum(axis = 1).values
adata.uns["nhood_adata"].obs[adata.uns["nhood_adata"].obs['membership'].isin(undersampled.index)]

In [None]:
lost

In [None]:
undersampled

In [None]:
nn_membership_mt

In [None]:
pred_out

In [None]:
pred_out.to_csv('A1_V2_X_LR_pred_out_SK_to_VASC.csv')

In [None]:
adata_temp.obs.groupby(['consensus_clus_prediction']).apply(len)