Skip to content

Commit

Permalink
Merge pull request #19 from matt-gardner/fix_split_creation_bug
Browse files Browse the repository at this point in the history
Fixed a bug when creating a split from metadata with negative examples
  • Loading branch information
matt-gardner committed Jan 20, 2017
2 parents f72d25d + d2d2025 commit d54f83d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 26 deletions.
14 changes: 8 additions & 6 deletions src/main/scala/edu/cmu/ml/rtw/pra/data/SplitCreator.scala
Expand Up @@ -52,7 +52,7 @@ class SplitCreator(
createSplitFromMetadata()
}
case "add negatives to split" => {
addNegativeToSplit()
addNegativesToSplit()
}
case other => throw new IllegalStateException("Unrecognized split type!")
}
Expand Down Expand Up @@ -90,7 +90,9 @@ class SplitCreator(
val data = if (negativeExampleSelector == null) {
all_instances
} else {
addNegativeExamples(all_instances, Seq(), relation, domains.toMap, ranges.toMap, graph.nodeDict)
all_instances.merge(
selectNegativeExamples(all_instances, Seq(), relation, domains.toMap, ranges.toMap, graph.nodeDict)
)
}
outputter.info("Splitting data")
val (training, testing) = data.splitData(percentTraining)
Expand All @@ -102,7 +104,7 @@ class SplitCreator(
fileUtil.deleteFile(inProgressFile)
}

def addNegativeToSplit() {
def addNegativesToSplit() {
outputter.info(s"Creating split at $splitDir")
fileUtil.mkdirOrDie(splitDir)
fileUtil.touchFile(inProgressFile)
Expand Down Expand Up @@ -142,7 +144,7 @@ class SplitCreator(

if (fileUtil.fileExists(training_file)) {
val negative_training_instances = if (generateNegativesFor.contains("training")) {
addNegativeExamples(training_data, Seq(), relation, domains.toMap, ranges.toMap, graph.nodeDict)
selectNegativeExamples(training_data, Seq(), relation, domains.toMap, ranges.toMap, graph.nodeDict)
} else {
new Dataset[NodePairInstance](Seq())
}
Expand All @@ -162,7 +164,7 @@ class SplitCreator(
.filterNot(i => i.source == -1 || i.target == -1)
val testing_data = new Dataset[NodePairInstance](filtered_testing_instances)
val negative_testing_instances = if (generateNegativesFor.contains("testing")) {
addNegativeExamples(testing_data, training_data.instances, relation, domains.toMap, ranges.toMap, graph.nodeDict)
selectNegativeExamples(testing_data, training_data.instances, relation, domains.toMap, ranges.toMap, graph.nodeDict)
} else {
new Dataset[NodePairInstance](Seq())
}
Expand All @@ -178,7 +180,7 @@ class SplitCreator(
fileUtil.deleteFile(inProgressFile)
}

def addNegativeExamples(
def selectNegativeExamples(
data: Dataset[NodePairInstance],
other_positive_instances: Seq[NodePairInstance],
relation: String,
Expand Down
37 changes: 17 additions & 20 deletions src/test/scala/edu/cmu/ml/rtw/pra/data/SplitCreatorSpec.scala
Expand Up @@ -40,22 +40,15 @@ class SplitCreatorSpec extends FlatSpecLike with Matchers {
fakeFileUtil.addFileToBeRead("/relation_metadata/nell/domains.tsv", "rel/1\tc1\n")
fakeFileUtil.addFileToBeRead("/relation_metadata/nell/ranges.tsv", "rel/1\tc2\n")
fakeFileUtil.addFileToBeRead("/relation_metadata/nell/relations/rel_1", "node1\tnode2\n")
fakeFileUtil.addFileToBeRead("/graphs/nell/node_dict.tsv", "1\tnode1\n2\tnode2\n")
fakeFileUtil.addFileToBeRead("/graphs/nell/node_dict.tsv", "1\tnode1\n2\tnode2\n3\tnode3\n")
fakeFileUtil.addFileToBeRead("/graphs/nell/edge_dict.tsv", "1\trel/1\n")
fakeFileUtil.onlyAllowExpectedFiles()
val splitCreator = new SplitCreator(params, praBase, splitDir, outputter, fakeFileUtil)
val graph = new GraphOnDisk("/graphs/nell/", outputter, fakeFileUtil)

val positiveInstances = Seq(new NodePairInstance(1, 1, true, graph), new NodePairInstance(1, 2, true, graph))
val negativeInstances = Seq(new NodePairInstance(2, 2, false, graph), new NodePairInstance(1, 2, false, graph))
val goodData = new Dataset[NodePairInstance](positiveInstances ++ negativeInstances) {
override def splitData(percent: Double) = {
println("Splitting fake data")
val training = new Dataset[NodePairInstance](positiveInstances.take(1) ++ negativeInstances.take(1))
val testing = new Dataset[NodePairInstance](positiveInstances.drop(1) ++ negativeInstances.drop(1))
(training, testing)
}
}
val positiveInstances = Seq(new NodePairInstance(3, 3, true, graph), new NodePairInstance(3, 2, true, graph))
val negativeInstances = Seq(new NodePairInstance(1, 3, false, graph), new NodePairInstance(3, 1, false, graph))
val goodData = new Dataset[NodePairInstance](positiveInstances ++ negativeInstances)
val badData = new Dataset[NodePairInstance](Seq())

"createNegativeExampleSelector" should "return null with no input" in {
Expand All @@ -70,32 +63,32 @@ class SplitCreatorSpec extends FlatSpecLike with Matchers {
graph.numShards should be(1)
}

"addNegativeExampels" should "read domains and ranges correctly" in {
"selectNegativeExamples" should "read domains and ranges correctly" in {
val relation = "rel1"
val domains = Map(relation -> "c1")
val ranges = Map(relation -> "c2")
var creator = splitCreatorWithFakeNegativeSelector(Some(Set(1)), Some(Set(2)))
creator.addNegativeExamples(goodData, Seq(), relation, domains, ranges, graph.nodeDict) should be(goodData)
creator.selectNegativeExamples(goodData, Seq(), relation, domains, ranges, graph.nodeDict) should be(goodData)
// Adding a test with the wrong sources and targets, just to be sure the test is really // working.
creator = splitCreatorWithFakeNegativeSelector(Some(Set(2)), Some(Set(1)))
creator.addNegativeExamples(goodData, Seq(), relation, domains, ranges, graph.nodeDict) should be(badData)
creator.selectNegativeExamples(goodData, Seq(), relation, domains, ranges, graph.nodeDict) should be(badData)
}

it should "handle null domains and ranges" in {
val creator = splitCreatorWithFakeNegativeSelector(None, None)
creator.addNegativeExamples(goodData, Seq(), "rel1", null, null, graph.nodeDict) should be(goodData)
creator.selectNegativeExamples(goodData, Seq(), "rel1", null, null, graph.nodeDict) should be(goodData)
}

it should "throw an error if the relation is missing from domain or range" in {
val creator = splitCreatorWithFakeNegativeSelector(None, None)
TestUtil.expectError(classOf[NoSuchElementException], new Function() {
def call() {
creator.addNegativeExamples(goodData, Seq(), "rel1", Map(), null, graph.nodeDict) should be(goodData)
creator.selectNegativeExamples(goodData, Seq(), "rel1", Map(), null, graph.nodeDict) should be(goodData)
}
})
TestUtil.expectError(classOf[NoSuchElementException], new Function() {
def call() {
creator.addNegativeExamples(goodData, Seq(), "rel1", null, Map(), graph.nodeDict) should be(goodData)
creator.selectNegativeExamples(goodData, Seq(), "rel1", null, Map(), graph.nodeDict) should be(goodData)
}
})
}
Expand All @@ -106,9 +99,13 @@ class SplitCreatorSpec extends FlatSpecLike with Matchers {
fakeFileUtil.addExpectedFileWritten("/splits/split_name/in_progress", "")
fakeFileUtil.addExpectedFileWritten("/splits/split_name/params.json", pretty(render(params)))
fakeFileUtil.addExpectedFileWritten("/splits/split_name/relations_to_run.tsv", "rel/1\n")
val trainingFile = "node1\tnode1\t1\nnode2\tnode2\t-1\n"
fakeFileUtil.addExpectedFileWritten("/splits/split_name/rel_1/training.tsv", trainingFile)
val testingFile = "node1\tnode2\t1\nnode1\tnode2\t-1\n"
// Because percentTraining is low, there will be no examples that actually end up in the
// training set. This is actually easier for us to test. And these nodes look funny because
// of the fake negative example selector. We have the one positive instance from the relation
// file above, plus all of the instances from `goodData`, positive and negative.
fakeFileUtil.addExpectedFileWritten("/splits/split_name/rel_1/training.tsv", "")
val testingFile = "node1\tnode2\t1\nnode3\tnode3\t1\nnode3\tnode2\t1\n" +
"node1\tnode3\t-1\nnode3\tnode1\t-1\n"
fakeFileUtil.addExpectedFileWritten("/splits/split_name/rel_1/testing.tsv", testingFile)
var creator = splitCreatorWithFakeNegativeSelector(Some(Set(1)), Some(Set(2)))
creator.createSplit()
Expand Down

0 comments on commit d54f83d

Please sign in to comment.