Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Fixed a corner case bug in mf_predict
Browse files Browse the repository at this point in the history
  • Loading branch information
myui committed Jun 27, 2016
1 parent fee7efa commit 4282b6a
Showing 1 changed file with 58 additions and 27 deletions.
85 changes: 58 additions & 27 deletions core/src/main/java/hivemall/mf/MFPredictionUDF.java
Expand Up @@ -20,10 +20,14 @@

import java.util.List;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.io.FloatWritable;

@Description(
Expand All @@ -32,77 +36,104 @@
@UDFType(deterministic = true, stateful = false)
public final class MFPredictionUDF extends UDF {

public FloatWritable evaluate(List<Float> Pu, List<Float> Qi) throws HiveException {
return evaluate(Pu, Qi, 0.d);
@Nonnull
public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
@Nullable List<FloatWritable> Qi) throws HiveException {
return evaluate(Pu, Qi, null);
}

public FloatWritable evaluate(List<Float> Pu, List<Float> Qi, double mu) throws HiveException {
@Nonnull
public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
@Nullable List<FloatWritable> Qi, @Nullable DoubleWritable mu) throws HiveException {
final double muValue = (mu == null) ? 0.d : mu.get();
if (Pu == null || Qi == null) {
return new FloatWritable((float) mu);
return new DoubleWritable(muValue);
}

final int PuSize = Pu.size();
final int QiSize = Qi.size();
// workaround for TD
if (PuSize == 0) {
return new FloatWritable((float) mu);
return new DoubleWritable(muValue);
} else if (QiSize == 0) {
return new FloatWritable((float) mu);
return new DoubleWritable(muValue);
}

if (QiSize != PuSize) {
throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
}

float ret = (float) mu;
double ret = muValue;
for (int k = 0; k < PuSize; k++) {
ret += Pu.get(k) * Qi.get(k);
FloatWritable Pu_k = Pu.get(k);
if (Pu_k == null) {
continue;
}
FloatWritable Qi_k = Qi.get(k);
if (Qi_k == null) {
continue;
}
ret += Pu_k.get() * Qi_k.get();
}
return new FloatWritable(ret);
return new DoubleWritable(ret);
}

public FloatWritable evaluate(List<Float> Pu, List<Float> Qi, double Bu, double Bi)
throws HiveException {
return evaluate(Pu, Qi, Bu, Bi, 0.d);
@Nonnull
public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
@Nullable List<FloatWritable> Qi, @Nullable DoubleWritable Bu,
@Nullable DoubleWritable Bi) throws HiveException {
return evaluate(Pu, Qi, Bu, Bi, null);
}

public FloatWritable evaluate(List<Float> Pu, List<Float> Qi, double Bu, double Bi, double mu)
throws HiveException {
@Nonnull
public DoubleWritable evaluate(@Nullable List<FloatWritable> Pu,
@Nullable List<FloatWritable> Qi, @Nullable DoubleWritable Bu,
@Nullable DoubleWritable Bi, @Nullable DoubleWritable mu) throws HiveException {
final double muValue = (mu == null) ? 0.d : mu.get();
if (Pu == null && Qi == null) {
return new FloatWritable((float) mu);
return new DoubleWritable(muValue);
}
final double BiValue = (Bi == null) ? 0.d : Bi.get();
final double BuValue = (Bu == null) ? 0.d : Bu.get();
if (Pu == null) {
float ret = (float) (mu + Bi);
return new FloatWritable(ret);
double ret = muValue + BiValue;
return new DoubleWritable(ret);
} else if (Qi == null) {
float ret = (float) (mu + Bu);
return new FloatWritable(ret);
return new DoubleWritable(muValue);
}

final int PuSize = Pu.size();
final int QiSize = Qi.size();
// workaround for TD
if (PuSize == 0) {
if (QiSize == 0) {
return new FloatWritable((float) mu);
return new DoubleWritable(muValue);
} else {
float ret = (float) (mu + Bi);
return new FloatWritable(ret);
double ret = muValue + BiValue;
return new DoubleWritable(ret);
}
} else if (QiSize == 0) {
float ret = (float) (mu + Bu);
return new FloatWritable(ret);
double ret = muValue + BuValue;
return new DoubleWritable(ret);
}

if (QiSize != PuSize) {
throw new HiveException("|Pu| " + PuSize + " was not equal to |Qi| " + QiSize);
}

float ret = (float) (mu + Bu + Bi);
double ret = muValue + BuValue + BiValue;
for (int k = 0; k < PuSize; k++) {
ret += Pu.get(k) * Qi.get(k);
FloatWritable Pu_k = Pu.get(k);
if (Pu_k == null) {
continue;
}
FloatWritable Qi_k = Qi.get(k);
if (Qi_k == null) {
continue;
}
ret += Pu_k.get() * Qi_k.get();
}
return new FloatWritable(ret);
return new DoubleWritable(ret);
}

}

0 comments on commit 4282b6a

Please sign in to comment.