diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java index 4ec0acbdb673..0a9c99cbcd16 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/module/train/TrainModule.java @@ -56,6 +56,7 @@ import java.text.DecimalFormat; import java.text.SimpleDateFormat; import java.util.*; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicInteger; import static play.mvc.Results.ok; @@ -85,12 +86,12 @@ private enum ModelType { }; private final int maxChartPoints; //Technically, the way it's set up: won't exceed 2*maxChartPoints - private Map knownSessionIDs = Collections.synchronizedMap(new LinkedHashMap<>()); + private Map knownSessionIDs = Collections.synchronizedMap(new WeakHashMap<>()); private String currentSessionID; private int currentWorkerIdx; - private Map workerIdxCount = Collections.synchronizedMap(new HashMap<>()); //Key: session ID - private Map> workerIdxToName = Collections.synchronizedMap(new HashMap<>()); //Key: session ID - private Map lastUpdateForSession = Collections.synchronizedMap(new HashMap<>()); + private Map workerIdxCount = new ConcurrentHashMap<>(); //Key: session ID + private Map> workerIdxToName = new ConcurrentHashMap<>(); //Key: session ID + private Map lastUpdateForSession = new ConcurrentHashMap<>(); public TrainModule() { String maxChartPointsProp = System.getProperty(DL4JSystemProperties.CHART_MAX_POINTS_PROPERTY); @@ -180,19 +181,25 @@ public synchronized void onAttach(StatsStorage statsStorage) { } @Override - public void onDetach(StatsStorage statsStorage) { + public synchronized void onDetach(StatsStorage statsStorage) { + Set toRemove = new HashSet<>(); for (String s : knownSessionIDs.keySet()) { if (knownSessionIDs.get(s) == statsStorage) { - knownSessionIDs.remove(s); +// knownSessionIDs.remove(s); + toRemove.add(s); workerIdxCount.remove(s); workerIdxToName.remove(s); currentSessionID = null; getDefaultSession(); } } + for(String s : toRemove) { +// knownSessionIDs.put(s, null); + knownSessionIDs.remove(s); + } } - private void getDefaultSession() { + private synchronized void getDefaultSession() { if (currentSessionID != null) return; @@ -376,7 +383,7 @@ private static void cleanLegacyIterationCounts(List iterationCounts) { } private Result getOverviewData() { - Long lastUpdate = lastUpdateForSession.get(currentSessionID); + Long lastUpdate = (currentSessionID == null ? null : lastUpdateForSession.get(currentSessionID)); if (lastUpdate == null) lastUpdate = -1L; I18N i18N = I18NProvider.getInstance(); diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java index 408bb3697eb0..c9e2a51895f1 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-play/src/main/java/org/deeplearning4j/ui/play/PlayUIServer.java @@ -420,6 +420,8 @@ private void runHelper() throws Exception { m.reportStorageEvents(out); } + events.clear(); + try { Thread.sleep(uiProcessingDelay); } catch (InterruptedException e) {