Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
Added bpr_sampling, train_bprmf, and bprmf_predict
  • Loading branch information
myui committed Apr 1, 2016
1 parent a5aed1c commit d427bd8
Show file tree
Hide file tree
Showing 10 changed files with 1,142 additions and 41 deletions.
4 changes: 4 additions & 0 deletions core/src/main/java/hivemall/common/ConversionState.java
Expand Up @@ -69,6 +69,10 @@ public void incrLoss(double loss) {
public void multiplyLoss(double multi) {
this.currLosses = currLosses * multi;
}

public boolean isLossIncreased() {
return currLosses > prevLosses;
}

public boolean isConverged(final int iter, final long obserbedTrainingExamples) {
if(conversionCheck == false) {
Expand Down
33 changes: 29 additions & 4 deletions core/src/main/java/hivemall/common/EtaEstimator.java
Expand Up @@ -18,6 +18,7 @@
*/
package hivemall.common;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

Expand Down Expand Up @@ -57,7 +58,7 @@ public SimpleEtaEstimator(double eta0, long total_steps) {

@Override
public float eta(final long t) {
if(t > total_steps) {
if (t > total_steps) {
return finalEta;
}
return (float) (eta0 / (1.d + (t / total_steps)));
Expand All @@ -82,20 +83,44 @@ public float eta(final long t) {

}

/**
* bold driver: Gemulla et al., Large-scale matrix factorization with distributed stochastic gradient descent, KDD 2011.
*/
public static final class AdjustingEtaEstimator extends EtaEstimator {

private final float eta0;
private float eta;

public AdjustingEtaEstimator(float eta) {
this.eta0 = eta;
this.eta = eta;
}

public void update(@Nonnegative float multipler) {
this.eta = Math.max(eta0, eta * multipler);
}

@Override
public float eta(long t) {
return eta;
}

}

@Nonnull
public static EtaEstimator get(@Nullable CommandLine cl) throws UDFArgumentException {
if(cl == null) {
if (cl == null) {
return new InvscalingEtaEstimator(0.1f, 0.1f);
}

String etaValue = cl.getOptionValue("eta");
if(etaValue != null) {
if (etaValue != null) {
float eta = Float.parseFloat(etaValue);
return new FixedEtaEstimator(eta);
}

double eta0 = Double.parseDouble(cl.getOptionValue("eta0", "0.1"));
if(cl.hasOption("t")) {
if (cl.hasOption("t")) {
long t = Long.parseLong(cl.getOptionValue("t"));
return new SimpleEtaEstimator(eta0, t);
}
Expand Down
232 changes: 232 additions & 0 deletions core/src/main/java/hivemall/ftvec/sampling/BprSamplingUDTF.java
@@ -0,0 +1,232 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hivemall.ftvec.sampling;

import hivemall.utils.hadoop.HiveUtils;

import java.util.ArrayList;
import java.util.BitSet;
import java.util.Random;

import javax.annotation.Nonnull;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.IntWritable;

@Description(
name = "bpr_sampling",
value = "_FUNC_(array<int> pos_items, const int max_item_id [, const double sampling_rates=1.0, const boolean withReplacement=true])"
+ "- Returns a relation consists of <int pos_item_id, int neg_item_id>")
public final class BprSamplingUDTF extends GenericUDTF {

private ListObjectInspector listOI;
private PrimitiveObjectInspector listElemOI;
private int maxItemId;
private float samplingRate;
private boolean withoutReplacement;

private Object[] forwardObjs;
private IntWritable posItemId;
private IntWritable negItemId;

private BitSet bitset;
private Random rand;

public BprSamplingUDTF() {}

@Override
public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
if (argOIs.length < 2 || argOIs.length > 4) {
throw new UDFArgumentException("bpr_sampling(array<long>, const long max_item_id "
+ "[, const double sampling_rate, const boolean withoutReplacement=false])"
+ " takes at least two arguments");
}
this.listOI = HiveUtils.asListOI(argOIs[0]);
this.listElemOI = HiveUtils.asPrimitiveObjectInspector(listOI.getListElementObjectInspector());

this.maxItemId = HiveUtils.getAsConstInt(argOIs[1]);
if (maxItemId <= 0) {
throw new UDFArgumentException("maxItemId MUST be greater than 0: " + maxItemId);
}

if (argOIs.length == 4) {
this.withoutReplacement = HiveUtils.getConstBoolean(argOIs[3]);
} else {
this.withoutReplacement = false;
}

float rate = 1.f;
if (argOIs.length >= 3) {
rate = HiveUtils.getAsConstFloat(argOIs[2]);
if (rate <= 0.f) {
throw new UDFArgumentException("sampling_rate MUST be greater than 0: " + rate);
}
if (withoutReplacement && rate > 1.f) {
throw new UDFArgumentException(
"sampling_rate MUST be in less than or equals to 1 where withoutReplacement is true: "
+ rate);
}
}
this.samplingRate = rate;

this.posItemId = new IntWritable();
this.negItemId = new IntWritable();
this.forwardObjs = new Object[] {posItemId, negItemId};

ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
fieldNames.add("pos_item");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
fieldNames.add("neg_item");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);

return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}

@Override
public void process(Object[] args) throws HiveException {
if (bitset == null) {
this.bitset = new BitSet();
this.rand = new Random(43);
} else {
bitset.clear();
}

final int numPosItems = HiveUtils.setBits(args[0], listOI, listElemOI, bitset);
if (numPosItems == 0) {
return;
}
final int numNegItems = maxItemId + 1 - numPosItems;
if (numNegItems == 0) {
return;
} else if (numNegItems < 0) {
throw new UDFArgumentException("maxItemId + 1 - numPosItems = " + maxItemId + " + 1 - "
+ numPosItems + " = " + numNegItems);
}

if (withoutReplacement) {
sampleWithoutReplacement(numPosItems, numNegItems, bitset);
} else {
sampleWithReplacement(numPosItems, numNegItems, bitset);
}
}

private void sampleWithoutReplacement(int numPosItems, int numNegItems,
@Nonnull final BitSet bitset) throws HiveException {
final BitSet bitsetForPosSampling = bitset;
final BitSet bitsetForNegSampling = new BitSet();
bitsetForPosSampling.or(bitset);

final int numSamples = Math.max(1, Math.round(numPosItems * samplingRate));
for (int s = 0; s < numSamples; s++) {
int nth = rand.nextInt(numPosItems);

int i = bitsetForPosSampling.nextSetBit(0);
for (int c = 0; i >= 0; i = bitsetForPosSampling.nextSetBit(i + 1), c++) {
if (c == nth) {
break;
}
}
if (i == -1) {
throw new UDFArgumentException("Illegal i value: " + i);
}
bitsetForPosSampling.set(i, false);
--numPosItems;

nth = rand.nextInt(numNegItems);
int j = bitsetForNegSampling.nextClearBit(0);
for (int c = 0; j <= maxItemId; j = bitsetForNegSampling.nextClearBit(j + 1), c++) {
if (c == nth) {
break;
}
}
if (j < 0 || j > maxItemId) {
throw new UDFArgumentException("j MUST be in [0," + maxItemId + "] but j was " + j);
}
bitsetForNegSampling.set(j, true);
--numNegItems;

posItemId.set(i);
negItemId.set(j);
forward(forwardObjs);

if (numPosItems <= 0) {
// cannot draw a positive example anymore
return;
} else if (numNegItems <= 0) {
// cannot draw a negative example anymore
return;
}
}
}

private void sampleWithReplacement(final int numPosItems, final int numNegItems,
@Nonnull final BitSet bitset) throws HiveException {
final int numSamples = Math.max(1, Math.round(numPosItems * samplingRate));
for (int s = 0; s < numSamples; s++) {
int nth = rand.nextInt(numPosItems);

int i = bitset.nextSetBit(0);
for (int c = 0; i >= 0; i = bitset.nextSetBit(i + 1), c++) {
if (c == nth) {
break;
}
}
if (i == -1) {
throw new UDFArgumentException("Illegal i value: " + i);
}

nth = rand.nextInt(numNegItems);
int j = bitset.nextClearBit(0);
for (int c = 0; j <= maxItemId; j = bitset.nextClearBit(j + 1), c++) {
if (c == nth) {
break;
}
}
if (j < 0 || j > maxItemId) {
throw new UDFArgumentException("j MUST be in [0," + maxItemId + "] but j was " + j);
}

posItemId.set(i);
negItemId.set(j);
forward(forwardObjs);
}
}

@Override
public void close() throws HiveException {
this.listOI = null;
this.listElemOI = null;
this.forwardObjs = null;
this.posItemId = null;
this.negItemId = null;
this.bitset = null;
this.rand = null;
}

}
73 changes: 73 additions & 0 deletions core/src/main/java/hivemall/mf/BPRMFPredictionUDF.java
@@ -0,0 +1,73 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
* Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology (AIST)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hivemall.mf;

import java.util.List;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.io.FloatWritable;

@Description(
name = "bprmf_predict",
value = "_FUNC_(List<Float> Pu, List<Float> Qi[, double Bi]) - Returns the prediction value")
@UDFType(deterministic = true, stateful = false)
public final class BPRMFPredictionUDF extends UDF {

public FloatWritable evaluate(List<Float> Pu, List<Float> Qi) throws HiveException {
return evaluate(Pu, Qi, 0.d);
}

public FloatWritable evaluate(List<Float> Pu, List<Float> Qi, double Bi) throws HiveException {
if (Pu == null && Qi == null) {
return new FloatWritable(0.f);
}
if (Pu == null) {
return new FloatWritable((float) Bi);
} else if (Qi == null) {
return new FloatWritable(0.f);
}

final int PuSize = Pu.size();
final int QiSize = Qi.size();
// workaround for TD
if (PuSize == 0) {
if (QiSize == 0) {
return new FloatWritable(0.f);
} else {
return new FloatWritable((float) Bi);
}
} else if (QiSize == 0) {
return new FloatWritable(0.f);
}

if (QiSize != PuSize) {
throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
}

float ret = (float) Bi;
for (int k = 0; k < PuSize; k++) {
ret += Pu.get(k) * Qi.get(k);
}
return new FloatWritable(ret);
}

}

0 comments on commit d427bd8

Please sign in to comment.