# BRCA

In [None]:
import sage
import numpy as np
import matplotlib.pyplot as plt

In [None]:
brca_sage = sage.load('results/brca_sage.pkl')

In [None]:
gene_names = [
    'BCL11A', 'IGF1R', 'CCND1', 'CDK6', 'BRCA1', 'BRCA2', 'EZH2', 'SFTPD',
    'CDC5L', 'ADMR', 'TSPAN2', 'EIF5B', 'ADRA2C', 'MRCL3', 'CCDC69', 'ADCY4',
    'TEX14', 'RRM2B', 'SLC22A5', 'HRH1', 'SLC25A1', 'CEBPE', 'IWS1', 'FLJ10213',
    'PSMD10', 'MARCH6', 'PDLIM4', 'SNTB1', 'CHCHD1', 'SCMH1', 'FLJ20489',
    'MDP-1', 'FLJ30092', 'YTHDC2', 'LFNG', 'HOXD10', 'RPS6KA5', 'WDR40B',
    'CST9L', 'ISLR', 'TMBIM1', 'TRABD', 'ARHGAP29', 'C15orf29', 'SCAMP4',
    'TTC31', 'ZNF570', 'RAB42', 'SERPINI2', 'C9orf21'
]

brca_colors = {
    'BCL11A': True,
    'IGF1R': True,
    'CCND1': True,
    'CDK6': True,
    'BRCA1': True,
    'BRCA2': True,
    'EZH2': True,
    'SFTPD': False,
    'CDC5L': False,
    'ADMR': False,
    'TSPAN2': False,
    'EIF5B': False,
    'ADRA2C': False,
    'MRCL3': False,
    'CCDC69': False,
    'ADCY4': False,  # True
    'TEX14': True,  # False
    'RRM2B': False,
    'SLC22A5': True, # https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3416040/
    'HRH1': False,  # True
    'SLC25A1': False, # Associated with tumor growth http://www.oncotarget.com/index.php?journal=oncotarget&page=article&op=view&path[]=1831&path[]=2259
    'CEBPE': False,  # True
    'IWS1': False,
    'FLJ10213': False,
    'PSMD10': False,  # True
    'MARCH6': False,
    'PDLIM4': False,
    'SNTB1': False,
    'CHCHD1': False,
    'SCMH1': False,
    'FLJ20489': False,
    'MDP-1': False,
    'FLJ30092': False,  # True
    'YTHDC2': False,
    'LFNG': True,  # False, https://www.sciencedirect.com/science/article/pii/S1535610812001341
    'HOXD10': False,  # True
    'RPS6KA5': False,
    'WDR40B': False,
    'CST9L': False,
    'ISLR': False,
    'TMBIM1': False,
    'TRABD': False,
    'ARHGAP29': False,  # True
    'C15orf29': False,
    'SCAMP4': False,
    'TTC31': False,
    'ZNF570': False,
    'RAB42': False,
    'SERPINI2': False,
    'C9orf21': False, 
}

In [None]:
plt.figure(figsize=(16, 5))

values = brca_sage.values
order = np.argsort(values)[::-1]
values = values[order]
std = brca_sage.std[order]
brca_associated = np.array([brca_colors[gene] for gene in np.array(gene_names)[order]])

# BRCA associated
plt.bar(np.arange(len(values))[brca_associated],
        values[brca_associated],
        yerr=1.96*std[brca_associated],
        capsize=5, color='orchid', label='BRCA Associated')

# Not BRCA associated
plt.bar(np.arange(len(values))[np.logical_not(brca_associated)],
        values[np.logical_not(brca_associated)],
        yerr=1.96*std[np.logical_not(brca_associated)],
        capsize=5, color='tab:blue', label='Not BRCA Associated')
plt.xticks(np.arange(len(values)), np.array(gene_names)[order], fontsize=14,
           rotation=45, rotation_mode='anchor', ha='right')

plt.legend(loc='upper right', fontsize=18)
plt.title('Breast Cancer Gene Identification', fontsize=20)
plt.ylabel('SAGE Values', fontsize=18)
plt.tick_params('y', labelsize=16)
plt.tight_layout()
# plt.show()
plt.savefig('figures/brca_sage.pdf')

# MNIST

In [None]:
import sage
import pickle

In [None]:
mnist_sage = sage.load('results/mnist_sage.pkl')

In [None]:
with open('results/mnist mean_importance.pkl', 'rb') as f:
    mean_imp = pickle.load(f)

In [None]:
permutation = []
for i in range(1024):
    filename = 'results/mnist permutation_test {}.pkl'.format(i)
    with open(filename, 'rb') as f:
        permutation.append(pickle.load(f)['scores'])
permutation = np.array(permutation).mean(axis=0)

In [None]:
with open('results/mnist feature_ablation.pkl', 'rb') as f:
    ablation = pickle.load(f)

In [None]:
with open('results/mnist univariate.pkl', 'rb') as f:
    univariate = pickle.load(f)

In [None]:
mnist_results = (mnist_sage, permutation, mean_imp, ablation, univariate)
mnist_names = ('SAGE', 'Permutation Test', 'Mean Importance', 'Feature Ablation', 'Univariate Predictors')

In [None]:
fig, axarr = plt.subplots(1, len(mnist_results), figsize=(16, 6))

for i, (result, name) in enumerate(zip(mnist_results, mnist_names)):
    ax = axarr[i]
    plt.sca(ax)
    if isinstance(result, np.ndarray):
        # Regular values
        values = result
    else:
        # SAGE
        values = result.values
        
    m = np.max(np.abs(values))
    plt.imshow(np.reshape(- values, (28, 28)), cmap='seismic', vmin=-m, vmax=m)
    plt.grid(False)
    plt.xticks([])
    plt.yticks([])
    plt.title(name, fontsize=20)

plt.tight_layout()
# plt.show()
plt.savefig('figures/mnist_sage.pdf')