Skip to content

Commit

Permalink
#7561 Add NDArrayText(De)Serializer to replace RowVector(De)Serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDBlack committed Apr 19, 2019
1 parent e656c16 commit 68cedd4
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 1 deletion.
Expand Up @@ -34,7 +34,7 @@ public class NDArrayDeSerializer extends JsonDeserializer<INDArray> {
public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException {
JsonNode node = jp.getCodec().readTree(jp);
String field = node.get("array").asText();
INDArray ret = Nd4jBase64.fromBase64(field.toString());
INDArray ret = Nd4jBase64.fromBase64(field);
return ret;

}
Expand Down
@@ -0,0 +1,98 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.nd4j.serde.jackson.shaded;

import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.core.JsonParser;
import org.nd4j.shade.jackson.databind.DeserializationContext;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonNode;
import org.nd4j.shade.jackson.databind.node.ArrayNode;

import java.io.IOException;
import java.util.Iterator;

/**
* @author Adam Gibson
*/

public class NDArrayTextDeSerializer extends JsonDeserializer<INDArray> {
@Override
public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException {
JsonNode n = jp.getCodec().readTree(jp);
String dtype = n.get("dataType").asText();
DataType dt = DataType.valueOf(dtype);
ArrayNode shapeNode = (ArrayNode)n.get("shape");
long[] shape = new long[shapeNode.size()];
for( int i=0; i<shape.length; i++ ){
shape[i] = shapeNode.get(i).asLong();
}
ArrayNode dataNode = (ArrayNode)n.get("data");
Iterator<JsonNode> iter = dataNode.elements();
int i=0;
INDArray arr;
switch (dt){
case DOUBLE:
double[] d = new double[dataNode.size()];
while(iter.hasNext())
d[i++] = iter.next().asDouble();
arr = Nd4j.create(d, shape);
break;
case FLOAT:
case HALF:
float[] f = new float[dataNode.size()];
while(iter.hasNext())
f[i++] = iter.next().floatValue();
arr = Nd4j.create(f, shape).castTo(dt);
break;
case LONG:
long[] l = new long[dataNode.size()];
while(iter.hasNext())
l[i++] = iter.next().longValue();
arr = Nd4j.createFromArray(l).reshape('c', shape);
break;
case INT:
case SHORT:
case UBYTE:
int[] a = new int[dataNode.size()];
while(iter.hasNext())
a[i++] = iter.next().intValue();
arr = Nd4j.createFromArray(a).reshape('c', shape).castTo(dt);
break;
case BYTE:
case BOOL:
byte[] b = new byte[dataNode.size()];
while(iter.hasNext())
b[i++] = (byte)iter.next().intValue();
arr = Nd4j.createFromArray(b).reshape('c', shape).castTo(dt);
break;
case UTF8:
String[] s = new String[dataNode.size()];
while(iter.hasNext())
s[i++] = iter.next().asText();
arr = Nd4j.create(s);
break;
case COMPRESSED:
case UNKNOWN:
default:
throw new RuntimeException("Unknown datatype: " + dt);
}
return arr;
}
}
@@ -0,0 +1,89 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/

package org.nd4j.serde.jackson.shaded;


import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.shade.jackson.core.JsonGenerator;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.SerializerProvider;

import java.io.IOException;

/**
* @author Alex Black
*/
public class NDArrayTextSerializer extends JsonSerializer<INDArray> {
@Override
public void serialize(INDArray arr, JsonGenerator jg, SerializerProvider serializerProvider)
throws IOException {
jg.writeStartObject();
jg.writeStringField("dataType", arr.dataType().toString());
jg.writeArrayFieldStart("shape");
for( int i=0; i<arr.rank(); i++ ){
jg.writeNumber(arr.size(i));
}
jg.writeEndArray();
jg.writeArrayFieldStart("data");

if(arr.isView() || arr.ordering() != 'c' || !Shape.hasDefaultStridesForShape(arr) || arr.isCompressed())
arr = arr.dup('c');

switch (arr.dataType()){
case DOUBLE:
double[] d = arr.data().asDouble();
for( double v : d )
jg.writeNumber(v);
break;
case FLOAT:
case HALF:
float[] f = arr.data().asFloat();
for( float v : f )
jg.writeNumber(v);
break;
case LONG:
long[] l = arr.data().asLong();
for( long v : l )
jg.writeNumber(v);
break;
case INT:
case SHORT:
case UBYTE:
int[] i = arr.data().asInt();
for( int v : i )
jg.writeNumber(v);
break;
case BYTE:
case BOOL:
byte[] b = arr.data().asBytes();
for( byte v : b )
jg.writeNumber(v);
break;
case UTF8:
String[] str = new String[(int)arr.length()];
for( int j=0; j<str.length; j++ )
jg.writeString(arr.getStringUnsafe(j));
break;
case COMPRESSED:
case UNKNOWN:
throw new UnsupportedOperationException("Cannot JSON serialize array with datatype: " + arr.dataType());
}
jg.writeEndArray();
}
}
@@ -0,0 +1,54 @@
package org.nd4j.linalg.serde;

import lombok.AllArgsConstructor;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

import static org.junit.Assert.assertEquals;

public class JsonSerdeTests {


@Test
public void testNDArrayTextSerializer() throws Exception {

Nd4j.getRandom().setSeed(12345);
INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4).muli(20).subi(10);

ObjectMapper om = new ObjectMapper();

for(DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT,
DataType.BYTE, DataType.UBYTE, DataType.BOOL}){

INDArray arr = in.castTo(dt);

TestClass tc = new TestClass(arr);

String s = om.writeValueAsString(tc);
System.out.println(dt);
System.out.println(s);
System.out.println("\n\n\n");

INDArray deserialized = om.readValue(s, INDArray.class);
assertEquals(dt.toString(), arr, deserialized);
}

}

@AllArgsConstructor
public static class TestClass {

@JsonDeserialize(using = NDArrayTextDeSerializer.class)
@JsonSerialize(using = NDArrayTextSerializer.class)
public INDArray arr;

}

}

0 comments on commit 68cedd4

Please sign in to comment.