Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HOPSWORKS-1982] Deequ statistics for Feature Groups/Training Datasets #96

Merged
merged 10 commits into from
Sep 25, 2020
7 changes: 7 additions & 0 deletions java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
<lombok.version>1.18.6</lombok.version>
<fasterxml.jackson.databind.version>2.6.7.1</fasterxml.jackson.databind.version>
<spark.version>2.4.3.2</spark.version>
<deequ.version>1.1.0-SNAPSHOT</deequ.version>
</properties>

<dependencies>
Expand All @@ -29,6 +30,12 @@
<version>${lombok.version}</version>
</dependency>

<dependency>
<groupId>com.logicalclocks</groupId>
<artifactId>deequ</artifactId>
<version>${deequ.version}</version>
</dependency>

<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
Expand Down
92 changes: 90 additions & 2 deletions java/src/main/java/com/logicalclocks/hsfs/FeatureGroup.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,20 @@

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.logicalclocks.hsfs.engine.FeatureGroupEngine;
import com.logicalclocks.hsfs.engine.StatisticsEngine;
import com.logicalclocks.hsfs.metadata.Query;
import com.logicalclocks.hsfs.metadata.Statistics;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.Date;
Expand Down Expand Up @@ -69,6 +74,21 @@ public class FeatureGroup {
@Getter @Setter
private String type = "cachedFeaturegroupDTO";

@Getter @Setter
@JsonProperty("descStatsEnabled")
private Boolean statisticsEnabled;

@Getter @Setter
@JsonProperty("featHistEnabled")
private Boolean histograms;

@Getter @Setter
@JsonProperty("featCorrEnabled")
private Boolean correlations;

@Getter @Setter
private List<String> statisticColumns;

@JsonIgnore
// These are only used in the client. In the server they are aggregated in the `features` field
private List<String> primaryKeys;
Expand All @@ -78,11 +98,15 @@ public class FeatureGroup {
private List<String> partitionKeys;

private FeatureGroupEngine featureGroupEngine = new FeatureGroupEngine();
private StatisticsEngine statisticsEngine = new StatisticsEngine(EntityEndpointType.FEATURE_GROUP);

private static final Logger LOGGER = LoggerFactory.getLogger(FeatureGroup.class);

@Builder
public FeatureGroup(FeatureStore featureStore, @NonNull String name, Integer version, String description,
List<String> primaryKeys, List<String> partitionKeys,
boolean onlineEnabled, Storage defaultStorage, List<Feature> features)
List<String> primaryKeys, List<String> partitionKeys, boolean onlineEnabled,
Storage defaultStorage, List<Feature> features, Boolean statisticsEnabled, Boolean histograms,
Boolean correlations, List<String> statisticColumns)
throws FeatureStoreException {

this.featureStore = featureStore;
Expand All @@ -94,6 +118,10 @@ public FeatureGroup(FeatureStore featureStore, @NonNull String name, Integer ver
this.onlineEnabled = onlineEnabled;
this.defaultStorage = defaultStorage != null ? defaultStorage : Storage.OFFLINE;
this.features = features;
this.statisticsEnabled = statisticsEnabled;
this.histograms = histograms;
this.correlations = correlations;
this.statisticColumns = statisticColumns;
}

public FeatureGroup() {
Expand Down Expand Up @@ -137,6 +165,9 @@ public void save(Dataset<Row> featureData) throws FeatureStoreException, IOExcep
public void save(Dataset<Row> featureData, Map<String, String> writeOptions)
throws FeatureStoreException, IOException {
featureGroupEngine.saveFeatureGroup(this, featureData, primaryKeys, partitionKeys, defaultStorage, writeOptions);
if (statisticsEnabled) {
statisticsEngine.computeStatistics(this, featureData);
}
Comment on lines +168 to +170
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this call the computeStatistics() method? Otherwise you might end up computing feature for the online feature store. which is not bad per se in this case, as you are not query NDB, but might confuse users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with the confusion part, I just wanted to reuse the dataframe as we already have it, instead of rereading it. I am not sure spark is smart enough to recognize that it's already there.

On the other hand this way it would always allow the user to have the statistics from the very first creation of the featuregroup even if it is purely online.

}

public void insert(Dataset<Row> featureData, Storage storage) throws IOException, FeatureStoreException {
Expand All @@ -161,12 +192,69 @@ public void insert(Dataset<Row> featureData, Storage storage, boolean overwrite,
throws FeatureStoreException, IOException {
featureGroupEngine.saveDataframe(this, featureData, storage,
overwrite ? SaveMode.Overwrite : SaveMode.Append, writeOptions);
computeStatistics();
}

public void delete() throws FeatureStoreException, IOException {
featureGroupEngine.delete(this);
}

/**
* Update the statistics configuration of the feature group.
* Change the `statisticsEnabled`, `histograms`, `correlations` or `statisticColumns` attributes and persist
* the changes by calling this method.
*
* @throws FeatureStoreException
* @throws IOException
*/
public void updateStatisticsConfig() throws FeatureStoreException, IOException {
featureGroupEngine.updateStatisticsConfig(this);
}

/**
* Recompute the statistics for the feature group and save them to the feature store.
*
* @return statistics object of computed statistics
* @throws FeatureStoreException
* @throws IOException
*/
public Statistics computeStatistics() throws FeatureStoreException, IOException {
if (statisticsEnabled) {
if (defaultStorage == Storage.ALL || defaultStorage == Storage.OFFLINE) {
return statisticsEngine.computeStatistics(this, read(Storage.OFFLINE));
} else {
LOGGER.info("StorageWarning: The default storage of feature group `" + name + "`, with version `" + version
+ "`, is `" + defaultStorage + "`. Statistics are only computed for default storage `offline and `all`.");
}
}
return null;
}

/**
* Get the last statistics commit for the feature group.
*
* @return statistics object of latest commit
* @throws FeatureStoreException
* @throws IOException
*/
@JsonIgnore
public Statistics getStatistics() throws FeatureStoreException, IOException {
return statisticsEngine.getLast(this);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validates my point in the python api. Here we return an object containing commit_time, content. Which I think is good

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also the object in python contains content and commit time as only accessible members

}

/**
* Get the statistics of a specific commit time for the feature group.
*
* @param commitTime commit time in the format "YYYYMMDDhhmmss"
* @return statistics object for the commit time
* @throws FeatureStoreException
* @throws IOException
*/
@JsonIgnore
public Statistics getStatistics(String commitTime) throws FeatureStoreException, IOException {
return statisticsEngine.get(this, commitTime);
}

/**
* Add a tag without value to the feature group.
*
Expand Down
80 changes: 75 additions & 5 deletions java/src/main/java/com/logicalclocks/hsfs/TrainingDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
package com.logicalclocks.hsfs;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.logicalclocks.hsfs.engine.StatisticsEngine;
import com.logicalclocks.hsfs.engine.TrainingDatasetEngine;
import com.logicalclocks.hsfs.metadata.Query;
import com.logicalclocks.hsfs.metadata.Statistics;
import lombok.Builder;
import lombok.Getter;
import lombok.NoArgsConstructor;
Expand Down Expand Up @@ -75,13 +78,30 @@ public class TrainingDataset {
@Getter @Setter
private List<Split> splits;

@Getter @Setter
@JsonIgnore
private Boolean statisticsEnabled = true;

@Getter @Setter
@JsonIgnore
private Boolean histograms;

@Getter @Setter
@JsonIgnore
private Boolean correlations;

@Getter @Setter
@JsonIgnore
private List<String> statisticColumns;

private TrainingDatasetEngine trainingDatasetEngine = new TrainingDatasetEngine();
private StatisticsEngine statisticsEngine = new StatisticsEngine(EntityEndpointType.TRAINING_DATASET);

@Builder
public TrainingDataset(@NonNull String name, Integer version, String description,
DataFormat dataFormat, StorageConnector storageConnector,
String location, List<Split> splits, Long seed,
FeatureStore featureStore) {
public TrainingDataset(@NonNull String name, Integer version, String description, DataFormat dataFormat,
StorageConnector storageConnector, String location, List<Split> splits, Long seed,
FeatureStore featureStore, Boolean statisticsEnabled, Boolean histograms,
Boolean correlations, List<String> statisticColumns) {
this.name = name;
this.version = version;
this.description = description;
Expand All @@ -100,6 +120,10 @@ public TrainingDataset(@NonNull String name, Integer version, String description
this.splits = splits;
this.seed = seed;
this.featureStore = featureStore;
this.statisticsEnabled = statisticsEnabled;
this.histograms = histograms;
this.correlations = correlations;
this.statisticColumns = statisticColumns;
}

/**
Expand Down Expand Up @@ -133,7 +157,11 @@ public void save(Dataset<Row> dataset) throws FeatureStoreException, IOException
* @throws IOException
*/
public void save(Query query, Map<String, String> writeOptions) throws FeatureStoreException, IOException {
trainingDatasetEngine.save(this, query.read(), writeOptions);
Dataset<Row> dataset = query.read();
trainingDatasetEngine.save(this, dataset, writeOptions);
if (statisticsEnabled) {
statisticsEngine.computeStatistics(this, dataset);
}
}

/**
Expand All @@ -147,6 +175,9 @@ public void save(Query query, Map<String, String> writeOptions) throws FeatureSt
public void save(Dataset<Row> dataset, Map<String, String> writeOptions)
throws FeatureStoreException, IOException {
trainingDatasetEngine.save(this, dataset, writeOptions);
if (statisticsEnabled) {
statisticsEngine.computeStatistics(this, dataset);
}
}

/**
Expand Down Expand Up @@ -186,6 +217,7 @@ public void insert(Query query, boolean overwrite, Map<String, String> writeOpti
throws FeatureStoreException, IOException {
trainingDatasetEngine.insert(this, query.read(),
writeOptions, overwrite ? SaveMode.Overwrite : SaveMode.Append);
computeStatistics();
}

/**
Expand All @@ -201,6 +233,7 @@ public void insert(Dataset<Row> dataset, boolean overwrite, Map<String, String>
throws FeatureStoreException, IOException {
trainingDatasetEngine.insert(this, dataset,
writeOptions, overwrite ? SaveMode.Overwrite : SaveMode.Append);
computeStatistics();
}

/**
Expand Down Expand Up @@ -253,6 +286,43 @@ public void show(int numRows) {
read("").show(numRows);
}

/**
* Recompute the statistics for the entire training dataset and save them to the feature store.
*
* @return statistics object of computed statistics
* @throws FeatureStoreException
* @throws IOException
*/
public Statistics computeStatistics() throws FeatureStoreException, IOException {
if (statisticsEnabled) {
return statisticsEngine.computeStatistics(this, read());
}
return null;
}

/**
* Get the last statistics commit for the training dataset.
*
* @return statistics object of latest commit
* @throws FeatureStoreException
* @throws IOException
*/
public Statistics getStatistics() throws FeatureStoreException, IOException {
return statisticsEngine.getLast(this);
}

/**
* Get the statistics of a specific commit time for the training dataset.
*
* @param commitTime commit time in the format "YYYYMMDDhhmmss"
* @return statistics object for the commit time
* @throws FeatureStoreException
* @throws IOException
*/
public Statistics getStatistics(String commitTime) throws FeatureStoreException, IOException {
return statisticsEngine.get(this, commitTime);
}

/**
* Add a tag without value to the training dataset.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,19 @@

import com.logicalclocks.hsfs.EntityEndpointType;
import com.logicalclocks.hsfs.FeatureGroup;
import com.logicalclocks.hsfs.FeatureStore;
import com.logicalclocks.hsfs.FeatureStoreException;
import com.logicalclocks.hsfs.Storage;
import com.logicalclocks.hsfs.StorageConnector;
import com.logicalclocks.hsfs.metadata.StorageConnectorApi;
import com.logicalclocks.hsfs.metadata.FeatureGroupApi;
import com.logicalclocks.hsfs.metadata.TagsApi;
import com.logicalclocks.hsfs.util.Constants;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand All @@ -46,9 +43,6 @@ public class FeatureGroupEngine {

private static final Logger LOGGER = LoggerFactory.getLogger(FeatureGroupEngine.class);

//TODO:
// Compute statistics

/**
* Create the metadata and write the data to the online/offline feature store.
*
Expand Down Expand Up @@ -102,6 +96,9 @@ public void saveFeatureGroup(FeatureGroup featureGroup, Dataset<Row> dataset,

// Update the original object - Hopsworks returns the incremented version
featureGroup.setVersion(apiFG.getVersion());
featureGroup.setId(apiFG.getId());
featureGroup.setCorrelations(apiFG.getCorrelations());
featureGroup.setHistograms(apiFG.getHistograms());

// Write the dataframe
saveDataframe(featureGroup, dataset, storage, SaveMode.Append, writeOptions);
Expand Down Expand Up @@ -179,4 +176,10 @@ public Map<String, String> getTag(FeatureGroup featureGroup, String name) throws
public void deleteTag(FeatureGroup featureGroup, String name) throws FeatureStoreException, IOException {
tagsApi.deleteTag(featureGroup, name);
}

public void updateStatisticsConfig(FeatureGroup featureGroup) throws FeatureStoreException, IOException {
FeatureGroup apiFG = featureGroupApi.updateStatsConfig(featureGroup);
featureGroup.setCorrelations(apiFG.getCorrelations());
featureGroup.setHistograms(apiFG.getHistograms());
}
}
Loading