# AML Cell-Cell Interaction Modeling with DIISCO

In [None]:
cd ..

In [None]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
%matplotlib inline
from matplotlib.colors import LogNorm
import seaborn as sns
#import tensorflow as tf
#from Scalable_GPRN.model.SGPRN import SGPRN
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF
import ipywidgets as widgets
from ipywidgets import interact
import matplotlib.pyplot as plt
# reload edited modules always
%load_ext autoreload
%autoreload 2
from diisco import DIISCO
import diisco.names as names


In [None]:
torch.set_default_dtype(torch.float64)

In [None]:
ls ../../data/AML

In [None]:
days_to_dli = pd.read_csv('../../data/AML/AML_days_to_DLI.csv')
days_to_dli.head()

In [None]:
cells_df = pd.read_csv('../../data/AML/umap_new_barcodes_revised.csv')
cells_df = cells_df[cells_df['DATA'] == 'AML']
cells_df = pd.merge(cells_df, days_to_dli, how='right', on='sample')
cells_df

In [None]:
clusters_of_interest = {
    0: 'T cell',
     1: 'B cell',
     7: 'T cell',
     8: 'Myeloid',
    14: 'AML',
     16: 'HSC',
     23: 'AML',
     28: 'Myeloid',
     30: 'NK cell',
     36: 'B cell',
     42: 'AML'
}

colors = {
    'T cell': 'tab:blue',
    'B cell': 'tab:pink',
    'Myeloid': 'tab:red',
    'AML': 'tab:gray',
    'HSC': 'goldenrod',
    'NK cell': 'tab:purple'
}

In [None]:
cells_df = cells_df[cells_df['days_to_DLI'] <= 1000]
cells_df

## All Responders Model

In [None]:
cells_df_responders = cells_df[cells_df['response'] == 'RESPONDER']
cells_df_responders

In [None]:
sample_cluster_proportions = []

all_clusters = sorted(cells_df_responders['cluster_number'].unique())
for sample, sample_df in cells_df_responders.groupby('sample'):
    cluster_proportions = sample_df['cluster_number'].value_counts(normalize=True)
    cluster_proportions_dict = {cluster: (cluster_proportions[cluster] 
                                          if cluster in cluster_proportions
                                          else 0)
                                for cluster in all_clusters}
    sample_cluster_proportions.append({
        'sample': sample,
        'days_to_DLI': sample_df.iloc[0]['days_to_DLI'],
        **cluster_proportions_dict
    })
    
sample_cluster_proportions_df = pd.DataFrame(sample_cluster_proportions)
sample_cluster_proportions_df = sample_cluster_proportions_df.sort_values('days_to_DLI')
sample_cluster_proportions_df

In [None]:
sample_cluster_proportions_df = sample_cluster_proportions_df[['sample', 'days_to_DLI'] 
                                                              + list(clusters_of_interest.keys())]
sample_cluster_proportions_df

In [None]:
cells_per_sample = cells_df_responders['sample'].value_counts(normalize=True)[sample_cluster_proportions_df['sample']]

plt.figure(figsize=(18, 11))
for i, cluster in enumerate(clusters_of_interest):
    plt.subplot(3, 4, i+1)
    plt.scatter(sample_cluster_proportions_df['days_to_DLI'], 
                sample_cluster_proportions_df[cluster], 
                s=cells_per_sample*500)
    plt.title(f'Cluster {cluster} proportion', fontsize=15)
    plt.xlabel('Days to/from DLI', fontsize=14)
plt.suptitle(f'Responders cluster proportions', y=1.03, fontsize=18)
plt.tight_layout()
# plt.savefig('figures/AML_paper/cell_type_proportions_R.eps', bbox_inches='tight')

In [None]:
cluster_means = sample_cluster_proportions_df[clusters_of_interest].mean()
cluster_std_devs = sample_cluster_proportions_df[clusters_of_interest].std()
sample_cluster_proportions_df.loc[:, clusters_of_interest] -= cluster_means
sample_cluster_proportions_df.loc[:, clusters_of_interest] /= cluster_std_devs
sample_cluster_proportions_df

In [None]:
X = sample_cluster_proportions_df['days_to_DLI'].values.reshape(-1, 1)
X

In [None]:
Y = sample_cluster_proportions_df[clusters_of_interest].values
Y[:5]

In [None]:
# plot y 
plt.figure(figsize=(40, 21))
for i, cluster in enumerate(clusters_of_interest):
    plt.subplot(3, 4, i+1)
    plt.scatter(X, Y[:, i])
    plt.title(f'Cluster {cluster} proportion', fontsize=15)
    plt.xlabel('Days to/from DLI', fontsize=14)

In [None]:
# W_prior_variance = np.load('../../datra/AML/responder_interaction_scores.npy')
#W_prior_variance = np.array([[0, 1], [1, 0]])

# W_prior_variance = np.array([[W_prior_variance[0, 0], W_prior_variance[0, 4]], 
#                              [W_prior_variance[4,0], W_prior_variance[4,4]]])
# W_prior_variance = np.ones_like(W_prior_variance)

plt.figure(figsize=(8, 7))
ax = sns.heatmap(W_prior_variance, cmap="Reds", annot=True, fmt='.2f')
ax.set_yticklabels(clusters_of_interest, fontsize=12)
ax.set_xticklabels(clusters_of_interest, fontsize=12)
plt.xlabel('Source cluster', fontsize=12)
plt.ylabel('Target cluster', fontsize=12)
plt.yticks(rotation=0)
plt.xticks(rotation=45)
plt.title('Interaction prior variances for responders \n based on receptor-ligand expression', fontsize=14)
# plt.savefig('figures/AML_paper/prior_construction_reg_penalties_R.eps')

In [None]:
# W_prior_variance should be a matrix of (n_clusters, n_clusters) with ones everywhere except for the diagonal 
# which should be zeros.
W_prior_variance = np.ones((len(clusters_of_interest), len(clusters_of_interest)))
np.fill_diagonal(W_prior_variance, 0)


In [None]:
W_prior_variance

In [None]:
timepoints = torch.tensor(X)
proportions = torch.tensor(Y)
prior_matrix = torch.tensor(W_prior_variance)
n_timepoints, n_cell_types = proportions.shape

# z score the proportions
# proportions_mean = proportions.mean(dim=0)
# proportions_std = proportions.std(dim=0)
# proportions = (proportions - proportions_mean)/ (proportions_std)
#proportions = proportions / proportions_std


print('timepoints.shape:', timepoints.shape)
print('cell_types.shape:', proportions.shape)

In [None]:
# mean of all intervals
lengthscale = np.mean(np.abs(X - X.T))
lengthscale

In [None]:
hyper_init_vals = {
    names.LENGTHSCALE_F: lengthscale,
    names.LENGTHSCALE_W: lengthscale,
    names.SIGMA_F: 0.3,
    names.VARIANCE_F: 2,
    names.SIGMA_W: 0.1,
    names.VARIANCE_W: 2,
    names.SIGMA_Y: 0.5,
}
print(hyper_init_vals)

In [None]:
model = DIISCO(lambda_matrix=prior_matrix, hypers_init_vals=hyper_init_vals, verbose=True, verbose_freq=100)

In [None]:
model.fit_and_set_f_prior_params(timepoints=timepoints, proportions=proportions, hypers=model.hypers_init_vals)
eval_timepoints = torch.linspace(-200, 750, 100).view(-1, 1)

n_samples = 1000
n_eval_timepoints = eval_timepoints.shape[0]
f_prior_np = model.sample_f_prior(eval_timepoints, n_samples=n_samples)

plt.figure(figsize=(18, 11))
for i in range(len(clusters_of_interest)):
    eval_timepoints_np = eval_timepoints.detach().numpy()
    f_prior_cell_type_np = f_prior_np[:, i, :].detach().numpy()
    unscale = (lambda proportions: 
               np.clip(proportions * cluster_std_devs.values[i] + cluster_means.values[i], 
                       0, None))
    mean = unscale(f_prior_cell_type_np.mean(axis=0))
    upper = unscale(np.percentile(f_prior_cell_type_np, 97.5, axis=0))
    lower = unscale(np.percentile(f_prior_cell_type_np, 2.5, axis=0))
    plt.subplot(3, 4, i+1)
    color = colors[list(clusters_of_interest.values())[i]]
    plt.plot(eval_timepoints_np, mean, c=color)
    plt.scatter(timepoints, unscale(proportions[:, i].flatten().detach().numpy()), 
                c=color, s=cells_per_sample*300)
    plt.fill_between(eval_timepoints_np.flatten(), lower, upper, color=color, alpha=0.2)
    plt.title('$f_{%s}$ latent (Cluster %s)' % (i+1, list(clusters_of_interest.keys())[i]), fontsize=14)
    plt.xlabel('Days to/from DLI', fontsize=12)
    if i%4==0: plt.ylabel('Proportion', fontsize=12)
plt.subplots_adjust(hspace=0.4)
plt.suptitle('Independent GP fits to set $f_i$ latent functions', fontsize=15, y=0.98)
# plt.savefig('figures/AML_paper/indep_gp_fits_R.eps')

In [None]:
model.fit(timepoints, proportions, n_iter=100000, lr=0.00003, hypers_to_optim=[], guide="MultivariateNormalFactorized")

In [None]:
start = 50
loss_moving_avg = np.convolve(model.losses[start:], np.ones(100)/100, 'valid')
plt.plot(loss_moving_avg)

In [None]:
y = model.sample_observed_proportions(n_samples=1000).detach()

In [None]:
from mimic_alpha.mimic_alpha import colorAlpha_to_rgb
import matplotlib.colors as colors2

plt.figure(figsize=(18, 11))
for i in range(len(clusters_of_interest)):
    cell_type_samples = y[:, :, i]
#     unscale = lambda x: x
    unscale = (lambda proportions: 
               np.clip(proportions * cluster_std_devs.values[i] + cluster_means.values[i], 
                       0, None))
    mean = unscale(cell_type_samples.mean(axis=0))
    percentile_75 = unscale(np.percentile(cell_type_samples, 95, axis=0))
    percentile_25 = unscale(np.percentile(cell_type_samples, 5, axis=0))
    plt.subplot(3, 4, i+1)
    color = colors[list(clusters_of_interest.values())[i]]
    plt.plot(timepoints, mean, c=color)
    color_with_alpha = list(colors2.to_rgb(color)) + [0.2]
    color_without_alpha = colorAlpha_to_rgb([color_with_alpha], 0.2)
    plt.fill_between(timepoints.squeeze(), percentile_25, percentile_75, color=color_without_alpha)
    plt.scatter(timepoints, unscale(proportions[:, i]), c=color, s=cells_per_sample*300)
    plt.title(f'Cluster {list(clusters_of_interest.keys())[i]} proportion', fontsize=14)
    plt.xlabel('Days to/from DLI', fontsize=12)
    if i%4==0: plt.ylabel('Proportion', fontsize=12)
    
plt.suptitle('DIISCO predictions', fontsize=15, y=0.98)
plt.subplots_adjust(hspace=0.4)
# plt.savefig('figures/AML_paper/gprn_predictions_R.eps')

In [None]:
predict_timepoints = torch.linspace(timepoints.min(), timepoints.max(), 1000).reshape(-1, 1)
y_samples_predict = model.sample(predict_timepoints, n_samples=100, 
                                                    include_emission_variance=False)
print(y_samples_predict.shape)
#print(W_samples_predict.shape)

In [None]:
from mimic_alpha.mimic_alpha import colorAlpha_to_rgb
import matplotlib.colors as colors2

plt.figure(figsize=(18, 11))
for i in range(len(clusters_of_interest)):
    cell_type_samples = y_samples_predict[:, :, i]
    unscale = (lambda proportions: proportions * cluster_std_devs.values[i] + cluster_means.values[i])
    mean = unscale(cell_type_samples.mean(axis=0))
    x = predict_timepoints.squeeze().numpy()
    percentile_75 = unscale(np.percentile(cell_type_samples, 75, axis=0))
    percentile_25 = unscale(np.percentile(cell_type_samples, 25, axis=0))
    plt.subplot(3, 4, i+1)
    color = colors[list(clusters_of_interest.values())[i]]
    plt.plot(x, mean, c=color)
    color_with_alpha = list(colors2.to_rgb(color)) + [0.2]
    color_without_alpha = colorAlpha_to_rgb([color_with_alpha], 0.2)
    plt.fill_between(x, percentile_25, percentile_75, color=color_without_alpha)
    plt.scatter(timepoints, unscale(proportions[:, i]), c=color, s=cells_per_sample*300)
    plt.title(f'Cluster {list(clusters_of_interest.keys())[i]} proportion', fontsize=14)
    plt.xlabel('Days to/from DLI', fontsize=12)
    if i%4==0: plt.ylabel('Proportion', fontsize=12)
    
plt.suptitle('DIISCO predictions', fontsize=15, y=0.98)
plt.subplots_adjust(hspace=0.4)
# plt.savefig('figures/AML_paper/gprn_predictions_R.eps')

In [None]:
X_200_days_pre_dli_index = np.where(predict_timepoints > -200)[0][0]
X_post_dli_index = np.where(predict_timepoints > 0)[0][0]
X_200_days_post_dli_index = np.where(predict_timepoints > 200)[0][0]
X_post_dli_index

W = W_samples_predict.mean(axis=0)
W_pre_dli = W[X_200_days_pre_dli_index:X_post_dli_index]
W_post_dli = W[X_post_dli_index:X_200_days_post_dli_index]
W_pre_dli_avg_over_time = W_pre_dli.mean(axis=0)
W_post_dli_avg_over_time = W_post_dli.mean(axis=0)

plt.figure(figsize=(18, 7))
plt.subplot(1, 2, 1)
ax = sns.heatmap(W_pre_dli_avg_over_time, cmap="RdBu_r", annot=True, 
                 fmt='.2f', vmin=-5, vmax=5)
ax.set_yticklabels(clusters_of_interest, fontsize=12)
ax.set_xticklabels(clusters_of_interest, fontsize=12)
plt.xlabel('Source cluster', fontsize=12)
plt.ylabel('Target cluster', fontsize=12)
plt.yticks(rotation=0)
plt.xticks(rotation=45)
plt.title('$\hat{W}_{avg}$ (200 days pre-DLI)', fontsize=14)

plt.subplot(1, 2, 2)
ax = sns.heatmap(W_post_dli_avg_over_time, cmap="RdBu_r", annot=True, 
                 fmt='.2f', vmin=-5, vmax=5)
ax.set_yticklabels(clusters_of_interest, fontsize=12)
ax.set_xticklabels(clusters_of_interest, fontsize=12)
plt.xlabel('Source cluster', fontsize=12)
plt.ylabel('Target cluster', fontsize=12)
plt.yticks(rotation=0)
plt.xticks(rotation=45)
plt.title('$\hat{W}_{avg}$ (200 days post-DLI)', fontsize=14)
# plt.savefig('figures/AML_paper/W_avg_pre_post_DLI_R.eps')

In [None]:
W_mean_over_time = W_samples_predict.mean(axis=0)
show_line_threshold = 0

plt.figure(figsize=(10, 6))
linestyles = ['-', '--', '-.', ':', '-', '--', '-.', ':']
for i, cluster_i in enumerate(clusters_of_interest):
    for j, cluster_j in enumerate(clusters_of_interest):
        show_line = (np.abs(W_mean_over_time[:, i, j]) > show_line_threshold).any()
        if show_line and i!= j:
            plt.plot(predict_timepoints.squeeze().numpy(),
                     W_mean_over_time[:, i, j],
                     linestyle=linestyles[i%8],
                     label='%s --> %s interaction' % (cluster_j, cluster_i))
plt.legend(bbox_to_anchor=(1, 1.02), loc='upper left', fontsize=14)
plt.title('Strongest cluster interactions over time', fontsize=14)
plt.ylabel('$W_{i, j}$', fontsize=14)
plt.xlabel('Days to/from DLI', fontsize=14)
# plt.savefig('figures/AML_paper/W_over_time_R.eps', bbox_inches='tight')

In [None]:
W_mean_over_time.shape

In [None]:
f_prior_np = model.sample_f_prior(predict_timepoints, n_samples=1000)
f_mean_over_time = f_prior_np.mean(axis=0)
f_mean_over_time.shape

In [None]:
y_pred = torch.stack([W_mean_over_time[i] @ f_mean_over_time[:, i] 
                      for i in range(len(predict_timepoints))])
y_pred.shape

In [None]:
from mimic_alpha.mimic_alpha import colorAlpha_to_rgb
import matplotlib.colors as colors2

plt.figure(figsize=(18, 11))
for i in range(len(clusters_of_interest)):
    unscale = (lambda proportions: proportions * cluster_std_devs.values[i] + cluster_means.values[i])
    x = predict_timepoints.squeeze().numpy()
    y_i_pred = unscale(y_pred[:, i])
    plt.subplot(3, 4, i+1)
    color = colors[list(clusters_of_interest.values())[i]]
    plt.plot(x, y_i_pred, c=color)
    plt.scatter(timepoints, unscale(proportions[:, i]), c=color, s=cells_per_sample*300)
    plt.title(f'Cluster {list(clusters_of_interest.keys())[i]} proportion', fontsize=14)
    plt.xlabel('Days to/from DLI', fontsize=12)
    if i%4==0: plt.ylabel('Proportion', fontsize=12)
    
plt.suptitle('DIISCO predictions', fontsize=15, y=0.98)
plt.subplots_adjust(hspace=0.4)
# plt.savefig('figures/AML_paper/gprn_predictions_R.eps')