In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import pandas as pd
import numpy as np
import scanpy as sc

import cuml
from micron2 import load_as_anndata
from micron2 import cluster_leiden_cu

from matplotlib import pyplot as plt
from matplotlib import rcParams
import seaborn as sns

rcParams['figure.facecolor'] = (1,1,1,1)
rcParams['figure.dpi'] = 180

In [None]:
path = f'/storage/codex/datasets_v1/merged_v3.h5ad'
adata = sc.read_h5ad(path)
adata

In [None]:
rcParams['figure.dpi'] = 100
r = np.max(np.abs(adata.obsm['coordinates_shift']), axis=0)
r = r[0]/r[1]
plt.figure(figsize=(r*6,6))
sc.pl.embedding(adata, basis='coordinates_shift', color='sample_id_printing', 
                s=1, ax=plt.gca(), legend_loc='on data')

In [None]:
use_features =  ['CD45', 'CD20', 'CD68', 'CD31', 'CD3e', 
                 'CD11c', 'CD138', 'PDGFRb', 'aSMA', 'CD8', 'CD4', 
                 'PanCytoK']
var_names = []
for f in use_features:
    var_names += [v for v in adata.var_names if f'{f}_' in v]
print(len(var_names))

# feats = np.array( [v for v in var_names if ('q50' in v) or ('percent' in v) ]) 
feats = np.array( [v for v in var_names if 'q50' in v] ) 
print(len(feats))
X = adata[:, feats].X
print(X.shape)

In [None]:
additional_features = [f for f in adata.obs.columns if 'ring' in f]
print(additional_features)

## Apply clean up gates

In [None]:
y = np.zeros(adata.shape[0], dtype=object)
y[:] = 'x'

In [None]:
""" All - negative class """
use_features =  ['CD45', 'CD20', 'CD68', 'CD31', 'CD3e', 
                 'CD11c', 'CD138', 'PDGFRb', 'aSMA', 'CD8', 'CD4', 
                 'PanCytoK']
feats = [f'{f}_membrane_percent_positive' for f in use_features]
vals = np.zeros([adata.shape[0], len(feats)], dtype=np.bool)
for i,f in enumerate(feats):
    v = adata[:, f].X.toarray().flatten()
    vals[v == 0, i] = 1 # if < 10% non-zero for feature, we call it negative
    
all_negative = vals.sum(axis=1)==len(feats)
y[all_negative] = 'neg'

for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))

In [None]:
""" Scoop up B cells. No other lineage markers. """
positive_channels = ['CD20']
negative_channels = ['CD68', 'CD31', 'PDGFRb', 'aSMA', 'CD138', 'CD8_', 'CD4_']
def channel_in_list(v, channel_list):
    for ch in channel_list:
        if ch in v:
            return True
    return False
positive_feats = [v for v in adata.var_names if channel_in_list(v, positive_channels)]
required_feats = [v for v in positive_feats if 'membrane_percent_positive' in v]

negative_feats = [v for v in adata.var_names if channel_in_list(v, negative_channels)]
exclude_feats = [v for v in negative_feats if 'membrane_percent_positive' in v]

# cells = np.array(y=='CD4T', dtype=np.bool)
cells = np.zeros(adata.shape[0], dtype=np.uint8)
print(f'starting with {np.sum(cells)} cells')
for f in required_feats:
    print(f'required feat: {f} ({np.sum(cells)})')
    v = adata[:, f].X.toarray().flatten()
    cells[ v > 0.30 ] += 1
cells = cells == len(required_feats)
print(f'Scooped up {np.sum(cells)} cells with all positive requirements')
y[cells] = 'Bcell'
    
toss = np.zeros(adata.shape[0], dtype=np.bool)
for f in exclude_feats:
    v = adata[:, f].X.toarray().flatten()
    toss[ v > 0.05 ] = 1
    print(f'excluding feat: {f} ({np.sum((v > 0.05)&(y=="Bcell"))})')
print(f'Tossing {np.sum(cells)} cells ')
y[toss & (y=="Bcell")] = 'x'

for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))
    
feats = [v for v in adata.var_names if ('membrane_percent_positive' in v) and ('DAPI' not in v)]
sc.pl.heatmap(adata[y=='Bcell'], feats, groupby='biopsy', log=False)

In [None]:
# """ Scoop up Plasma cells. No other lineage markers. """
# positive_channels = ['CD138']
# negative_channels = ['CD68', 'CD31', 'PDGFRb', 'aSMA', 'CD20', 'CD8_', 'CD4_', 'PanCytoK', 'CD45']
# def channel_in_list(v, channel_list):
#     for ch in channel_list:
#         if ch in v:
#             return True
#     return False
# positive_feats = [v for v in adata.var_names if channel_in_list(v, positive_channels)]
# required_feats = [v for v in positive_feats if 'membrane_percent_positive' in v]

# negative_feats = [v for v in adata.var_names if channel_in_list(v, negative_channels)]
# exclude_feats = [v for v in negative_feats if 'membrane_percent_positive' in v]

# # cells = np.array(y=='CD4T', dtype=np.bool)
# cells = np.zeros(adata.shape[0], dtype=np.uint8)
# print(f'starting with {np.sum(cells)} cells')
# for f in required_feats:
#     print(f'required feat: {f} ({np.sum(cells)})')
#     v = adata[:, f].X.toarray().flatten()
#     cells[ v > 0.1 ] += 1
# cells = cells == len(required_feats)
# print(f'Scooped up {np.sum(cells)} cells with all positive requirements')
# y[cells] = 'Plasma'
    
# toss = np.zeros(adata.shape[0], dtype=np.bool)
# for f in exclude_feats:
#     v = adata[:, f].X.toarray().flatten()
#     toss[ v > 0.05 ] = 1
#     print(f'excluding feat: {f} ({np.sum((v > 0.1)&(y=="Plasma"))})')
# print(f'Tossing {np.sum(cells)} cells ')
# y[toss & (y=="Plasma")] = 'x'

# for k in np.unique(y):
#     print(f'{k:<15}', np.sum(y==k))
    
# feats = [v for v in adata.var_names if ('membrane_percent_positive' in v) and ('DAPI' not in v)]
# sc.pl.heatmap(adata[y=='Plasma'], feats, groupby='biopsy', log=False)

In [None]:
""" Scoop up CD4 Tcells. No other lineage markers. """
positive_channels = ['CD4_']
negative_channels = ['CD68', 'CD138', 'CD31', 'PDGFRb', 'aSMA', 'CD20', 'CD8_']
def channel_in_list(v, channel_list):
    for ch in channel_list:
        if ch in v:
            return True
    return False
positive_feats = [v for v in adata.var_names if channel_in_list(v, positive_channels)]
required_feats = [v for v in positive_feats if 'membrane_percent_positive' in v]

negative_feats = [v for v in adata.var_names if channel_in_list(v, negative_channels)]
exclude_feats = [v for v in negative_feats if 'membrane_percent_positive' in v]

cells = np.zeros(adata.shape[0], dtype=np.bool)
print(f'starting with {np.sum(cells)} cells')
for f in required_feats:
    print(f'required feat: {f} ({np.sum(cells)})')
    v = adata[:, f].X.toarray().flatten()
    cells[ v > 0.30 ] = 1
# cells = cells == len(required_feats)
print(f'Scooped up {np.sum(cells)} cells with all positive requirements')
y[cells] = 'CD4T'

toss = np.zeros(adata.shape[0], dtype=bool)    
for f in exclude_feats:
    v = adata[:, f].X.toarray().flatten()
    toss[ v > 0.05 ] = 1
    print(f'excluding feat: {f} ({np.sum((v > 0.05) & (y=="CD4T"))})')
print(f'Tossing {np.sum(toss & (y=="CD4T"))} cells ')
y[toss & (y=='CD4T')] = 'x'

for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))
    
feats = [v for v in adata.var_names if ('membrane_percent_positive' in v) and ('DAPI' not in v)]
sc.pl.heatmap(adata[y=='CD4T'], feats, groupby='biopsy', log=False)

In [None]:
""" Scoop up CD8 Tcells. No other lineage markers. """
positive_channels = ['CD8_']
negative_channels = ['CD68', 'CD138', 'CD31', 'PDGFRb', 'aSMA', 'CD20', 'CD4_']
def channel_in_list(v, channel_list):
    for ch in channel_list:
        if ch in v:
            return True
    return False
positive_feats = [v for v in adata.var_names if channel_in_list(v, positive_channels)]
required_feats = [v for v in positive_feats if 'membrane_percent_positive' in v]

negative_feats = [v for v in adata.var_names if channel_in_list(v, negative_channels)]
exclude_feats = [v for v in negative_feats if 'membrane_percent_positive' in v]

cells = np.zeros(adata.shape[0], dtype=np.bool)
print(f'starting with {np.sum(cells)} cells')
for f in required_feats:
    print(f'required feat: {f} ({np.sum(cells)})')
    v = adata[:, f].X.toarray().flatten()
    cells[ v > 0.30 ] = 1
    
# cells = cells == len(required_feats)
print(f'Scooped up {np.sum(cells)} cells with all positive requirements')
y[cells] = 'CD8T'
    
toss = np.zeros(adata.shape[0], dtype=np.bool)
for f in exclude_feats:
    v = adata[:, f].X.toarray().flatten()
    toss[ v > 0.05 ] = 1
    print(f'excluding feat: {f} ({np.sum((v>0.05) & (y=="CD8T"))})')
print(f'tossing {np.sum(toss & (y=="CD8T"))}')

y[toss & (y=='CD8T')] = 'x'

for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))
    
feats = [v for v in adata.var_names if ('membrane_percent_positive' in v) and ('DAPI' not in v)]
sc.pl.heatmap(adata[y=='CD8T'], feats, groupby='biopsy', log=False)

In [None]:
""" Make a general Immune class with CD45 only """
positive_channels = ['CD45_']
negative_channels = ['CD138', 'CD31', 'CD20', 'CD11c', 'PanCytoK', 
                     'HLA-DR', 'CD68', 'PDGFRb', 'aSMA', 'CD3e', 'CD4_', 'CD8_']
def channel_in_list(v, channel_list):
    for ch in channel_list:
        if ch in v:
            return True
    return False
positive_feats = [v for v in adata.var_names if channel_in_list(v, positive_channels)]
required_feats = [v for v in positive_feats if 'membrane_percent_positive' in v]

negative_feats = [v for v in adata.var_names if channel_in_list(v, negative_channels)]
exclude_feats = [v for v in negative_feats if 'membrane_percent_positive' in v]

cells = np.zeros(adata.shape[0], dtype=np.uint8)
print(f'starting with {np.sum(cells)} cells')
for f in required_feats:
    print(f'required feat: {f} ({np.sum(cells)})')
    v = adata[:, f].X.toarray().flatten()
    cells[ v > 0.30 ] += 1
cells = cells == len(required_feats)
print(f'Scooped up {np.sum(cells)} cells with all positive requirements')
y_tmp = y.copy()
y_tmp[cells] = 'Immune'
    
toss = np.zeros(adata.shape[0], dtype=bool)
for f in exclude_feats:
    v = adata[:, f].X.toarray().flatten()
    toss[ v > 0.05 ] = 1
    print(f'excluding feat: {f} ({np.sum((v>0.01) & (y_tmp=="Immune"))})')
print(f'Tossing {np.sum(toss&(y_tmp=="Immune"))} cells ')
y_tmp[toss&(y_tmp=='Immune')] = 'x'
y[y_tmp=='Immune'] = 'Immune'

feats = [v for v in adata.var_names if ('membrane_percent_positive' in v) and ('DAPI' not in v)]
sc.pl.heatmap(adata[y=='Immune'], feats, groupby='biopsy', log=False)

for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))

In [None]:
# """ Scoop up DCs. No other lineage markers. """
# positive_channels = ['CD11c']
# negative_channels = ['CD138', 'CD31', 'PanCytoK', 'CD68', 'PDGFRb', 'aSMA', 'CD3e', 'CD4_', 'CD8_']
# def channel_in_list(v, channel_list):
#     for ch in channel_list:
#         if ch in v:
#             return True
#     return False
# positive_feats = [v for v in adata.var_names if channel_in_list(v, positive_channels)]
# required_feats = [v for v in positive_feats if 'membrane_percent_positive' in v]

# negative_feats = [v for v in adata.var_names if channel_in_list(v, negative_channels)]
# exclude_feats = [v for v in negative_feats if 'membrane_percent_positive' in v]

# cells = np.zeros(adata.shape[0], dtype=np.uint8)
# print(f'starting with {np.sum(cells)} cells')
# for f in required_feats:
#     print(f'required feat: {f} ({np.sum(cells)})')
#     v = adata[:, f].X.toarray().flatten()
#     cells[ v > 0.3 ] += 1
# cells = cells == len(required_feats)
# y[cells] = 'DC'
# print(f'Scooped up {np.sum(cells)} cells with all positive requirements')
    
# toss = np.zeros(adata.shape[0], dtype=bool)
# for f in exclude_feats:
#     v = adata[:, f].X.toarray().flatten()
#     toss[ v > 0.05 ] = 1
#     print(f'excluding feat: {f} ({np.sum((v>0.05)&(y=="DC"))})')
# print(f'tossing {np.sum(toss & (y=="DC"))} cells ')

# y[toss&(y=='DC')] = 'x'
# print(f'Kept {np.sum(y=="DC")} cells')

# feats = [v for v in adata.var_names if ('nuclei_percent_positive' in v) and ('DAPI' not in v)]
# sc.pl.heatmap(adata[y=='DC'], feats, groupby='biopsy', log=False)

# for k in np.unique(y):
#     print(f'{k:<15}', np.sum(y==k))

In [None]:
""" Scoop up Macs. No other lineage markers. """
positive_channels = ['CD45_', 'CD68']
negative_channels = ['CD138', 'CD31', 'CD11c', 'PDGFRb', 'aSMA', 'CD3e', 'CD8_']
def channel_in_list(v, channel_list):
    for ch in channel_list:
        if ch in v:
            return True
    return False
positive_feats = [v for v in adata.var_names if channel_in_list(v, positive_channels)]
required_feats = [v for v in positive_feats if 'membrane_percent_positive' in v]

negative_feats = [v for v in adata.var_names if channel_in_list(v, negative_channels)]
exclude_feats = [v for v in negative_feats if 'membrane_percent_positive' in v]

cells = np.zeros(adata.shape[0], dtype=np.uint8)
print(f'starting with {np.sum(cells)} cells')
for f in required_feats:
    print(f'required feat: {f} ({np.sum(cells)})')
    v = adata[:, f].X.toarray().flatten()
    cells[ v > 0.25 ] += 1
cells = cells == len(required_feats)
y[cells] = 'Mac'
print(f'Scooped up {np.sum(cells)} cells with all positive requirements')
    
toss = np.zeros(adata.shape[0], dtype=bool)
for f in exclude_feats:
    v = adata[:, f].X.toarray().flatten()
    toss[ v > 0.05 ] = 1
    print(f'excluding feat: {f} ({np.sum((v>0.05)&(y=="Mac"))})')
print(f'tossing {np.sum(toss & (y=="Mac"))} cells ')

y[toss&(y=='Mac')] = 'x'

feats = [v for v in adata.var_names if ('membrane_percent_positive' in v) and ('DAPI' not in v)]
sc.pl.heatmap(adata[y=='Mac'], feats, groupby='biopsy', log=False)

for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))

In [None]:
""" Scoop up Epithelial. No other lineage markers. """
positive_channels = ['PanCytoK']
negative_channels = ['CD31', 'CD11c', 'CD68', 'PDGFRb', 'aSMA', 'CD3e', 'CD4_', 'CD8_']
def channel_in_list(v, channel_list):
    for ch in channel_list:
        if ch in v:
            return True
    return False
positive_feats = [v for v in adata.var_names if channel_in_list(v, positive_channels)]
required_feats = [v for v in positive_feats if 'membrane_percent_positive' in v]

negative_feats = [v for v in adata.var_names if channel_in_list(v, negative_channels)]
exclude_feats = [v for v in negative_feats if 'membrane_percent_positive' in v]

cells = np.zeros(adata.shape[0], dtype=np.uint8)
print(f'starting with {np.sum(cells)} cells')
for f in required_feats:
    print(f'required feat: {f} ({np.sum(cells)})')
    v = adata[:, f].X.toarray().flatten()
    cells[ v > 0.15 ] += 1
cells = cells == len(required_feats)
y[cells] = 'Epithelial'
print(f'Scooped up {np.sum(cells)} cells with all positive requirements')
    
toss = np.zeros(adata.shape[0], dtype=bool)
for f in exclude_feats:
    v = adata[:, f].X.toarray().flatten()
    toss[ v > 0.01 ] = 1
    print(f'excluding feat: {f} ({np.sum((v>0.05)&(y=="Epithelial"))})')
print(f'tossing {np.sum(toss & (y=="Epithelial"))} cells ')

y[toss&(y=='Epithelial')] = 'x'

feats = [v for v in adata.var_names if ('membrane_percent_positive' in v) and ('DAPI' not in v)]
sc.pl.heatmap(adata[y=='Epithelial'], feats, groupby='biopsy', log=False)

for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))

In [None]:
""" Scoop up Endothelial. No other lineage markers. """
positive_channels = ['CD31']
negative_channels = ['CD138', 'PanCytoK', 'CD11c', 'CD68', 'PDGFRb', 'aSMA', 'CD3e', 'CD4_', 'CD8_']
def channel_in_list(v, channel_list):
    for ch in channel_list:
        if ch in v:
            return True
    return False
positive_feats = [v for v in adata.var_names if channel_in_list(v, positive_channels)]
required_feats = [v for v in positive_feats if 'membrane_percent_positive' in v]

negative_feats = [v for v in adata.var_names if channel_in_list(v, negative_channels)]
exclude_feats = [v for v in negative_feats if 'membrane_percent_positive' in v]

cells = np.zeros(adata.shape[0], dtype=np.uint8)
print(f'starting with {np.sum(cells)} cells')
for f in required_feats:
    print(f'required feat: {f} ({np.sum(cells)})')
    v = adata[:, f].X.toarray().flatten()
    cells[ v > 0.15 ] += 1
cells = cells == len(required_feats)
y[cells] = 'Endothelial'
print(f'Scooped up {np.sum(cells)} cells with all positive requirements')
    
toss = np.zeros(adata.shape[0], dtype=bool)
for f in exclude_feats:
    v = adata[:, f].X.toarray().flatten()
    toss[ v > 0.01 ] = 1
    print(f'excluding feat: {f} ({np.sum((v>0.05)&(y=="Endothelial"))})')
print(f'tossing {np.sum(toss & (y=="Endothelial"))} cells ')

y[toss&(y=='Endothelial')] = 'x'

for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))
    
feats = [v for v in adata.var_names if ('membrane_percent_positive' in v) and ('DAPI' not in v)]
sc.pl.heatmap(adata[y=='Endothelial'], feats, groupby='biopsy', log=False)

In [None]:
""" Scoop up Stromal. No other lineage markers. """
positive_channels = ['PDGFRb', 'aSMA']
negative_channels = ['CD138', 'CD11c', 'CD68', 'CD3e', 'CD4_', 'CD8_', 'CD45']
def channel_in_list(v, channel_list):
    for ch in channel_list:
        if ch in v:
            return True
    return False
positive_feats = [v for v in adata.var_names if channel_in_list(v, positive_channels)]
required_feats = [v for v in positive_feats if 'membrane_percent_positive' in v]

negative_feats = [v for v in adata.var_names if channel_in_list(v, negative_channels)]
exclude_feats = [v for v in negative_feats if 'membrane_percent_positive' in v]

cells = np.zeros(adata.shape[0], dtype=np.bool)
print(f'starting with {np.sum(cells)} cells')
for f in required_feats:
    print(f'required feat: {f} ({np.sum(cells)})')
    v = adata[:, f].X.toarray().flatten()
    cells[ v > 0.15 ] = 1
# cells = cells == len(required_feats)
y[cells] = 'Stromal'
print(f'Scooped up {np.sum(cells)} cells with all positive requirements')
    
toss = np.zeros(adata.shape[0], dtype=bool)
for f in exclude_feats:
    v = adata[:, f].X.toarray().flatten()
    toss[ v > 0.005 ] = 1
    print(f'excluding feat: {f} ({np.sum((v>0.05)&(y=="Stromal"))})')
print(f'tossing {np.sum(toss & (y=="Stromal"))} cells ')

y[toss&(y=='Stromal')] = 'x'

for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))
    
feats = [v for v in adata.var_names if ('membrane_percent_positive' in v) and ('DAPI' not in v)]
sc.pl.heatmap(adata[y=='Stromal'], feats, groupby='biopsy', log=False)

In [None]:
# Shove non-specific groups back into x
# y[y=='Immune'] = 'x'
# y[y=='neg'] = 'x'

for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))

In [None]:
## Toss low DAPI cells - for each slide
dapi = adata[:, 'DAPI_membrane_mean'].X.toarray().flatten()
biopsy = np.array(adata.obs.biopsy)

for b in np.unique(biopsy):
    bi = biopsy == b
    q = np.quantile(dapi[bi], 0.01)

    print(q)
    y[(bi)&(dapi < q)] = 'x'
    
for k in np.unique(y):
    print(f'{k:<15}', np.sum(y==k))

In [None]:
rcParams['figure.dpi'] = 100
adata.obs['training_labels'] = y
plt.figure(figsize=(r*5,5))
sc.pl.embedding(adata[adata.obs.training_labels!='x'], basis='coordinates_shift', color='training_labels', 
                s=1, ax=plt.gca())

adata.obs['training_labels'].value_counts()

In [None]:
adata.obs.loc[:,'training_labels'].to_csv('/storage/tmp/labels.csv')
!ls /storage/tmp

# Train and predict

In [None]:
labels = pd.read_csv("/storage/tmp/labels.csv", index_col=0)
# labels
adata.obs['training_labels'] = labels['training_labels']
adata.obs['training_labels'].value_counts()

In [None]:
use_features =  ['CD45', 'CD20', 'CD68', 'CD31', 'CD3e', 
                 'CD11c', 'CD138', 'PDGFRb', 'aSMA', 'CD8', 'CD4', 
                 'PanCytoK']
var_names = []
for f in use_features:
    var_names += [v for v in adata.var_names if f'{f}_' in v]

# Toss features that aren't expression values
loggable_features = [v for v in var_names if 'percent' not in v]
loggable_features = [v for v in loggable_features if 'std' not in v]

non_loggable_features = list(set(var_names) - set(loggable_features))

In [None]:
X_train = np.log1p(adata[adata.obs.training_labels!='x',loggable_features].X.toarray())
X_pred = np.log1p(adata[adata.obs.training_labels=='x',loggable_features].X.toarray())

X_train_nl = adata[adata.obs.training_labels!='x',non_loggable_features].X.toarray()
X_pred_nl = adata[adata.obs.training_labels=='x',non_loggable_features].X.toarray()

X_train_extra = adata.obs.loc[adata.obs.training_labels!='x', additional_features].values
X_pred_extra = adata.obs.loc[adata.obs.training_labels=='x', additional_features].values

X_train = np.concatenate([X_train, X_train_nl, X_train_extra], axis=1)
X_pred = np.concatenate([X_pred, X_pred_nl, X_pred_extra], axis=1)

from sklearn.preprocessing import RobustScaler, MinMaxScaler
# scaler = RobustScaler().fit(X_train)
# X_train = scaler.transform(X_train)
# X_pred = scaler.transform(X_pred)

print(X_train.shape)
print(X_pred.shape)

Y_train = np.array(adata.obs.loc[adata.obs.training_labels!='x','training_labels'].values)
Y_levels, Y_train_numeric = np.unique(Y_train, return_inverse=True)
Y_mapper = {k: v for k,v in enumerate(Y_levels)}
print(Y_train.shape)

from sklearn.model_selection import train_test_split
X_train, X_test, Y_train_numeric, Y_test_numeric = train_test_split(
    X_train, Y_train_numeric, train_size=0.8)

print(X_train.shape, X_test.shape)
print(Y_train_numeric.shape, Y_test_numeric.shape)


# knn = cuml.neighbors.KNeighborsClassifier(n_neighbors=50, metric='euclidean')
# knn.fit(X_train, Y_train_numeric)
# Y_pred = knn.predict(X_pred)
# Y_pred_proba = knn.predict_proba(X_pred)

rf = cuml.ensemble.RandomForestClassifier(max_samples=0.8, max_depth=10, 
                                          max_features=1.,
                                          n_estimators=125,
                                          n_bins=16)
rf.fit(X_train, Y_train_numeric)
Y_pred = rf.predict(X_pred)
Y_pred_test = rf.predict(X_test)
Y_pred_proba = rf.predict_proba(X_pred)


# Clobber tcells
# Y_pred[Y_pred_proba[:,1]>0.1] = 1
# Y_pred[Y_pred_proba[:,2]>0.1] = 2

Y_pred_remap = np.array([Y_mapper[i] for i in Y_pred])
probs = np.max(Y_pred_proba,axis=1)


In [None]:
emb_train = cuml.UMAP(n_neighbors=20).fit_transform(X_train)

In [None]:
plt.figure(figsize=(4,4), dpi=180)
for y in np.unique(Y_train_numeric):
    i = Y_train_numeric == y
    label = Y_levels[y]
    print(label)
    if label=='neg':
        continue
    plt.scatter(emb_train[i,0], emb_train[i,1], label=Y_levels[y], s=1)
plt.legend(bbox_to_anchor=(1,1), markerscale=5)

In [None]:
Y_levels

In [None]:
X_pred_non_neg = X_pred[Y_pred < 7]
emb_pred = cuml.UMAP(n_neighbors=20, n_components=3).fit_transform(X_pred_non_neg)

In [None]:
# plt.figure(figsize=(4,4), dpi=180)
# ax = plt.gca()
for y in np.unique(Y_pred):
    label = Y_levels[int(y)]
    print(label)
    i = Y_pred[Y_pred < 7] == int(y)
    if sum(i) == 0:
        continue
        
    plt.figure(figsize=(3,3), dpi=180)
    xplot = emb_pred[i,0]
    yplot = emb_pred[i,1]
    xplot = np.concatenate([xplot, [min(emb_pred[:,0]), max(emb_pred[:,0])]])
    yplot = np.concatenate([yplot, [min(emb_pred[:,1]), max(emb_pred[:,1])]])
    plt.hist2d(xplot, yplot, bins=100, density=True,
              ) 
    plt.title(Y_levels[int(y)])

In [None]:
from sklearn.metrics import confusion_matrix
cm = confusion_matrix(Y_test_numeric, Y_pred_test, normalize='true')
rcParams['figure.dpi'] = 60
cm = pd.DataFrame(cm, index=Y_levels, columns=Y_levels)
sns.heatmap(cm, cmap='Reds', mask=cm==0)

In [None]:
predicted_labels = np.zeros(adata.shape[0], dtype=object)
predicted_labels[adata.obs.training_labels!='x'] = 'training'
predicted_labels[adata.obs.training_labels=='x'] = Y_pred_remap

predicted_proba = np.zeros(adata.shape[0], dtype='float')
predicted_proba[adata.obs.training_labels!='x'] = 1.
predicted_proba[adata.obs.training_labels=='x'] = probs

print(np.sum(predicted_proba < 0.2))
predicted_labels[predicted_proba < 0.2] = 'low_confidence'

dapi = adata[:, 'DAPI_membrane_mean'].X.toarray().flatten()
biopsy = np.array(adata.obs.biopsy)

for b in np.unique(biopsy):
    bi = biopsy == b
    q = np.quantile(dapi[bi], 0.01)
    
    predicted_labels[bi & (dapi<q)] = 'not_a_cell'
    predicted_proba[bi & (dapi<q)] = 0

rcParams['figure.dpi'] = 200
adata.obs['predicted_labels'] = predicted_labels
adata.obs['predicted_proba'] = predicted_proba
plt.figure(figsize=(r*4,4))

plt_idx = ~adata.obs.predicted_labels.isin(['training'])
sc.pl.embedding(adata[plt_idx], basis='coordinates_shift', 
                color='predicted_proba', 
                s=2, ax=plt.gca())

plt.figure(figsize=(r*12,12))
sc.pl.embedding(adata[plt_idx], basis='coordinates_shift', 
                color='predicted_labels', 
                s=1, ax=plt.gca())

In [None]:
adata.obs['predicted_labels'].value_counts()

In [None]:
feats = np.array([v for v in var_names if ('nuclei' in v ) and ('mean' in v) ]) 
sc.pl.dotplot(adata[adata.obs.predicted_labels!='training'], 
              feats, 'predicted_labels', 
              standard_scale='var',
              swap_axes=True)

In [None]:
""" Join the training and predicted labels """
all_labels = np.array(adata.obs.predicted_labels.values)
ix = adata.obs.training_labels!='x'
all_labels[ix] = np.array(adata.obs.loc[ix,'training_labels'])
adata.obs['celltype'] = pd.Categorical(all_labels)
adata.obs['celltype'].value_counts()

In [None]:
feats = np.array([v for v in var_names if ('nuclei' in v ) and ('mean' in v) ]) 
# sc.pl.dotplot(adata[adata.obs.predicted_labels=='training'], feats, 'celltype', 
#               standard_scale='var',
#               swap_axes=True)

sc.pl.dotplot(adata, feats, 'celltype', 
              standard_scale='var',
              swap_axes=True)

In [None]:
plt.figure(figsize=(r*6,6))
sc.pl.embedding(adata[plt_idx], basis='coordinates_shift', 
                color='celltype', 
                s=1, ax=plt.gca())

In [None]:
rcParams['figure.dpi'] = 90
for p in np.unique(predicted_labels):
    plt.figure(figsize=(4,1))
    _ = plt.hist(predicted_proba[predicted_labels == p], bins=100)
    plt.title(p)

In [None]:
help(scrna.plot_group_percents)

In [None]:
import scrna
scrna.plot_group_percents(adata, 'celltype', 'biopsy')
scrna.plot_group_percents(adata, 'celltype', 'sample_id')
scrna.plot_group_percents(adata[~adata.obs.training_labels.isin(['x', 'neg'])], 
                          'training_labels', 'sample_id',
                          annotate_total=True
                         )

In [None]:
rcParams['figure.dpi'] = 100
r = np.max(np.abs(adata.obsm['coordinates_shift']), axis=0)
r = r[0]/r[1]
plt.figure(figsize=(r*6,6))
sc.pl.embedding(adata, basis='coordinates_shift', color='sample_id_printing', 
                s=1, ax=plt.gca(), legend_loc='on data')

In [None]:
rcParams['figure.dpi'] = 100
r = np.max(np.abs(adata.obsm['coordinates_shift']), axis=0)
r = r[0]/r[1]
plt.figure(figsize=(r*6,6))
sc.pl.embedding(adata, basis='coordinates_shift', color='biopsy', 
                s=1, ax=plt.gca(), legend_loc='on data')

In [None]:
adata.write(path)