## predict gene upregulation/downregulation
- Using Expectation Reflection to infer the gene network and predict gene upregulation/downregulation at the next time point
- initial measurements: $t[0-4]$
- final measurements: $t[1-5]$
- midpoint measurements: $\dfrac{t[0-4] + t[1-5])}{2}$

In [1]:
from functions import *
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
%matplotlib inline

In [2]:
def ER_inference(X, y, kf=5, regu=0.005):
    kfold = KFold(n_splits=kf, shuffle=False, random_state=1)
    accuracy_v = []
    
#     plt.rcParams.update({'font.size':20})
#     fig, ax = plt.subplots(99//3, 3, figsize=(15, 4*(99//3)))
#     plt.subplots_adjust(left=0.1, wspace=0.4, hspace=0.4)
    for (tr_ind, te_ind) in (kfold.split(y)):
        X_tr, y_tr = X[tr_ind], y[tr_ind]
        X_te, y_te = X[te_ind], y[te_ind]
        
#         W_full = np.zeros((X_tr.shape[1], y_tr.shape[1]))
#         h0_full = np.zeros(y_tr.shape[1])
        accuracy = np.zeros(y_tr.shape[1])
        
        for n in range(y_tr.shape[1]):
            h0, w = fit(X_tr, y_tr[:,n], niter_max=100, regu=0.005)
            h_pred = h0 + X_te.dot(w)
            y_pred = np.sign(h_pred)
            accuracy[n] += accuracy_score(y_te[:,n], y_pred, normalize=False)
#             ax[n//3, n%3].scatter(h_pred, y_te[:,n])
#             ax[n//3, n%3].set_xlabel('h_pred')
#             ax[n//3, n%3].set_ylabel('y_te')
#             W_full[:,n] = w
#             h0_full[n] = h0
#         print(y_te.size)
        accuracy = accuracy/y_te.shape[0]
        accuracy_v.append(accuracy)
#     plt.show()
#     pdf.savefig(fig)
    return np.vstack(accuracy_v)

In [None]:
data_complete = np.loadtxt('../data_complete.txt')
data_init, data_fin, data_midpt, data_deriv = make_data(data_complete, n_bin=6)

names = ["initial measurements", "final measurements", "midpoint measurements"]
data_list = [data_init, data_fin, data_midpt]
markers = ["o", "x", "+"]

fig = plt.figure(figsize=(14,10))
plt.rcParams.update({'font.size':16})
for i, data in enumerate(data_list):
    X, y = shuffle(data, data_deriv)
    accu = ER_inference(X, y, kf=5, regu=0.005)
    df=pd.DataFrame(accu)
    plt.barh(range(y.shape[1]), accu.mean(axis=0), height=0.8, alpha=0.5)
    plt.xlabel('gene')
    plt.ylabel('accuracy')
    plt.xlim(0.6, 1)
plt.grid()
plt.show()
#     print(df.mean().sort_values(ascending=False))
#     print(names[i],": ", mean, std)
    