In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container {width: 80% !important; }</style>"))

In [None]:
# import warnings
# warnings.filterwarnings("default")

In [None]:
import sys
import time
import scanpy as sc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib
from matplotlib import colors

In [None]:
myColors = ['#e6194b', '#3cb44b', '#ffe119', '#4363d8', '#f58231',
            '#911eb4', '#46f0f0', '#f032e6', '#bcf60c', '#fabebe',
            '#008080', '#e6beff', '#9a6324', '#fffac8', '#800000',
            '#aaffc3', '#808000', '#ffd8b1', '#000075', '#808080', 
            '#307D7E', '#000000', "#DDEFFF", "#000035", "#7B4F4B", 
            "#A1C299", "#300018", "#C2FF99", "#0AA6D8", "#013349", 
            "#00846F", "#8CD0FF", "#3B9700", "#04F757", "#C8A1A1", 
            "#1E6E00", "#DFFB71", "#868E7E", "#513A01", "#CCAA35"]

colors2 = plt.cm.Reds(np.linspace(0, 1, 128))
colors3 = plt.cm.Greys_r(np.linspace(0.7,0.8,20))
colorsComb = np.vstack([colors3, colors2])
mymap = colors.LinearSegmentedColormap.from_list('my_colormap', colorsComb)

In [None]:
import smashpy
sf = smashpy.smashpy()

# Loading annData object

In [None]:
obj = sc.read_h5ad('../../../External_datasets/10X_Healthy_Foetal_Liver_withOrigAnnot.h5ad')

In [None]:
print("%d genes across %s cells"%(obj.n_vars, obj.n_obs))

#### Data split

In [None]:
s = time.time()

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
data = obj.X.copy()

myDict = {}
for idx, c in enumerate(obj.obs["leiden_merged_final"].cat.categories):
    myDict[c] = idx

labels = []
for l in obj.obs["leiden_merged_final"].tolist():
    labels.append(myDict[l])

labels = np.array(labels)

X = data
y = labels

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42, stratify=y)

#### scGeneFit

In [None]:
from scGeneFit.functions import *

In [None]:
markers = get_markers(X_train, y_train, num_markers=30, method="centers", epsilon=1.0, redundancy=0.25)

In [None]:
genes = obj.var.index.tolist()

selectedGenes = [genes[m] for m in markers]

selectedGenes_dict = {}
selectedGenes_dict["group"] = selectedGenes

In [None]:
e = time.time()

#### Classifiers

In [None]:
sf.run_classifiers(obj, group_by="leiden_merged_final", genes=selectedGenes, classifier="KNN", balance=True, title="scGeneFit-KNN")

#### Heatmap selected genes

In [None]:
matplotlib.rcdefaults()
matplotlib.rcParams.update({'font.size': 11})
ax = sc.pl.DotPlot(obj,
                   selectedGenes,
                   groupby="leiden_merged_final",
                   standard_scale='var',
                   use_raw=True,
                   figsize=(7,15),
                   linewidths=2).style(cmap=mymap, color_on='square', grid=True, dot_edge_lw=1)
ax.swap_axes(swap_axes=True)
# ax.show()

ax.savefig("Figures/scGeneFit_top30.pdf")

# Elapsed time

In [None]:
print("%d genes across %s cells"%(obj.n_vars, obj.n_obs))

In [None]:
print('Elapsed time (s): ', e-s)