/
MeasureDatasetFairnessMetrics.scala
54 lines (47 loc) · 2.05 KB
/
MeasureDatasetFairnessMetrics.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package com.linkedin.lift.eval.jobs
import com.linkedin.lift.eval.{FairnessMetricsUtils, MeasureDatasetFairnessMetricsCmdLineArgs}
import com.linkedin.lift.types.Distribution
import org.apache.spark.sql.SparkSession
/**
* A basic dataset-level fairness metrics measurement program. If your use case
* is more involved, you can create a similar wrapper driver program that
* prepares the data and calls the computeDatasetMetrics API.
*/
object MeasureDatasetFairnessMetrics {
/**
* Driver program to measure various fairness metrics
*
* @param progArgs Command line arguments
*/
def main(progArgs: Array[String]): Unit = {
val spark = SparkSession
.builder()
.appName(getClass.getSimpleName)
.getOrCreate()
val args = MeasureDatasetFairnessMetricsCmdLineArgs.parseArgs(progArgs)
// One could choose to do their own preprocessing here
// For example, filtering out only certain records based on some threshold
val dfReader = spark.read.format(args.dataFormat).options(args.dataOptions)
val df = dfReader.load(args.datasetPath)
.select(args.uidField, args.labelField)
val protectedDF = dfReader.load(args.protectedDatasetPath)
// Similar preprocessing can be done with the protected attribute data
val joinedDF = FairnessMetricsUtils.computeJoinedDF(protectedDF, df, args.uidField,
args.protectedDatasetPath, args.uidProtectedAttributeField,
args.protectedAttributeField)
// Input distributions are computed using the joined data
val referenceDistrOpt =
if (args.referenceDistribution.isEmpty) {
None
} else {
val distribution = Distribution.compute(joinedDF,
Set(args.labelField, args.protectedAttributeField))
FairnessMetricsUtils.computeReferenceDistributionOpt(
distribution, args.referenceDistribution)
}
// Passing in the appropriate parameters to this API computes and writes
// out the fairness metrics
FairnessMetricsUtils.computeAndWriteDatasetMetrics(joinedDF,
referenceDistrOpt, args)
}
}