In [1]:
import pandas as pd
import numpy as np
import random
import scanpy as sc
import os
import sys

from sklearn.metrics import mean_absolute_percentage_error as mape
from sklearn.metrics import mean_squared_error as mse
from sklearn.metrics import roc_auc_score

sys.path.insert(0, '../repos/GNNImpute/')
from GNNImpute.api import GNNImpute

In [2]:
def get_data_for_i(i):
    original_ = pd.read_csv('../../data/mid_simulation/data.csv.gz', index_col=0)
    df_ = pd.read_csv('../../data/mid_simulation/drp_{}0.csv.gz'.format(i), index_col=0)
    df_.index = [int(i) for i in df_.index]
    df_.columns = [int(i) for i in df_.columns]

    original_.columns = df_.columns
    original_.index = df_.index

    n = original_.size
    original_val = original_.values.copy()
    t = list(np.ndindex(original_.shape))
    random.Random(42).shuffle(t)

    mask = t[:int(len(t)/10 * i)]

    thr = np.sum(np.sign(df_)) > 0
    original_ = original_.loc[:, list(thr)]
    df_ = df_.loc[:, list(thr)]

    # original = original_.values
    original = np.log(original_+1)

    # df = df_.values
    df = np.log(df_+1)

    tmp = pd.DataFrame(thr)
    remove = [int(i) for i in tmp[tmp[0] == False].index]
    mask = [i for i in mask if i[1] not in remove]
    
    return df, mask, original

In [3]:
mses = {}
corrs = {}
mses_ = {}
corrs_ = {}
mses__ = {}
corrs__ = {}
aucs = {}
method = 'GNNImpute'

for i in (range(9)):
    print(i)
    df, mask, original = get_data_for_i(i+1)

    adata = sc.AnnData(df.values)
    adata = GNNImpute(
        adata=adata, layer='GATConv',
        no_cuda=False,
        d = '/export/scratch/inoue019/gnn_mid/',
        epochs=3000, 
        lr=0.001, weight_decay=0.0005,
        hidden=50, patience=200,
        fastmode=True, heads=3,
        use_raw=False,
        verbose=True
    )
    
    pred = adata.X
    pred = pd.DataFrame(pred, columns=df.columns, index=df.index)
#     pred.to_csv('result/{}_{}.csv.gz'.format(method, i), compression='gzip')

    origin = np.array([original.loc[i] for i in mask])
    predict = np.array([pred.loc[i] for i in mask])

    mses[i] = mse(origin, predict)
    corrs[i] = np.corrcoef(origin, predict)[0][1]
    mses_[i] = mse(origin[origin != 0], predict[origin != 0])
    corrs_[i] = np.corrcoef(origin[origin != 0], predict[origin != 0])[0][1]
    mses__[i] = mse(origin[origin == 0], predict[origin == 0])
    
    df =  pd.DataFrame(np.array(predict))
    df['rank'] = df.rank()
    df['label'] = np.sign(origin)
    aucs[i] = roc_auc_score(df['label'], df['rank'])

0
Epoch: 0010 loss_train: 0.9997 loss_val: 1.0029
Epoch: 0020 loss_train: 0.9983 loss_val: 1.0017
Epoch: 0030 loss_train: 0.9928 loss_val: 0.9967
Epoch: 0040 loss_train: 0.9857 loss_val: 0.9897
Epoch: 0050 loss_train: 0.9768 loss_val: 0.9820
Epoch: 0060 loss_train: 0.9697 loss_val: 0.9751
Epoch: 0070 loss_train: 0.9640 loss_val: 0.9694
Epoch: 0080 loss_train: 0.9599 loss_val: 0.9669
Epoch: 0090 loss_train: 0.9578 loss_val: 0.9637
Epoch: 0100 loss_train: 0.9564 loss_val: 0.9634
Epoch: 0110 loss_train: 0.9554 loss_val: 0.9626
Epoch: 0120 loss_train: 0.9548 loss_val: 0.9618
Epoch: 0130 loss_train: 0.9543 loss_val: 0.9617
Epoch: 0140 loss_train: 0.9540 loss_val: 0.9610
Epoch: 0150 loss_train: 0.9537 loss_val: 0.9613
Epoch: 0160 loss_train: 0.9537 loss_val: 0.9608
Epoch: 0170 loss_train: 0.9535 loss_val: 0.9608
Epoch: 0180 loss_train: 0.9536 loss_val: 0.9609
Epoch: 0190 loss_train: 0.9531 loss_val: 0.9609
Epoch: 0200 loss_train: 0.9532 loss_val: 0.9606
Epoch: 0210 loss_train: 0.9531 loss_va

Epoch: 0450 loss_train: 0.9593 loss_val: 0.9625
Epoch: 0460 loss_train: 0.9596 loss_val: 0.9625
['/export/scratch/inoue019/gnn_mid/55.pkl', '/export/scratch/inoue019/gnn_mid/85.pkl', '/export/scratch/inoue019/gnn_mid/52.pkl', '/export/scratch/inoue019/gnn_mid/45.pkl', '/export/scratch/inoue019/gnn_mid/13.pkl', '/export/scratch/inoue019/gnn_mid/14.pkl', '/export/scratch/inoue019/gnn_mid/233.pkl', '/export/scratch/inoue019/gnn_mid/65.pkl', '/export/scratch/inoue019/gnn_mid/76.pkl', '/export/scratch/inoue019/gnn_mid/90.pkl', '/export/scratch/inoue019/gnn_mid/87.pkl', '/export/scratch/inoue019/gnn_mid/56.pkl', '/export/scratch/inoue019/gnn_mid/9.pkl', '/export/scratch/inoue019/gnn_mid/122.pkl', '/export/scratch/inoue019/gnn_mid/67.pkl', '/export/scratch/inoue019/gnn_mid/105.pkl', '/export/scratch/inoue019/gnn_mid/108.pkl', '/export/scratch/inoue019/gnn_mid/75.pkl', '/export/scratch/inoue019/gnn_mid/103.pkl', '/export/scratch/inoue019/gnn_mid/23.pkl', '/export/scratch/inoue019/gnn_mid/15.pk

Total time elapsed: 52.2824s
3
Epoch: 0010 loss_train: 0.9996 loss_val: 0.9988
Epoch: 0020 loss_train: 0.9992 loss_val: 0.9986
Epoch: 0030 loss_train: 0.9974 loss_val: 0.9973
Epoch: 0040 loss_train: 0.9946 loss_val: 0.9943
Epoch: 0050 loss_train: 0.9906 loss_val: 0.9913
Epoch: 0060 loss_train: 0.9858 loss_val: 0.9867
Epoch: 0070 loss_train: 0.9835 loss_val: 0.9850
Epoch: 0080 loss_train: 0.9815 loss_val: 0.9835
Epoch: 0090 loss_train: 0.9805 loss_val: 0.9811
Epoch: 0100 loss_train: 0.9777 loss_val: 0.9796
Epoch: 0110 loss_train: 0.9755 loss_val: 0.9780
Epoch: 0120 loss_train: 0.9744 loss_val: 0.9774
Epoch: 0130 loss_train: 0.9733 loss_val: 0.9771
Epoch: 0140 loss_train: 0.9730 loss_val: 0.9769
Epoch: 0150 loss_train: 0.9727 loss_val: 0.9766
Epoch: 0160 loss_train: 0.9723 loss_val: 0.9758
Epoch: 0170 loss_train: 0.9720 loss_val: 0.9760
Epoch: 0180 loss_train: 0.9722 loss_val: 0.9763
Epoch: 0190 loss_train: 0.9718 loss_val: 0.9757
Epoch: 0200 loss_train: 0.9718 loss_val: 0.9758
Epoch: 02

Epoch: 0370 loss_train: 0.9769 loss_val: 0.9873
Epoch: 0380 loss_train: 0.9767 loss_val: 0.9871
Epoch: 0390 loss_train: 0.9768 loss_val: 0.9875
Epoch: 0400 loss_train: 0.9772 loss_val: 0.9873
Epoch: 0410 loss_train: 0.9769 loss_val: 0.9873
Epoch: 0420 loss_train: 0.9770 loss_val: 0.9872
Epoch: 0430 loss_train: 0.9770 loss_val: 0.9873
Epoch: 0440 loss_train: 0.9771 loss_val: 0.9874
Epoch: 0450 loss_train: 0.9772 loss_val: 0.9872
Epoch: 0460 loss_train: 0.9772 loss_val: 0.9871
Epoch: 0470 loss_train: 0.9772 loss_val: 0.9872
Epoch: 0480 loss_train: 0.9770 loss_val: 0.9875
['/export/scratch/inoue019/gnn_mid/55.pkl', '/export/scratch/inoue019/gnn_mid/85.pkl', '/export/scratch/inoue019/gnn_mid/172.pkl', '/export/scratch/inoue019/gnn_mid/107.pkl', '/export/scratch/inoue019/gnn_mid/52.pkl', '/export/scratch/inoue019/gnn_mid/45.pkl', '/export/scratch/inoue019/gnn_mid/13.pkl', '/export/scratch/inoue019/gnn_mid/58.pkl', '/export/scratch/inoue019/gnn_mid/155.pkl', '/export/scratch/inoue019/gnn_mid

Total time elapsed: 44.1451s
6
Epoch: 0010 loss_train: 1.0025 loss_val: 0.9935
Epoch: 0020 loss_train: 1.0025 loss_val: 0.9935
Epoch: 0030 loss_train: 1.0024 loss_val: 0.9935
Epoch: 0040 loss_train: 1.0020 loss_val: 0.9932
Epoch: 0050 loss_train: 1.0008 loss_val: 0.9923
Epoch: 0060 loss_train: 0.9992 loss_val: 0.9910
Epoch: 0070 loss_train: 0.9970 loss_val: 0.9887
Epoch: 0080 loss_train: 0.9956 loss_val: 0.9884
Epoch: 0090 loss_train: 0.9948 loss_val: 0.9873
Epoch: 0100 loss_train: 0.9935 loss_val: 0.9865
Epoch: 0110 loss_train: 0.9920 loss_val: 0.9856
Epoch: 0120 loss_train: 0.9912 loss_val: 0.9851
Epoch: 0130 loss_train: 0.9904 loss_val: 0.9849
Epoch: 0140 loss_train: 0.9900 loss_val: 0.9847
Epoch: 0150 loss_train: 0.9898 loss_val: 0.9845
Epoch: 0160 loss_train: 0.9896 loss_val: 0.9845
Epoch: 0170 loss_train: 0.9894 loss_val: 0.9844
Epoch: 0180 loss_train: 0.9892 loss_val: 0.9845
Epoch: 0190 loss_train: 0.9890 loss_val: 0.9847
Epoch: 0200 loss_train: 0.9889 loss_val: 0.9844
Epoch: 02

Epoch: 0410 loss_train: 0.9899 loss_val: 0.9963
Epoch: 0420 loss_train: 0.9899 loss_val: 0.9962
Epoch: 0430 loss_train: 0.9901 loss_val: 0.9963
Epoch: 0440 loss_train: 0.9899 loss_val: 0.9963
Epoch: 0450 loss_train: 0.9900 loss_val: 0.9963
Epoch: 0460 loss_train: 0.9897 loss_val: 0.9964
Epoch: 0470 loss_train: 0.9900 loss_val: 0.9961
Epoch: 0480 loss_train: 0.9901 loss_val: 0.9963
Epoch: 0490 loss_train: 0.9898 loss_val: 0.9961
Epoch: 0500 loss_train: 0.9899 loss_val: 0.9961
Epoch: 0510 loss_train: 0.9899 loss_val: 0.9965
Epoch: 0520 loss_train: 0.9901 loss_val: 0.9961
Epoch: 0530 loss_train: 0.9900 loss_val: 0.9962
Epoch: 0540 loss_train: 0.9902 loss_val: 0.9961
Epoch: 0550 loss_train: 0.9902 loss_val: 0.9963
Epoch: 0560 loss_train: 0.9901 loss_val: 0.9962
Epoch: 0570 loss_train: 0.9900 loss_val: 0.9966
Epoch: 0580 loss_train: 0.9902 loss_val: 0.9962
Epoch: 0590 loss_train: 0.9903 loss_val: 0.9961
Epoch: 0600 loss_train: 0.9904 loss_val: 0.9961
Epoch: 0610 loss_train: 0.9904 loss_val:

Epoch: 0270 loss_train: 0.9970 loss_val: 0.9897
Epoch: 0280 loss_train: 0.9971 loss_val: 0.9894
Epoch: 0290 loss_train: 0.9965 loss_val: 0.9893
Epoch: 0300 loss_train: 0.9969 loss_val: 0.9893
Epoch: 0310 loss_train: 0.9970 loss_val: 0.9894
Epoch: 0320 loss_train: 0.9968 loss_val: 0.9893
Epoch: 0330 loss_train: 0.9972 loss_val: 0.9893
Epoch: 0340 loss_train: 0.9971 loss_val: 0.9894
Epoch: 0350 loss_train: 0.9962 loss_val: 0.9894
Epoch: 0360 loss_train: 0.9967 loss_val: 0.9894
Epoch: 0370 loss_train: 0.9967 loss_val: 0.9893
Epoch: 0380 loss_train: 0.9972 loss_val: 0.9893
Epoch: 0390 loss_train: 0.9968 loss_val: 0.9894
Epoch: 0400 loss_train: 0.9967 loss_val: 0.9892
Epoch: 0410 loss_train: 0.9970 loss_val: 0.9893
Epoch: 0420 loss_train: 0.9970 loss_val: 0.9892
Epoch: 0430 loss_train: 0.9969 loss_val: 0.9893
Epoch: 0440 loss_train: 0.9967 loss_val: 0.9894
Epoch: 0450 loss_train: 0.9970 loss_val: 0.9899
Epoch: 0460 loss_train: 0.9971 loss_val: 0.9893
Epoch: 0470 loss_train: 0.9969 loss_val:

Total time elapsed: 128.8874s


In [4]:
pd.DataFrame([
    mses.values(),
    mses_.values(),
    mses__.values(),
    corrs.values(),
    corrs_.values(),
    aucs.values()
], index=['mse', 'mse (nonzero)', 'mse (zero)', 'corr', 'corrs (nonzero)', 'auc'])

Unnamed: 0,0,1,2,3,4,5,6,7,8
mse,0.646655,0.654851,0.664911,0.672421,0.679852,0.689355,0.698951,0.715709,0.734875
mse (nonzero),1.87649,1.906685,1.943408,1.969146,1.995527,2.027546,2.05986,2.115081,2.175624
mse (zero),0.018847,0.016465,0.013548,0.011891,0.010072,0.007984,0.005775,0.002642,0.000734
corr,0.266393,0.252818,0.238831,0.226322,0.218506,0.204955,0.194336,0.181836,0.151137
corrs (nonzero),0.1472,0.131866,0.114687,0.100331,0.093449,0.082926,0.075359,0.069127,0.068195
auc,0.656489,0.654489,0.652002,0.645138,0.640988,0.632401,0.627204,0.630108,0.618531
