In [17]:
import csv
import numpy as np
import itertools

In [18]:
# Read files of Disease/Drug node IDs and train/test edges
def readFiles():
    
    # Get All Disease/Drug Node IDs
    allDiseases = set([])
    allDrugs = set([])
    
    nodeIDsFile = "disease-chemical.tsv"
    
    index = 0
    with open(nodeIDsFile) as new:
        for line in csv.reader(new, delimiter="\t"):
            # Ignore header
            if index == 0:
                index += 1
                continue
            
            allDiseases.add(int(line[0]))
            allDrugs.add(int(line[1]))
            
            index += 1
            
    # Get Train/Test Edges and Diseases/Drugs Nodes
    
    # Only valid if in current train/test graph since some diseases/drugs in main file may have been filtered due k-core
    validDiseases = set([])
    validDrugs = set([])
    
    trainEdges = set([])
    deletedEdges = set([])
    
    trainEdgesFile = "train_edges.txt"
    deletedEdgesFile = "deleted_edges.txt"
    
    # Read train edges file
    with open(trainEdgesFile) as new:
        for line in csv.reader(new, delimiter="\t"):
            node1 = int(line[0])
            node2 = int(line[1])
            
            if node1 in allDiseases:
                validDiseases.add(node1)
                validDrugs.add(node2)
                trainEdges.add((node1, node2)) # will always be (disease, drug)
            else:
                validDrugs.add(node1)
                validDiseases.add(node2)
                trainEdges.add((node2, node1))
    
    # Read deleted edges file
    with open(deletedEdgesFile) as new:
        for line in csv.reader(new, delimiter="\t"):
            node1 = int(line[0])
            node2 = int(line[1])
            
            if node1 in allDiseases:
                validDiseases.add(node1)
                validDrugs.add(node2)
                deletedEdges.add((node1, node2))
            else:
                validDrugs.add(node1)
                validDiseases.add(node2)
                deletedEdges.add((node2, node1))
    
    return validDiseases, validDrugs, trainEdges, deletedEdges

In [25]:
# Generate train/test sets for classifier including both positive and negative examples
def generateTrainTestSets():
    validDiseases, validDrugs, trainEdges, deletedEdges = readFiles()
    numTrainEdges = len(trainEdges)
    numDeletedEdges = len(deletedEdges)
    print len(validDiseases), len(validDrugs), len(validDiseases) + len(validDrugs), numTrainEdges, numDeletedEdges
    
    # Generate random negative examples
    allPairs = set(itertools.product(validDiseases, validDrugs))
    allNegExamples = allPairs.difference(trainEdges)
    allNegExamples = list(allNegExamples.difference(deletedEdges))
    
    numAllNegExamples = len(allNegExamples)
    negIndices = np.random.choice(numAllNegExamples, numTrainEdges + numDeletedEdges, replace=False)
    negExamples = [allNegExamples[i] for i in negIndices]
    print len(negExamples)
    
    # Generate train/test sets
    train = [] # list of (disease, drug, 1 or 0 indicating a link between the nodes)
    test = []
    
    for disease, drug in trainEdges:
        train.append((disease, drug, 1))
    for disease, drug in negExamples[:numTrainEdges]:
        train.append((disease, drug, 0))
    
    for disease, drug in deletedEdges:
        test.append((disease, drug, 1))
    for disease, drug in negExamples[numTrainEdges:]:
        test.append((disease, drug, 0))
        
    return train, test

In [26]:
# Write train/test sets to output files
def writeTrainTestOutput():
    train, test = generateTrainTestSets()
    print len(train), len(test)
    
    trainOutputName = "train.tsv"
    with open(trainOutputName, 'w') as output:
        for disease, drug, indicator in train:
            output.write(str(disease) + "\t" + str(drug) + "\t" + str(indicator) + "\n")
    output.close()
    
    testOutputName = "test.tsv"
    with open(testOutputName, 'w') as output:
        for disease, drug, indicator in test:
            output.write(str(disease) + "\t" + str(drug) + "\t" + str(indicator) + "\n")
    output.close()

In [27]:
writeTrainTestOutput()

4766 1562 6328 463253 2000
465253
926506 4000
