# Environment

In [285]:
import os
import gc
import sys
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import seaborn as sns
import scipy.sparse as sp
import matplotlib.pyplot as plt
from scipy.sparse import issparse
import plotly.graph_objects as go
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed
from IPython.display import display, clear_output

from dotenv import load_dotenv
load_dotenv()
sys.path.insert(0, os.getenv('PROJECT_FUNCTIONS_PATH'))

from evaluated_helpers import (
    load_GRNs_gene_sets,
    remove_duplicates_preserve_order_GRNs
)

from gene_scoring import score_genes

In [286]:
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=80)

In [287]:
data_path = os.getenv('DATA_PATH')
root_dir = os.getenv('ROOT_DIR')

In [288]:
input_file = os.path.join(data_path, '3_FiltNormAdata.h5ad')
output_file = os.path.join(data_path, '4_GeneScores.h5ad')

In [289]:
gpu_support = False
recompute = True

In [290]:
gene_sets = ['L2-3_CUX2']
cell_types = ['L2-3_CUX2']
gois = ['AHR', 'AR', 'NR1I2', 'NR1I3', 'NR3C1', 'NR3C2', 'ESR1', 'RARA', 'ESR2', 'THRB', 'THRA']
control_condition = 'LDN'
conditions = ['FGF2-50', 'FGF2-20', 'FGF4', 'FGF8'] # FGF pathway
conditions += ['SAG1000', 'SAG250'] # SAG pathway
conditions += ['BMP4'] # BMP4 pathway
conditions += ['BMP7'] # BMP7 pathway
conditions += ['CHIR3', 'CHIR1.5'] # WNT pathway
conditions += ['IWP2'] # IWP2 pathway
conditions += ['RA100', 'RA10'] # Retinoic acid pathway
conditions += [control_condition]

# Load Data

## GRNs

In [291]:
_, gene_sets_dict_cell_type_first = load_GRNs_gene_sets(root_dir=root_dir, gene_set_list=gene_sets)

In [292]:
for cell_type, cell_type_dict in gene_sets_dict_cell_type_first.items():
    for gene_set, gene_dict in cell_type_dict.items():
        
        # SAG pathway genes
        print("\nSAG pathway genes:")
        for gene in ['GLI1', 'GLI2', 'GLI3', 'GLI4']:
            if gene in gene_dict:
                print(f"\nGene: {gene}")
                print("Targets:", gene_dict[gene]['targets'])
        
        # Retinoic Acid pathway genes
        print("\nRetinoic Acid pathway genes:")
        for gene in ['RARA', 'RARB', 'RARG', 'RXRA', 'RXRB', 'RXRG']:
            if gene in gene_dict:
                print(f"\nGene: {gene}")
                print("Targets:", gene_dict[gene]['targets'])


SAG pathway genes:

Gene: GLI1
Targets: ['MGST1', 'ASTN2', 'PDE1A', 'ANK1']

Gene: GLI2
Targets: ['ASIC2', 'ZNF804A', 'ASTN2', 'DGKB', 'PCSK5', 'MGST1', 'PDE1A', 'NRG1', 'MAN1A1', 'ITPR2', 'NKAIN3', 'TOX3', 'ACAT2', 'HS3ST4', 'AKAP13', 'SQLE', 'FIGN', 'PAG1', 'TXNDC16', 'TFDP2', 'CGGBP1', 'ABHD3', 'FOXG1']

Gene: GLI3
Targets: ['SQLE', 'NRG1', 'FDFT1', 'TXNDC16', 'EFNA5', 'SORCS1', 'PAG1', 'MSI2', 'SEMA5B', 'BMPR1A', 'ASTN2', 'MGST1', 'TUBB2A', 'SVIL', 'SCD', 'DRAXIN', 'KCNIP4', 'NNAT', 'MDK', 'SC5D', 'HS3ST4', 'ITPR2', 'RERG', 'BTG1', 'USP3']

Gene: GLI4
Targets: ['MEIS2', 'IGF2BP3']

Retinoic Acid pathway genes:

Gene: RARA
Targets: ['NRGN', 'ABHD3', 'XYLT1', 'CDH18', 'DNAJB1', 'C1QTNF4', 'PRKD1', 'HSP90AA1', 'PTPRZ1', 'DPP10', 'PDE5A', 'MTSS1', 'HMGN3', 'KDM6A', 'DRAXIN', 'SLC4A7', 'TET1', 'EBF4']

Gene: RARB
Targets: ['DPP10', 'MAML3', 'PTPRZ1', 'KIF26B', 'PRKD1', 'STMN1', 'ANTXR2', 'MAK', 'RGS20', 'KIAA0895L', 'FGFR2', 'CGGBP1']

Gene: RARG
Targets: ['CNTNAP5', 'KDM6A', 'GABRG3']

In [293]:
for cell_type, cell_type_dict in gene_sets_dict_cell_type_first.items():
    for gene_set, gene_dict in cell_type_dict.items():
        
        # Initialize pathway dictionaries
        pathways = {
            'SAG': {'targets': {}, 'weights': {}},
            'RA': {'targets': {}, 'weights': {}}
        }
        
        # Group SAG pathway genes
        for gene in ['GLI1', 'GLI2', 'GLI3', 'GLI4']:
            if gene in gene_dict:
                for target, weight in zip(gene_dict[gene]['targets'], gene_dict[gene]['scored_coef_mean']):
                    if target not in pathways['SAG']['weights']:
                        pathways['SAG']['weights'][target] = []
                    pathways['SAG']['weights'][target].append(weight)
        
        # Group Retinoic Acid pathway genes  
        for gene in ['RARA', 'RARB', 'RARG', 'RXRA', 'RXRB', 'RXRG']:
            if gene in gene_dict:
                for target, weight in zip(gene_dict[gene]['targets'], gene_dict[gene]['scored_coef_mean']):
                    if target not in pathways['RA']['weights']:
                        pathways['RA']['weights'][target] = []
                    pathways['RA']['weights'][target].append(weight)
        
        # Average weights for repeated targets
        for pathway in ['SAG', 'RA']:
            pathways[pathway]['targets'] = list(pathways[pathway]['weights'].keys())
            pathways[pathway]['weights'] = [np.mean(weights) for weights in pathways[pathway]['weights'].values()]
        
        print(f"\nCell type: {cell_type}, Gene set: {gene_set}")
        print("\nSAG pathway:")
        print("Number of targets:", len(pathways['SAG']['targets']))
        print("Average weight:", np.mean(pathways['SAG']['weights']) if pathways['SAG']['weights'] else 0)
        
        print("\nRetinoic Acid pathway:")
        print("Number of targets:", len(pathways['RA']['targets']))
        print("Average weight:", np.mean(pathways['RA']['weights']) if pathways['RA']['weights'] else 0)


Cell type: L2-3_CUX2, Gene set: L2-3_CUX2

SAG pathway:
Number of targets: 43
Average weight: -0.1346416236739442

Retinoic Acid pathway:
Number of targets: 57
Average weight: 0.3640531124939981


## Transcriptomics data

### Load

In [294]:
adata = sc.read_h5ad(input_file)
adata.var_names_make_unique()

In [295]:
adata

AnnData object with n_obs × n_vars = 36252 × 19633
    obs: 'Og_sample', 'Og_bc_index', 'Og_species', 'Og_tscp_count', 'Og_tscp_count_50dup', 'Og_gene_count', 'Og_quality', 'Og_condition', 'run_id', 'sample_id', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mito', 'log1p_total_counts_mito', 'pct_counts_mito', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'gene_UMI_ratio', 'log1p_gene_UMI_ratio', 'n_genes', 'n_counts', 'class', 'class2', 'region', 'Leiden_01', 'Leiden_02', 'Leiden_03', 'Leiden_04', 'Leiden_05', 'Leiden_06', 'Leiden_08', 'Leiden_10', 'Leiden_12', 'Leiden_Sel'
    var: 'gene_id', 'genome', 'gene_name_unique', 'mito', 'ribo', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'Leiden_01_colors', 'Leiden_02_colors', 'Leiden_03_colors', 'Leiden_04_color

In [296]:
print(list(adata.obs['class'].unique()))
print(list(adata.obs['class2'].unique()))

['neuron', 'choroid', 'progenitor', 'IPC', nan, 'astroglia', 'oligo']
['neuron.glut1', 'choroid', 'progenitor', 'IPC', 'neuron.glut2', 'neuron.gaba', 'neuron.mixed', nan, 'astroglia', 'neuron.dop', 'OPC', 'oligo']


### Subset

In [297]:
neurons = adata[adata.obs['class2'].isin(['neuron.glut1', 'neuron.glut2'])]

In [298]:
neurons = neurons[neurons.obs['Og_condition'].isin(conditions)]

In [299]:
neurons

View of AnnData object with n_obs × n_vars = 2883 × 19633
    obs: 'Og_sample', 'Og_bc_index', 'Og_species', 'Og_tscp_count', 'Og_tscp_count_50dup', 'Og_gene_count', 'Og_quality', 'Og_condition', 'run_id', 'sample_id', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'total_counts_mito', 'log1p_total_counts_mito', 'pct_counts_mito', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'gene_UMI_ratio', 'log1p_gene_UMI_ratio', 'n_genes', 'n_counts', 'class', 'class2', 'region', 'Leiden_01', 'Leiden_02', 'Leiden_03', 'Leiden_04', 'Leiden_05', 'Leiden_06', 'Leiden_08', 'Leiden_10', 'Leiden_12', 'Leiden_Sel'
    var: 'gene_id', 'genome', 'gene_name_unique', 'mito', 'ribo', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'Leiden_01_colors', 'Leiden_02_colors', 'Leiden_03_colors', 'Leiden_0

In [300]:
print(neurons.obs.Og_condition.nunique(), ": ",list(neurons.obs.Og_condition.unique()))

13 :  ['FGF8', 'LDN', 'CHIR1.5', 'BMP7', 'FGF4', 'IWP2', 'FGF2-20', 'RA10', 'CHIR3', 'FGF2-50', 'RA100', 'SAG1000', 'SAG250']


In [301]:
neurons.layers['counts'].toarray()

array([[0., 1., 2., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 2., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 2., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [302]:
neurons.X.toarray()

array([[0.        , 0.5253986 , 0.86805195, ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.87906605, 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 2.662588  , ..., 0.        , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ]], dtype=float32)

### Agreggate

In [303]:
# Convert Categorical columns to string and create a combined condition column
neurons.obs['class2_str'] = neurons.obs['class2'].astype(str)
neurons.obs['Og_condition_str'] = neurons.obs['Og_condition'].astype(str)
neurons.obs['combined_condition'] = neurons.obs['class2_str'].str.cat(neurons.obs['Og_condition_str'], sep='-')

combined_conditions = neurons.obs['combined_condition'].unique()
aggr_data = []

for cond in combined_conditions:
    subset = neurons[neurons.obs['combined_condition'] == cond]
    aggr_data.append(subset.layers['counts'].sum(axis=0))

# Create a new AnnData object with aggregated data
neurons_aggr = sc.AnnData(X=np.vstack(aggr_data),
                               obs=pd.DataFrame(index=combined_conditions, columns=['combined_condition']),
                               var=neurons.var)

neurons_aggr.obs['combined_condition'] = combined_conditions

# Split the combined condition back into class2 and Og_condition
split_conditions = neurons_aggr.obs['combined_condition'].str.split('-', n=1, expand=True)
if split_conditions.shape[1] == 2:
    neurons_aggr.obs['class2'] = split_conditions[0]
    neurons_aggr.obs['Og_condition'] = split_conditions[1]
else:
    print("Warning: Some combined conditions don't contain the '-' separator.")
    neurons_aggr.obs['class2'] = split_conditions[0]
    neurons_aggr.obs['Og_condition'] = ''

# Normalize for sequencing depth
sc.pp.normalize_total(neurons_aggr, target_sum=1e6)  # Normalize to counts per million (CPM)


Trying to modify attribute `.obs` of view, initializing view as actual.



normalizing counts per cell
    finished (0:00:00)


In [304]:
gc.collect()

2984

# Scoring

In [305]:
if sp.issparse(neurons_aggr.X):
    X_array = neurons_aggr.X.toarray()
else:
    X_array = np.array(neurons_aggr.X)

In [306]:
neurons_aggr.X = X_array

In [307]:
gc.collect()

0

In [309]:
if recompute:
    for control in [True, False]:
        for control_condition in [control_condition, None]:
            for normalize_weights in [True, False]:
                for scaling_only_based_on_control in [True, False]: 
                    for scale_by_variance in [True, False]:
                        for pathway in list(pathways.keys()):
                            try:
                                score_genes(
                                    neurons_aggr,
                                    gene_list=pathways[pathway]['targets'], 
                                    gene_weights=pathways[pathway]['weights'],   
                                    score_name = (
                                        f'gene_score_{pathway}_{control}_'
                                        f'normalized_{normalize_weights}_'
                                        f'scaled_{scale_by_variance}_'
                                        f'cc_{control_condition}_'
                                        f'sc_{scaling_only_based_on_control}'
                                    ),                                   
                                    ctrl_size=50,
                                    gene_pool=None,
                                    n_bins=25,
                                    random_state=0,
                                    copy=False,
                                    used_layer=None,
                                    return_scores=False,
                                    control=control,
                                    weighted=True,
                                    abs_diff=False,
                                    gpu=gpu_support,
                                    chunk_size=10000,
                                    disable_chunking=True,
                                    scale_by_variance=scale_by_variance,
                                    normalize_weights=normalize_weights,
                                    conditions_labels='Og_condition',
                                    control_condition=control_condition,
                                    debug=True,
                                    scaling_only_based_on_control=scaling_only_based_on_control
                                )
                            except IndexError as e:
                                print(f"Index   Error occurred for {pathway}: {str(e)}")
                                continue

In [310]:
neurons_aggr

AnnData object with n_obs × n_vars = 22 × 19633
    obs: 'combined_condition', 'class2', 'Og_condition', 'gene_score_SAG_True_normalized_True_scaled_True_cc_LDN_sc_True', 'gene_score_RA_True_normalized_True_scaled_True_cc_LDN_sc_True', 'gene_score_SAG_True_normalized_True_scaled_False_cc_LDN_sc_True', 'gene_score_RA_True_normalized_True_scaled_False_cc_LDN_sc_True', 'gene_score_SAG_True_normalized_True_scaled_True_cc_LDN_sc_False', 'gene_score_RA_True_normalized_True_scaled_True_cc_LDN_sc_False', 'gene_score_SAG_True_normalized_True_scaled_False_cc_LDN_sc_False', 'gene_score_RA_True_normalized_True_scaled_False_cc_LDN_sc_False', 'gene_score_SAG_True_normalized_False_scaled_True_cc_LDN_sc_True', 'gene_score_RA_True_normalized_False_scaled_True_cc_LDN_sc_True', 'gene_score_SAG_True_normalized_False_scaled_False_cc_LDN_sc_True', 'gene_score_RA_True_normalized_False_scaled_False_cc_LDN_sc_True', 'gene_score_SAG_True_normalized_False_scaled_True_cc_LDN_sc_False', 'gene_score_RA_True_normali

# Plots

In [319]:
SAG_conditions = ['SAG1000', 'SAG250'] # SAG pathway
RA_conditions = ['RA100', 'RA10'] # Retinoic acid pathway


In [345]:
def create_figure(data, pathway, control, normalize_weights, 
                  scale_by_variance, control_condition, scaling_only_based_on_control):
    score_of_interest = (
        f'gene_score_{pathway}_{control}_'
        f'normalized_{normalize_weights}_'
        f'scaled_{scale_by_variance}_'
        f'cc_{control_condition}_'
        f'sc_{scaling_only_based_on_control}'
    )
    
    # Create figure with secondary y-axis
    fig = go.Figure()
    
    # Filter data for relevant conditions based on pathway
    conditions_to_plot = []
    if pathway == 'SAG':
        conditions_to_plot = SAG_conditions + [control_condition]
    elif pathway == 'RA':
        conditions_to_plot = RA_conditions + [control_condition]
        
    plot_data = data[data['Og_condition'].isin(conditions_to_plot)]
    
    # Create separate traces for each class2 type
    for class_type in ['neuron.glut1', 'neuron.glut2']:
        class_data = plot_data[plot_data['class2'] == class_type]
        
        fig.add_trace(go.Scatter(
            x=class_data['Og_condition'],
            y=class_data[score_of_interest],
            mode='markers',
            marker=dict(
                size=10,
                symbol='circle' if class_type == 'neuron.glut1' else 'square'
            ),
            name=class_type
        ))
    
    # Update layout
    fig.update_layout(
        title=f'{pathway} pathway: {score_of_interest}',
        xaxis_title='Condition',
        yaxis_title=f'{pathway} score',
        xaxis=dict(
            tickangle=45,
            categoryorder='array',
            categoryarray=[control_condition] + (SAG_conditions if pathway == 'SAG' else RA_conditions)
        ),
        height=600,
        width=800,
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99
        )
    )
    
    return fig

# Create widgets for each parameter
pathway_widget = widgets.Dropdown(
    options=['SAG', 'RA'], 
    description='Pathway:'
)
control_widget = widgets.Dropdown(
    options=[True, False], 
    description='Control:'
)
normalize_weights_widget = widgets.Dropdown(
    options=[True, False], 
    description='Normalize weights:'
)
scale_by_variance_widget = widgets.Dropdown(
    options=[True, False], 
    description='Scale variance:'
)
scaling_only_based_on_control_widget = widgets.Dropdown(
    options=[True, False], 
    description='Scale control:'
)

# Create output widget to display the plot
out = widgets.Output()

@out.capture()  # This ensures the output only goes to our output widget
def update_plot(data, pathway, control, normalize_weights, 
                scale_by_variance, scaling_only_based_on_control):
    out.clear_output(wait=True)
    fig = create_figure(
        data=data,
        pathway=pathway,
        control=control,
        normalize_weights=normalize_weights,
        scale_by_variance=scale_by_variance,
        control_condition="LDN" if control else None,
        scaling_only_based_on_control=scaling_only_based_on_control
    )
    fig.show()

# Create the controls box
controls = widgets.VBox([
    pathway_widget,
    control_widget,
    normalize_weights_widget,
    scale_by_variance_widget,
    scaling_only_based_on_control_widget
])

# Wire up the controls to the update function
def on_change(change):
    update_plot(
        neurons_aggr.obs,
        pathway_widget.value,
        control_widget.value,
        normalize_weights_widget.value,
        scale_by_variance_widget.value,
        scaling_only_based_on_control_widget.value
    )

# Register the callback for all widgets
for w in [pathway_widget, control_widget, normalize_weights_widget, 
          scale_by_variance_widget, scaling_only_based_on_control_widget]:
    w.observe(on_change, 'value')

# Create the final layout
layout = widgets.VBox([controls, out])

# Show initial plot
update_plot(
    neurons_aggr.obs,
    pathway_widget.value,
    control_widget.value,
    normalize_weights_widget.value,
    scale_by_variance_widget.value,
    scaling_only_based_on_control_widget.value
)

# Display only once
display(layout)

VBox(children=(VBox(children=(Dropdown(description='Pathway:', options=('SAG', 'RA'), value='SAG'), Dropdown(d…

# Pseudobulk

In [346]:
gc.collect()

6622

In [347]:
%%script false --no-raise-error

cell_types_counts = neurons.obs['class'].value_counts()

print("Number of cells in each cell type:")
print(cell_types_counts)

In [348]:
%%script false --no-raise-error

import decoupler as dc

pseudobulk = dc.get_pseudobulk(
    adata=adata,
    sample_col='sample_id',
    groups_col="class",
    layer=None, 
    use_raw=False,
    mode='mean',
    min_cells=10,
    min_counts=1000,
    dtype=None
)

In [None]:
%%script false --no-raise-error

print(pseudobulk.obs['psbulk_n_cells'].min())
print(pseudobulk.obs['psbulk_n_cells'].max())

In [None]:
%%script false --no-raise-error

plt.figure(figsize=(10, 6))
plt.hist(pseudobulk.obs['psbulk_n_cells'], bins=30, edgecolor='black', alpha = 0.8)
plt.title('Histogram of psbulk_n_cells')
plt.xlabel('Pseudobulk Counts')
plt.ylabel('Frequency')
plt.show()

In [None]:
%%script false --no-raise-error

rara_expression = pseudobulk[:, pseudobulk.var['gene_name_unique'] == 'RARA'].X.flatten()
print(rara_expression.shape)
rara_expression[:10]

In [None]:
%%script false --no-raise-error

plot_data = pd.DataFrame({
    'class': pseudobulk.obs['class'],
    'RARA_expression': rara_expression
})

plot_data = plot_data.groupby('class')['RARA_expression'].mean().reset_index()
sorted_plot_data = plot_data.sort_values(by='RARA_expression', ascending=True)

# Then we reset the index so we can have a continuous x-axis for the scatter plot
sorted_plot_data.reset_index(drop=True, inplace=True)

# Now let's create the scatter plot
plt.figure(figsize=(10, 6))
sns.scatterplot(data=sorted_plot_data, x=sorted_plot_data.index, y='RARA_expression', hue='class', s=100)
plt.title('Scatter Plot of RARA Gene Expression Sorted by Expression Level')
plt.ylabel('RARA Expression')
plt.xlabel('Index (sorted)')
plt.legend(title='Class')
plt.tight_layout()
plt.show()

# Scores across gene sets and cell types

In [None]:
%%script false --no-raise-error

gene_set = grn_set 
gene_set.head()

In [None]:
%%script false --no-raise-error

gene_set = grn_set
gene_dict = dict_genes_grn
weight_dict = weight_genes_grn

for celltype in pseudobulk.obs['class'].unique():
    for goi in gois:
        gene_set_tmp = gene_set[gene_set.source == goi] 
        # print(f"process cell tyoe: {celltype}")
        matched_type = pseudobulk.obs['class'] == celltype
        subset_indices = pseudobulk.obs[matched_type].index
        subset_pseudobulk = pseudobulk[subset_indices]
        # print(f"subset_pseudobulk.shape: {subset_pseudobulk.shape}")

        filtered_networks = gene_set_tmp[gene_set_tmp['mapped_cell_types'].apply(lambda x: celltype in x if x is not None else False)]
        gene_weights = filtered_networks.set_index('target')['score'].to_dict()

        score_name = f"{goi}_score"
        if len(gene_weights.keys()) > 0:
            scores = score_genes_weighted(
                adata=subset_pseudobulk, 
                gene_list=pd.Series(gene_weights).index, 
                gene_weights=pd.Series(gene_weights),  
                score_name=score_name,
                ctrl_size=50,
                n_bins=25,
                return_scores=True,
                weighted=True,
                control=True,
                abs=False
            )

            for idx in subset_indices:
                if idx in scores.index:  # Ensure the score is available for this index
                    pseudobulk.obs.at[idx, score_name] = scores.loc[idx]

In [None]:
%%script false --no-raise-error

celltypes = ['neuron', 'progenitor']
filtered_df = pseudobulk.obs[pseudobulk.obs['class'].isin(celltypes)]
filtered_df.head()

In [None]:
%%script false --no-raise-error

filtered_df['condition_celltype'] = filtered_df['Og_condition'] + '_' + filtered_df['class']

In [None]:
%%script false --no-raise-error

pivot_table = filtered_df.pivot_table(index='condition_celltype', values=[f"{goi}_score" for goi in gois], aggfunc='mean')
# Reset index to have 'Og_condition' and 'class' as separate columns if needed for labeling
pivot_table = pivot_table.reset_index()

In [None]:
%%script false --no-raise-error

pivot_table.head()

In [None]:
%%script false --no-raise-error

# plt.figure(figsize=(10, 80))

# # Set the 'condition_celltype' column as the row labels
# row_labels = pivot_table['condition_celltype']

# # Create the heatmap without setting an index
# sns.heatmap(pivot_table.drop(columns=['condition_celltype']), cmap="viridis", annot=True, fmt=".2f",
#             yticklabels=row_labels, cbar_kws={"orientation": "horizontal"})

# plt.title('Gene Score Heatmap across Conditions and Cell Types')
# plt.ylabel('Condition and Cell Type')
# plt.xlabel('Gene Scores')

# plt.xticks(rotation=45, ha="right")
# plt.yticks(fontsize=8)

# # Adjust the layout to accommodate the colorbar
# plt.tight_layout()

# plt.show()

In [None]:
%%script false --no-raise-error

neurons = pseudobulk.obs[pseudobulk.obs['class'] == 'neuron']
neurons_pivot_table = neurons.pivot_table(index='Og_condition', values=[f"{goi}_score" for goi in gois], aggfunc='mean')
neurons_pivot_table.reset_index(inplace=True)
neurons_pivot_table.head()

In [None]:
%%script false --no-raise-error

plt.figure(figsize=(30, 10))

sns.heatmap(neurons_pivot_table.set_index('Og_condition').T, cmap="viridis", annot=False)

plt.title('Gene Score Heatmap across Conditions and Cell Types')
plt.xlabel('Condition and Cell Type')
plt.ylabel('Gene Scores')

plt.yticks(rotation=45, ha="right")
plt.xticks(fontsize=12)

plt.show()

In [None]:
%%script false --no-raise-error

progenitors = pseudobulk.obs[pseudobulk.obs['class'] == 'progenitor']
progenitors_pivot_table = progenitors.pivot_table(index='Og_condition', values=[f"{goi}_score" for goi in gois], aggfunc='mean')
progenitors_pivot_table.reset_index(inplace=True)
progenitors_pivot_table.head()

In [None]:
%%script false --no-raise-error

plt.figure(figsize=(30, 10))

sns.heatmap(progenitors_pivot_table.set_index('Og_condition').T, cmap="viridis", annot=False)

plt.title('Gene Score Heatmap across Conditions and Cell Types')
plt.xlabel('Condition and Cell Type')
plt.ylabel('Gene Scores')

plt.yticks(rotation=45, ha="right")
plt.xticks(fontsize=12)

plt.show()

# Single gene set and single cell type

In [None]:
%%script false --no-raise-error

gene_set = grn_set 
gene_set.head()

In [None]:
%%script false --no-raise-error

cell_type = "neuron"

goi = "ESR2"

# Filter the DataFrame to include only rows where 'mapped_cell_types' contains the cell_type
filtered_networks = gene_set[gene_set['mapped_cell_types'].apply(lambda x: cell_type in x if x is not None else False)]

In [None]:
%%script false --no-raise-error

filtered_networks.head()

In [None]:
%%script false --no-raise-error

condition_met = pseudobulk.obs['class'].isin(["neuron"])
subset_pseudobulk = pseudobulk[condition_met]

In [None]:
%%script false --no-raise-error

subset_pseudobulk.obs.head()

In [None]:
%%script false --no-raise-error

score_genes_weighted(
    adata=subset_pseudobulk,
    gene_list=filtered_networks[filtered_networks.goi == goi].target.tolist(),
    gene_weights=filtered_networks[filtered_networks.goi == goi].score.tolist(),
    score_name=f"{goi}_score",
    ctrl_size=50,
    n_bins=25
)

In [None]:
%%script false --no-raise-error

subset_pseudobulk.obs[f'{goi}_score'].head()

In [None]:
%%script false --no-raise-error

pseudobulk.obs.Og_condition.unique()

In [None]:
%%script false --no-raise-error

len(pseudobulk.obs.Og_condition.unique())

In [None]:
%%script false --no-raise-error

sorted_subset = subset_pseudobulk.obs.sort_values(by='Og_condition')
plt.figure(figsize=(10, 6))
sns.boxplot(x='Og_condition', y=f'{goi}_score', data=sorted_subset, showfliers=False)
plt.title('score_RARA Grouped by Og_condition')
plt.xticks(rotation=90, fontsize='small')  
plt.tight_layout()
plt.show()