Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes and SameDiff functionality (#7807)
* #6992 SameDiff mixed precision training support * Placeholder shape validation * Checkpoint listener * SameDiff checkpoint listener * SameDiff: Remove no longer required trainable params config from TrainingConfig * SameDiff: add name scopes * SameDiff name scopes - javadoc and tests * #7802 Evaluation class - report single class not macro avg in stats() for binary case * 7804 Arbiter - update score functions to use ND4J evaluation metric enums * SameDiff flatbuffers export: don't export arrays for array type variables (not required)
- Loading branch information
1 parent
ce51666
commit 1e2bcc1
Showing
20 changed files
with
1,492 additions
and
236 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
61 changes: 61 additions & 0 deletions
61
...-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/listeners/checkpoint/Checkpoint.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/******************************************************************************* | ||
* Copyright (c) 2015-2018 Skymind, Inc. | ||
* | ||
* This program and the accompanying materials are made available under the | ||
* terms of the Apache License, Version 2.0 which is available at | ||
* https://www.apache.org/licenses/LICENSE-2.0. | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
* License for the specific language governing permissions and limitations | ||
* under the License. | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
******************************************************************************/ | ||
|
||
package org.nd4j.autodiff.listeners.checkpoint; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Data; | ||
|
||
import java.io.Serializable; | ||
import java.util.Arrays; | ||
|
||
/** | ||
* A model checkpoint, used with {@link CheckpointListener} | ||
* | ||
* @author Alex Black | ||
*/ | ||
@AllArgsConstructor | ||
@Data | ||
public class Checkpoint implements Serializable { | ||
|
||
private int checkpointNum; | ||
private long timestamp; | ||
private int iteration; | ||
private int epoch; | ||
private String filename; | ||
|
||
public static String getFileHeader(){ | ||
return "checkpointNum,timestamp,iteration,epoch,filename"; | ||
} | ||
|
||
public static Checkpoint fromFileString(String str){ | ||
String[] split = str.split(","); | ||
if(split.length != 5){ | ||
throw new IllegalStateException("Cannot parse checkpoint entry: expected 5 entries, got " + split.length | ||
+ " - values = " + Arrays.toString(split)); | ||
} | ||
return new Checkpoint( | ||
Integer.parseInt(split[0]), | ||
Long.parseLong(split[1]), | ||
Integer.parseInt(split[2]), | ||
Integer.parseInt(split[3]), | ||
split[4]); | ||
} | ||
|
||
public String toFileString(){ | ||
return checkpointNum + "," + timestamp + "," + iteration + "," + epoch + "," + filename; | ||
} | ||
} |
Oops, something went wrong.