-
Notifications
You must be signed in to change notification settings - Fork 185
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
277 additions
and
0 deletions.
There are no files selected for viewing
25 changes: 25 additions & 0 deletions
25
photon-avro-schemas/src/main/avro/ResponsePredictionAvro.avsc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
{ | ||
"type" : "record", | ||
"name" : "SimplifiedResponsePrediction", | ||
"namespace" : "com.linkedin.photon.avro.generated", | ||
"doc" : "Response prediction format truncated with the only field photon is expecting", | ||
"fields" : [ | ||
{ | ||
"name" : "response", | ||
"type" : "double" | ||
}, { | ||
"name" : "features", | ||
"type": { | ||
"type": "array", | ||
"items": "FeatureAvro" | ||
} | ||
}, { | ||
"name" : "weight", | ||
"type" : "double", | ||
"default": 1.0 | ||
}, { | ||
"name" : "offset", | ||
"type" : "double", | ||
"default": 0.0 | ||
}] | ||
} |
51 changes: 51 additions & 0 deletions
51
...client/src/integTest/scala/com/linkedin/photon/ml/data/avro/AvroDataWriterIntegTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
package com.linkedin.photon.ml.data.avro | ||
|
||
import org.apache.hadoop.fs.{FileSystem, Path} | ||
import org.testng.Assert._ | ||
import org.testng.annotations.Test | ||
|
||
import com.linkedin.photon.ml.io.FeatureShardConfiguration | ||
import com.linkedin.photon.ml.test.{SparkTestUtils, TestTemplateWithTmpDir} | ||
|
||
/** | ||
* Integeration test for AvroDataWriter | ||
*/ | ||
class AvroDataWriterIntegTest extends SparkTestUtils with TestTemplateWithTmpDir { | ||
|
||
import AvroDataWriterIntegTest._ | ||
|
||
@Test | ||
def testWrite(): Unit = sparkTest("testRead") { | ||
val dr = new AvroDataReader() | ||
val (df, indexMapLoadersMap) = dr.readMerged(inputPath.toString, featureShardConfigurationsMap, numPartitions) | ||
val outputDir = new Path(getTmpDir) | ||
|
||
assertTrue(df.columns.contains(featureColumn)) | ||
assertTrue(df.columns.contains(responseColumn)) | ||
assertEquals(df.count, 34810) | ||
assertTrue(indexMapLoadersMap.contains(featureColumn)) | ||
|
||
val indexMapLoader = indexMapLoadersMap(featureColumn) | ||
val writer = new AvroDataWriter | ||
writer.write(df, outputDir.toString, indexMapLoader, responseColumn, featureColumn, overwrite = true) | ||
|
||
val fs = FileSystem.get(sc.hadoopConfiguration) | ||
val files = fs.listStatus(outputDir).filter(_.getPath.getName.startsWith("part")) | ||
assertEquals(files.length, numPartitions) | ||
|
||
val (writeData, _) = dr.read(outputDir.toString, numPartitions) | ||
assertTrue(writeData.columns.contains(responseColumn)) | ||
assertTrue(writeData.columns.contains(featureColumn)) | ||
assertEquals(writeData.count(), 34810) | ||
} | ||
} | ||
|
||
object AvroDataWriterIntegTest { | ||
private val inputDir = getClass.getClassLoader.getResource("GameIntegTest/input").getPath | ||
private val inputPath = new Path(inputDir, "train") | ||
private val numPartitions = 4 | ||
private val featureColumn = "features" | ||
private val responseColumn = "response" | ||
private val featureShardConfigurationsMap = Map( | ||
featureColumn -> FeatureShardConfiguration(Set("userFeatures", "songFeatures"), hasIntercept = false)) | ||
} |
128 changes: 128 additions & 0 deletions
128
photon-client/src/main/scala/com/linkedin/photon/ml/data/avro/AvroDataWriter.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
package com.linkedin.photon.ml.data.avro | ||
|
||
import scala.collection.mutable.ListBuffer | ||
import scala.collection.JavaConversions._ | ||
|
||
import org.apache.hadoop.fs.{FileSystem, Path} | ||
import org.apache.spark.ml.linalg.Vector | ||
import org.apache.spark.sql.{DataFrame, Row, SparkSession} | ||
|
||
import com.linkedin.photon.avro.generated.{FeatureAvro, SimplifiedResponsePrediction} | ||
import com.linkedin.photon.ml.Constants.DELIMITER | ||
import com.linkedin.photon.ml.index.{IndexMap, IndexMapLoader} | ||
|
||
/** | ||
* Write dataframe to Avro files on HDFS in [[SimplifiedResponsePrediction]] format | ||
*/ | ||
class AvroDataWriter { | ||
|
||
import AvroDataWriter._ | ||
|
||
private val sparkSession = SparkSession.builder.getOrCreate() | ||
private val sc = sparkSession.sparkContext | ||
|
||
/** | ||
* Write the DataFrame into avro records using the given indexMapLoader | ||
* | ||
* @param df The DataFrame | ||
* @param outputPath The output path to store the avro files | ||
* @param indexMapLoader The IndexMapLoader store feature to index information | ||
* @param responseColumn The response column name in df | ||
* @param featureColumn The feature column name in df | ||
*/ | ||
def write( | ||
df: DataFrame, | ||
outputPath: String, | ||
indexMapLoader: IndexMapLoader, | ||
responseColumn: String, | ||
featureColumn: String, | ||
overwrite: Boolean = false): Unit = { | ||
|
||
// TODO: Save other fields in the dataset, i.e. feature columns | ||
val columns = df.columns | ||
require(columns.contains(responseColumn), s"There must be a $responseColumn column present in dataframe") | ||
require(columns.contains(featureColumn), s"There must be a $featureColumn column present in dataframe") | ||
|
||
val hasOffset = columns.contains("offset") | ||
val hasWeight = columns.contains("weight") | ||
|
||
val avroDataset = df.rdd.mapPartitions { rows => | ||
val indexMap = indexMapLoader.indexMapForRDD() | ||
val rowBuilder = SimplifiedResponsePrediction.newBuilder() | ||
|
||
rows.map { r: Row => | ||
val features = r.getAs[Vector](featureColumn) | ||
val response = getValueAsDouble(r, responseColumn) | ||
val offset = if (hasOffset) getValueAsDouble(r, "offset") else 0.0D | ||
val weight = if (hasWeight) getValueAsDouble(r, "weight") else 1.0D | ||
rowBuilder | ||
.setResponse(response) | ||
.setOffset(offset) | ||
.setWeight(weight) | ||
.setFeatures(buildAvroFeatures(features, indexMap)) | ||
.build() | ||
} | ||
} | ||
|
||
// Write the converted dataset back to HDFS | ||
if (overwrite) { | ||
val fs = FileSystem.get(sc.hadoopConfiguration) | ||
val output = new Path(outputPath) | ||
if (fs.exists(output)) { | ||
fs.delete(output, true) | ||
} | ||
} | ||
|
||
AvroUtils.saveAsAvro[SimplifiedResponsePrediction]( | ||
avroDataset, | ||
outputPath, | ||
SimplifiedResponsePrediction.getClassSchema.toString) | ||
} | ||
} | ||
|
||
object AvroDataWriter { | ||
/** | ||
* Helper function to convert Row index field to double | ||
* | ||
* @param row A training record in [[Row]] format | ||
* @param fieldName The index of particular field | ||
* @return A double in this field | ||
*/ | ||
protected[data] def getValueAsDouble(row: Row, fieldName: String): Double = { | ||
|
||
row.getAs[Any](fieldName) match { | ||
case n: Number => n.doubleValue() | ||
case b: Boolean => if (b) 1.0D else 0.0D | ||
case _ => | ||
throw new IllegalArgumentException(s"Unsupported data type") | ||
} | ||
} | ||
|
||
/** | ||
* Build a list of Avro Feature instances for the given list [[Vector]] and [[IndexMap]] | ||
* | ||
* @param vector The extracted feature in [[Vector]] for a particular training instance | ||
* @param indexMap The reverse index map from feature to index | ||
* @return A list of Avro Feature instances built from the vector | ||
*/ | ||
protected[data] def buildAvroFeatures(vector: Vector, indexMap: IndexMap): java.util.List[FeatureAvro] = { | ||
|
||
val builder = FeatureAvro.newBuilder() | ||
val avroFeatures = new ListBuffer[FeatureAvro] | ||
vector.foreachActive { | ||
case (vectorIdx, vectorValue) => | ||
val feature = indexMap.getFeatureName(vectorIdx).get | ||
feature.split(DELIMITER) match { | ||
case Array(name, term) => | ||
builder.setName(name).setTerm(term) | ||
case Array(name) => | ||
builder.setName(name).setTerm("") | ||
case _ => | ||
throw new IllegalArgumentException(s"Error parsing the name and term for this feature $feature") | ||
} | ||
builder.setValue(vectorValue) | ||
avroFeatures += builder.build() | ||
} | ||
avroFeatures.toList | ||
} | ||
} |
73 changes: 73 additions & 0 deletions
73
photon-client/src/test/scala/com/linkedin/photon/ml/data/avro/AvroDataWriterTest.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
package com.linkedin.photon.ml.data.avro | ||
|
||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.types.DataTypes._ | ||
import org.testng.Assert._ | ||
import org.apache.spark.ml.linalg.Vectors | ||
import org.apache.spark.mllib.linalg.VectorUDT | ||
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema | ||
import org.apache.spark.sql.types.{StructField, StructType} | ||
import org.testng.annotations.{DataProvider, Test} | ||
|
||
import com.linkedin.photon.ml.Constants.DELIMITER | ||
import com.linkedin.photon.ml.index.{DefaultIndexMap, IndexMap} | ||
|
||
class AvroDataWriterTest { | ||
|
||
@DataProvider | ||
def rowsProvider(): Array[Array[GenericRowWithSchema]] = { | ||
|
||
val vector = Vectors.sparse(3, Array(0, 2), Array(0.0, 1.0)) | ||
val arrays = Array( | ||
Array(1, 0, 1, vector), | ||
Array(true, false, true, vector), | ||
Array(1.0f, 0.0f, 1.0f, vector), | ||
Array(1L, 0L, 1L, vector), | ||
Array(1.0D, 0.0D, 1.0D, vector) | ||
) | ||
val types = Array(IntegerType, BooleanType, FloatType, LongType, DoubleType) | ||
|
||
arrays.zip(types).map { case (a, t) => | ||
val schema = new StructType( | ||
Array( | ||
StructField("response", t), | ||
StructField("offset", t), | ||
StructField("weight", t), | ||
StructField("features", new VectorUDT))) | ||
Array(new GenericRowWithSchema(a, schema)) | ||
} | ||
} | ||
|
||
@Test(dataProvider = "rowsProvider") | ||
def testGetValueAsDouble(row: Row): Unit = { | ||
|
||
val label = AvroDataWriter.getValueAsDouble(row, "response") | ||
assertEquals(label, 1.0D) | ||
val offset = AvroDataWriter.getValueAsDouble(row, "offset") | ||
assertEquals(offset, 0.0D) | ||
val weight = AvroDataWriter.getValueAsDouble(row, "weight") | ||
assertEquals(weight, 1.0D) | ||
} | ||
|
||
@Test | ||
def testBuildAvroFeatures(): Unit = { | ||
|
||
val vector = Vectors.sparse(3, Array(0, 1, 2), Array(1.0, 2.0, 3.0)) | ||
val indexMap: IndexMap = new DefaultIndexMap( | ||
featureNameToIdMap = Map( | ||
s"name0${DELIMITER}term0" -> 0, | ||
s"name1$DELIMITER" -> 1, | ||
s"${DELIMITER}term2" -> 2)) | ||
val results = AvroDataWriter.buildAvroFeatures(vector, indexMap) | ||
assertEquals(results.size(), 3) | ||
assertEquals(results.get(0).getName, "name0") | ||
assertEquals(results.get(0).getTerm, "term0") | ||
assertEquals(results.get(0).getValue, 1.0) | ||
assertEquals(results.get(1).getName, "name1") | ||
assertEquals(results.get(1).getTerm, "") | ||
assertEquals(results.get(1).getValue, 2.0) | ||
assertEquals(results.get(2).getName, "") | ||
assertEquals(results.get(2).getTerm, "term2") | ||
assertEquals(results.get(2).getValue, 3.0) | ||
} | ||
} |