In [11]:
from pygam import GAM, s
from scipy.ndimage import convolve
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
from tqdm import tqdm

In [2]:
adata_preds = sc.read_h5ad("radiation_counterfactuals.h5ad")

  utils.warn_names_duplicates("obs")


In [19]:
adata_preds.obs['radiation'] = adata_preds.obs['radiation'].astype(float)

In [55]:
adata_preds.obs['radiation'].value_counts()

radiation
0.000000    44928
0.578305    44928
0.943489    44928
0.961291    44928
0.979093    44928
1.000000    44928
Name: count, dtype: int64

In [38]:
models = {}
y_test_dict = {}
x_test_dict = {}
results = {}

term = s(0, spline_order=3, n_splines=6, lam=3, penalties=['derivative', 'l2'])
radiation_levels = np.unique(adata_preds.obs['radiation'].values)
weights_per_level = {level: 10 if level == 0.0 else 
                            8 if level == 0.5783052351375333 else 
                            1 for level in radiation_levels}
weights_full = np.array([weights_per_level[level] for level in x])

for gene in tqdm(adata_preds.var_names):
    gam = GAM(terms=term, distribution='normal', link='log', max_iter=2000)

    # Gene expression data (response variable, y)
    y = adata_preds[:, gene].X.reshape((adata_preds.shape[0], 1))
    y = y.reshape((len(y), 1))
    
    # Radiation levels (predictor variable, x)
    x = adata_preds.obs['radiation'].values
    x = x.reshape((len(x), 1))
    
    w = weights_full.reshape((len(weights_full), 1))
    

    
    # Fit the GAM model
    model = gam.fit(x, y, weights=w)
    
    p_values = gam.statistics_['p_values'][0]  # p-value for the effect of radiation
    coefficients = gam.coef_  # Coefficients of the fitted model
    variance_explained = gam.statistics_['pseudo_r2']  # Variance explained (pseudo R-squared)
    
    results[gene] = {
        'p_value': p_values,
        'coefficients': coefficients,
        'variance_explained': variance_explained
    }
    
    val_start = np.min(x)
    val_end = np.max(x)
    if val_start > val_end:
        val_start, val_end = val_end, val_start
    
    val_start, val_end = (max(val_start, np.min(x)), max(val_end, np.max(x)))
    
    # Filter the test range
    fil = (x >= val_start) & (x <= val_end)
    
    # Generate test points for predictions
    x_test = np.linspace(val_start, val_end, 50)
    y_test = gam.predict(x_test)
    
    models[gene] = gam
    x_test_dict[gene] = x_test
    y_test_dict[gene] = y_test

100%|██████████| 15749/15749 [3:22:56<00:00,  1.29it/s]  


In [39]:
results_list = []

for gene, gam in models.items():
    result = {
        'gene': gene,
        'p_value': results[gene]['p_value'],
        'coefficients': results[gene]['coefficients'],
        'variance_explained': results[gene]['variance_explained'],
    }
    results_list.append(result)

results_df = pd.DataFrame(results_list)

results_df.to_csv('gam/gam_results.csv', index=False)

In [40]:
import pickle

with open('gam/gam_models.pkl', 'wb') as f:
    pickle.dump(models, f)

with open('gam/x_test_dict.pkl', 'wb') as f:
    pickle.dump(x_test_dict, f)

with open('gam/y_test_dict.pkl', 'wb') as f:
    pickle.dump(y_test_dict, f)

In [49]:
increase_at_low_levels = []
decrease_at_low_levels = []

for gene, y_pred in y_test_dict.items():
    low_expression = np.median(y_pred[:2])  # Expression at lower levels (0.0, 0.578305)
    high_expression = np.median(y_pred[-2:])  # Expression at higher levels (0.979093, 1.0)

    if low_expression < high_expression:
        increase_at_low_levels.append((gene, high_expression - low_expression))  
    else:
        decrease_at_low_levels.append((gene, low_expression - high_expression))  

increase_at_low_levels = sorted(increase_at_low_levels, key=lambda x: x[1], reverse=True)
decrease_at_low_levels = sorted(decrease_at_low_levels, key=lambda x: x[1], reverse=True)

top_20_increase = increase_at_low_levels[:20]
top_20_decrease = decrease_at_low_levels[:20]

df_increase = pd.DataFrame(top_20_increase, columns=['Gene', 'Change_Magnitude'])
df_decrease = pd.DataFrame(top_20_decrease, columns=['Gene', 'Change_Magnitude'])

In [54]:
df_increase.to_csv("gam/increase_in_radiation_signature.csv")
df_decrease.to_csv("gam/decrease_in_radiation_signature.csv")

In [66]:
for i, (gene, _) in enumerate(top_20_increase[:5]):  
    plt.figure(figsize=(8, 5))
    x_values = x_test_dict[gene]  
    y_values = y_test_dict[gene]  
    plt.plot(x_values, y_values, label=f'Gene: {gene}', color='green')
    plt.xlabel('Radiation Levels')
    plt.ylabel('Predicted Expression')
    plt.title(f'Predicted Expression for Gene: {gene} (Increase at Lower Levels)')
    plt.legend()
    plt.savefig(f'gam/gene_{gene}_increase.png')
    plt.close()

for i, (gene, _) in enumerate(top_20_decrease[:5]):  
    plt.figure(figsize=(8, 5))
    x_values = x_test_dict[gene] 
    y_values = y_test_dict[gene]  
    plt.plot(x_values, y_values, label=f'Gene: {gene}', color='red')
    plt.xlabel('Radiation Levels')
    plt.ylabel('Predicted Expression')
    plt.title(f'Predicted Expression for Gene: {gene} (Decrease at Higher Levels)')
    plt.legend()
    plt.savefig(f'gam/gene_{gene}_decrease.png')
    plt.close()
