<a href="https://colab.research.google.com/github/dtabuena/Workshop/blob/main/RNA_Workshop/Gabaergic_Analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install scanpy
!pip install pybiomart
!pip install python-igraph
!pip install louvain
!pip install pynndescent


In [None]:
import h5py
import numpy as np
import scipy as sci
from matplotlib import pyplot as plt
import scanpy as sc
import tarfile
import os
import anndata as ad
import pandas as pd
import pybiomart
from tqdm import tqdm
import urllib.request
from IPython.display import clear_output
from matplotlib.pyplot import rc_context
from scipy import stats as st
os.chdir("C:/Users/dennis.tabuena/Dropbox (Gladstone)/0_Projects/_ReAnalyze_Zalocusky_2021")

def publishable_plots(FS=6):
    plt.rcParams.update({'font.size': FS,'axes.linewidth':.5,'figure.dpi':300,
                         'xtick.major.width': 0.5,'ytick.major.width': 0.5,
                         'figure.titlesize':FS,'axes.titlesize': FS,'xtick.labelsize': FS,
                         'ytick.labelsize':FS,'axes.labelsize': FS,'legend.fontsize': FS,
                         'figure.labelsize':FS})

    import urllib.request
    arial_link = 'https://raw.githubusercontent.com/dtabuena/Resources/main/Fonts/arial.ttf'
    filename = './arial.ttf'
    urllib.request.urlretrieve(arial_link, filename)
    plt.rcParams.update({'font.family': 'arial'})
    return None
publishable_plots(6)
sc.settings.verbosity = 'error'             # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.settings.set_figure_params(dpi=300, facecolor='white',fontsize=6,)

In [None]:
def trim_key(k):
    floxed_dict = {'GSM5106175_YH_KZ03_01':('E3fKI_Syn_Cre602_15m','GSM5106175_602_E3fKI_15_XX'),
                   'GSM5106176_YH_KZ03_03':('E4fKI_Syn_Cre475_15m','GSM5106176_475_E4fKI_15_XX')}
    for f in floxed_dict.keys():
        if f in k: return floxed_dict[f][1]
    k = k.replace('_raw_gene_bc_matrices_h5.h5',"")
    return k
def query_capitilaziation(gene,adata):
    try:
        return adata.var.index[ [g.lower() for g in list(adata.var.index)].index(gene.lower()) ]
    except:
        return gene + ' not_found'
def z_score(x,axis=-1):
    x=np.array(x)
    return (x-np.mean(x,axis=axis))/np.std(x,axis=axis)

In [None]:
os.chdir("C:/Users/dennis.tabuena/Dropbox (Gladstone)/0_Projects/_ReAnalyze_Zalocusky_2021")
url = 'https://www.ncbi.nlm.nih.gov/geo/download/?acc=GSE167497&format=file'
filename = './'+'zalocusky_indiv.tar'

try:
    for f in os.listdir('./indiv_animal_results'):
        print(f)
except:
    urllib.request.urlretrieve(url, filename)
    my_tar = tarfile.open(filename)
    my_tar.extractall('./indiv_animal_results') # specify which folder to extract to
    my_tar.close()
    for f in os.listdir('./indiv_animal_results'):
        print(f)


In [None]:
### Read, Combine, and Sample Split multiple 10x's

adata_dict = {}
for f in tqdm( os.listdir('./indiv_animal_results') ):
    a = sc.read_10x_h5('./indiv_animal_results/'+f)
    a.var_names_make_unique()
    sample_code = trim_key(f)
    a.obs['age_bin'] = str(int(np.ceil( int(sample_code.split("_")[3])/5)*5))+'m'
    a.obs['E_type'] = sample_code.split("_")[2]
    a.obs['mouse_ID'] = sample_code.split("_")[1]
    a.obs['well'] = sample_code.split("_")[4]
    a.obs['GSM'] = sample_code.split("_")[0]
    adata_dict[sample_code.split("_")[0]] = a
adata = ad.concat(adata_dict,axis = 0,label="Sample",index_unique="_")
adata = adata[['fKI' not in t for t in adata.obs.E_type], :]
adata_dict = {}
clear_output()
print('data_loaded.')




In [None]:
def pull_gene_annots(csv_loc='./mmusculus_coding_noncoding.csv',
                     my_git='https://raw.githubusercontent.com/dtabuena/Resources/main/Genetics/mmusculus_coding_noncoding.csv',
                     biomart_name='mmusculus',
                     biomart_keys=["ensembl_gene_id", "chromosome_name","transcript_biotype","external_gene_name"]):


    if os.path.exists('./mmusculus_coding_noncoding.csv'):
        print( 'Use local copy of musmus')
        annot_dd = pd.read_csv('./mmusculus_coding_noncoding.csv').set_index("external_gene_name")
    else:
        try:
            print( 'attempting to pull mus mus from git...')
            musmus_link = 'https://raw.githubusercontent.com/dtabuena/Resources/main/Genetics/mmusculus_coding_noncoding.csv'
            filename = './mmusculus_coding_noncoding.csv'
            urllib.request.urlretrieve(musmus_link, filename)
            annot_dd = pd.read_csv('./mmusculus_coding_noncoding.csv').set_index("external_gene_name")
        except:
            print('attempting to pull mus mus from biomart...')
            annot = sc.queries.biomart_annotations("mmusculus",["ensembl_gene_id", "chromosome_name","transcript_biotype","external_gene_name"],).set_index('ensembl_gene_id')
            uniq_inds = list(set(list(annot.index)))
            for r in tqdm(uniq_inds):
                match_bool = annot.index.str.contains(r)
                if np.sum(match_bool)>1:
                    new_val ='__'.join(list(annot.loc[r,'transcript_biotype']))
                    annot.at[r,'transcript_biotype']=new_val
            annot['is_coding']= annot.transcript_biotype.str.contains('coding')
            annot_dd = annot.drop_duplicates().set_index("external_gene_name")
            annot_dd.to_csv('./mmusculus_coding_noncoding.csv')

    coding_list = annot_dd.index[ annot_dd['is_coding'] ].to_list()
    return coding_list, annot_dd




def preprocess_andata10x(adata_og,pct_mito=0.25,min_genes=500,max_genes=2400,min_counts=500,max_counts=4500):

    print('pulling gene annotations...')
    coding_list, _ = pull_gene_annots()
    adata_og.var['mt'] = adata_og.var_names.str.startswith('mt-')
    adata_og.var['coding'] = [gene in coding_list for gene in adata_og.var_names]
    sc.pp.calculate_qc_metrics(adata_og, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

    adata_QC = adata_og.copy()

    print('Filtering...')
    adata_QC = adata_QC[adata_QC.obs.pct_counts_mt < pct_mito, :]
    print(str(np.sum(adata_og.obs.pct_counts_mt <pct_mito)) + f' cells with >{pct_mito}% removed')
    adata_QC = adata_QC[:, adata_QC.var.coding]
    print(str(np.sum(np.logical_not(adata_og.var.coding))) + ' non coding genes removed')
    sc.pp.filter_cells(adata_QC, min_genes=min_genes)
    sc.pp.filter_cells(adata_QC, max_genes=max_genes)
    sc.pp.filter_cells(adata_QC, min_counts=min_counts)
    sc.pp.filter_cells(adata_QC, max_counts=max_counts)
    fig,ax=plt.subplots(1,figsize=(1.5,1.5))
    sc.pl.scatter(adata_QC, x='total_counts', y='n_genes_by_counts',ax=ax)

    return adata_QC,adata_og

adata_QC = preprocess_andata10x(adata)[0]




In [None]:
def high_var_genes_dim_reduc(adata,min_mean = 0.25,max_mean = 4,min_disp=0.55):
    ''' The gene expression matrices were then log-normalized with a scale factor of 10,000,
    using the Seurat NormalizeData function57,58. Highly dispersed genes were selected using
    the Seurat FindVariableGenes function57,58,filtering for an average expression range of
    0.25–4 and a minimum dispersion of 0.55, resulting in a list of 2,197 genes.'''
    adata.raw = adata
    sc.pp.normalize_total(adata, target_sum=10000)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(adata, min_mean=min_mean, max_mean=max_mean, min_disp=min_disp)
    with rc_context({'figure.figsize': (1.5, 1.5)}):
        sc.pl.highly_variable_genes(adata)
    plt.tight_layout()
    print(np.sum(adata.var['highly_variable']),'hv genes')


    #### PCA
    sc.tl.pca(adata, svd_solver='arpack',n_comps=50)
    fig,ax=plt.subplots(figsize=(1,1))
    ax.plot(adata.uns['pca']['variance_ratio'][:25],'ok',markersize=1)
    quiet_PCA_plots(adata,['E_type','age_bin','mouse_ID'])

    return adata

def quiet_PCA_plots(adata,key_list,figsize=(2,2)):
    fig,ax=plt.subplots(1,len(key_list),figsize=(figsize[0]*len(key_list),figsize[1]))
    if len(key_list) == 1: ax=[ax]
    for key in key_list:
        key_ind = key_list.index(key)
        key_types = sorted(list(set( adata.obs[key] )))
        for k in key_types:
            is_k = adata.obs[key]==k
            ax[key_ind].scatter(adata.obsm['X_pca'][is_k,0],adata.obsm['X_pca'][is_k,1],s=2,marker='.',linewidth=0,edgecolors=None,label=k)
        if len(key_types)<8: ax[key_ind].legend(key_types,loc='best',markerscale=3)
        ax[key_ind].set_title(key)
        plt.tight_layout()
    return None

############## GABAERGIC Filter #####################
adata_GABA = adata_QC.copy()
age_dict = {'5m':'05m', '10m': '10m','15m': '15-20m','20m': '15-20m'}
adata_GABA.obs['age_bin'] = [ age_dict[a] for a in adata_GABA.obs['age_bin'] ]

adata_GABA.obs['Gad1_Pos'] = z_score(sc.get.obs_df(adata_GABA,'Gad1'))>0.5
adata_GABA.obs['Syn1_pos'] = z_score(sc.get.obs_df(adata_GABA,'Syn1'))>0.5

is_gaba = adata_GABA.obs['Gad1_Pos']
adata_GABA.obs['Gabaergic'] = is_gaba
adata_GABA = adata_GABA[ adata_GABA.obs['Gabaergic'] ,:]
sc.pp.filter_genes(adata_GABA, min_counts=50)


############## Dim Reduction #####################
adata_GABA = high_var_genes_dim_reduc(adata_GABA)

In [None]:
def umap_and_cluster(adata, n_neighbors=10, n_pcs=12,resolution=.6,plot_keys=['Cluster (nn)'],size = 1):
    sc.pp.neighbors(adata, n_neighbors=n_neighbors, n_pcs=n_pcs,random_state=42)
    sc.tl.louvain(adata,resolution=resolution,random_state=42)
    sc.tl.paga(adata)
    sc.tl.umap(adata,random_state=42)
    adata.obs['Cluster (nn)']= adata.obs['louvain']
    with rc_context({'figure.figsize': (2.5, 2.5)}):
        sc.pl.umap(adata,add_outline=False, legend_loc='on data', color=plot_keys,size=size)
    return adata

adata_GABA = umap_and_cluster(adata_GABA,plot_keys=['Cluster (nn)'])

In [None]:
def explore_umap(adata_GABA,key_list=[],size=1,legend_loc=None):
    with rc_context({'figure.figsize': (1.5,1.5)}):
        sc.pl.umap(adata_GABA, legend_loc=legend_loc, color=key_list,vmin=0,size=size) # add_outline=True,
        plt.tight_layout()

def get_gene_cluster(adata,gene):
    gene_bool = z_score(sc.get.obs_df(adata,gene))>2
    gene_cluster_list = np.array(adata.obs['louvain'][gene_bool])
    gene_clust = st.mode(gene_cluster_list)
    return gene_clust[0][0]



explore_umap(adata_GABA,['Cluster (nn)','Gad1','Syn1'],legend_loc='on data')
explore_umap(adata_GABA,['Sst' ,'Pvalb','Reln','Vip'])
explore_umap(adata_GABA,['E_type', 'Apoe'])

sst_clust = get_gene_cluster(adata_GABA,'Sst')
pv_clust = get_gene_cluster(adata_GABA,'Pvalb')

print('sst_clust',sst_clust)
print('pv_clust',pv_clust)

In [None]:
#### Get Marker Genes
age = adata_GABA.obs['age_bin']
e_type = adata_GABA.obs['E_type']
louvain = adata_GABA.obs['louvain']
clusters= sorted(list(set(louvain)))
ages  = sorted(list(set(age)))
e_types  =list(set(e_type))



sc.tl.rank_genes_groups(adata_GABA, 'louvain', method='wilcoxon',key_added='maker_genes',pts=True,use_raw =True)
# sc.tl.rank_genes_groups(adata_GABA, 'louvain', method='logreg',pts=True,use_raw =True)

maker_genes_df = pd.DataFrame(adata_GABA.uns['maker_genes']['names'])
maker_genes_df_LFC = pd.DataFrame(adata_GABA.uns['maker_genes']['logfoldchanges'])

top_30 = maker_genes_df[:30]
lfc_thresh=4

marker_genes = [m for m in set(top_30[maker_genes_df_LFC[:30]>lfc_thresh].values.flatten()) if isinstance(m,str)]
# display(marker_genes)
marker_genes = marker_genes + ['Sst', 'Vip', 'Pvalb']


import os
import shutil
path = './cluster_markers/'
if not os.path.exists(path): os.makedirs(path)

# sc.pl.stacked_violin(adata_GABA, marker_genes, groupby='louvain');

for c in clusters:
     # with rc_context({'figure.figsize': (4,1),'lines.linewidth': 0}): sc.pl.rank_genes_groups_violin(adata_GABA, groups=c, n_genes=16,size=0,scale='width')
    mark_df = sc.get.rank_genes_groups_df(adata_GABA, group =c,key='maker_genes')
    new_file = os.path.join(path,f'cluster{c}_markers.csv')
    if os.path.exists(new_file): os.remove(new_file)
    mark_df[:20].to_csv(new_file)

In [None]:
kcn_list = [n for n in adata_GABA.var.index if 'kcn' in str(n).lower()]
kcn_list.sort()




In [None]:
def deg_analysis(adata_temp,adata_name,new_file_path,analysis_pairs,group_key,log2fc_extrema=[-15,15],label_cutoff=1):
    '''
    Takes in anndata RNA object and a stratifying key and returns
    DEGs based on given grouppings. Also builds volcano plots.
    '''
    adata_temp.raw = adata_temp


    '''Perform statistics and write to DF'''
    # print(os.getcwd())
    deg_df_dict={}
    for key,pair in analysis_pairs.items():
        sc.tl.rank_genes_groups(adata_temp,use_raw=True,groupby=group_key,groups=pair,reference=pair[1],key_added=key,method='wilcoxon',tie_correct=True)
        deg_df = sc.get.rank_genes_groups_df(adata_temp, group=pair[0],key=key,pval_cutoff=1,log2fc_min=log2fc_extrema[0], log2fc_max=log2fc_extrema[1])
        deg_df = deg_df.set_index('names')
        deg_df_dict[key] = deg_df
        new_file_name = os.path.join(new_file_path,adata_name+'_'+key+'.csv')
        deg_df.to_csv(new_file_name)
        # print('     ',new_file_name)



    ''' Plot Volcanoes '''
    fig_volcano,ax=plt.subplots(1,len(deg_df_dict),figsize=(len(deg_df_dict)*2,3))
    for key,df in deg_df_dict.items():
        key_ind = list(deg_df_dict.keys()).index(key)
        ax[key_ind].set_title(key)

        genes= df.index
        lfc = df['logfoldchanges']
        pval = df['pvals']
        neg_log10_pval = -np.log10(pval)

        for i in range(len(genes)):
            if 'kcn' in str(genes[i]).lower():
                if not np.isnan(lfc[i]*neg_log10_pval[i]):
                    ax[key_ind].scatter(lfc[i],neg_log10_pval[i],c='k',s=3)
                    if pval[i]<label_cutoff:
                        ax[key_ind].text(lfc[i],neg_log10_pval[i],genes[i])
    x_lim=[]
    y_lim=[]
    for a in ax:
        x_lim.append(a.get_xlim())
        y_lim.append(a.get_ylim())

    x_etr = np.max(np.abs(x_lim))
    y_etr = np.max(np.abs(y_lim))

    for a in ax:
        a.set_xlim(-x_etr*1.1,x_etr*1.1)
        a.set_ylim(0,y_etr*1.1)
        a.set_xlabel('Log2 Fold Change')
        a.set_ylabel('-log10(pvalue)')
        a.axhline(-np.log10(0.05),c='k',linestyle=":",linewidth=.6,label='p=0.05')

    fig_volcano.suptitle(adata_name.replace('_',' ').title())
    plt.tight_layout()
    return adata_temp, deg_df_dict,fig_volcano,ax



In [None]:
###### SST Cluster KCN DEGs #####
path='./SST_CLUSTER_DEG_KCN/'
try: os.makedirs(path)
except: None
adata_sst_cluster = adata_GABA.copy()
adata_sst_cluster = adata_sst_cluster[adata_sst_cluster.obs['louvain']==sst_clust,:]
adata_sst_cluster.var['is_KCN'] = [g in kcn_list for g in adata_sst_cluster.var.index]
adata_sst_cluster.obs['age_geno'] = adata_sst_cluster.obs['age_bin'].str.cat(adata_sst_cluster.obs['E_type'],sep='_')
ages = sorted(list(set(adata_sst_cluster.obs['age_bin'])))
analysis_pairs = { a+'_E4vE3':(a+"_E4", a+"_E3") for a in ages}
adata_name = 'SST_cluster_KCNs'
sst_cluster_deg_data = deg_analysis(adata_sst_cluster,adata_name,path,analysis_pairs,group_key = 'age_geno',log2fc_extrema=[-15,15])


###### PV Cluster KCN DEGs #####
path='./PV_CLUSTER_DEG_KCN/'
try: os.makedirs(path)
except: None
adata_pv_cluster = adata_GABA.copy()
adata_pv_cluster = adata_pv_cluster[adata_pv_cluster.obs['louvain']==pv_clust,:]
adata_pv_cluster.var['is_KCN'] = [g in kcn_list for g in adata_pv_cluster.var.index]
adata_pv_cluster.obs['age_geno'] = adata_pv_cluster.obs['age_bin'].str.cat(adata_pv_cluster.obs['E_type'],sep='_')
ages = sorted(list(set(adata_pv_cluster.obs['age_bin'])))
analysis_pairs = { a+'_E4vE3':(a+"_E4", a+"_E3") for a in ages}
adata_name = 'pv_cluster_KCNs'
pv_cluster_deg_data = deg_analysis(adata_pv_cluster,adata_name,path,analysis_pairs,group_key = 'age_geno',log2fc_extrema=[-15,15])

In [None]:
# kcn_list
# fig,ax=plt.subplots(1,2)
kcn_list = sorted(kcn_list)
pv_cluster_deg_data

deg_dict = pv_cluster_deg_data[1]
for key,deg_df in deg_dict.items():
    key_id = list(deg_dict.keys()).index(key)
    for kcn in kcn_list:
        i_kcn = kcn_list.index(kcn)
        mtx[i_kcn,key_id] = deg_df.loc[kcn,logfoldchanges]

display(mtx)


In [None]:
# ###### SST Positive KCN DEGs #####
# path='./SST_Positive_DEG_KCN/'
# try: os.makedirs(path)
# except: None
# adata_sst_pos = adata_GABA.copy()
# sst_bool = z_score(sc.get.obs_df(adata_sst_pos,'Sst'))>0
# adata_sst_pos = adata_sst_pos[sst_bool,:]
# adata_sst_pos.var['is_KCN'] = [g in kcn_list for g in adata_sst_pos.var.index]
# adata_sst_pos.obs['age_geno'] = adata_sst_pos.obs['age_bin'].str.cat(adata_sst_pos.obs['E_type'],sep='_')
# ages = sorted(list(set(adata_sst_pos.obs['age_bin'])))
# analysis_pairs = { a+'_E4vE3':(a+"_E4", a+"_E3") for a in ages}
# adata_name = 'SST_cluster_KCNs'
# deg_data = deg_analysis(adata_sst_pos,adata_name,path,analysis_pairs,group_key = 'age_geno',log2fc_extrema=[-15,15])


# ###### PV Positive KCN DEGs #####
# path='./PV_Positive_DEG_KCN/'
# try: os.makedirs(path)
# except: None
# adata_pv_pos = adata_GABA.copy()
# pv_bool = z_score(sc.get.obs_df(adata_pv_pos,'Pvalb'))>0
# adata_pv_pos = adata_pv_pos[pv_bool,:]
# adata_pv_pos.obs['age_geno'] = adata_pv_pos.obs['age_bin'].str.cat(adata_pv_pos.obs['E_type'],sep='_')
# ages = sorted(list(set(adata_pv_pos.obs['age_bin'])))
# analysis_pairs = { a+'_E4vE3':(a+"_E4", a+"_E3") for a in ages}
# adata_name = 'pv_cluster_KCNs'
# deg_data = deg_analysis(adata_pv_pos,adata_name,path,analysis_pairs,group_key = 'age_geno',log2fc_extrema=[-15,15])