-
Notifications
You must be signed in to change notification settings - Fork 283
/
SyncableDataFrame.scala
90 lines (78 loc) · 3.25 KB
/
SyncableDataFrame.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package edu.mit.csail.db.ml.modeldb.client
import java.util.Random
import edu.mit.csail.db.ml.modeldb.client.event.RandomSplitEvent
import org.apache.spark.sql.DataFrame
/**
* This trait defines an implicit class that
* augments a DataFrame with the randomSplitSync functions
* that log a RandomSplitEvent to the ModelDB.
*/
trait SyncableDataFrame {
implicit class DataFrameSync(m: DataFrame) {
/**
* Randomly split a DataFrame into pieces (see DataFrame.randomSplit).
* This function will generate a random seed for you.
* @param weights - The weights used for splitting.
* @param mdbs - The ModelDB Syncer.
* @return The pieces of the DataFrame.
*/
def randomSplitSync(weights: Array[Double])(implicit mdbs: Option[ModelDbSyncer]): Array[DataFrame] =
randomSplitSync(weights, new Random().nextLong)
/**
* Randomly split a DataFrame into pieces (see DataFrame.split).
* @param weights - The weights used for splitting.
* @param seed - The seed to use for splitting.
* @param mdbs - The ModelDB Syncer.
* @return The pieces of the DataFrame.
*/
def randomSplitSync(weights: Array[Double], seed: Long)(implicit mdbs: Option[ModelDbSyncer]): Array[DataFrame] = {
val splits = m.randomSplit(weights, seed)
if (mdbs.isDefined) {
mdbs.get.buffer(RandomSplitEvent(m, weights, seed, splits))
splits.foreach(df => mdbs.get.featureTracker.copyFeatures(m, df))
}
// We can think of random splitting as performing n transformations from the original DataFrame to
// n smaller DataFrames where there are no input features or output features.
// Thus, we will feed this information to the FeatureTracker so that each of the splits know that they
// originated from the same DataFrame and so that they remember its features.
SyncableDataFramePaths.getPath(m) match {
case Some(path) => splits.foreach(spl => SyncableDataFramePaths.setPath(spl, path))
case None => {}
}
splits
}
}
}
object SyncableDataFrame extends SyncableDataFrame {
/**
* Convert a Spark DataFrame into a modeldb.DataFrame.
* @param df - The Spark DataFrame.
* @param mdbs - The syncer (used for the id mapping).
* @return A modeldb.DataFrame representing the Spark DataFrame.
*/
def apply(df: DataFrame)(implicit mdbs: Option[ModelDbSyncer]): modeldb.DataFrame = {
val id = mdbs.get.id(df).getOrElse(-1)
val tag = mdbs.get.tag(df).getOrElse("")
// If this dataframe already has an ID, the columns are already stored on the server, so we leave them empty.
val columns = if (id != -1) {
Seq[modeldb.DataFrameColumn]()
} else {
df.schema.map(field => modeldb.DataFrameColumn(field.name, field.dataType.simpleString))
}
// Similar to above, we only compute the number of rows in the dataframe if the server has not seen this dataframe.
val numRows = if (id != -1) {
-1
} else {
// df.count.toInt // This is a performance intensive operation. We need to think if we actually want to keep this in.
-1
}
val modeldbDf = modeldb.DataFrame(
id,
columns,
numRows,
tag=tag,
filepath = SyncableDataFramePaths.getPath(df)
)
modeldbDf
}
}