Skip to content

Commit

Permalink
Add RecordConverter.toRecord(Schema, List<Object>) (#5849)
Browse files Browse the repository at this point in the history
  • Loading branch information
treo authored and Adam Gibson committed Jul 8, 2018
1 parent d73e036 commit 6f4b765
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 9 deletions.
Expand Up @@ -20,17 +20,19 @@
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import lombok.NonNull;
import org.datavec.api.timeseries.util.TimeSeriesWritableUtils;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.writable.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;

/**
* @author Adam Gibson
Expand Down Expand Up @@ -207,6 +209,78 @@ public static List<Writable> toRecord(INDArray array) {
return writables;
}

/**
* Convert a collection into a `List<Writable>`, i.e. a record that can be used with other datavec methods.
* Uses a schema to decide what kind of writable to use.
*
* @return a record
*/
public static List<Writable> toRecord(Schema schema, List<Object> source){
final List<Writable> record = new ArrayList<>(source.size());
final List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();

if(columnMetaData.size() != source.size()){
throw new IllegalArgumentException("Schema and source list don't have the same length!");
}

for (int i = 0; i < columnMetaData.size(); i++) {
final ColumnMetaData metaData = columnMetaData.get(i);
final Object data = source.get(i);
if(!metaData.isValid(data)){
throw new IllegalArgumentException("Element "+i+": "+data+" is not valid for Column \""+metaData.getName()+"\" ("+metaData.getColumnType()+")");
}

try {
final Writable writable;
switch (metaData.getColumnType().getWritableType()){
case Float:
writable = new FloatWritable((Float) data);
break;
case Double:
writable = new DoubleWritable((Double) data);
break;
case Int:
writable = new IntWritable((Integer) data);
break;
case Byte:
writable = new ByteWritable((Byte) data);
break;
case Boolean:
writable = new BooleanWritable((Boolean) data);
break;
case Long:
writable = new LongWritable((Long) data);
break;
case Null:
writable = new NullWritable();
break;
case Bytes:
writable = new BytesWritable((byte[]) data);
break;
case NDArray:
writable = new NDArrayWritable((INDArray) data);
break;
case Text:
if(data instanceof String)
writable = new Text((String) data);
else if(data instanceof Text)
writable = new Text((Text) data);
else if(data instanceof byte[])
writable = new Text((byte[]) data);
else
throw new IllegalArgumentException("Element "+i+": "+data+" is not usable for Column \""+metaData.getName()+"\" ("+metaData.getColumnType()+")");
break;
default:
throw new IllegalArgumentException("Element "+i+": "+data+" is not usable for Column \""+metaData.getName()+"\" ("+metaData.getColumnType()+")");
}
record.add(writable);
} catch (ClassCastException e) {
throw new IllegalArgumentException("Element "+i+": "+data+" is not usable for Column \""+metaData.getName()+"\" ("+metaData.getColumnType()+")", e);
}
}

return record;
}

/**
* Convert a DataSet to a matrix
Expand Down
Expand Up @@ -17,18 +17,16 @@
package org.datavec.api.writable;

import com.google.common.collect.Lists;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.collection.CollectionRecordReader;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.util.ndarray.RecordConverter;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.TimeZone;

import static org.junit.Assert.assertEquals;

Expand Down Expand Up @@ -106,4 +104,34 @@ public void testNDArrayWritableConcatToMatrix(){

assertEquals(exp, act);
}

@Test
public void testToRecordWithListOfObject(){
final List<Object> list = Arrays.asList((Object)3, 7.0f, "Foo", "Bar", 1.0, 3f, 3L, 7, 0L);
final Schema schema = new Schema.Builder()
.addColumnInteger("a")
.addColumnFloat("b")
.addColumnString("c")
.addColumnCategorical("d", "Bar", "Baz")
.addColumnDouble("e")
.addColumnFloat("f")
.addColumnLong("g")
.addColumnInteger("h")
.addColumnTime("i", TimeZone.getDefault())
.build();

final List<Writable> record = RecordConverter.toRecord(schema, list);

assertEquals(record.get(0).toInt(), 3);
assertEquals(record.get(1).toFloat(), 7f, 1e-6);
assertEquals(record.get(2).toString(), "Foo");
assertEquals(record.get(3).toString(), "Bar");
assertEquals(record.get(4).toDouble(), 1.0, 1e-6);
assertEquals(record.get(5).toFloat(), 3f, 1e-6);
assertEquals(record.get(6).toLong(), 3L);
assertEquals(record.get(7).toInt(), 7);
assertEquals(record.get(8).toLong(), 0);


}
}

0 comments on commit 6f4b765

Please sign in to comment.