In [1]:
import os
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

### regular split

In [2]:
input_dir = "../../result/input_perturb_go/"

In [3]:
X = pd.read_csv(input_dir+"X", sep="\t", index_col=0)
Y = pd.read_csv(input_dir+"Y", sep="\t", index_col=0)

In [4]:
# 0.1 test size
X2, X_test, Y2, Y_test = train_test_split(X, Y, test_size=0.1, random_state=123)
# 0.9*0.222=0.2 valid size
# 0.9*0.777=0.7 train size
X_train, X_valid, Y_train, Y_valid = train_test_split(X2, Y2, test_size=0.2222, random_state=100)

In [5]:
X_train.to_csv(input_dir+"X_train", sep="\t")
X_valid.to_csv(input_dir+"X_valid", sep="\t")
X_test.to_csv(input_dir+"X_test", sep="\t")
Y_train.to_csv(input_dir+"Y_train", sep="\t")
Y_valid.to_csv(input_dir+"Y_valid", sep="\t")
Y_test.to_csv(input_dir+"Y_test", sep="\t")

In [6]:
genes = list(X_train.index) + list(X_valid.index) + list(X_test.index)
pd.DataFrame(genes).to_csv(input_dir+"genes_train_valid_test", index=False, header=False)

### stratified by existence in DAGMA result
repeat with 10 different random seeds

In [7]:
input_dir = "../../result/input_perturb_go/"
network_dir = "../../result/network_perturb_go/"

In [8]:
X = pd.read_csv(input_dir+"X", sep="\t", index_col=0)
Y = pd.read_csv(input_dir+"Y", sep="\t", index_col=0)

In [9]:
dag = pd.read_csv(network_dir+"DAGMA_thresholdAdaptive.tsv", sep="\t", header=None)
id2genes = pd.read_csv(network_dir+"valid_genes", sep="\t").set_index("ID")['genes'].to_dict()
dag[0] = dag[0].map(id2genes)
dag[1] = dag[1].map(id2genes)
dag_genes = list(set.union(set(dag[0]), set(dag[1])))

In [10]:
for rs in range(10):
    os.makedirs("%s/%d/" % (input_dir, rs), exist_ok=True)
    # 0.1 test size
    group = np.zeros(X.shape[0])
    group[X.index.isin(dag_genes)] = 1
    X2, X_test, Y2, Y_test = train_test_split(X, Y, test_size=0.1, random_state=123+rs, stratify=group)
    # 0.9*0.222=0.2 valid size
    # 0.9*0.777=0.7 train size
    group = np.zeros(X2.shape[0])
    group[X2.index.isin(dag_genes)] = 1
    X_train, X_valid, Y_train, Y_valid = train_test_split(X2, Y2, test_size=0.2222, random_state=123+rs, stratify=group)

    X_train.to_csv("%s/%d/X_train_stratified" % (input_dir, rs), sep="\t")
    X_valid.to_csv("%s/%d/X_valid_stratified" % (input_dir, rs), sep="\t")
    X_test.to_csv("%s/%d/X_test_stratified" % (input_dir, rs), sep="\t")
    Y_train.to_csv("%s/%d/Y_train_stratified" % (input_dir, rs), sep="\t")
    Y_valid.to_csv("%s/%d/Y_valid_stratified" % (input_dir, rs), sep="\t")
    Y_test.to_csv("%s/%d/Y_test_stratified" % (input_dir, rs), sep="\t")

    genes = list(X_train.index) + list(X_valid.index) + list(X_test.index)
    pd.DataFrame(genes).to_csv("%s/%d/genes_train_valid_test_stratified" % (input_dir, rs), index=False, header=False)