In [None]:
import os
os.chdir('/home/evanlee/PBMC_Hao/')


In [1]:
# import AD2_w_utils as ad2
from AD2_w_utils import *
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import time
import scipy
import scanpy as sc
import sklearn
import copy
# %matplotlib inline


In [None]:
# Demo
url_1 = 'https://raw.githubusercontent.com/YinchengChen23/ADlasso/main/data/crc_zeller/ASV_vst.txt'
url_2 = 'https://raw.githubusercontent.com/YinchengChen23/ADlasso/main/data/crc_zeller/ASV_table.txt'
url_3 = 'https://raw.githubusercontent.com/YinchengChen23/ADlasso/main/data/crc_zeller/metadata.txt'
Data = pd.read_csv(url_1, sep="\t")
Data = Data.T           # Variance-stabilizing transformation was conducted by DESeq2
# we using z-normalization data as input-data
Data_std = scipy.stats.zscore(Data, axis=0, ddof=0)
RawData = pd.read_csv(url_2, sep="\t")
RawData = RawData.T  # Raw count data, was used as an assessment of prevalence
Cohort = pd.read_csv(url_3, sep="\t")                        # Metadata
Label = Cohort['Class'].tolist()

print('This data contains',
      Data_std.shape[0], 'samples and', Data_std.shape[1], 'features')
print(Label[0:10], np.unique(Label))

# get_prevalence(data, [0,1,2, ..., 129])
pvl0 = get_prevalence(RawData, np.arange(RawData.shape[0]))
res0 = ADlasso2(lmbd=1e-5, alpha=0.9, echo=True)
start = time.time()
res0.fit(Data_std, Label, pvl0)  # .fit(X, y, prevalence)
# minimum epoch =  9999 ; minimum lost =  6.27363842795603e-05 ; diff weight =  0.002454951871186495
end = time.time()

print('median of selected prevalence :', np.median(
    [pvl0[i] for i, w in enumerate(res0.feature_set) if w != 0]))
print('total selected feature :', np.sum(res0.feature_set))
print("Total cost：%f sec" % (end - start))

# Export selection result
res0.writeList('./demo_selectedList.txt', Data_std.columns)

## Read PBMC dataset

In [2]:
data_path = '/Users/evanli/Documents/Research_datasets/PBMC_Hao/'
# os.chdir('/home/evanlee/PBMC_Hao/')
# data_path = ''
adata_raw = sc.read(data_path + 'Hao_PBMC.h5ad')
print(adata_raw.shape)  # row is cells, column is gene
# (161764, 20568)


(161764, 20568)


In [3]:
# Metadata
types_l1 = adata_raw.obs['celltype.l1'].unique()  # 8
types_l2 = adata_raw.obs['celltype.l2'].unique()  # 31
types_l3 = adata_raw.obs['celltype.l3'].unique()  # 58

celltype_df = adata_raw.obs[['celltype.l1', 'celltype.l2', 'cell_type']]
celltype_df = celltype_df.sort_values(['celltype.l1', 'celltype.l2'])

celltype_dict = {k: [] for k in sorted(types_l1)}
for i in range(len(celltype_df)):
    level_1 = celltype_df.iloc[i, 0]
    level_2 = celltype_df.iloc[i, 1]
    if level_2 not in celltype_dict[level_1]:
        celltype_dict[level_1].append(level_2)

celltype_dict

{'B': ['B intermediate', 'B memory', 'B naive', 'Plasmablast'],
 'CD4 T': ['CD4 CTL',
  'CD4 Naive',
  'CD4 Proliferating',
  'CD4 TCM',
  'CD4 TEM',
  'Treg'],
 'CD8 T': ['CD8 Naive', 'CD8 Proliferating', 'CD8 TCM', 'CD8 TEM'],
 'DC': ['ASDC', 'cDC1', 'cDC2', 'pDC'],
 'Mono': ['CD14 Mono', 'CD16 Mono'],
 'NK': ['NK', 'NK Proliferating', 'NK_CD56bright'],
 'other': ['Doublet', 'Eryth', 'HSPC', 'ILC', 'Platelet'],
 'other T': ['MAIT', 'dnT', 'gdT']}

In [4]:
# standardize expression data
adata_std = copy.deepcopy(adata_raw)
# Total-count normalize the data matrix X to 10,000 reads per cell
sc.pp.normalize_total(adata_std, target_sum=1e4)
# Log
sc.pp.log1p(adata_std)


In [None]:
# create binary labels for B naive
# TODO：create labels without creating adata_raw.obs['is B naive']
# adata_raw.obs['is B naive'] = ['B_naive' if x ==
#                                'B naive' else "False" for x in adata_raw.obs['celltype.l2']]
# labels = adata_raw.obs['is B naive'].tolist()
labels = [1 if x == 'B naive' else 0 for x in adata_raw.obs['celltype.l2']]

# create index for B naive
b_naive_indices = [idx for idx, cell_type in enumerate(
    adata_raw.obs['celltype.l2']) if cell_type == 'B naive']


In [None]:
# pvl = get_prevalence(adata_raw.X, np.arange(adata_raw.shape[0]))
pvl = get_prevalence(adata_raw.X, b_naive_indices)
res = ADlasso2(lmbd=1e-5, echo=True, device='cpu')

st = time.time()
res.fit(adata_std.X, labels, pvl)  # .fit(X, y, prevalence)
et = time.time()

print('median of selected prevalence :', np.median(
    [pvl[i] for i, w in enumerate(res.feature_set) if w != 0]))
print('total selected feature :', np.sum(res.feature_set))
print("Total cost：%f sec" % (et - st))

# write feature list
res.writeList('./Bnaive_selectedList.txt', adata_std.X.columns)


In [9]:
adata_std.var_names

Index(['ENSG00000238009', 'ENSG00000237491', 'ENSG00000225880',
       'ENSG00000230368', 'ENSG00000188976', 'ENSG00000187961',
       'ENSG00000187583', 'ENSG00000272512', 'ENSG00000188290',
       'ENSG00000187608',
       ...
       'ENSG00000260213', 'ENSG00000274363', 'ENSG00000006042',
       'ENSG00000101280', 'ENSG00000089101', 'ENSG00000267124',
       'ENSG00000105523', 'ENSG00000282602', 'ENSG00000228404',
       'ENSG00000228137'],
      dtype='object', name='gene_ids', length=20568)

In [10]:
# Multithreaded version of ADLasso2 feature selection algorithm
import time
import numpy as np
from typing import List
from anndata import AnnData
from concurrent.futures import ThreadPoolExecutor

def select_features_by_celltype(adata_raw: AnnData, adata_std: AnnData, celltype: str) -> List[str]:
    print('=====================')
    print('Selecting features for', celltype)

    # create binary labels for the given cell type
    labels = [1 if x == celltype else 0 for x in adata_raw.obs['celltype.l2']]
    
    # create index for the given cell type
    celltype_indices = [idx for idx, ct in enumerate(adata_raw.obs['celltype.l2']) if ct == celltype]
    
    # calculate prevalence of the given cell type
    pvl = get_prevalence(adata_raw.X, celltype_indices)
    
    # run ADLasso2 algorithm to select features
    res = ADlasso2(lmbd=1e-5, echo=True, device='cpu')
    st = time.time()
    res.fit(adata_std.X, labels, pvl)
    et = time.time()
    
    # print summary statistics
    median_pvl = np.median([pvl[i] for i, w in enumerate(res.feature_set) if w != 0])
    total_features = np.sum(res.feature_set)
    print('---------------------')
    print('Summary statistics for ', celltype)
    print(f"Median of selected prevalence: {median_pvl}")
    print(f"Total selected features: {total_features}")
    print(f"Total cost: {et - st} sec")
    
    # write feature list to file
    feature_list = adata_raw.var_names[res.feature_set != 0]
    try:
        print('writing var_names')
        res.writeList(f"./{celltype}_selectedList.txt", adata_std.var_names)
    except:
        print('writing feature_list')
        res.writeList(f"./{celltype}_selectedList.txt", feature_list)
    
    return feature_list

def select_features_by_celltype_multithread(adata_raw: AnnData, adata_std: AnnData, celltypes: List[str], max_workers: int = 4) -> List[List[str]]:
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [executor.submit(select_features_by_celltype, adata_raw, adata_std, celltype) for celltype in celltypes]
        results = [future.result() for future in futures]
    return results

In [11]:
# celltype_query = ['B intermediate', 'B memory', 'NK', 'NK Proliferating']
celltype_query = ['NK Proliferating']

results = select_features_by_celltype_multithread(adata_raw, adata_std, celltype_query, max_workers=4)

# NK Proliferating has 885 features
# res.writelist() resulting in 885 features, with responding weights

Selecting features for NK Proliferating
---------------------
Summary statistics for  NK Proliferating
Median of selected prevalence: 0.09124087591240876
Total selected features: 885
Total cost: 1156.285015821457 sec
writing var_names


In [12]:
results

[Index(['ENSG00000157933', 'ENSG00000215788', 'ENSG00000049245',
        'ENSG00000162444', 'ENSG00000054523', 'ENSG00000142657',
        'ENSG00000177000', 'ENSG00000116691', 'ENSG00000159339',
        'ENSG00000158825',
        ...
        'ENSG00000156273', 'ENSG00000156299', 'ENSG00000159128',
        'ENSG00000205726', 'ENSG00000157557', 'ENSG00000183486',
        'ENSG00000160216', 'ENSG00000227039', 'ENSG00000129195',
        'ENSG00000237541'],
       dtype='object', name='gene_ids', length=885)]

In [None]:
for i in range(4):
    celltype = celltype_query[i]
    print(celltype, len(results[i]))
    filename = celltype + '_feature.txt'
    np.savetxt(filename, results[i], fmt='%s', delimiter='\t')

In [None]:
# Lambda tuning
# auto_scale(X_input, X_raw, y, step=50)
log_lmbd_range = auto_scale(adata_std.X, adata_raw.X, labels, step=50)

lmbd_range = np.exp(log_lmbd_range)
print(lmbd_range)
