Skip to content

Commit

Permalink
Recommendation Model Summary
Browse files Browse the repository at this point in the history
  • Loading branch information
madawas committed Aug 12, 2015
1 parent 0857322 commit 987c799
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ private MLConstants() {
public static final String CLASS_CLASSIFICATION_AND_REGRESSION_MODEL_SUMMARY = "ClassClassificationAndRegressionModelSummary";
public static final String PROBABILISTIC_CLASSIFICATION_MODEL_SUMMARY = "ProbabilisticClassificationModelSummary";
public static final String CLUSTER_MODEL_SUMMARY = "ClusterModelSummary";
public static final String RECOMMENDATION_MODEL_SUMMARY = "RecommendationModelSummary";

public static final int K_MEANS_SAMPLE_SIZE = 10000;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;

public class CollaborativeFiltering {
import java.io.Serializable;

public class CollaborativeFiltering implements Serializable{

private static final long serialVersionUID = 5273514743795162923L;

/**
* This method uses alternating least squares (ALS) algorithm to train a matrix factorization model given an JavaRDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,18 @@
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel;
import org.apache.spark.mllib.recommendation.Rating;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.constants.MLConstants.RECOMMENDATION_ALGORITHM;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.commons.domain.ModelSummary;
import org.wso2.carbon.ml.commons.domain.Workflow;
import org.wso2.carbon.ml.core.exceptions.AlgorithmNameException;
import org.wso2.carbon.ml.core.exceptions.MLModelBuilderException;
import org.wso2.carbon.ml.core.interfaces.MLModelBuilder;
import org.wso2.carbon.ml.core.internal.MLModelConfigurationContext;
import org.wso2.carbon.ml.core.spark.models.MLMatrixFactorizationModel;
import org.wso2.carbon.ml.core.spark.summary.RecommendationModelSummary;
import org.wso2.carbon.ml.core.utils.MLCoreServiceValueHolder;
import org.wso2.carbon.ml.database.DatabaseService;

import java.util.Map;

Expand All @@ -51,43 +56,40 @@ public RecommendationModelBuilder(MLModelConfigurationContext context) {
*/
@Override public MLModel build() throws MLModelBuilderException {
MLModelConfigurationContext context = getContext();

DatabaseService databaseService = MLCoreServiceValueHolder.getInstance().getDatabaseService();
try {
Workflow workflow = context.getFacts();
long modelId = context.getModelId();
MLModel mlModel = new MLModel();
ModelSummary summaryModel;
JavaRDD<Rating> trainingData;

mlModel.setAlgorithmName(workflow.getAlgorithmName());
mlModel.setAlgorithmClass(workflow.getAlgorithmClass());

MatrixFactorizationModel model;
JavaRDD<Rating> trainingData;
mlModel.setFeatures(workflow.getFeatures());

// build a recommendation model according to user selected algorithm
MLConstants.RECOMMENDATION_ALGORITHM recommendation_algorithm =
MLConstants.RECOMMENDATION_ALGORITHM.valueOf(workflow.getAlgorithmName());
RECOMMENDATION_ALGORITHM recommendation_algorithm =
RECOMMENDATION_ALGORITHM.valueOf(workflow.getAlgorithmName());
switch (recommendation_algorithm) {
case COLLABORATIVE_FILTERING:
trainingData = RecommendationUtils.preProcess(context, false);
model = buildCollaborativeFilteringModel(trainingData, workflow, mlModel, false);
summaryModel = buildCollaborativeFilteringModel(trainingData, workflow, mlModel, false);
break;
case COLLABORATIVE_FILTERING_IMPLICIT:
trainingData = RecommendationUtils.preProcess(context, true);
model = buildCollaborativeFilteringModel(trainingData, workflow, mlModel, true);
summaryModel = buildCollaborativeFilteringModel(trainingData, workflow, mlModel, true);
break;
default:
throw new AlgorithmNameException(
"Incorrect algorithm name: " + workflow.getAlgorithmName() + " for model id: " + modelId);
}
// Recommendations printing in console
Rating[] recommendedProducts = new CollaborativeFiltering().recommendProducts(model,1,50);
for (Rating recommendedProduct : recommendedProducts) {
System.out.println(recommendedProduct.user() + " " + recommendedProduct.product());
}
mlModel.setModel(new MLMatrixFactorizationModel(model));
//persist model summary
databaseService.updateModelSummary(modelId, summaryModel);
return mlModel;
} catch (Exception e) {
throw new MLModelBuilderException(
"An error occurred while building unsupervised machine learning model: " + e.getMessage(), e);
"An error occurred while building recommendation model: " + e.getMessage(), e);
}
}

Expand All @@ -97,33 +99,46 @@ public RecommendationModelBuilder(MLModelConfigurationContext context) {
* @param workflow {@link Workflow}
* @param mlModel {@link MLModel}
* @param trainImplicit train using implicit data
* @return {@link MatrixFactorizationModel}
* @return {@link ModelSummary}
* @throws MLModelBuilderException If failed to build the model
*/
private MatrixFactorizationModel buildCollaborativeFilteringModel(JavaRDD<Rating> trainingData, Workflow workflow,
private ModelSummary buildCollaborativeFilteringModel(JavaRDD<Rating> trainingData, Workflow workflow,
MLModel mlModel, boolean trainImplicit) throws MLModelBuilderException {

try {
Map<String, String> parameters = workflow.getHyperParameters();
CollaborativeFiltering collaborativeFiltering = new CollaborativeFiltering();
RecommendationModelSummary recommendationModelSummary = new RecommendationModelSummary();
MatrixFactorizationModel model;
if (trainImplicit) {
model = collaborativeFiltering.trainImplicit(trainingData,
Integer.parseInt(parameters.get(MLConstants.RANK)),
Integer.parseInt(parameters.get(MLConstants.ITERATIONS)),
Double.parseDouble(parameters.get(MLConstants.LAMBDA)),
Double.parseDouble(parameters.get(MLConstants.ALPHA)),
Integer.parseInt(parameters.get(MLConstants.BLOCKS)));
model = collaborativeFiltering
.trainImplicit(trainingData, Integer.parseInt(parameters.get(MLConstants.RANK)),
Integer.parseInt(parameters.get(MLConstants.ITERATIONS)),
Double.parseDouble(parameters.get(MLConstants.LAMBDA)),
Double.parseDouble(parameters.get(MLConstants.ALPHA)),
Integer.parseInt(parameters.get(MLConstants.BLOCKS)));
recommendationModelSummary
.setAlgorithm(RECOMMENDATION_ALGORITHM.COLLABORATIVE_FILTERING_IMPLICIT.toString());
} else {
model = collaborativeFiltering.trainExplicit(trainingData,
Integer.parseInt(parameters.get(MLConstants.RANK)),
Integer.parseInt(parameters.get(MLConstants.ITERATIONS)),
Double.parseDouble(parameters.get(MLConstants.LAMBDA)),
Integer.parseInt(parameters.get(MLConstants.BLOCKS)));
model = collaborativeFiltering
.trainExplicit(trainingData, Integer.parseInt(parameters.get(MLConstants.RANK)),
Integer.parseInt(parameters.get(MLConstants.ITERATIONS)),
Double.parseDouble(parameters.get(MLConstants.LAMBDA)),
Integer.parseInt(parameters.get(MLConstants.BLOCKS)));
recommendationModelSummary.setAlgorithm(RECOMMENDATION_ALGORITHM.COLLABORATIVE_FILTERING.toString());
}

mlModel.setModel(new MLMatrixFactorizationModel(model));
return model;
recommendationModelSummary.setUserFeatures(model.userFeatures());
recommendationModelSummary.setProductFeatures(model.productFeatures());
recommendationModelSummary.setRank(model.rank());

//Recommendations printing in console
Rating[] recommendedProducts = new CollaborativeFiltering().recommendProducts(model, 1, 50);
for (Rating recommendedProduct : recommendedProducts) {
System.out.println(recommendedProduct.user() + " " + recommendedProduct.product());
}

return recommendationModelSummary;
} catch (Exception e) {
throw new MLModelBuilderException(
"An error occurred while building recommendation model: " + e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ public static JavaRDD<Rating> preProcess(MLModelConfigurationContext context, bo
Map<String, String> parameters = workflow.getHyperParameters();
List<Double> weightList = getWeightList(parameters.get(MLConstants.WEIGHTS));
tokens = tokens.map(new ImplicitDataToRating(userIndex, productIndex, observationList, weightList));
return tokens.map(new StringArrayToRating(0,1,2));
} else {
return tokens.map(new StringArrayToRating(userIndex, productIndex, ratingIndex));
}

return tokens.map(new StringArrayToRating(userIndex, productIndex, ratingIndex));
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Copyright (c) 2015, WSO2 Inc. (http://www.wso2.org) All Rights Reserved.
*
* WSO2 Inc. licenses this file to you under the Apache License,
* Version 2.0 (the "License"); you may not use this file except
* in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.wso2.carbon.ml.core.spark.summary;

import org.apache.spark.rdd.RDD;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.ModelSummary;
import scala.Tuple2;

import java.io.Serializable;
import java.util.Arrays;

public class RecommendationModelSummary implements ModelSummary, Serializable {

private static final long serialVersionUID = 5557613175001838756L;
private String algorithm;
private String[] features;
private int rank;
private RDD<Tuple2<Object, double[]>> userFeatures;
private RDD<Tuple2<Object, double[]>> productFeatures;

@Override
public String getModelSummaryType() {
return MLConstants.RECOMMENDATION_MODEL_SUMMARY;
}

public String getAlgorithm() {
return algorithm;
}

public void setAlgorithm(String algorithm) {
this.algorithm = algorithm;
}

public int getRank() {
return rank;
}

public void setRank(int rank) {
this.rank = rank;
}

public RDD<Tuple2<Object, double[]>> getUserFeatures() {
return userFeatures;
}

public void setUserFeatures(RDD<Tuple2<Object, double[]>> userFeatures) {
this.userFeatures = userFeatures;
}

public RDD<Tuple2<Object, double[]>> getProductFeatures() {
return productFeatures;
}

public void setProductFeatures(RDD<Tuple2<Object, double[]>> productFeatures) {
this.productFeatures = productFeatures;
}

/**
* @param features Array of names of the features
*/
public void setFeatures(String[] features) {
if (features == null) {
this.features = new String[0];
} else {
this.features = Arrays.copyOf(features, features.length);
}
}

@Override
public String[] getFeatures() {
return features;
}
}

0 comments on commit 987c799

Please sign in to comment.