Skip to content
This repository has been archived by the owner on Oct 8, 2019. It is now read-only.

Commit

Permalink
Supported a feature to load a prediction model ("-loadmodel" option) and
Browse files Browse the repository at this point in the history
modified to return covarinace in cw/arow/scw
  • Loading branch information
myui committed Jul 2, 2014
1 parent 0e536dc commit e7ae2f1
Show file tree
Hide file tree
Showing 13 changed files with 202 additions and 80 deletions.
31 changes: 20 additions & 11 deletions src/main/hivemall/LearnerBaseUDTF.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableFloatObjectInspector;
import hivemall.common.WeightValue;
import hivemall.common.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.OpenHashMap;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
Expand All @@ -33,16 +34,17 @@

import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableFloatObjectInspector;
import org.apache.hadoop.io.Text;

public abstract class LearnerBaseUDTF extends UDTFWithOptions {

protected boolean returnCovariance() {
return true;
return false;
}

protected void loadPredictionModel(OpenHashMap<Object, WeightValue> map, String filename, PrimitiveObjectInspector keyOI) {
Expand Down Expand Up @@ -72,6 +74,8 @@ private static void loadPredictionModel(OpenHashMap<Object, WeightValue> map, Fi
} else {
LazySimpleSerDe serde = HiveUtils.getKeyValueLineSerde(keyOI, valueOI);
StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector();
StructField keyRef = lineOI.getStructFieldRef("key");
StructField valueRef = lineOI.getStructFieldRef("value");

final BufferedReader reader = HadoopUtils.getBufferedReader(file);
try {
Expand All @@ -85,8 +89,8 @@ private static void loadPredictionModel(OpenHashMap<Object, WeightValue> map, Fi
if(f0 == null || f1 == null) {
continue; // avoid the case that key or value is null
}
Object k = ObjectInspectorUtils.copyToStandardObject(f0, keyOI);
float v = valueOI.get(f1);
Object k = ((PrimitiveObjectInspector) keyRef.getFieldObjectInspector()).getPrimitiveWritableObject(f0);
float v = ((FloatObjectInspector) valueRef.getFieldObjectInspector()).get(f1);
map.put(k, new WeightValue(v));
}
} finally {
Expand All @@ -104,11 +108,14 @@ private static void loadPredictionModel(OpenHashMap<Object, WeightValue> map, Fi
if(!file.getName().endsWith(".crc")) {
if(file.isDirectory()) {
for(File f : file.listFiles()) {
loadPredictionModel(map, f, keyOI, valueOI);
loadPredictionModel(map, f, keyOI, valueOI, covarOI);
}
} else {
LazySimpleSerDe serde = HiveUtils.getKeyValueLineSerde(keyOI, valueOI);
LazySimpleSerDe serde = HiveUtils.getLineSerde(keyOI, valueOI, covarOI);
StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector();
StructField c1ref = lineOI.getStructFieldRef("c1");
StructField c2ref = lineOI.getStructFieldRef("c2");
StructField c3ref = lineOI.getStructFieldRef("c3");

final BufferedReader reader = HadoopUtils.getBufferedReader(file);
try {
Expand All @@ -119,12 +126,14 @@ private static void loadPredictionModel(OpenHashMap<Object, WeightValue> map, Fi
List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj);
Object f0 = fields.get(0);
Object f1 = fields.get(1);
if(f0 == null || f1 == null) {
continue; // avoid the case that key or value is null
Object f2 = fields.get(2);
if(f0 == null || f1 == null || f2 == null) {
continue; // avoid unexpected case
}
Object k = ObjectInspectorUtils.copyToStandardObject(f0, keyOI);
float v = valueOI.get(f1);
map.put(k, new WeightValue(v));
Object k = ((PrimitiveObjectInspector) c1ref.getFieldObjectInspector()).getPrimitiveWritableObject(f0);
float v = ((FloatObjectInspector) c2ref.getFieldObjectInspector()).get(f1);
float cov = ((FloatObjectInspector) c3ref.getFieldObjectInspector()).get(f1);
map.put(k, new WeightValueWithCovar(v, cov));
}
} finally {
reader.close();
Expand Down
5 changes: 5 additions & 0 deletions src/main/hivemall/classifier/AROWClassifierUDTF.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
return super.initialize(argOIs);
}

@Override
protected boolean returnCovariance() {
return true;
}

@Override
protected Options getOptions() {
Options opts = super.getOptions();
Expand Down
63 changes: 43 additions & 20 deletions src/main/hivemall/classifier/BinaryOnlineClassifierUDTF.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
import static hivemall.HivemallConstants.INT_TYPE_NAME;
import static hivemall.HivemallConstants.STRING_TYPE_NAME;
import hivemall.LearnerBaseUDTF;
import hivemall.UDTFWithOptions;
import hivemall.common.FeatureValue;
import hivemall.common.PredictionResult;
import hivemall.common.WeightValue;
import hivemall.common.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.OpenHashMap;
import hivemall.utils.collections.OpenHashMap.IMapIterator;
import hivemall.utils.hadoop.HiveUtils;
Expand Down Expand Up @@ -94,18 +94,9 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
this.biasKey = null;
}

ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

fieldNames.add("feature");
ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI);
fieldOIs.add(featureOI);
fieldNames.add("weight");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);

this.weights = new OpenHashMap<Object, WeightValue>(16384);

return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
return getReturnOI(featureRawOI);
}

@Override
Expand Down Expand Up @@ -141,6 +132,23 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen
return cl;
}

protected StructObjectInspector getReturnOI(ObjectInspector featureRawOI) {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

fieldNames.add("feature");
ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI);
fieldOIs.add(featureOI);
fieldNames.add("weight");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
if(returnCovariance()) {
fieldNames.add("covar");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
}

return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}

@Override
public void process(Object[] args) throws HiveException {
List<?> features = (List<?>) featureListOI.getList(args[0]);
Expand Down Expand Up @@ -318,15 +326,30 @@ protected void update(final List<?> features, final float coeff) {
@Override
public void close() throws HiveException {
if(weights != null) {
final Object[] forwardMapObj = new Object[2];
IMapIterator<Object, WeightValue> itor = weights.entries();
while(itor.next() != -1) {
Object k = itor.unsafeGetAndFreeKey();
WeightValue v = itor.unsafeGetAndFreeValue();
FloatWritable fv = new FloatWritable(v.getValue());
forwardMapObj[0] = k;
forwardMapObj[1] = fv;
forward(forwardMapObj);
if(returnCovariance()) {
final Object[] forwardMapObj = new Object[3];
IMapIterator<Object, WeightValue> itor = weights.entries();
while(itor.next() != -1) {
Object k = itor.unsafeGetAndFreeKey();
WeightValueWithCovar v = (WeightValueWithCovar) itor.unsafeGetAndFreeValue();
FloatWritable fv = new FloatWritable(v.get());
FloatWritable cov = new FloatWritable(v.getCovariance());
forwardMapObj[0] = k;
forwardMapObj[1] = fv;
forwardMapObj[2] = cov;
forward(forwardMapObj);
}
} else {
final Object[] forwardMapObj = new Object[2];
IMapIterator<Object, WeightValue> itor = weights.entries();
while(itor.next() != -1) {
Object k = itor.unsafeGetAndFreeKey();
WeightValue v = itor.unsafeGetAndFreeValue();
FloatWritable fv = new FloatWritable(v.get());
forwardMapObj[0] = k;
forwardMapObj[1] = fv;
forward(forwardMapObj);
}
}
this.weights = null;
}
Expand Down
5 changes: 5 additions & 0 deletions src/main/hivemall/classifier/ConfidenceWeightedUDTF.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
return super.initialize(argOIs);
}

@Override
protected boolean returnCovariance() {
return true;
}

@Override
protected Options getOptions() {
Options opts = super.getOptions();
Expand Down
5 changes: 5 additions & 0 deletions src/main/hivemall/classifier/SoftConfideceWeightedUDTF.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
return super.initialize(argOIs);
}

@Override
protected boolean returnCovariance() {
return true;
}

@Override
protected Options getOptions() {
Options opts = super.getOptions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
return super.initialize(argOIs);
}

@Override
protected boolean returnCovariance() {
return true;
}

@Override
protected Options getOptions() {
Options opts = super.getOptions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
return super.initialize(argOIs);
}

@Override
protected boolean returnCovariance() {
return true;
}

@Override
protected Options getOptions() {
Options opts = super.getOptions();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import hivemall.common.Margin;
import hivemall.common.PredictionResult;
import hivemall.common.WeightValue;
import hivemall.common.WeightValue.WeightValueWithCovar;
import hivemall.utils.collections.OpenHashMap;
import hivemall.utils.collections.OpenHashMap.IMapIterator;

Expand Down Expand Up @@ -99,21 +100,9 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
this.biasKey = null;
}

ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

fieldNames.add("label");
ObjectInspector labelOI = ObjectInspectorUtils.getStandardObjectInspector(labelRawOI);
fieldOIs.add(labelOI);
fieldNames.add("feature");
ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI);
fieldOIs.add(featureOI);
fieldNames.add("weight");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);

this.label2FeatureWeight = new HashMap<Object, OpenHashMap<Object, WeightValue>>(64);

return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
return getReturnOI(labelRawOI, featureRawOI);
}

@Override
Expand Down Expand Up @@ -149,6 +138,26 @@ protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumen
return cl;
}

protected StructObjectInspector getReturnOI(ObjectInspector labelRawOI, ObjectInspector featureRawOI) {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

fieldNames.add("label");
ObjectInspector labelOI = ObjectInspectorUtils.getStandardObjectInspector(labelRawOI);
fieldOIs.add(labelOI);
fieldNames.add("feature");
ObjectInspector featureOI = ObjectInspectorUtils.getStandardObjectInspector(featureRawOI);
fieldOIs.add(featureOI);
fieldNames.add("weight");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
if(returnCovariance()) {
fieldNames.add("covar");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
}

return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}

@Override
public void process(Object[] args) throws HiveException {
List<?> features = (List<?>) featureListOI.getList(args[0]);
Expand Down Expand Up @@ -438,19 +447,39 @@ protected void update(List<?> features, float coeff, Object actual_label, Object
@Override
public void close() throws HiveException {
if(label2FeatureWeight != null) {
final Object[] forwardMapObj = new Object[3];
for(Map.Entry<Object, OpenHashMap<Object, WeightValue>> label2map : label2FeatureWeight.entrySet()) {
Object label = label2map.getKey();
forwardMapObj[0] = label;
OpenHashMap<Object, WeightValue> fvmap = label2map.getValue();
IMapIterator<Object, WeightValue> fvmapItor = fvmap.entries();
while(fvmapItor.next() != -1) {
Object k = fvmapItor.unsafeGetAndFreeKey();
WeightValue v = fvmapItor.unsafeGetAndFreeValue();
FloatWritable fv = new FloatWritable(v.getValue());
forwardMapObj[1] = k;
forwardMapObj[2] = fv;
forward(forwardMapObj);
if(returnCovariance()) {
final Object[] forwardMapObj = new Object[4];
for(Map.Entry<Object, OpenHashMap<Object, WeightValue>> label2map : label2FeatureWeight.entrySet()) {
Object label = label2map.getKey();
forwardMapObj[0] = label;
OpenHashMap<Object, WeightValue> fvmap = label2map.getValue();
IMapIterator<Object, WeightValue> fvmapItor = fvmap.entries();
while(fvmapItor.next() != -1) {
Object k = fvmapItor.unsafeGetAndFreeKey();
WeightValueWithCovar v = (WeightValueWithCovar) fvmapItor.unsafeGetAndFreeValue();
FloatWritable fv = new FloatWritable(v.getValue());
FloatWritable cov = new FloatWritable(v.getCovariance());
forwardMapObj[1] = k;
forwardMapObj[2] = fv;
forwardMapObj[3] = cov;
forward(forwardMapObj);
}
}
} else {
final Object[] forwardMapObj = new Object[3];
for(Map.Entry<Object, OpenHashMap<Object, WeightValue>> label2map : label2FeatureWeight.entrySet()) {
Object label = label2map.getKey();
forwardMapObj[0] = label;
OpenHashMap<Object, WeightValue> fvmap = label2map.getValue();
IMapIterator<Object, WeightValue> fvmapItor = fvmap.entries();
while(fvmapItor.next() != -1) {
Object k = fvmapItor.unsafeGetAndFreeKey();
WeightValue v = fvmapItor.unsafeGetAndFreeValue();
FloatWritable fv = new FloatWritable(v.getValue());
forwardMapObj[1] = k;
forwardMapObj[2] = fv;
forward(forwardMapObj);
}
}
}
this.label2FeatureWeight = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
return super.initialize(argOIs);
}

@Override
protected boolean returnCovariance() {
return true;
}

@Override
protected Options getOptions() {
Options opts = super.getOptions();
Expand Down
5 changes: 5 additions & 0 deletions src/main/hivemall/regression/AROWRegressionUDTF.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgu
return super.initialize(argOIs);
}

@Override
protected boolean returnCovariance() {
return true;
}

@Override
protected Options getOptions() {
Options opts = super.getOptions();
Expand Down
Loading

0 comments on commit e7ae2f1

Please sign in to comment.