In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
from sklearn.metrics import f1_score,precision_recall_fscore_support
import torch.nn.functional as F
import torch
import math

def f1(y_true, y_pred):
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(y_true, y_pred, average='macro')
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(y_true, y_pred, average='micro')
    return (precision_macro, recall_macro, f1_macro), (precision_micro, recall_micro, f1_micro)

marker = pd.read_csv('/data_path/marker.csv', index_col = 0)
df = pd.read_csv("/data_path/data.csv", index_col = 0)
cell_type = pd.read_csv("/data_path/cell_type.csv", index_col = 0)

# Get real cell labels for calculating classification accuracy
y_true = []
adata_t_marker_nouk = sc.AnnData(df)
dicts_cell_type = {}
for index,value in enumerate(marker.index):
    dicts_cell_type[value] = index
temp = list(cell_type)
for value in temp:
    y_true.append(dicts_cell_type[value])


adata_dense = adata_t_marker_nouk.X
dict_gene_index = {}
for index in range(len(adata_t_marker_nouk.var_names)):
    dict_gene_index[adata_t_marker_nouk.var_names[index]] = index

dicts_count = {}
all_marker = []
for index,value in enumerate(marker.index):
    list_gene = []
    for gene_index,gene in enumerate(marker.iloc[index]):
        if gene == 1 and marker.columns[gene_index] in dict_gene_index:
            list_gene.append(dict_gene_index[marker.columns[gene_index]])
    dicts_count[dicts_cell_type[value]] = list_gene
    all_marker = all_marker + list_gene

marker_count = []
dict_marker_count = {}
for value in set(all_marker):
    marker_count.append(all_marker.count(value))
    dict_marker_count[value] = all_marker.count(value)

list_count = []
for index in range(len(adata_dense)):
    temp_arr = np.zeros(len(dicts_count))
    for key,value in dicts_count.items():
        temp_weight = []
        for ind,val in enumerate(value):
            temp_weight.append(1-((dict_marker_count[val]-min(marker_count))/(max(marker_count)-min(marker_count)+1) ))
        temp_weight = F.softmax(torch.from_numpy(np.array(temp_weight))).numpy()*len(temp_weight)
        temp_arr[key] = (sum(adata_dense[index][value]*temp_weight)/(pow(len(value), 1/2)))
    list_count.append(temp_arr)

sim = F.softmax(torch.from_numpy(np.array(list_count)),dim=1)
y_pred = torch.max(sim, dim=1)[1].numpy()

# Get sctype score and based on this calculate Unique, Frequent and co-occurrence of Marker.

dicts_weight = {}
for key,values in dicts_count.items():
    weight = []
    for value in values:
        wei1= sum(adata_dense[y_pred==key][:,value]>0)/(sum(adata_dense[:,value]>0))
        wei2 = F.relu(torch.from_numpy(np.array(np.mean(adata_dense[y_pred==key][:,value]))))
        wei3 = math.log10(torch.from_numpy(np.array(len(adata_dense)/(sum(adata_dense[:,value]>0)))))
    dicts_weight[key] = weight

list_count = []
for index in range(len(adata_dense)):
    temp_arr = np.zeros(len(dicts_count))
    for indexs,value in dicts_count.items():
        temp_weight = []
        temp_weight_2 = []
        for ind,val in enumerate(value):
            temp_weight.append(dicts_weight[indexs][ind])
            temp_weight_2.append(1-((dict_marker_count[val]-min(marker_count))/(max(marker_count)-min(marker_count)+1)))
        new_weight = temp_weight*np.array(temp_weight_2)
        new_weight = F.softmax(torch.from_numpy(new_weight)).numpy()*len(new_weight)
        temp_arr[indexs] = sum(new_weight*adata_dense[index][value])/(pow(len(new_weight), 1/2))
    list_count.append(temp_arr)

sim = torch.from_numpy(np.array(list_count))
y_pred = torch.max(sim, dim=1)[1].numpy()
y_soft = F.softmax(sim,dim=1).numpy()
(precision_macro, recall_macro, f1_macro), (precision_micro, recall_micro, f1_micro)= np.round(f1(np.array(y_true), np.array(y_pred)), 5)
print('F1 score: f1_macro = {}, f1_micro = {}'.format(f1_macro, f1_micro))
print('precision score: precision_macro = {}, precision_micro = {}'.format(precision_macro, precision_micro))
print('recall score: recall_macro = {}, recall_micro = {}'.format(recall_macro, recall_micro))

adata_t_marker_nouk.obs["broad_cell_type"] = y_true
adata_t_marker_nouk.uns["Celltype_soft"] = y_soft
adata_t_marker_nouk.uns["dicts_cell_type"] = dicts_cell_type
adata_t_marker_nouk.obs["y_pred"] = y_pred

adata_t_marker_nouk.write("/data_path/data_CAS.h5ad")