In [14]:
import os
import time
import numpy as np
import pandas as pd
import pickle
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import shutil
import zipfile
import scipy


data_folder="../DATA/holo4k preprocessed"
batch_size=32
p2rank_test=list(pd.read_csv("../DATA/p2rank_test.csv")["pdb_id"])

In [15]:
def get_batch(a,mode,emb_name,graphs):
    b=a+batch_size
    print(f"{mode} batch : {a} - {b}")
    with zipfile.ZipFile(data_folder+f'/{emb_name}.zip') as thezip:
        for k,graph in enumerate(graphs[a:b]):
            g=pickle.load(thezip.open(graph,"r"))
            if k==0:
                batch_graph=g
            else:
                batch_graph=dgl.batch([batch_graph,g])
    
    return batch_graph

def batch_data(mode,emb_name):
    with open(f"../DATA/holo4k splits/holo4k_{mode}.txt") as f:
        graphs=[f"{emb_name}_{x.strip()}.p" for x in f.readlines() if x[:4] in p2rank_test]
    with zipfile.ZipFile(f'../DATA/holo4k batched/{emb_name}_{mode}_p2rank_batched_{batch_size}.zip',"w") as thezip:
        a=0
        while a<len(graphs):
            # load batch
            graph=get_batch(a,mode,emb_name,graphs)
            filename=f"{emb_name}_{mode}_{a}_{a+batch_size}.p"
            pickle.dump(graph,open(filename,"wb"))
            thezip.write(filename,filename,compress_type=zipfile.ZIP_BZIP2)
            os.remove(filename)
            a+=batch_size
            

#### Make Folds for CV

In [112]:
emb_name="onehot"
data_folder="../DATA/holo4k preprocessed/"
with zipfile.ZipFile(data_folder+f'/{emb_name}.zip') as thezip:
    all_files=[x.replace(emb_name+"_","") for x in thezip.namelist()]
    all_unps=np.unique([x.replace(emb_name,"")[8:][:-2] for x in thezip.namelist()])


In [98]:
from sklearn.model_selection import train_test_split,KFold
model_unps,test_unps=train_test_split(all_unps,test_size=0.1)

In [99]:
n_folds=5
splitter=KFold(n_splits=n_folds,shuffle=True)
train_folds=[]
val_folds=[]
for train_idx,val_idx in splitter.split(model_unps):
    train_folds.append(model_unps[train_idx])
    val_folds.append(model_unps[val_idx])

In [108]:
for k in range(n_folds):
    with open(f"../DATA/holo4k splits/holo4k_fold_{k}_train.txt","w") as f:
        for file in all_files:
            for unp_id in train_folds[k]:
                if unp_id in file:
                    f.write(file)
                    f.write("\n")
    
    with open(f"../DATA/holo4k splits/holo4k_fold_{k}_val.txt","w") as f:
        for file in all_files:
            for unp_id in val_folds[k]:
                if unp_id in file:
                    f.write(file)
                    f.write("\n")
    
    print(set(train_folds[k]).intersection(set(val_folds[k])))

with open(f"../DATA/holo4k splits/holo4k_test.txt","w") as f:
    for file in all_files:
        for unp_id in test_unps:
            if unp_id in file:
                f.write(file)
                f.write("\n")



set()
set()
set()
set()
set()


In [16]:
for EMB_NAME in ["xlnet"]:#["onehot","bert","xlnet"]:
    print(EMB_NAME)
    batch_data("test",EMB_NAME)
    # for k in range(5):
    #     print(f"FOLD {k} TRAIN _______________________________________")
    #     batch_data(f"fold_{k}_train",EMB_NAME)
    #     batch_data(f"fold_{k}_val",EMB_NAME)

xlnet
test batch : 0 - 32
test batch : 32 - 64
test batch : 64 - 96
test batch : 96 - 128
test batch : 128 - 160
test batch : 160 - 192
test batch : 192 - 224
test batch : 224 - 256
test batch : 256 - 288
test batch : 288 - 320
test batch : 320 - 352
test batch : 352 - 384
test batch : 384 - 416
test batch : 416 - 448
test batch : 448 - 480
test batch : 480 - 512
test batch : 512 - 544
test batch : 544 - 576
test batch : 576 - 608


In [5]:
with zipfile.ZipFile("../DATA/holo4k batched/onehot_model_dev_no_graph_batched_32.zip","w") as thezip:
    with zipfile.ZipFile("../DATA/holo4k batched/onehot_model_dev_batched_32.zip") as source:
        for filename in source.namelist():
            print(filename)
            graph=pickle.load(source.open(filename,"r"))
            n_nodes=graph.num_nodes()
            new_graph=dgl.from_scipy(scipy.sparse.csr_matrix(np.eye(n_nodes)))
            new_graph.ndata["label"]=graph.ndata["label"]
            new_graph.ndata["feat"]=graph.ndata["feat"]
            pickle.dump(new_graph,open(filename,"wb"))
            thezip.write(filename,filename,compress_type=zipfile.ZIP_BZIP2)
            os.remove(filename)

onehot_model_dev_0_32.p
onehot_model_dev_32_64.p
onehot_model_dev_64_96.p
onehot_model_dev_96_128.p
onehot_model_dev_128_160.p
onehot_model_dev_160_192.p
onehot_model_dev_192_224.p
onehot_model_dev_224_256.p
onehot_model_dev_256_288.p
onehot_model_dev_288_320.p
onehot_model_dev_320_352.p
onehot_model_dev_352_384.p
onehot_model_dev_384_416.p
onehot_model_dev_416_448.p
onehot_model_dev_448_480.p
onehot_model_dev_480_512.p
onehot_model_dev_512_544.p
onehot_model_dev_544_576.p
onehot_model_dev_576_608.p
onehot_model_dev_608_640.p
onehot_model_dev_640_672.p
onehot_model_dev_672_704.p
onehot_model_dev_704_736.p
onehot_model_dev_736_768.p
onehot_model_dev_768_800.p
onehot_model_dev_800_832.p
onehot_model_dev_832_864.p
onehot_model_dev_864_896.p
onehot_model_dev_896_928.p
onehot_model_dev_928_960.p
onehot_model_dev_960_992.p
onehot_model_dev_992_1024.p
onehot_model_dev_1024_1056.p
onehot_model_dev_1056_1088.p
onehot_model_dev_1088_1120.p
onehot_model_dev_1120_1152.p
onehot_model_dev_1152_1184.

KeyboardInterrupt: 

In [17]:
TN=220815
FP=25901   
FN=12008   
TP=11077

print("precision",TP/(TP+FP))
print("Sensitivity",TP/(TP+FN))
print("Negative predictive value",TN/(TN+FN))
print("Specificity",TN/(TN+FP))

precision 0.2995564930499216
Sensitivity 0.47983539094650207
Negative predictive value 0.9484243395197209
Specificity 0.8950169425574345
