Permalink
Browse files

Move coGroup to its own builder.

  • Loading branch information...
1 parent cbf393b commit 722adfae65b35708aa051bd976a3fe6c3a2d75c0 @azymnis committed Apr 5, 2012
@@ -17,7 +17,6 @@ package com.twitter.scalding
import cascading.pipe._
import cascading.pipe.assembly._
-import cascading.pipe.joiner._
import cascading.operation._
import cascading.operation.aggregator._
import cascading.operation.filter._
@@ -46,11 +45,9 @@ class GroupBuilder(val groupFields : Fields) extends FieldConversions
/**
* This is the description of this Grouping in terms of a sequence of Every operations
*/
- private var evs : List[Pipe => Every] = Nil
- private var isReversed : Boolean = false
- private var sortBy : Option[Fields] = None
- private var coGroups : List[(Fields, Pipe)] = Nil
- private var joiner : Option[Joiner] = None
+ protected var evs : List[Pipe => Every] = Nil
+ protected var isReversed : Boolean = false
+ protected var sortBy : Option[Fields] = None
/*
* maxMF is the maximum index of a "middle field" allocated for mapReduceMap operations
*/
@@ -65,17 +62,6 @@ class GroupBuilder(val groupFields : Fields) extends FieldConversions
//Put any pure reduce functions into the below object
import CommonReduceFunctions._
- // Joins (cogroups) with pipe p on fields f.
- // Make sure that pipe p is smaller than the left side pipe, otherwise this
- // might take a while.
- def coGroup(f : Fields, p : Pipe) = {
- coGroups ::= (f, RichPipe.assignName(p))
- // aggregateBy replaces the grouping operation
- // but we actually want to do a coGroup
- reds = None
- this
- }
-
private def tryAggregateBy(ab : AggregateBy, ev : Pipe => Every) : Boolean = {
// Concat if there if not none
reds = reds.map(rl => ab::rl)
@@ -98,7 +84,7 @@ class GroupBuilder(val groupFields : Fields) extends FieldConversions
this
}
- private def overrideReducers(p : Pipe) : Pipe = {
+ protected def overrideReducers(p : Pipe) : Pipe = {
numReducers.map { r => RichPipe.setReducers(p, r) }.getOrElse(p)
}
@@ -424,9 +410,6 @@ class GroupBuilder(val groupFields : Fields) extends FieldConversions
}
def groupMode : GroupMode = {
- if(!coGroups.isEmpty) {
- return CoGroupMode
- }
return reds match {
case None => GroupByMode
case Some(Nil) => IdentityMode
@@ -463,31 +446,9 @@ class GroupBuilder(val groupFields : Fields) extends FieldConversions
overrideReducers(ag.getGroupBy())
ag
}
- case CoGroupMode => {
- assert(!sortBy.isDefined, "cannot use a sortBy when doing a coGroup")
- // overrideReducers(pipe)
- val fields = (groupFields :: coGroups.map{ _._1 }).toArray
- val pipes = (pipe :: coGroups.map{ _._2 }).toArray
- val cg : Pipe = new CoGroup(pipes, fields, null, joiner.getOrElse(new InnerJoin))
- overrideReducers(cg)
- evs.foldRight(cg)( (op : Pipe => Every, p) => op(p) )
- }
}
}
- def joiner(j : Joiner) : GroupBuilder = {
- this.joiner = this.joiner match {
- case None => Some(j)
- case Some(otherJ) => if ( otherJ != j ) {
- throw new IllegalArgumentException("trying to set joiner to: " +
- j + " while already set to: " + otherJ)
- } else {
- Some(otherJ)
- }
- }
- this
- }
-
//This invalidates aggregateBy!
def sortBy(f : Fields) : GroupBuilder = {
reds = None
@@ -603,8 +564,7 @@ object CommonReduceFunctions extends java.io.Serializable {
}
}
-sealed abstract class GroupMode
-case object AggregateByMode extends GroupMode
-case object GroupByMode extends GroupMode
-case object CoGroupMode extends GroupMode
-case object IdentityMode extends GroupMode
+sealed private[scalding] abstract class GroupMode
+private[scalding] case object AggregateByMode extends GroupMode
+private[scalding] case object GroupByMode extends GroupMode
+private[scalding] case object IdentityMode extends GroupMode
@@ -19,7 +19,7 @@ import cascading.tap._
import cascading.scheme._
import cascading.pipe._
import cascading.pipe.assembly._
-import cascading.pipe.joiner._
+import cascading.pipe.joiner.{InnerJoin => CInnerJoin, LeftJoin => CLeftJoin}
import cascading.flow._
import cascading.operation._
import cascading.operation.aggregator._
@@ -38,6 +38,10 @@ trait JoinAlgorithms {
def pipe : Pipe
+ def coGroupBy(f : Fields, j : JoinMode = InnerJoinMode)(builder : CoGroupBuilder => GroupBuilder) : Pipe = {
+ builder(new CoGroupBuilder(f, j)).schedule(pipe.getName, pipe)
+ }
+
/*
* WARNING! doing a cross product with even a moderate sized pipe can
* create ENORMOUS output. The use-case here is attaching a constant (e.g.
@@ -92,26 +96,25 @@ trait JoinAlgorithms {
* joinWithSmaller(('other1, 'other2)->('this1, 'this2), pipe, new FancyJoin)
* }
*/
- def joinWithSmaller(fs :(Fields,Fields), that : Pipe, joiner : Joiner = new InnerJoin, reducers : Int = -1) = {
+ def joinWithSmaller(fs :(Fields,Fields), that : Pipe, joiners : (JoinMode, JoinMode) = (InnerJoinMode, InnerJoinMode), reducers : Int = -1) = {
// If we are not doing an inner join, the join fields must be disjoint:
val intersection = asSet(fs._1).intersect(asSet(fs._2))
if (intersection.size == 0) {
// Common case: no intersection in names: just CoGroup, which duplicates the grouping fields:
- assignName(pipe).groupBy(fs._1) {
- _.coGroup(fs._2, that)
- .joiner(joiner)
+ assignName(pipe).coGroupBy(fs._1, joiners._1) {
+ _.coGroup(fs._2, that, joiners._2)
.reducers(reducers)
}
}
- else if (joiner.isInstanceOf[InnerJoin]) {
+ else if (joiners._1 == InnerJoinMode && joiners._2 == InnerJoinMode) {
/*
* Since it is an inner join, we only output if the key is present an equal in both sides.
* For this (common) case, it doesn't matter if we drop one of the matching grouping fields.
* So, we rename the right hand side to temporary names, then discard them after the operation
*/
val (renamedThat, newJoinFields, temp) = renameCollidingFields(that, fs._2, intersection)
- assignName(pipe).groupBy(fs._1) {
- _.coGroup(newJoinFields, renamedThat)
+ assignName(pipe).coGroupBy(fs._1, joiners._1) {
+ _.coGroup(newJoinFields, renamedThat, joiners._2)
.reducers(reducers)
}.discard(temp)
}
@@ -121,17 +124,17 @@ trait JoinAlgorithms {
}
}
- def joinWithLarger(fs : (Fields, Fields), that : Pipe, joiner : Joiner = new InnerJoin, reducers : Int = -1) = {
- that.joinWithSmaller((fs._2, fs._1), pipe, joiner, reducers)
+ def joinWithLarger(fs : (Fields, Fields), that : Pipe, joiners : (JoinMode, JoinMode) = (InnerJoinMode,InnerJoinMode), reducers : Int = -1) = {
+ that.joinWithSmaller((fs._2, fs._1), pipe, joiners, reducers)
}
def leftJoinWithSmaller(fs :(Fields,Fields), that : Pipe, reducers : Int = -1) = {
- joinWithSmaller(fs, that, new LeftJoin, reducers)
+ joinWithSmaller(fs, that, (InnerJoinMode, OuterJoinMode), reducers)
}
def leftJoinWithLarger(fs :(Fields,Fields), that : Pipe, reducers : Int = -1) = {
//We swap the order, and turn left into right:
- that.joinWithSmaller((fs._2, fs._1), pipe, new RightJoin, reducers)
+ that.joinWithSmaller((fs._2, fs._1), pipe, (OuterJoinMode, InnerJoinMode), reducers)
}
/**
@@ -148,18 +151,18 @@ trait JoinAlgorithms {
def joinWithTiny(fs :(Fields,Fields), that : Pipe) = {
val intersection = asSet(fs._1).intersect(asSet(fs._2))
if (intersection.size == 0) {
- new Join(assignName(pipe), fs._1, assignName(that), fs._2, new InnerJoin)
+ new Join(assignName(pipe), fs._1, assignName(that), fs._2, new CInnerJoin)
}
else {
val (renamedThat, newJoinFields, temp) = renameCollidingFields(that, fs._2, intersection)
- (new Join(assignName(pipe), fs._1, assignName(renamedThat), newJoinFields, new InnerJoin))
+ (new Join(assignName(pipe), fs._1, assignName(renamedThat), newJoinFields, new CInnerJoin))
.discard(temp)
}
}
def leftJoinWithTiny(fs :(Fields,Fields), that : Pipe) = {
//Rename these pipes to avoid cascading name conflicts
- new Join(assignName(pipe), fs._1, assignName(that), fs._2, new LeftJoin)
+ new Join(assignName(pipe), fs._1, assignName(that), fs._2, new CLeftJoin)
}
/*
@@ -190,11 +193,11 @@ trait JoinAlgorithms {
*/
def blockJoinWithSmaller(fs : (Fields, Fields),
otherPipe : Pipe, rightReplication : Int = 1, leftReplication : Int = 1,
- joiner : Joiner = new InnerJoin, reducers : Int = -1) : Pipe = {
+ joiners : (JoinMode, JoinMode) = (InnerJoinMode, InnerJoinMode), reducers : Int = -1) : Pipe = {
assert(rightReplication > 0, "Must specify a positive number for the right replication in block join")
assert(leftReplication > 0, "Must specify a positive number for the left replication in block join")
- assertValidJoinMode(joiner, leftReplication, rightReplication)
+ assertValidJoinMode(joiners, leftReplication, rightReplication)
// These are the new dummy fields used in the skew join
val leftFields = new Fields("__LEFT_I__", "__LEFT_J__")
@@ -208,7 +211,7 @@ trait JoinAlgorithms {
val rightJoinFields = Fields.join(fs._2, rightFields)
newLeft
- .joinWithSmaller((leftJoinFields, rightJoinFields), newRight, joiner, reducers)
+ .joinWithSmaller((leftJoinFields, rightJoinFields), newRight, joiners, reducers)
.discard(leftFields)
.discard(rightFields)
}
@@ -225,11 +228,11 @@ trait JoinAlgorithms {
}
}
- private def assertValidJoinMode(joiner : Joiner, left : Int, right : Int) {
- (joiner, left, right) match {
- case (i : InnerJoin, _, _) => true
- case (k : LeftJoin, 1, _) => true
- case (m : RightJoin, _, 1) => true
+ private def assertValidJoinMode(joiners : (JoinMode, JoinMode), left : Int, right : Int) {
+ (joiners, left, right) match {
+ case ((InnerJoinMode, InnerJoinMode), _, _) => true
+ case ((InnerJoinMode, OuterJoinMode), 1, _) => true
+ case ((OuterJoinMode, InnerJoinMode), _, 1) => true
case (j, l, r) =>
throw new InvalidJoinModeException(
"you cannot use joiner " + j + " with left replication " + l + " and right replication " + r
@@ -12,16 +12,16 @@ class InnerProductJob(args : Args) extends Job(args) {
val l = args.getOrElse("left", "1").toInt
val r = args.getOrElse("right", "1").toInt
val j = args.getOrElse("joiner", "i") match {
- case "i" => new InnerJoin
- case "l" => new LeftJoin
- case "r" => new RightJoin
- case "o" => new OuterJoin
+ case "i" => (InnerJoinMode, InnerJoinMode)
+ case "l" => (InnerJoinMode, OuterJoinMode)
+ case "r" => (OuterJoinMode, InnerJoinMode)
+ case "o" => (OuterJoinMode, OuterJoinMode)
}
val in0 = Tsv("input0").read.mapTo((0,1,2) -> ('x1, 'y1, 's1)) { input : (Int, Int, Int) => input }
val in1 = Tsv("input1").read.mapTo((0,1,2) -> ('x2, 'y2, 's2)) { input : (Int, Int, Int) => input }
in0
- .blockJoinWithSmaller('y1 -> 'y2, in1, leftReplication = l, rightReplication = r, joiner = j)
+ .blockJoinWithSmaller('y1 -> 'y2, in1, leftReplication = l, rightReplication = r, joiners = j)
.map(('s1, 's2) -> 'score) { v : (Int, Int) =>
v._1 * v._2
}
@@ -9,11 +9,10 @@ class StarJoinJob(args : Args) extends Job(args) {
val in2 = Tsv("input2").read.mapTo((0,1) -> ('x2, 'c)) { input : (Int, Int) => input }
val in3 = Tsv("input3").read.mapTo((0,1) -> ('x3, 'd)) { input : (Int, Int) => input }
- in0.groupBy('x0) {
- _.coGroup('x1, in1)
- .coGroup('x2, in2)
- .coGroup('x3, in3)
- .joiner(new MixedJoin(Array(true, false, false, false)))
+ in0.coGroupBy('x0) {
+ _.coGroup('x1, in1, OuterJoinMode)
+ .coGroup('x2, in2, OuterJoinMode)
+ .coGroup('x3, in3, OuterJoinMode)
}
.project('x0, 'a, 'b, 'c, 'd)
.write(Tsv("output"))

0 comments on commit 722adfa

Please sign in to comment.