Skip to content

Commit

Permalink
Merge pull request #4870 from deeplearning4j/sa_kerasfix
Browse files Browse the repository at this point in the history
Fix memory corruption when using Hdf5Archive
  • Loading branch information
AlexDBlack committed Mar 29, 2018
2 parents 5759458 + 0a0c419 commit e2f8b09
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,23 @@
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;

import java.io.Closeable;
import java.io.IOException;
import java.lang.Exception;
import java.util.ArrayList;
import java.util.List;

import static org.bytedeco.javacpp.hdf5.*;

/**
* Class for reading ND4J arrays and JSON strings from HDF5
* achive files.
* Class for reading ND4J arrays and JSON strings from HDF5 archive files.
*
* HDF5 is <i>really</i> sensitive about the order its resources are deallocated in.
* Make sure to <b>ALWAYS</b> call {@link #close()} explicitly or with try-with-resources,
* or it might decide to crash the JVM.
*
* @author dave@skymind.io
*/
@Slf4j
public class Hdf5Archive {
public class Hdf5Archive implements Closeable {

static {
try {
Expand All @@ -55,10 +57,14 @@ public class Hdf5Archive {
}

private hdf5.H5File file;
private hdf5.DataType dataType = new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT());
private static hdf5.DataType dataType = new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT());

public Hdf5Archive(String archiveFilename) {
this.file = new hdf5.H5File(archiveFilename, H5F_ACC_RDONLY());
this.file = new hdf5.H5File(archiveFilename, hdf5.H5F_ACC_RDONLY());
}

@Override public void close() {
file.deallocate();
}

private hdf5.Group[] openGroups(String... groups) {
Expand Down Expand Up @@ -103,10 +109,16 @@ public INDArray readDataSet(String datasetName, String... groups) throws Unsuppo
*/
public String readAttributeAsJson(String attributeName, String... groups)
throws UnsupportedKerasConfigurationException {
if (groups.length == 0)
return readAttributeAsJson(this.file.openAttribute(attributeName));
if (groups.length == 0) {
hdf5.Attribute a = this.file.openAttribute(attributeName);
String s = readAttributeAsJson(a);
a.deallocate();
return s;
}
hdf5.Group[] groupArray = openGroups(groups);
String s = readAttributeAsJson(groupArray[groups.length - 1].openAttribute(attributeName));
hdf5.Attribute a = groupArray[groups.length - 1].openAttribute(attributeName);
String s = readAttributeAsJson(a);
a.deallocate();
closeGroups(groupArray);
return s;
}
Expand All @@ -121,10 +133,16 @@ public String readAttributeAsJson(String attributeName, String... groups)
*/
public String readAttributeAsString(String attributeName, String... groups)
throws UnsupportedKerasConfigurationException {
if (groups.length == 0)
return readAttributeAsString(this.file.openAttribute(attributeName));
if (groups.length == 0) {
hdf5.Attribute a = this.file.openAttribute(attributeName);
String s = readAttributeAsString(a);
a.deallocate();
return s;
}
hdf5.Group[] groupArray = openGroups(groups);
String s = readAttributeAsString(groupArray[groupArray.length - 1].openAttribute(attributeName));
hdf5.Attribute a = groupArray[groups.length - 1].openAttribute(attributeName);
String s = readAttributeAsString(a);
a.deallocate();
closeGroups(groupArray);
return s;
}
Expand Down Expand Up @@ -153,9 +171,9 @@ public boolean hasAttribute(String attributeName, String... groups) {
*/
public List<String> getDataSets(String... groups) {
if (groups.length == 0)
return getObjects(this.file, H5O_TYPE_DATASET);
return getObjects(this.file, hdf5.H5O_TYPE_DATASET);
hdf5.Group[] groupArray = openGroups(groups);
List<String> ls = getObjects(groupArray[groupArray.length - 1], H5O_TYPE_DATASET);
List<String> ls = getObjects(groupArray[groupArray.length - 1], hdf5.H5O_TYPE_DATASET);
closeGroups(groupArray);
return ls;
}
Expand All @@ -168,9 +186,9 @@ public List<String> getDataSets(String... groups) {
*/
public List<String> getGroups(String... groups) {
if (groups.length == 0)
return getObjects(this.file, H5O_TYPE_GROUP);
return getObjects(this.file, hdf5.H5O_TYPE_GROUP);
hdf5.Group[] groupArray = openGroups(groups);
List<String> ls = getObjects(groupArray[groupArray.length - 1], H5O_TYPE_GROUP);
List<String> ls = getObjects(groupArray[groupArray.length - 1], hdf5.H5O_TYPE_GROUP);
closeGroups(groupArray);
return ls;
}
Expand Down Expand Up @@ -304,6 +322,7 @@ private String readAttributeAsJson(hdf5.Attribute attribute) throws UnsupportedK
throw new UnsupportedKerasConfigurationException("Could not read abnormally long HDF5 attribute");
}
}
vl.deallocate();
return s;
}

Expand Down Expand Up @@ -342,7 +361,7 @@ private String readAttributeAsString(hdf5.Attribute attribute) throws Unsupporte
throw new UnsupportedKerasConfigurationException("Could not read abnormally long HDF5 attribute");
}
}

vl.deallocate();
return s;
}

Expand All @@ -356,7 +375,10 @@ private String readAttributeAsString(hdf5.Attribute attribute) throws Unsupporte
*/
public String readAttributeAsFixedLengthString(String attributeName, int bufferSize)
throws UnsupportedKerasConfigurationException {
return readAttributeAsFixedLengthString(this.file.openAttribute(attributeName), bufferSize);
hdf5.Attribute a = this.file.openAttribute(attributeName);
String s = readAttributeAsFixedLengthString(a, bufferSize);
a.deallocate();
return s;
}

/**
Expand All @@ -373,6 +395,7 @@ private String readAttributeAsFixedLengthString(hdf5.Attribute attribute, int bu
BytePointer attrPointer = new BytePointer(attrBuffer);
attribute.read(vl, attrPointer);
attrPointer.get(attrBuffer);
vl.deallocate();
return new String(attrBuffer);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import org.nd4j.shade.jackson.databind.ObjectMapper;

import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Map;

@Data
public class KerasModelBuilder implements Cloneable {
public class KerasModelBuilder implements Cloneable, Closeable {
protected String modelJson = null;
protected String modelYaml = null;
protected String trainingJson = null;
Expand Down Expand Up @@ -81,28 +82,33 @@ public KerasModelBuilder trainingJsonInputStream(InputStream trainingJsonInputSt

public KerasModelBuilder modelHdf5Filename(String modelHdf5Filename)
throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException, IOException {
this.weightsArchive = this.trainingArchive = new Hdf5Archive(modelHdf5Filename);
this.weightsRoot = config.getTrainingWeightsRoot();
if (!this.weightsArchive.hasAttribute(config.getTrainingModelConfigAttribute()))
throw new InvalidKerasConfigurationException(
"Model configuration attribute missing from " + modelHdf5Filename + " archive.");
String initialModelJson = this.weightsArchive.readAttributeAsJson(
config.getTrainingModelConfigAttribute());

String kerasVersion = this.weightsArchive.readAttributeAsFixedLengthString(
config.getFieldKerasVersion(), 5);
Map<String, Object> modelMapper = KerasModelUtils.parseJsonString(initialModelJson);
modelMapper.put(config.getFieldKerasVersion(), kerasVersion);

int majorKerasVersion = Character.getNumericValue(kerasVersion.charAt(0));
if (majorKerasVersion == 2) {
String backend = this.weightsArchive.readAttributeAsString(config.getFieldBackend());
modelMapper.put(config.getFieldBackend(), backend);
try {
this.weightsArchive = this.trainingArchive = new Hdf5Archive(modelHdf5Filename);
this.weightsRoot = config.getTrainingWeightsRoot();
if (!this.weightsArchive.hasAttribute(config.getTrainingModelConfigAttribute()))
throw new InvalidKerasConfigurationException(
"Model configuration attribute missing from " + modelHdf5Filename + " archive.");
String initialModelJson = this.weightsArchive.readAttributeAsJson(
config.getTrainingModelConfigAttribute());

String kerasVersion = this.weightsArchive.readAttributeAsFixedLengthString(
config.getFieldKerasVersion(), 5);
Map<String, Object> modelMapper = KerasModelUtils.parseJsonString(initialModelJson);
modelMapper.put(config.getFieldKerasVersion(), kerasVersion);

int majorKerasVersion = Character.getNumericValue(kerasVersion.charAt(0));
if (majorKerasVersion == 2) {
String backend = this.weightsArchive.readAttributeAsString(config.getFieldBackend());
modelMapper.put(config.getFieldBackend(), backend);
}

this.modelJson = new ObjectMapper().writeValueAsString(modelMapper);
if (this.trainingArchive.hasAttribute(config.getTrainingTrainingConfigAttribute()))
this.trainingJson = this.trainingArchive.readAttributeAsJson(config.getTrainingTrainingConfigAttribute());
} catch (Throwable t) {
close();
throw t;
}

this.modelJson = new ObjectMapper().writeValueAsString(modelMapper);
if (this.trainingArchive.hasAttribute(config.getTrainingTrainingConfigAttribute()))
this.trainingJson = this.trainingArchive.readAttributeAsJson(config.getTrainingTrainingConfigAttribute());
return this;
}

Expand All @@ -119,11 +125,26 @@ public KerasModelBuilder enforceTrainingConfig(boolean enforceTrainingConfig) {

public KerasModel buildModel()
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
return new KerasModel(this);
KerasModel m = new KerasModel(this);
close();
return m;
}

public KerasSequentialModel buildSequential()
throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
return new KerasSequentialModel(this);
KerasSequentialModel m = new KerasSequentialModel(this);
close();
return m;
}

@Override public void close() {
if (trainingArchive != null && trainingArchive != weightsArchive) {
trainingArchive.close();
trainingArchive = null;
}
if (weightsArchive != null) {
weightsArchive.close();
weightsArchive = null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,10 @@ public void importDcganGenerator() throws Exception {
*/
@Test
public void importWganDiscriminator() throws Exception {
importSequentialModelH5Test("modelimport/keras/examples/gans/wgan_discriminator.h5");
for (int i = 0; i < 100; i++) {
// run a few times to make sure HDF5 doesn't crash
importSequentialModelH5Test("modelimport/keras/examples/gans/wgan_discriminator.h5");
}
}

@Test
Expand Down Expand Up @@ -298,7 +301,7 @@ private void importEndModelTest(String modelPath, String inputsOutputsPath, bool
KerasModelEndToEndTest.class.getClassLoader());
File outputsFile = File.createTempFile(TEMP_OUTPUTS_FILENAME, H5_EXTENSION);
Files.copy(outputsResource.getInputStream(), outputsFile.toPath(), StandardCopyOption.REPLACE_EXISTING);
Hdf5Archive outputsArchive = new Hdf5Archive(outputsFile.getAbsolutePath());
try (Hdf5Archive outputsArchive = new Hdf5Archive(outputsFile.getAbsolutePath())) {

if (checkPredictions) {
INDArray input = getInputs(outputsArchive, tfOrdering)[0];
Expand Down Expand Up @@ -344,6 +347,7 @@ private void importEndModelTest(String modelPath, String inputsOutputsPath, bool
}
checkGradients(model, input, testLabels);
}
}
}

private static INDArray[] getInputs(Hdf5Archive archive, boolean tensorFlowImageDimOrdering) throws Exception {
Expand Down

0 comments on commit e2f8b09

Please sign in to comment.