In [None]:
import numpy as np
import pickle
import random
from sklearn.linear_model import OrthogonalMatchingPursuit, OrthogonalMatchingPursuitCV
import bz2
import matplotlib
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns

matplotlib.rcParams.update({'font.size':10})
matplotlib.rcParams.update({'font.family':'Times New Roman'})
matplotlib.rcParams['legend.frameon'] = False
plt.rcParams.update({"text.usetex": True})

In [None]:
def make_quad(X):
    quad = np.zeros((int(X.shape[0]), int(X.shape[1] + (X.shape[1]*(X.shape[1]-1))/2)))
    quad[:, :X.shape[1]] = np.copy(X)
    col = 99
    for i in range(X.shape[1]-1):
        for j in range(i+1, X.shape[1]):
            quad[:,col] = (X[:,i]*X[:,j])
            col += 1
    return quad

In [None]:
with open(PATH + 'LAD_(99,99)_tvt_e8.pkl', 'rb') as f:
    LAD = pickle.load(f)
    
w = LAD['w'][-1]
bias = LAD['bias'][-1]

data = np.loadtxt('../data_complete_tf_new.txt')
init = np.vsplit(data, 6)[0]

quad = make_quad(init)
f = init + bias + quad.dot(w)
f[f<0] = 0

In [None]:
raw_data = np.loadtxt('../dmel_data.txt').T  # transpose the data so that the shape is 6078 cells x (3+6*99) columns
xyz = raw_data[:,:3]

In [None]:
# perturb the effect of gene 22 (DOC2) on gene 91 (trn).

source = 22
target = 91

if w[source,target] > 0:
    print('the source gene has positive effect on the target gene')
elif w[source,target] < 0:
    print('the source gene has negative effect on the target gene')
    
lower = np.percentile(init[:,source][init[:,source] > 0], 30)
higher = np.percentile(init[:,source][init[:,source] > 0], 80)

new_f = np.copy(f)

new_w = np.copy(w)
new_w[source, target] = 1.15*new_w[source, target]

pred_f = init + bias + quad.dot(new_w)
pred_f[pred_f<0] = 0

current_palette=[sns.color_palette('Set3')[8], sns.color_palette('Paired')[6],
                 sns.color_palette('Paired')[7], sns.color_palette('Paired')[5]]

fig = plt.figure(figsize=(5,2), dpi=600)
fig.subplots_adjust(wspace=0.4)
ax = fig.add_subplot(1,2,1)
ax.axis('off')
xy=ax.scatter(xyz[:,0][xyz[:,1] > 0], xyz[:,2][xyz[:,1] > 0], c=init[:,source][xyz[:,1] > 0], s=1.5, cmap=plt.cm.OrRd)
ax.scatter(xyz[:,0][xyz[:,1] < 0], 190-xyz[:,2][xyz[:,1] < 0], c=init[:,source][xyz[:,1] < 0], s=1.5, cmap=plt.cm.OrRd)
ax.text(-245,95, 'Anterior', rotation=90, va='center')
ax.text(0,92, 'Dorsal', va='center', ha='center')
ax.text(185,100, 'Posterior', rotation=-90, va='center')
ax.text(0,280, 'Ventral', ha='center')
ax.text(0,-115, 'Ventral', ha='center')
ax.text(-220,265, '(A)')
cbar_ax = fig.add_axes([0.48, 0.25, 0.02, 0.5])
cbar=fig.colorbar(xy, cax=cbar_ax, ticks=[0,0.6], orientation='vertical')
# fig.colorbar(xy, orientation='horizontal')
ax = fig.add_subplot(1,2,2)
ax.axis('off')
ax.scatter(xyz[:,0][(init[:,source] == 0) & (xyz[:,1] > 0)],
           xyz[:,2][(init[:,source] == 0) & (xyz[:,1] > 0)],
           color=current_palette[0], alpha=1, lw=1, s=2, label='zero expression')
ax.scatter(xyz[:,0][(init[:,source] == 0) & (xyz[:,1] < 0)],
           190-xyz[:,2][(init[:,source] == 0) & (xyz[:,1] < 0)],
           color=current_palette[0], alpha=1, lw=1, s=2)
ax.scatter(xyz[:,0][(init[:,source] <= lower) & (init[:,source] > 0) & (xyz[:,1] > 0)],
           xyz[:,2][(init[:,source] <= lower) & (init[:,source] > 0) & (xyz[:,1] > 0)],
           color=current_palette[1], alpha=1, lw=1, s=2, label='low expression')
ax.scatter(xyz[:,0][(init[:,source] <= lower) & (xyz[:,1] < 0)],
           190-xyz[:,2][(init[:,source] <= lower) & (xyz[:,1] < 0)],
           color=current_palette[1], alpha=1, lw=1, s=2)
ax.scatter(xyz[:,0][(init[:,source] > lower) & (init[:,source] < higher) & (xyz[:,1] > 0)],
           xyz[:,2][(init[:,source] > lower) & (init[:,source] < higher) & (xyz[:,1] > 0)],
           color=current_palette[2], alpha=1, lw=1, s=2, label='mid expression')
ax.scatter(xyz[:,0][(init[:,source] > lower) & (init[:,source] < higher) & (xyz[:,1] < 0)],
           190-xyz[:,2][(init[:,source] > lower) & (init[:,source] < higher) & (xyz[:,1] < 0)],
           color=current_palette[2], alpha=1, lw=1, s=2)
ax.scatter(xyz[:,0][(init[:,source] >= higher) & (xyz[:,1] > 0)],
           xyz[:,2][(init[:,source] >= higher) & (xyz[:,1] > 0)],
           color=current_palette[3], alpha=1, lw=1, s=2, label='high expression')
ax.scatter(xyz[:,0][(init[:,source] >= higher) & (xyz[:,1] < 0)],
           190-xyz[:,2][(init[:,source] >= higher) & (xyz[:,1] < 0)],
           color=current_palette[3], alpha=1, lw=1, s=2)
ax.text(-220,265, '(B)')
plt.legend(bbox_to_anchor=(1,0.85))
plt.savefig(figures+'22to91emb.png', bbox_inches='tight', facecolor='white', transparent=False)
plt.show()

nom = np.mean(np.abs(new_f[:,target][init[:,source] >= higher] - pred_f[:,target][init[:,source] >= higher]))
error = np.abs(new_f[:,target] - pred_f[:,target])/nom
plt.figure(figsize=(2,2), dpi=600)
sns.violinplot(x=np.ones(len(error[(init[:,source] > 0) & (init[:,source] <= lower)])),
               y=error[(init[:,source] > 0) & (init[:,source] <= lower)], order=[1,2,3],
               width=0.6, palette=current_palette[1:], saturation=1)
sns.violinplot(x=2*np.ones(len(error[(init[:,source] > lower) & (init[:,source] < higher)])),
               y=error[(init[:,source] > lower) & (init[:,source] < higher)], order=[1,2,3],
               width=0.6, palette=current_palette[1:], saturation=1)
sns.violinplot(x=3*np.ones(len(error[init[:,source] >= higher])),
               y=error[init[:,source] >= higher], order=[1,2,3],
               width=0.6, palette=current_palette[1:], saturation=1)
plt.xticks(ticks=[0,1,2], labels=[r'cell$_{\rm{low}}$', r'cell$_{\rm{mid}}$', r'cell$_{\rm{high}}$'])
plt.xlabel('cell groups')
plt.ylabel('normalized difference')
# plt.show()
plt.savefig(figures+'22to91normdiff.png', bbox_inches='tight')