In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import numpy as np 
from matplotlib import pyplot as plt 
import seaborn as sbn 

import statsmodels.api as sm
from sklearn.preprocessing import OneHotEncoder

%aimport HER2_classifier

In [None]:
class myspace(object): 
    def __init__(self): 
        pass

args                    =       myspace()
args.data               =       ['./data/HER2_SKBR3_data_6-7-21/']
args.out                =       ['./output/'] 
args.drug               =       ['Neratinib']              # ['Trastuzumab']
args.sensitive_line     =       ['WT']
args.resistant_line     =       ['T798I']                  # ['ND611]
args.load               =       ['normalized']             # ['raw']
args.nclus              =       [10]
args.resample_sz        =       [125]
args.burnin             =       [4]

--- 
---
# Load Data 
---
---

In [None]:
data, clover_sel, mscarl_sel = HER2_classifier.load_data(args)
print('len selector:', len(clover_sel))
print(clover_sel[0:5])
print(mscarl_sel[0:5])
data.head()

In [None]:
data.drug.unique()

In [None]:
data.cell_line.unique()

In [None]:
data.mutant.unique()

--- 
---
# Filter NA 
---
---

In [None]:
data, low_data_flags, clover_sel, mscarl_sel = HER2_classifier.filter_na(data, args, clover_sel, mscarl_sel)
print('len selector:', len(clover_sel))
print(clover_sel[0:5])
print(mscarl_sel[0:5])
data.head()

In [None]:
low_data_flags

In [None]:
plt.figure()
plt.hist(low_data_flags.cell_track_count, bins=np.linspace(0,250,20))
plt.show()

--- 
---
# Add `Burn-in`  

Remove the first few time points  

---
---

In [None]:
clover_sel, mscarl_sel = HER2_classifier.burnin(args, clover_sel, mscarl_sel)
print('len selector:', len(clover_sel))
print('len selector:', len(mscarl_sel))
print(clover_sel[0:5])
print(mscarl_sel[0:5])

--- 
---
# Resample time-series
---
---

In [None]:
X_train = HER2_classifier.resample(data, args, clover_sel, mscarl_sel)

--- 
---
# Fit the time-series K-means clustering
---
---

In [None]:
y_pred, km = HER2_classifier.fit_timeseries_kmeans(args, X_train, plot=True, save=None)

In [None]:
X_train.shape

In [None]:
np.unique(y_pred)

In [None]:
alpha_ = 0.015

for clus in range(args.nclus[0]): 
    clus_tracks = X_train[y_pred == clus, :, :]
    erk_ = clus_tracks[:, :, 0]
    akt_ = clus_tracks[:, :, 1]

    f,axes = plt.subplots(1,2, figsize=(10,5), sharey=True)
    axes[0].set_title('CLOVER - ERK')
    axes[1].set_title('MSCARLET - AKT')

    for t in erk_: 
        axes[0].plot(t, 'r-', alpha=alpha_)

    for t in akt_: 
        axes[1].plot(t, 'b-', alpha=alpha_)

    axes[0].plot(erk_.mean(axis=0), 'k--', linewidth=5, label='mean')
    axes[1].plot(akt_.mean(axis=0), 'k--', linewidth=5, label='mean')

    #axes[0].plot(np.median(erk_, axis=0), 'k-', linewidth=3, label='median')
    #axes[1].plot(np.median(akt_, axis=0), 'k-', linewidth=3, label='median')

    axes[0].set_ylabel('Pathway Reporter Value', fontsize=13)

    axes[0].set_xlabel('Longitudinal measurement (15 min. incr.)', fontsize=12)
    axes[1].set_xlabel('Longitudinal measurement (15 min. incr.)', fontsize=12)

    plt.suptitle(f'CLUSTER: {clus + 1}')

    axes[0].set_ylim(0,1); axes[1].set_ylim(0,1)
    plt.tight_layout()
    #plt.legend()
    plt.savefig(f'./figs/cluster_{clus+1}')

--- 
---
# Calculate the cluster proportion 

...within each experiment. 

---
---

In [None]:
cm, lb = HER2_classifier.quantify_cluster_prop(args, data, y_pred)
cm.shape

--- 
---
# Resistance Signature 

---
---

In [None]:
cm_df = pd.DataFrame(cm).assign(label=lb.classes_).rename({x:f'cluster_{x+1}' for x in range(len(lb.classes_))}, axis=1)
cm_df = cm_df.assign(drug=[x.split('--')[0] for x in cm_df.label])
cm_df = cm_df.assign(mutant=[x.split('--')[1] for x in cm_df.label])
cm_df.head()

In [None]:
_line_sig = cm_df[lambda x: (x.mutant.isin([args.resistant_line[0], args.sensitive_line[0]]))]

In [None]:
_line_sig2 = _line_sig.drop('label', axis=1).set_index(['drug', 'mutant']).stack().reset_index().rename({'level_2':'feature', 0:'membership_prob'}, axis=1).assign(group=lambda x: [f'{mut}-{args.drug[0]}' if treat != 'untreated' else f'{mut}-{treat}' for mut, treat in zip(x.mutant.values, x.drug.values)])
_line_sig2

In [None]:
plt.figure(figsize=(20,7))
sbn.boxplot(x='feature', y='membership_prob', hue='group', data=_line_sig2, width=0.7)
plt.show()

--- 
---
# Visualize the cluster co-occurence within experiment
---
---

In [None]:
HER2_classifier.plot_cluster_corr(cm, save=None)

--- 
---
# Plot hiearchichal clustering of cluster membership 
---
---

In [None]:
HER2_classifier.plot_cluster_heatmap(cm, lb, save=None)

--- 
---
# Dimensionality Reduction 
---
---

In [None]:
res, pca = HER2_classifier.reduce_dim(args, cm, lb, plot=True, save=None)

In [None]:
res

In [None]:
pc_loadings = pd.DataFrame({'clus_feat': range(pca.components_.shape[1]), 'PC1':pca.components_[0], 'PC2':pca.components_[1]})
pc_loadings.head()

In [None]:
plt.figure(figsize=(10,7))
sbn.barplot(x='clus_feat', y='PC1', data=pc_loadings, order=pc_loadings.sort_values(by='PC1').clus_feat)
plt.show()

In [None]:
plt.figure(figsize=(10,7))
sbn.barplot(x='clus_feat', y='PC2', data=pc_loadings, order=pc_loadings.sort_values(by='PC2').clus_feat)
plt.show()

In [None]:
res[lambda x: (x.pc2 < -0.2) & (x.pc1 > 0.3)]

---
---
# Check for Batch Effects
---
---

In [None]:
batch_res = HER2_classifier.check_batch_effects(args, res, plot=True, save=None)
batch_res.head()

--- 
---
# Train classifier on `[positive/negative]` controls
---
---

In [None]:
model, accuracy = HER2_classifier.train_classifier(args, res, plot=True, save=None)

--- 
---
# Assign mutant sensitivity/resistance calls
---
---

In [None]:
prob_res = HER2_classifier.predict_mutants(args, model, res, batch_res, low_data_flags)
prob_res

In [None]:
plt.figure(figsize=(10,10))
sbn.scatterplot(x='pc1', y='pc2', data=prob_res[lambda x: ~x.mutant.isin(['WT', 'ND611', 'T798I'])], hue='prob_res', style='call', s=300)
plt.show()

In [None]:
prob_res.sort_values('odds_ratio', ascending=False).head(10)[['pc1', 'pc2', 'mutant', 'treatment', 'prob_res', 'odds_ratio', 'call', 'PC1_batch_flag', 'PC2_batch_flag', 'low_data_flag']]

In [None]:
prob_res[lambda x: (x.call == 'no-call') & (x.pc1 < 0)]