In [1]:
import sys
# ^^^ pyforest auto-imports - don't write above this line
import numpy as np
import pandas as pd
import scanpy as sc

from sklearn.metrics import adjusted_rand_score, adjusted_mutual_info_score, normalized_mutual_info_score
from sklearn.cluster import KMeans, SpectralClustering
from sklearn.cluster import SpectralClustering

from sklearn.decomposition import PCA, SparsePCA, KernelPCA
from sklearn.manifold import TSNE

from rpy2.robjects import r, pandas2ri
from rpy2.robjects.vectors import StrVector

pandas2ri.activate()

import magic
import scprep

%matplotlib inline

# from sklearnex import patch_sklearn
# patch_sklearn()

import warnings

from sklearn.cluster import KMeans
from tqdm import tqdm

sys.path.insert(0, '../../imputation2/notebooks/repos/GNNImpute/')
from GNNImpute.api import GNNImpute

In [2]:
def get_data(i):
    df = pd.read_csv('../data/{}/data.csv.gz'.format(i), index_col=0)
    tmp = np.sign(df)
    cols = (np.sum(tmp) > int((df.shape[0])*0.05))
    rows = (np.sum(tmp, axis=1) > int((df.shape[1])*0.05))
    df = np.log(df.loc[rows, cols] + 1)
    df_norm = df.copy()
    df_norm = scprep.normalize.library_size_normalize(df_norm)    
    df_norm = scprep.transform.sqrt(df_norm)
    X_norm = pd.DataFrame(df_norm, columns=df.columns)
    labels = df.index
    return X_norm, labels

In [3]:
X_norm, labels = get_data('brosens')

adata = sc.AnnData(X_norm.values)
adata = GNNImpute(
    adata=adata, layer='GATConv',
    no_cuda=False,
    d = '/export/scratch/inoue019/gnn/',
    epochs=3000, 
    lr=0.001, weight_decay=0.0005,
    hidden=50, patience=200,
    fastmode=True, heads=3,
    use_raw=False,
    verbose=True
)

pred = adata.X
pred = pd.DataFrame(pred, columns=X_norm.columns, index=X_norm.index)

Epoch: 0010 loss_train: 1.0015 loss_val: 1.0043
Epoch: 0020 loss_train: 1.0004 loss_val: 1.0033
Epoch: 0030 loss_train: 0.9934 loss_val: 0.9968
Epoch: 0040 loss_train: 0.9823 loss_val: 0.9862
Epoch: 0050 loss_train: 0.9741 loss_val: 0.9788
Epoch: 0060 loss_train: 0.9687 loss_val: 0.9741
Epoch: 0070 loss_train: 0.9663 loss_val: 0.9722
Epoch: 0080 loss_train: 0.9646 loss_val: 0.9715
Epoch: 0090 loss_train: 0.9640 loss_val: 0.9707
Epoch: 0100 loss_train: 0.9638 loss_val: 0.9706
Epoch: 0110 loss_train: 0.9630 loss_val: 0.9703
Epoch: 0120 loss_train: 0.9628 loss_val: 0.9702
Epoch: 0130 loss_train: 0.9625 loss_val: 0.9696
Epoch: 0140 loss_train: 0.9618 loss_val: 0.9694
Epoch: 0150 loss_train: 0.9617 loss_val: 0.9690
Epoch: 0160 loss_train: 0.9612 loss_val: 0.9690
Epoch: 0170 loss_train: 0.9610 loss_val: 0.9687
Epoch: 0180 loss_train: 0.9608 loss_val: 0.9693
Epoch: 0190 loss_train: 0.9613 loss_val: 0.9684
Epoch: 0200 loss_train: 0.9613 loss_val: 0.9695
Epoch: 0210 loss_train: 0.9607 loss_val:

In [5]:
pred

Unnamed: 0,A2M,A4GALT,AAAS,AACS,AADAT,AAGAB,AAK1,AAMDC,AAMP,AAR2,...,ZSWIM6,ZSWIM7,ZSWIM8,ZSWIM9,ZUP1,ZWILCH,ZWINT,ZYG11B,ZYX,ZZEF1
SS2,0.303880,0.000000,0.151850,0.248136,0.226868,0.221843,0.195860,0.217089,0.407839,0.236042,...,0.316320,0.239315,0.059002,0.115658,0.077305,0.252361,0.068709,0.249620,0.499381,0.139167
SS2,0.329073,0.000000,0.193669,0.264413,0.274464,0.198994,0.224000,0.268463,0.409397,0.256556,...,0.289232,0.248243,0.070246,0.079380,0.080437,0.157756,0.007012,0.262602,0.558678,0.159994
SS2,0.344039,0.000000,0.197376,0.357389,0.393534,0.300404,0.199509,0.228877,0.562303,0.269647,...,0.384918,0.388639,0.000000,0.143464,0.052437,0.319977,0.078552,0.376002,0.623488,0.240350
SS2,0.424885,0.000000,0.068215,0.218071,0.238444,0.250544,0.198645,0.237818,0.369851,0.240262,...,0.307168,0.218632,0.048718,0.092333,0.011054,0.163839,0.000000,0.269002,0.614939,0.129824
SS2,0.365599,0.000000,0.209065,0.335089,0.362439,0.253681,0.223193,0.261014,0.491484,0.237775,...,0.368551,0.347219,0.021325,0.118936,0.074922,0.232977,0.034227,0.333316,0.571686,0.211969
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
EpS4,0.000000,0.112224,0.000000,0.000000,0.000000,0.005347,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.011284,0.000000,0.000000,0.000000
SS4,0.000000,0.009245,0.000000,0.000000,0.000000,0.004600,0.000000,0.000000,0.013264,0.010146,...,0.000000,0.000000,0.004335,0.000000,0.000000,0.000000,0.000000,0.000000,0.066452,0.000000
EpS4,0.000000,0.062549,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.005466,0.182172,0.000000,0.000000,0.000000
EpS4,0.000000,0.077393,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000


In [6]:
pred.to_csv('result/gnnimpute.csv')