Skip to content

Commit

Permalink
Merge pull request #3907 from deeplearning4j/ab_statstorageopt
Browse files Browse the repository at this point in the history
Optimizations for UI stats storage querying and J7FileStatsStorage
  • Loading branch information
AlexDBlack committed Aug 22, 2017
2 parents 4cab9f8 + b8fdf3d commit 6e4b3f3
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 70 deletions.
Expand Up @@ -160,6 +160,27 @@ public interface StatsStorage extends StatsStorageRouter {
*/
List<Persistable> getAllUpdatesAfter(String sessionID, String typeID, long timestamp);

/**
* List the times of all updates for the specified sessionID, typeID and workerID
*
* @param sessionID Session ID to get update times for
* @param typeID Type ID to get update times for
* @param workerID Worker ID to get update times for
* @return Times of all updates
*/
long[] getAllUpdateTimes(String sessionID, String typeID, String workerID);

/**
* Get updates for the specified times only
*
* @param sessionID Session ID to get update times for
* @param typeID Type ID to get update times for
* @param workerID Worker ID to get update times for
* @param timestamps Timestamps to get the updates for. Note that if one of the specified times does not exist,
* it will be ommitted from the returned results list.
* @return List of updates at the specified times
*/
List<Persistable> getUpdates(String sessionID, String typeID, String workerID, long[] timestamps);

/**
* Get the session metadata, if any has been registered via {@link #putStorageMetaData(StorageMetaData)}
Expand Down
17 changes: 12 additions & 5 deletions deeplearning4j-ui-parent/deeplearning4j-play/pom.xml
Expand Up @@ -104,6 +104,12 @@
</exclusions>
</dependency>

<dependency>
<groupId>org.eclipse.collections</groupId>
<artifactId>eclipse-collections</artifactId>
<version>${eclipse.collections.version}</version>
</dependency>

<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-netty-server_2.11</artifactId>
Expand Down Expand Up @@ -192,6 +198,12 @@
<version>2.3.13</version>
</dependency>

<dependency>
<groupId>com.beust</groupId>
<artifactId>jcommander</artifactId>
<version>${jcommander.version}</version>
</dependency>

<!-- Test Scope Dependencies -->
<dependency>
<groupId>junit</groupId>
Expand All @@ -204,11 +216,6 @@
<artifactId>logback-classic</artifactId> <!-- Version set by deeplearning4j-parent dependency management -->
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.beust</groupId>
<artifactId>jcommander</artifactId>
<version>1.27</version>
</dependency>
</dependencies>


Expand Down
Expand Up @@ -7,6 +7,7 @@
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.api.storage.StatsStorageEvent;
import org.deeplearning4j.api.storage.StatsStorageListener;
import org.eclipse.collections.impl.list.mutable.primitive.LongArrayList;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
Expand Down Expand Up @@ -311,7 +312,7 @@ private Result setWorkerByIdx(String newWorkerIdx) {
try {
currentWorkerIdx = Integer.parseInt(newWorkerIdx);
} catch (NumberFormatException e) {
log.debug("Invaild call to setWorkerByIdx", e);
log.debug("Invalid call to setWorkerByIdx", e);
}
return ok();
}
Expand Down Expand Up @@ -373,8 +374,24 @@ private Result getOverviewData() {
result.put("scoresIter", scoresIterCount);

//Get scores info
List<Persistable> updates =
(noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
long[] allTimes = (noData ? null : ss.getAllUpdateTimes(currentSessionID, StatsListener.TYPE_ID, wid));
List<Persistable> updates = null;
if(allTimes != null && allTimes.length > maxChartPoints){
int subsamplingFrequency = allTimes.length / maxChartPoints;
LongArrayList timesToQuery = new LongArrayList(maxChartPoints+2);
int i=0;
for(; i<allTimes.length; i+= subsamplingFrequency){
timesToQuery.add(allTimes[i]);
}
if((i-subsamplingFrequency) != allTimes.length-1){
//Also add final point
timesToQuery.add(allTimes[allTimes.length-1]);
}
updates = ss.getUpdates(currentSessionID, StatsListener.TYPE_ID, wid, timesToQuery.toArray());
} else if(allTimes != null) {
//Don't subsample
updates = ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0);
}
if (updates == null || updates.size() == 0) {
noData = true;
}
Expand Down Expand Up @@ -686,57 +703,40 @@ private Result getModelData(String str) {
result.put("layerInfo", layerInfoTable);

//First: get all data, and subsample it if necessary, to avoid returning too many points...
List<Persistable> updates =
(noData ? null : ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0));
long[] allTimes = (noData ? null : ss.getAllUpdateTimes(currentSessionID, StatsListener.TYPE_ID, wid));

List<Persistable> updates = null;
List<Integer> iterationCounts = null;
boolean needToHandleLegacyIterCounts = false;
if (updates != null && updates.size() > maxChartPoints) {
int subsamplingFrequency = updates.size() / maxChartPoints;
List<Persistable> subsampled = new ArrayList<>();
iterationCounts = new ArrayList<>();
int pCount = -1;
int lastUpdateIdx = updates.size() - 1;

int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;;
StatsReport sr = (StatsReport) p;
pCount++;

int iterCount = sr.getIterationCount();
if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
lastIterCount = iterCount;


if (pCount > 0 && subsamplingFrequency > 1 && pCount % subsamplingFrequency != 0) {
//Skip this to subsample the data
if (pCount != lastUpdateIdx)
continue; //Always keep the most recent value
}

subsampled.add(p);
iterationCounts.add(iterCount);
if(allTimes != null && allTimes.length > maxChartPoints){
int subsamplingFrequency = allTimes.length / maxChartPoints;
LongArrayList timesToQuery = new LongArrayList(maxChartPoints+2);
int i=0;
for(; i<allTimes.length; i+= subsamplingFrequency){
timesToQuery.add(allTimes[i]);
}
updates = subsampled;
} else if (updates != null) {
int offset = 0;
iterationCounts = new ArrayList<>(updates.size());
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;;
StatsReport sr = (StatsReport) p;
int iterCount = sr.getIterationCount();

if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
if((i-subsamplingFrequency) != allTimes.length-1){
//Also add final point
timesToQuery.add(allTimes[allTimes.length-1]);
}
updates = ss.getUpdates(currentSessionID, StatsListener.TYPE_ID, wid, timesToQuery.toArray());
} else if(allTimes != null) {
//Don't subsample
updates = ss.getAllUpdatesAfter(currentSessionID, StatsListener.TYPE_ID, wid, 0);
}

iterationCounts.add(iterCount);
iterationCounts = new ArrayList<>(updates.size());
int lastIterCount = -1;
for (Persistable p : updates) {
if (!(p instanceof StatsReport))
continue;;
StatsReport sr = (StatsReport) p;
int iterCount = sr.getIterationCount();

if (iterCount <= lastIterCount) {
needToHandleLegacyIterCounts = true;
}
iterationCounts.add(iterCount);
}

//Legacy issue - Spark training - iteration counts are used to be reset... which means: could go 0,1,2,0,1,2, etc...
Expand Down
Expand Up @@ -201,7 +201,7 @@ public static GraphInfo buildGraphInfo(NeuralNetConfiguration config) {

Map<String, String> decoderInfo = new LinkedHashMap<>();
inputSize = (i == 0 ? va.getNOut() : decLayerSizes[i - 1]);
outputSize = encLayerSizes[i];
outputSize = decLayerSizes[i];
decoderInfo.put("Input Size", String.valueOf(inputSize));
decoderInfo.put("Layer Size", String.valueOf(outputSize));
decoderInfo.put("Num Parameters", String.valueOf((inputSize + 1) * outputSize));
Expand Down
Expand Up @@ -308,6 +308,41 @@ public StorageMetaData getStorageMetaData(String sessionID, String typeID) {
return this.storageMetaData.get(new SessionTypeId(sessionID, typeID));
}

@Override
public long[] getAllUpdateTimes(String sessionID, String typeID, String workerID) {
SessionTypeWorkerId stw = new SessionTypeWorkerId(sessionID, typeID, workerID);
Map<Long,Persistable> m = updates.get(stw);
if(m == null){
return new long[0];
}

long[] ret = new long[m.size()];
int i=0;
for(Long l : m.keySet()){
ret[i++] = l;
}
Arrays.sort(ret);
return ret;
}

@Override
public List<Persistable> getUpdates(String sessionID, String typeID, String workerID, long[] timestamps) {
SessionTypeWorkerId stw = new SessionTypeWorkerId(sessionID, typeID, workerID);
Map<Long,Persistable> m = updates.get(stw);
if(m == null){
return Collections.emptyList();
}

List<Persistable> ret = new ArrayList<>(timestamps.length);
for(long l : timestamps){
Persistable p = m.get(l);
if(p != null){
ret.add(p);
}
}
return ret;
}

// ----- Store new info -----

@Override
Expand Down
Expand Up @@ -2,6 +2,7 @@

import lombok.NonNull;
import org.deeplearning4j.api.storage.*;
import org.eclipse.collections.impl.list.mutable.primitive.LongArrayList;
import org.nd4j.linalg.primitives.Pair;

import java.io.*;
Expand Down Expand Up @@ -517,23 +518,26 @@ public Persistable getUpdate(String sessionID, String typeId, String workerID, l

@Override
public List<Persistable> getLatestUpdateAllWorkers(String sessionID, String typeID) {
String sql = "SELECT * FROM " + TABLE_NAME_UPDATES + " t1" + " LEFT JOIN " + TABLE_NAME_UPDATES
+ " t2 ON t1.SessionID = t2.SessionID AND "
+ "t1.TypeID = t2.TypeID AND t1.WorkerID = t2.WorkerID AND t1.Timestamp < t2.Timestamp "
+ "WHERE t2.Timestamp IS NULL AND t1.SessionID = '" + sessionID + "' AND t1.TypeID = '" + typeID
+ "';";
String sql = "SELECT workerId, MAX(Timestamp) FROM " + TABLE_NAME_UPDATES + " WHERE SessionID ='"
+ sessionID + "' AND " + "TypeID = '" + typeID + "' GROUP BY workerId";

Map<String,Long> m = new HashMap<>();
try (Statement statement = connection.createStatement()) {
ResultSet rs = statement.executeQuery(sql);
List<Persistable> out = new ArrayList<>();
while (rs.next()) {
byte[] bytes = rs.getBytes(6);
out.add((Persistable) deserialize(bytes));
String wid = rs.getString(1);
long ts = rs.getLong(2);
m.put(wid, ts);
}
return out;
} catch (SQLException e) {
throw new RuntimeException(e);
}

List<Persistable> out = new ArrayList<>();
for(String s : m.keySet()){
out.add(getUpdate(sessionID, typeID, s, m.get(s)));
}
return out;
}

@Override
Expand All @@ -555,13 +559,62 @@ public List<Persistable> getAllUpdatesAfter(String sessionID, String typeID, Str

@Override
public List<Persistable> getAllUpdatesAfter(String sessionID, String typeID, long timestamp) {
String sql = "SELECT * FROM " + TABLE_NAME_UPDATES + " WHERE SessionID = '" + sessionID + "' "
String sql = "SELECT ObjectBytes FROM " + TABLE_NAME_UPDATES + " WHERE SessionID = '" + sessionID + "' "
+ "AND Timestamp > " + timestamp + ";";
return queryUpdates(sql);
}

@Override
public long[] getAllUpdateTimes(String sessionID, String typeID, String workerID) {
/*
statement.executeUpdate("CREATE TABLE " + TABLE_NAME_UPDATES + " (" + "SessionID TEXT NOT NULL, "
+ "TypeID TEXT NOT NULL, " + "WorkerID TEXT NOT NULL, " + "Timestamp INTEGER NOT NULL, "
+ "ObjectClass TEXT NOT NULL, " + "ObjectBytes BLOB NOT NULL, "
+ "PRIMARY KEY ( SessionID, TypeID, WorkerID, Timestamp )" + ");");
*/
String sql = "SELECT Timestamp FROM " + TABLE_NAME_UPDATES + " WHERE SessionID = '" + sessionID + "' "
+ "AND TypeID = '" + typeID + "' AND workerID = '" + workerID + "';";
try (Statement statement = connection.createStatement()) {
ResultSet rs = statement.executeQuery(sql);
LongArrayList list = new LongArrayList();
while (rs.next()) {
list.add(rs.getLong(1));
}
return list.toArray();
} catch (SQLException e) {
throw new RuntimeException(e);
}
}

@Override
public List<Persistable> getUpdates(String sessionID, String typeID, String workerID, long[] timestamps) {
if(timestamps == null || timestamps.length == 0){
return Collections.emptyList();
}

StringBuilder sb = new StringBuilder();
sb.append("SELECT ObjectBytes FROM ").append(TABLE_NAME_UPDATES).append(" WHERE SessionID = '").append(sessionID)
.append("' AND TypeID = '").append(typeID).append("' AND workerID='").append(workerID)
.append("' AND Timestamp IN (");

for( int i=0; i<timestamps.length; i++ ){
if(i > 0){
sb.append(",");
}
sb.append(timestamps[i]);
}
sb.append(");");

String sql = sb.toString();
return queryUpdates(sql);
}

private List<Persistable> queryUpdates(String sql){
try (Statement statement = connection.createStatement()) {
ResultSet rs = statement.executeQuery(sql);
List<Persistable> out = new ArrayList<>();
while (rs.next()) {
byte[] bytes = rs.getBytes(6);
byte[] bytes = rs.getBytes(1);
out.add((Persistable) deserialize(bytes));
}
return out;
Expand Down

0 comments on commit 6e4b3f3

Please sign in to comment.