diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index 5bf2f6594daf..db9ac02f7eb9 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -214,6 +214,10 @@ public synchronized void reportStorageEvents(Collection event && StatsListener.TYPE_ID.equals(sse.getTypeID()) && !knownSessionIDs.containsKey(sse.getSessionID())) { knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage()); + if (multiSession) { + log.info("Adding training session {}/train/{} of StatsStorage instance {}", + addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage()); + } } Long lastUpdate = lastUpdateForSession.get(sse.getSessionID()); @@ -236,7 +240,11 @@ public synchronized void onAttach(StatsStorage statsStorage) { if (!StatsListener.TYPE_ID.equals(typeID)) continue; knownSessionIDs.put(sessionID, statsStorage); - log.info("Training session attached (onAttach), available at {}/train/{}", addressSupplier.get(), sessionID); + if (multiSession) { + log.info("Adding training session {}/train/{} of StatsStorage instance {}", + addressSupplier.get(), sessionID, statsStorage); + } + List latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID); for (Persistable update: latestUpdates) { long updateTime = update.getTimeStamp(); @@ -264,7 +272,10 @@ public synchronized void onDetach(StatsStorage statsStorage) { } for(String s : toRemove) { knownSessionIDs.remove(s); - log.info("Training session detached, not available any more at {}/train/{}", addressSupplier.get(), s); + if (multiSession) { + log.info("Removing training session {}/train/{} of StatsStorage instance {}.", + addressSupplier.get(), s, statsStorage); + } lastUpdateForSession.remove(s); I18NProvider.removeInstance(s); } @@ -1489,51 +1500,54 @@ private static Map getMemory(List staticInfoAllWork } } - for (Persistable p : updatesLastNMinutes) { - //TODO single pass - if (!p.getWorkerID().equals(wid)) - continue; - if (!(p instanceof StatsReport)) - continue; - - StatsReport sp = (StatsReport) p; - - timestamps.add(sp.getTimeStamp()); - - long jvmCurrentBytes = sp.getJvmCurrentBytes(); - long jvmMaxBytes = sp.getJvmMaxBytes(); - long ohCurrentBytes = sp.getOffHeapCurrentBytes(); - long ohMaxBytes = sp.getOffHeapMaxBytes(); - - double jvmFrac = jvmCurrentBytes / ((double) jvmMaxBytes); - double offheapFrac = ohCurrentBytes / ((double) ohMaxBytes); - if (Double.isNaN(jvmFrac)) - jvmFrac = 0.0; - if (Double.isNaN(offheapFrac)) - offheapFrac = 0.0; - fracJvm.add((float) jvmFrac); - fracOffHeap.add((float) offheapFrac); - - lastBytes[0] = jvmCurrentBytes; - lastBytes[1] = ohCurrentBytes; - - lastMaxBytes[0] = jvmMaxBytes; - lastMaxBytes[1] = ohMaxBytes; - - if (numDevices > 0) { - long[] devBytes = sp.getDeviceCurrentBytes(); - long[] devMaxBytes = sp.getDeviceMaxBytes(); - for (int i = 0; i < numDevices; i++) { - double frac = devBytes[i] / ((double) devMaxBytes[i]); - if (Double.isNaN(frac)) - frac = 0.0; - fracDeviceMem.get(i).add((float) frac); - lastBytes[2 + i] = devBytes[i]; - lastMaxBytes[2 + i] = devMaxBytes[i]; + if (updatesLastNMinutes != null) { + for (Persistable p : updatesLastNMinutes) { + //TODO single pass + if (!p.getWorkerID().equals(wid)) + continue; + if (!(p instanceof StatsReport)) + continue; + + StatsReport sp = (StatsReport) p; + + timestamps.add(sp.getTimeStamp()); + + long jvmCurrentBytes = sp.getJvmCurrentBytes(); + long jvmMaxBytes = sp.getJvmMaxBytes(); + long ohCurrentBytes = sp.getOffHeapCurrentBytes(); + long ohMaxBytes = sp.getOffHeapMaxBytes(); + + double jvmFrac = jvmCurrentBytes / ((double) jvmMaxBytes); + double offheapFrac = ohCurrentBytes / ((double) ohMaxBytes); + if (Double.isNaN(jvmFrac)) + jvmFrac = 0.0; + if (Double.isNaN(offheapFrac)) + offheapFrac = 0.0; + fracJvm.add((float) jvmFrac); + fracOffHeap.add((float) offheapFrac); + + lastBytes[0] = jvmCurrentBytes; + lastBytes[1] = ohCurrentBytes; + + lastMaxBytes[0] = jvmMaxBytes; + lastMaxBytes[1] = ohMaxBytes; + + if (numDevices > 0) { + long[] devBytes = sp.getDeviceCurrentBytes(); + long[] devMaxBytes = sp.getDeviceMaxBytes(); + for (int i = 0; i < numDevices; i++) { + double frac = devBytes[i] / ((double) devMaxBytes[i]); + if (Double.isNaN(frac)) + frac = 0.0; + fracDeviceMem.get(i).add((float) frac); + lastBytes[2 + i] = devBytes[i]; + lastMaxBytes[2 + i] = devMaxBytes[i]; + } } } } + List> fracUtilized = new ArrayList<>(); fracUtilized.add(fracJvm); fracUtilized.add(fracOffHeap); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java index 29890d40fd29..e06b1143784d 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java @@ -449,7 +449,8 @@ private void runHelper() throws Exception { List callbackTypes = m.getCallbackTypeIDs(); List out = new ArrayList<>(); for (StatsStorageEvent e : events) { - if (callbackTypes.contains(e.getTypeID())) { + if (callbackTypes.contains(e.getTypeID()) + && statsStorageInstances.contains(e.getStatsStorage())) { out.add(e); } } diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestPlayUIMultiSession.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestPlayUIMultiSession.java index 493dc0d17b3a..d86e1a055490 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestPlayUIMultiSession.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/test/java/org/deeplearning4j/ui/play/TestPlayUIMultiSession.java @@ -34,7 +34,12 @@ import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.function.Function; import org.nd4j.linalg.lossfunctions.LossFunctions; +import play.mvc.Http; +import java.io.UnsupportedEncodingException; +import java.net.HttpURLConnection; +import java.net.URL; +import java.net.URLEncoder; import java.util.HashMap; import static org.junit.Assert.*; @@ -47,16 +52,19 @@ public class TestPlayUIMultiSession { @Test @Ignore - public void testUIMultiSession() throws Exception { + public void testUIMultiSession() { UIServer uiServer = UIServer.getInstance(true, null); + HashMap statStorageForThread = new HashMap<>(); - for (int session = 0; session < 3; session++) { + for (int session = 0; session < 300; session++) { StatsStorage ss = new InMemoryStatsStorage(); final int sid = session; - new Thread(() -> { + final String sessionId = Integer.toString(sid); + + Thread training = new Thread(() -> { int layerSize = sid + 4; MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() @@ -69,7 +77,7 @@ public void testUIMultiSession() throws Exception { net.init(); StatsListener statsListener = new StatsListener(ss); - String sessionId = Integer.toString(sid); + statsListener.setSessionID(sessionId); net.setListeners(statsListener, new ScoreIterationListener(1)); uiServer.attach(ss); @@ -79,33 +87,42 @@ public void testUIMultiSession() throws Exception { for (int i = 0; i < 20; i++) { net.fit(iter); } - try { - Thread.sleep(600_000); - } catch (InterruptedException e) { - e.printStackTrace(); - fail(e.getMessage()); - } finally { - uiServer.detach(ss); - } - }).start(); + }); + + training.start(); + statStorageForThread.put(training, ss); + } + + for (Thread thread: statStorageForThread.keySet()) { + StatsStorage ss = statStorageForThread.get(thread); + try { + thread.join(); + assertTrue(uiServer.isAttached(ss)); + } catch (InterruptedException e) { + e.printStackTrace(); + fail(e.getMessage()); + } finally { + uiServer.detach(ss); + assertFalse(uiServer.isAttached(ss)); + } } - Thread.sleep(1_000_000); } @Test @Ignore - public void testUIStatsStorageProvider() throws Exception { + public void testUIAutoAttach() throws Exception { + HashMap statsStorageForSession = new HashMap<>(); - AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(); - UIServer playUIServer = UIServer.getInstance(true, statsProvider); - statsProvider.setUIServer(playUIServer); + Function statsStorageProvider = statsStorageForSession::get; + UIServer uIServer = UIServer.getInstance(true, statsStorageProvider); for (int session = 0; session < 3; session++) { int layerSize = session + 4; + InMemoryStatsStorage ss = new InMemoryStatsStorage(); String sessionId = Integer.toString(session); - statsProvider.put(sessionId, ss); + statsStorageForSession.put(sessionId, ss); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build()) @@ -119,7 +136,7 @@ public void testUIStatsStorageProvider() throws Exception { StatsListener statsListener = new StatsListener(ss, 1); statsListener.setSessionID(sessionId); net.setListeners(statsListener, new ScoreIterationListener(1)); - playUIServer.attach(ss); + uIServer.attach(ss); DataSetIterator iter = new IrisDataSetIterator(150, 150); @@ -127,19 +144,88 @@ public void testUIStatsStorageProvider() throws Exception { net.fit(iter); } + assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId))); + uIServer.detach(ss); + assertFalse(uIServer.isAttached(statsStorageForSession.get(sessionId))); + /* - * Wait for the first update (containing session ID) to effectively attach StatsStorage in PlayUIServer. - */ - Thread.sleep(1000); + * Visiting /train/:sessionId to auto-attach StatsStorage + */ + String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId); + HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection(); + conn.connect(); + + assertEquals(Http.Status.OK, conn.getResponseCode()); + assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId))); + } - playUIServer.detach(ss); - System.out.println("To re-attach StatsStorage of training session, visit " - + playUIServer.getAddress() + "/train/" + sessionId); + } + + @Test + @Ignore + public void testUIAutoAttachDetach() throws Exception { + + long autoDetachTimeoutMillis = 30_000; + AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(autoDetachTimeoutMillis); + UIServer uIServer = UIServer.getInstance(true, statsProvider); + statsProvider.setUIServer(uIServer); + + for (int session = 0; session < 3; session++) { + int layerSize = session + 4; + + InMemoryStatsStorage ss = new InMemoryStatsStorage(); + String sessionId = Integer.toString(session); + statsProvider.put(sessionId, ss); + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list() + .layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build()) + .layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT) + .activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build()) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + StatsListener statsListener = new StatsListener(ss, 1); + statsListener.setSessionID(sessionId); + net.setListeners(statsListener, new ScoreIterationListener(1)); + uIServer.attach(ss); + + DataSetIterator iter = new IrisDataSetIterator(150, 150); + + for (int i = 0; i < 20; i++) { + net.fit(iter); + } + + assertTrue(uIServer.isAttached(ss)); + uIServer.detach(ss); + assertFalse(uIServer.isAttached(ss)); + + /* + * Visiting /train/:sessionId to auto-attach StatsStorage + */ + String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId); + HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection(); + conn.connect(); + + assertEquals(Http.Status.OK, conn.getResponseCode()); + assertTrue(uIServer.isAttached(ss)); } Thread.sleep(1_000_000); } + /** + * Get URL-encoded URL for training session on given server address + * @param serverAddress server address + * @param sessionId session ID + * @return URL + * @throws UnsupportedEncodingException if the used encoding is not supported + */ + private static String trainingSessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException { + return URLEncoder.encode(String.format("%s/train/%s", serverAddress, sessionId), "UTF-8"); + } + /** * StatsStorage provider with automatic detaching of StatsStorage after a timeout * @author fenyvesit @@ -149,6 +235,11 @@ private static class AutoDetachingStatsStorageProvider implements Function storageForSession = new HashMap<>(); UIServer uIServer; + long autoDetachTimeoutMillis; + + public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) { + this.autoDetachTimeoutMillis = autoDetachTimeoutMillis; + } public void put(String sessionId, InMemoryStatsStorage statsStorage) { storageForSession.put(sessionId, statsStorage); @@ -163,8 +254,6 @@ public StatsStorage apply(String sessionId) { StatsStorage statsStorage = storageForSession.get(sessionId); if (statsStorage != null) { - // auto-detach StatsStorage instances that will be attached via this provider - long autoDetachTimeoutMillis = 1000*30; new Thread(() -> { try { System.out.println("Waiting to detach StatsStorage (session ID: " + sessionId + ")" + @@ -173,7 +262,7 @@ public StatsStorage apply(String sessionId) { } catch (InterruptedException e) { e.printStackTrace(); } finally { - System.out.println("Auto-detaching StatsStorage (session ID:" + sessionId + ") after " + + System.out.println("Auto-detaching StatsStorage (session ID: " + sessionId + ") after " + autoDetachTimeoutMillis + " ms."); uIServer.detach(statsStorage); System.out.println(" To re-attach StatsStorage of training session, visit " +