In [None]:
siginfo = pd.read_csv('../../data/siginfo_beta.txt', low_memory=False, sep='\t')
druginfo = pd.read_csv('../../data/compoundinfo_beta.txt', sep='\t')
geneinfo = pd.read_csv('../extdata/omnipath_uniprot2genesymb.tsv', sep='\t')

reducer = umap.UMAP(n_neighbors=15, min_dist=1.) 
u = reducer.fit_transform(out.detach().numpy())

udf = pd.DataFrame(u, columns=['umap1', 'umap2']).assign(sig_id=np.array(sig_ids)[hiconc_drug_mask])
udf = pd.concat((udf, pd.DataFrame(out.detach().numpy(), columns=out_edge_names)), axis=1)
udf = udf.merge(siginfo[['sig_id', 'pert_id', 'cell_iname', 'pert_dose']], on='sig_id', how='left', validate='1:1')
udf = udf.assign(log_pert_dose=np.log10(udf.pert_dose))
brd2cmap = druginfo[lambda x: x.target == 'EGFR'][['pert_id', 'cmap_name']].set_index('pert_id').to_dict()['cmap_name']
udf = udf.assign(cmap_name = [brd2cmap[x] if x in brd2cmap else x for x in udf.pert_id])

# NOTE: redundant outgoing edges. 
cols=pd.Series(udf.columns)
for dup in cols[cols.duplicated()].unique(): 
    cols[cols[cols == dup].index.values.tolist()] = [dup + '.' + str(i) if i != 0 else dup for i in range(sum(cols == dup))]
# rename the columns with the cols list.
udf.columns=cols

f,axes = plt.subplots(1,2, figsize=(14,6), sharey=True)
sbn.scatterplot(x='umap1', y='umap2', data=udf, hue='log_pert_dose', s=250, ax=axes.flat[0], alpha=0.25, marker='.', linewidth=0)
g = sbn.scatterplot(x='umap1', y='umap2', data=udf, hue='cmap_name', style='cmap_name', alpha=1., ax=axes.flat[1], s=25, edgecolor='k')
g.legend(loc='center left', bbox_to_anchor=(1.0, 0.5), ncol=2)

#sbn.scatterplot(x='umap1', y='umap2', data=udf[lambda x: x.cell_iname == 'MCF7'], hue='cmap_name', style='cmap_name', alpha=1., ax=axes.flat[3], s=50, legend=None)
#sbn.scatterplot(x='umap1', y='umap2', data=udf[lambda x: x.cell_iname == 'PC3'], hue='cmap_name', style='cmap_name', alpha=1., ax=axes.flat[4], s=50, legend=None)
#sbn.scatterplot(x='umap1', y='umap2', data=udf[lambda x: x.cell_iname == 'A375'], hue='cmap_name', style='cmap_name', alpha=1., ax=axes.flat[5], s=50, legend=None)
plt.suptitle('EGFR Response', fontsize=18)
plt.tight_layout()
plt.show()