Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Add coGroup as a method of GroupBuilder.

  • Loading branch information...
commit cbf393b18639b16003f6555d8b4b411bef880499 1 parent 76e1e49
@azymnis authored
View
82 src/main/scala/com/twitter/scalding/GroupBuilder.scala
@@ -15,10 +15,9 @@ limitations under the License.
*/
package com.twitter.scalding
-import cascading.pipe.Pipe
-import cascading.pipe.Every
-import cascading.pipe.GroupBy
+import cascading.pipe._
import cascading.pipe.assembly._
+import cascading.pipe.joiner._
import cascading.operation._
import cascading.operation.aggregator._
import cascading.operation.filter._
@@ -29,6 +28,8 @@ import scala.collection.JavaConverters._
import scala.annotation.tailrec
import scala.math.Ordering
+import java.lang.IllegalArgumentException
+
// This controls the sequence of reductions that happen inside a
// particular grouping operation. Not all elements can be combined,
// for instance, a scanLeft/foldLeft generally requires a sorting
@@ -48,6 +49,8 @@ class GroupBuilder(val groupFields : Fields) extends FieldConversions
private var evs : List[Pipe => Every] = Nil
private var isReversed : Boolean = false
private var sortBy : Option[Fields] = None
+ private var coGroups : List[(Fields, Pipe)] = Nil
+ private var joiner : Option[Joiner] = None
/*
* maxMF is the maximum index of a "middle field" allocated for mapReduceMap operations
*/
@@ -62,6 +65,17 @@ class GroupBuilder(val groupFields : Fields) extends FieldConversions
//Put any pure reduce functions into the below object
import CommonReduceFunctions._
+ // Joins (cogroups) with pipe p on fields f.
+ // Make sure that pipe p is smaller than the left side pipe, otherwise this
+ // might take a while.
+ def coGroup(f : Fields, p : Pipe) = {
+ coGroups ::= (f, RichPipe.assignName(p))
+ // aggregateBy replaces the grouping operation
+ // but we actually want to do a coGroup
+ reds = None
+ this
+ }
+
private def tryAggregateBy(ab : AggregateBy, ev : Pipe => Every) : Boolean = {
// Concat if there if not none
reds = reds.map(rl => ab::rl)
@@ -409,14 +423,25 @@ class GroupBuilder(val groupFields : Fields) extends FieldConversions
every(pipe => new Every(pipe, inFields, b))
}
- def schedule(name : String, allpipes : Pipe*) : Pipe = {
- val mpipes : Array[Pipe] = allpipes.toArray
- reds match {
- case None => {
- //We cannot aggregate, so group:
+ def groupMode : GroupMode = {
+ if(!coGroups.isEmpty) {
+ return CoGroupMode
+ }
+ return reds match {
+ case None => GroupByMode
+ case Some(Nil) => IdentityMode
+ case Some(redList) => AggregateByMode
+ }
+ }
+
+ def schedule(name : String, pipe : Pipe) : Pipe = {
+
+ groupMode match {
+ //In this case we cannot aggregate, so group:
+ case GroupByMode => {
val startPipe : Pipe = sortBy match {
- case None => new GroupBy(name, mpipes, groupFields)
- case Some(sf) => new GroupBy(name, mpipes, groupFields, sf, isReversed)
+ case None => new GroupBy(name, pipe, groupFields)
+ case Some(sf) => new GroupBy(name, pipe, groupFields, sf, isReversed)
}
overrideReducers(startPipe)
@@ -424,20 +449,43 @@ class GroupBuilder(val groupFields : Fields) extends FieldConversions
evs.foldRight(startPipe)( (op : Pipe => Every, p) => op(p) )
}
//This is the case where the group function is identity: { g => g }
- case Some(Nil) => {
- val gb = new GroupBy(name, mpipes, groupFields)
+ case IdentityMode => {
+ val gb = new GroupBy(name, pipe, groupFields)
overrideReducers(gb)
gb
}
//There is some non-empty AggregateBy to do:
- case Some(redlist) => {
+ case AggregateByMode => {
+ val redlist = reds.get
val THRESHOLD = 100000 //tune this, default is 10k
- val ag = new AggregateBy(name, mpipes, groupFields,
+ val ag = new AggregateBy(name, pipe, groupFields,
THRESHOLD, redlist.reverse.toArray : _*)
overrideReducers(ag.getGroupBy())
ag
}
+ case CoGroupMode => {
+ assert(!sortBy.isDefined, "cannot use a sortBy when doing a coGroup")
+ // overrideReducers(pipe)
+ val fields = (groupFields :: coGroups.map{ _._1 }).toArray
+ val pipes = (pipe :: coGroups.map{ _._2 }).toArray
+ val cg : Pipe = new CoGroup(pipes, fields, null, joiner.getOrElse(new InnerJoin))
+ overrideReducers(cg)
+ evs.foldRight(cg)( (op : Pipe => Every, p) => op(p) )
+ }
+ }
+ }
+
+ def joiner(j : Joiner) : GroupBuilder = {
+ this.joiner = this.joiner match {
+ case None => Some(j)
+ case Some(otherJ) => if ( otherJ != j ) {
+ throw new IllegalArgumentException("trying to set joiner to: " +
+ j + " while already set to: " + otherJ)
+ } else {
+ Some(otherJ)
+ }
}
+ this
}
//This invalidates aggregateBy!
@@ -554,3 +602,9 @@ object CommonReduceFunctions extends java.io.Serializable {
mergeSortR(Nil, v1, v2, k).reverse
}
}
+
+sealed abstract class GroupMode
+case object AggregateByMode extends GroupMode
+case object GroupByMode extends GroupMode
+case object CoGroupMode extends GroupMode
+case object IdentityMode extends GroupMode
View
13 src/main/scala/com/twitter/scalding/JoinAlgorithms.scala
@@ -97,7 +97,11 @@ trait JoinAlgorithms {
val intersection = asSet(fs._1).intersect(asSet(fs._2))
if (intersection.size == 0) {
// Common case: no intersection in names: just CoGroup, which duplicates the grouping fields:
- setReducers(new CoGroup(assignName(pipe), fs._1, assignName(that), fs._2, joiner), reducers)
+ assignName(pipe).groupBy(fs._1) {
+ _.coGroup(fs._2, that)
+ .joiner(joiner)
+ .reducers(reducers)
+ }
}
else if (joiner.isInstanceOf[InnerJoin]) {
/*
@@ -106,9 +110,10 @@ trait JoinAlgorithms {
* So, we rename the right hand side to temporary names, then discard them after the operation
*/
val (renamedThat, newJoinFields, temp) = renameCollidingFields(that, fs._2, intersection)
- setReducers(new CoGroup(assignName(pipe), fs._1,
- assignName(renamedThat), newJoinFields, joiner), reducers)
- .discard(temp)
+ assignName(pipe).groupBy(fs._1) {
+ _.coGroup(newJoinFields, renamedThat)
+ .reducers(reducers)
+ }.discard(temp)
}
else {
throw new IllegalArgumentException("join keys must be disjoint unless you are doing an InnerJoin. Found: " +
View
40 src/test/scala/com/twitter/scalding/CoGroupTest.scala
@@ -0,0 +1,40 @@
+package com.twitter.scalding
+
+import cascading.pipe.joiner._
+import org.specs._
+
+class StarJoinJob(args : Args) extends Job(args) {
+ val in0 = Tsv("input0").read.mapTo((0,1) -> ('x0, 'a)) { input : (Int, Int) => input }
+ val in1 = Tsv("input1").read.mapTo((0,1) -> ('x1, 'b)) { input : (Int, Int) => input }
+ val in2 = Tsv("input2").read.mapTo((0,1) -> ('x2, 'c)) { input : (Int, Int) => input }
+ val in3 = Tsv("input3").read.mapTo((0,1) -> ('x3, 'd)) { input : (Int, Int) => input }
+
+ in0.groupBy('x0) {
+ _.coGroup('x1, in1)
+ .coGroup('x2, in2)
+ .coGroup('x3, in3)
+ .joiner(new MixedJoin(Array(true, false, false, false)))
+ }
+ .project('x0, 'a, 'b, 'c, 'd)
+ .write(Tsv("output"))
+}
+
+class CoGroupTest extends Specification with TupleConversions {
+ noDetailedDiffs()
+ "A StarJoinJob" should {
+ JobTest("com.twitter.scalding.StarJoinJob")
+ .source(Tsv("input0"), List((0, 1), (1, 1), (2, 1), (3, 2)))
+ .source(Tsv("input1"), List((0, 1), (2, 5), (3, 2)))
+ .source(Tsv("input2"), List((1, 1), (2, 8)))
+ .source(Tsv("input3"), List((0, 9), (2, 11)))
+ .sink[(Int, Int, Int, Int, Int)](Tsv("output")) { outputBuf =>
+ "be able to work" in {
+ val out = outputBuf.toSet
+ val expected = Set((0,1,1,0,9), (1,1,0,1,0), (2,1,5,8,11), (3,2,2,0,0))
+ out must_== expected
+ }
+ }
+ .run
+ .finish
+ }
+}
Please sign in to comment.
Something went wrong with that request. Please try again.