Skip to content

Commit

Permalink
Improved tests and fixes for UI
Browse files Browse the repository at this point in the history
Fixes:
- only log attached and detached sessions with URL in multi-session mode
- prevent PlayUIServer from propagating event of a StatsStorage that is not attached yet.
- check for null value of a collection before foreach in TrainModule.getMemory(...)
  • Loading branch information
printomi committed Feb 18, 2019
1 parent 13c70db commit f95a9a9
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 73 deletions.
Expand Up @@ -214,6 +214,10 @@ public synchronized void reportStorageEvents(Collection<StatsStorageEvent> event
&& StatsListener.TYPE_ID.equals(sse.getTypeID())
&& !knownSessionIDs.containsKey(sse.getSessionID())) {
knownSessionIDs.put(sse.getSessionID(), sse.getStatsStorage());
if (multiSession) {
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
addressSupplier.get(), sse.getSessionID(), sse.getStatsStorage());
}
}

Long lastUpdate = lastUpdateForSession.get(sse.getSessionID());
Expand All @@ -236,7 +240,11 @@ public synchronized void onAttach(StatsStorage statsStorage) {
if (!StatsListener.TYPE_ID.equals(typeID))
continue;
knownSessionIDs.put(sessionID, statsStorage);
log.info("Training session attached (onAttach), available at {}/train/{}", addressSupplier.get(), sessionID);
if (multiSession) {
log.info("Adding training session {}/train/{} of StatsStorage instance {}",
addressSupplier.get(), sessionID, statsStorage);
}

List<Persistable> latestUpdates = statsStorage.getLatestUpdateAllWorkers(sessionID, typeID);
for (Persistable update: latestUpdates) {
long updateTime = update.getTimeStamp();
Expand Down Expand Up @@ -264,7 +272,10 @@ public synchronized void onDetach(StatsStorage statsStorage) {
}
for(String s : toRemove) {
knownSessionIDs.remove(s);
log.info("Training session detached, not available any more at {}/train/{}", addressSupplier.get(), s);
if (multiSession) {
log.info("Removing training session {}/train/{} of StatsStorage instance {}.",
addressSupplier.get(), s, statsStorage);
}
lastUpdateForSession.remove(s);
I18NProvider.removeInstance(s);
}
Expand Down Expand Up @@ -1489,51 +1500,54 @@ private static Map<String, Object> getMemory(List<Persistable> staticInfoAllWork
}
}

for (Persistable p : updatesLastNMinutes) {
//TODO single pass
if (!p.getWorkerID().equals(wid))
continue;
if (!(p instanceof StatsReport))
continue;

StatsReport sp = (StatsReport) p;

timestamps.add(sp.getTimeStamp());

long jvmCurrentBytes = sp.getJvmCurrentBytes();
long jvmMaxBytes = sp.getJvmMaxBytes();
long ohCurrentBytes = sp.getOffHeapCurrentBytes();
long ohMaxBytes = sp.getOffHeapMaxBytes();

double jvmFrac = jvmCurrentBytes / ((double) jvmMaxBytes);
double offheapFrac = ohCurrentBytes / ((double) ohMaxBytes);
if (Double.isNaN(jvmFrac))
jvmFrac = 0.0;
if (Double.isNaN(offheapFrac))
offheapFrac = 0.0;
fracJvm.add((float) jvmFrac);
fracOffHeap.add((float) offheapFrac);

lastBytes[0] = jvmCurrentBytes;
lastBytes[1] = ohCurrentBytes;

lastMaxBytes[0] = jvmMaxBytes;
lastMaxBytes[1] = ohMaxBytes;

if (numDevices > 0) {
long[] devBytes = sp.getDeviceCurrentBytes();
long[] devMaxBytes = sp.getDeviceMaxBytes();
for (int i = 0; i < numDevices; i++) {
double frac = devBytes[i] / ((double) devMaxBytes[i]);
if (Double.isNaN(frac))
frac = 0.0;
fracDeviceMem.get(i).add((float) frac);
lastBytes[2 + i] = devBytes[i];
lastMaxBytes[2 + i] = devMaxBytes[i];
if (updatesLastNMinutes != null) {
for (Persistable p : updatesLastNMinutes) {
//TODO single pass
if (!p.getWorkerID().equals(wid))
continue;
if (!(p instanceof StatsReport))
continue;

StatsReport sp = (StatsReport) p;

timestamps.add(sp.getTimeStamp());

long jvmCurrentBytes = sp.getJvmCurrentBytes();
long jvmMaxBytes = sp.getJvmMaxBytes();
long ohCurrentBytes = sp.getOffHeapCurrentBytes();
long ohMaxBytes = sp.getOffHeapMaxBytes();

double jvmFrac = jvmCurrentBytes / ((double) jvmMaxBytes);
double offheapFrac = ohCurrentBytes / ((double) ohMaxBytes);
if (Double.isNaN(jvmFrac))
jvmFrac = 0.0;
if (Double.isNaN(offheapFrac))
offheapFrac = 0.0;
fracJvm.add((float) jvmFrac);
fracOffHeap.add((float) offheapFrac);

lastBytes[0] = jvmCurrentBytes;
lastBytes[1] = ohCurrentBytes;

lastMaxBytes[0] = jvmMaxBytes;
lastMaxBytes[1] = ohMaxBytes;

if (numDevices > 0) {
long[] devBytes = sp.getDeviceCurrentBytes();
long[] devMaxBytes = sp.getDeviceMaxBytes();
for (int i = 0; i < numDevices; i++) {
double frac = devBytes[i] / ((double) devMaxBytes[i]);
if (Double.isNaN(frac))
frac = 0.0;
fracDeviceMem.get(i).add((float) frac);
lastBytes[2 + i] = devBytes[i];
lastMaxBytes[2 + i] = devMaxBytes[i];
}
}
}
}


List<List<Float>> fracUtilized = new ArrayList<>();
fracUtilized.add(fracJvm);
fracUtilized.add(fracOffHeap);
Expand Down
Expand Up @@ -449,7 +449,8 @@ private void runHelper() throws Exception {
List<String> callbackTypes = m.getCallbackTypeIDs();
List<StatsStorageEvent> out = new ArrayList<>();
for (StatsStorageEvent e : events) {
if (callbackTypes.contains(e.getTypeID())) {
if (callbackTypes.contains(e.getTypeID())
&& statsStorageInstances.contains(e.getStatsStorage())) {
out.add(e);
}
}
Expand Down
Expand Up @@ -34,7 +34,12 @@
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import play.mvc.Http;

import java.io.UnsupportedEncodingException;
import java.net.HttpURLConnection;
import java.net.URL;
import java.net.URLEncoder;
import java.util.HashMap;

import static org.junit.Assert.*;
Expand All @@ -47,16 +52,19 @@ public class TestPlayUIMultiSession {

@Test
@Ignore
public void testUIMultiSession() throws Exception {
public void testUIMultiSession() {

UIServer uiServer = UIServer.getInstance(true, null);
HashMap<Thread, StatsStorage> statStorageForThread = new HashMap<>();

for (int session = 0; session < 3; session++) {
for (int session = 0; session < 300; session++) {

StatsStorage ss = new InMemoryStatsStorage();

final int sid = session;
new Thread(() -> {
final String sessionId = Integer.toString(sid);

Thread training = new Thread(() -> {
int layerSize = sid + 4;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
Expand All @@ -69,7 +77,7 @@ public void testUIMultiSession() throws Exception {
net.init();

StatsListener statsListener = new StatsListener(ss);
String sessionId = Integer.toString(sid);

statsListener.setSessionID(sessionId);
net.setListeners(statsListener, new ScoreIterationListener(1));
uiServer.attach(ss);
Expand All @@ -79,33 +87,42 @@ public void testUIMultiSession() throws Exception {
for (int i = 0; i < 20; i++) {
net.fit(iter);
}
try {
Thread.sleep(600_000);
} catch (InterruptedException e) {
e.printStackTrace();
fail(e.getMessage());
} finally {
uiServer.detach(ss);
}
}).start();
});

training.start();
statStorageForThread.put(training, ss);
}

for (Thread thread: statStorageForThread.keySet()) {
StatsStorage ss = statStorageForThread.get(thread);
try {
thread.join();
assertTrue(uiServer.isAttached(ss));
} catch (InterruptedException e) {
e.printStackTrace();
fail(e.getMessage());
} finally {
uiServer.detach(ss);
assertFalse(uiServer.isAttached(ss));
}
}

Thread.sleep(1_000_000);
}

@Test
@Ignore
public void testUIStatsStorageProvider() throws Exception {
public void testUIAutoAttach() throws Exception {
HashMap<String, StatsStorage> statsStorageForSession = new HashMap<>();

AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider();
UIServer playUIServer = UIServer.getInstance(true, statsProvider);
statsProvider.setUIServer(playUIServer);
Function<String, StatsStorage> statsStorageProvider = statsStorageForSession::get;
UIServer uIServer = UIServer.getInstance(true, statsStorageProvider);

for (int session = 0; session < 3; session++) {
int layerSize = session + 4;

InMemoryStatsStorage ss = new InMemoryStatsStorage();
String sessionId = Integer.toString(session);
statsProvider.put(sessionId, ss);
statsStorageForSession.put(sessionId, ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build())
Expand All @@ -119,27 +136,96 @@ public void testUIStatsStorageProvider() throws Exception {
StatsListener statsListener = new StatsListener(ss, 1);
statsListener.setSessionID(sessionId);
net.setListeners(statsListener, new ScoreIterationListener(1));
playUIServer.attach(ss);
uIServer.attach(ss);

DataSetIterator iter = new IrisDataSetIterator(150, 150);

for (int i = 0; i < 20; i++) {
net.fit(iter);
}

assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId)));
uIServer.detach(ss);
assertFalse(uIServer.isAttached(statsStorageForSession.get(sessionId)));

/*
* Wait for the first update (containing session ID) to effectively attach StatsStorage in PlayUIServer.
*/
Thread.sleep(1000);
* Visiting /train/:sessionId to auto-attach StatsStorage
*/
String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId);
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
conn.connect();

assertEquals(Http.Status.OK, conn.getResponseCode());
assertTrue(uIServer.isAttached(statsStorageForSession.get(sessionId)));
}

playUIServer.detach(ss);
System.out.println("To re-attach StatsStorage of training session, visit "
+ playUIServer.getAddress() + "/train/" + sessionId);
}

@Test
@Ignore
public void testUIAutoAttachDetach() throws Exception {

long autoDetachTimeoutMillis = 30_000;
AutoDetachingStatsStorageProvider statsProvider = new AutoDetachingStatsStorageProvider(autoDetachTimeoutMillis);
UIServer uIServer = UIServer.getInstance(true, statsProvider);
statsProvider.setUIServer(uIServer);

for (int session = 0; session < 3; session++) {
int layerSize = session + 4;

InMemoryStatsStorage ss = new InMemoryStatsStorage();
String sessionId = Integer.toString(session);
statsProvider.put(sessionId, ss);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).list()
.layer(0, new DenseLayer.Builder().activation(Activation.TANH).nIn(4).nOut(layerSize).build())
.layer(1, new OutputLayer.Builder().lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(3).build())
.build();

MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();

StatsListener statsListener = new StatsListener(ss, 1);
statsListener.setSessionID(sessionId);
net.setListeners(statsListener, new ScoreIterationListener(1));
uIServer.attach(ss);

DataSetIterator iter = new IrisDataSetIterator(150, 150);

for (int i = 0; i < 20; i++) {
net.fit(iter);
}

assertTrue(uIServer.isAttached(ss));
uIServer.detach(ss);
assertFalse(uIServer.isAttached(ss));

/*
* Visiting /train/:sessionId to auto-attach StatsStorage
*/
String sessionUrl = trainingSessionUrl(uIServer.getAddress(), sessionId);
HttpURLConnection conn = (HttpURLConnection) new URL(sessionUrl).openConnection();
conn.connect();

assertEquals(Http.Status.OK, conn.getResponseCode());
assertTrue(uIServer.isAttached(ss));
}

Thread.sleep(1_000_000);
}

/**
* Get URL-encoded URL for training session on given server address
* @param serverAddress server address
* @param sessionId session ID
* @return URL
* @throws UnsupportedEncodingException if the used encoding is not supported
*/
private static String trainingSessionUrl(String serverAddress, String sessionId) throws UnsupportedEncodingException {
return URLEncoder.encode(String.format("%s/train/%s", serverAddress, sessionId), "UTF-8");
}

/**
* StatsStorage provider with automatic detaching of StatsStorage after a timeout
* @author fenyvesit
Expand All @@ -149,6 +235,11 @@ private static class AutoDetachingStatsStorageProvider implements Function<Strin

HashMap<String, InMemoryStatsStorage> storageForSession = new HashMap<>();
UIServer uIServer;
long autoDetachTimeoutMillis;

public AutoDetachingStatsStorageProvider(long autoDetachTimeoutMillis) {
this.autoDetachTimeoutMillis = autoDetachTimeoutMillis;
}

public void put(String sessionId, InMemoryStatsStorage statsStorage) {
storageForSession.put(sessionId, statsStorage);
Expand All @@ -163,8 +254,6 @@ public StatsStorage apply(String sessionId) {
StatsStorage statsStorage = storageForSession.get(sessionId);

if (statsStorage != null) {
// auto-detach StatsStorage instances that will be attached via this provider
long autoDetachTimeoutMillis = 1000*30;
new Thread(() -> {
try {
System.out.println("Waiting to detach StatsStorage (session ID: " + sessionId + ")" +
Expand All @@ -173,7 +262,7 @@ public StatsStorage apply(String sessionId) {
} catch (InterruptedException e) {
e.printStackTrace();
} finally {
System.out.println("Auto-detaching StatsStorage (session ID:" + sessionId + ") after " +
System.out.println("Auto-detaching StatsStorage (session ID: " + sessionId + ") after " +
autoDetachTimeoutMillis + " ms.");
uIServer.detach(statsStorage);
System.out.println(" To re-attach StatsStorage of training session, visit " +
Expand Down

0 comments on commit f95a9a9

Please sign in to comment.