Skip to content

Commit

Permalink
#36 - All recommendations disappear during training
Browse files Browse the repository at this point in the history
- An empty "incoming" predictions object was created when the training task was scheduled. This caused the bug
- Instead, create the predictions object when training is completed and also at the point submit it as the "incoming" predictions
  • Loading branch information
reckart committed Apr 11, 2018
1 parent a5a9174 commit 8d3ecdd
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 48 deletions.
8 changes: 4 additions & 4 deletions inception-active-learning/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@
<artifactId>webanno-support</artifactId>
</dependency>

<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-ui-core</artifactId>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-recommendation</artifactId>
Expand Down Expand Up @@ -88,6 +84,10 @@
<groupId>org.apache.wicket</groupId>
<artifactId>wicket-core</artifactId>
</dependency>
<dependency>
<groupId>org.apache.wicket</groupId>
<artifactId>wicket-util</artifactId>
</dependency>
<dependency>
<groupId>org.apache.wicket</groupId>
<artifactId>wicket-spring</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import org.apache.uima.jcas.JCas;
import org.apache.wicket.ajax.AjaxRequestTarget;
import org.apache.wicket.event.Broadcast;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.stereotype.Component;
Expand Down Expand Up @@ -63,6 +65,8 @@ public class RecommendationEditorExtension
{
public static final String BEAN_NAME = "recommendationEditorExtension";

private final Logger log = LoggerFactory.getLogger(getClass());

private @Autowired AnnotationSchemaService annotationService;
private @Autowired RecommendationService recommendationService;
private @Autowired LearningRecordService learningRecordService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,11 @@
import org.springframework.context.ApplicationContext;
import org.springframework.stereotype.Component;

import de.tudarmstadt.ukp.clarin.webanno.api.AnnotationSchemaService;
import de.tudarmstadt.ukp.clarin.webanno.api.DocumentService;
import de.tudarmstadt.ukp.clarin.webanno.model.Project;
import de.tudarmstadt.ukp.clarin.webanno.security.model.User;
import de.tudarmstadt.ukp.inception.recommendation.model.Predictions;
import de.tudarmstadt.ukp.inception.recommendation.scheduling.tasks.SelectionTask;
import de.tudarmstadt.ukp.inception.recommendation.scheduling.tasks.Task;
import de.tudarmstadt.ukp.inception.recommendation.scheduling.tasks.TrainingTask;
import de.tudarmstadt.ukp.inception.recommendation.service.RecommendationService;

/**
* Used to run the selection, training and prediction task concurrently.
Expand All @@ -50,10 +46,6 @@ public class RecommendationScheduler

private @Autowired ApplicationContext applicationContext;

private @Autowired RecommendationService recService;
private @Autowired DocumentService docService;
private @Autowired AnnotationSchemaService annoService;

private Thread consumer;
private BlockingQueue<Task> queue = new ArrayBlockingQueue<Task>(100);
private int counter = 0;
Expand All @@ -74,15 +66,15 @@ public void destroy()
consumer.interrupt();
}

public void enqueueTask(User user, Project project, Predictions model)
public void enqueueTask(User user, Project project)
{
// Add Selection Task
if (counter % 2 == 0) {
enqueue(new SelectionTask(user, project));
}

// Add Training (which in turn will later enqueue the prediction Task)
enqueue(new TrainingTask(user, project, model));
enqueue(new TrainingTask(user, project));

counter++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
*/
package de.tudarmstadt.ukp.inception.recommendation.scheduling.tasks;

import static org.apache.commons.lang3.Validate.notNull;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -52,19 +50,13 @@ public class PredictionTask
{
private Logger log = LoggerFactory.getLogger(getClass());

private Predictions model;

private @Autowired AnnotationSchemaService annoService;
private @Autowired RecommendationService recommendationService;
private @Autowired DocumentService documentService;

public PredictionTask(User aUser, Project aProject, Predictions aPredictions)
public PredictionTask(User aUser, Project aProject)
{
super(aProject, aUser);

notNull(aPredictions);

model = aPredictions;
}

@Override
Expand Down Expand Up @@ -120,7 +112,11 @@ public void run()
recommender.getName());
List<AnnotationObject> predictions = classifier.predict(tokens, layer);
predictions.forEach(token -> token.setRecommenderId(ct.getId()));

Predictions model = new Predictions(getProject(), getUser());
model.putPredictions(layer.getId(), predictions);
recommendationService.putIncomingPredictions(getUser(), getProject(), model);

log.info("[{}][{}]: Prediction complete ({} ms)", user.getUsername(),
recommender.getName(), (System.currentTimeMillis() - startTime));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
*/
package de.tudarmstadt.ukp.inception.recommendation.scheduling.tasks;

import static org.apache.commons.lang3.Validate.notNull;

import java.io.IOException;
import java.util.LinkedList;
import java.util.List;
Expand All @@ -39,7 +37,6 @@
import de.tudarmstadt.ukp.inception.recommendation.imls.core.dataobjects.AnnotationObject;
import de.tudarmstadt.ukp.inception.recommendation.imls.core.loader.AnnotationObjectLoader;
import de.tudarmstadt.ukp.inception.recommendation.imls.core.trainer.Trainer;
import de.tudarmstadt.ukp.inception.recommendation.model.Predictions;
import de.tudarmstadt.ukp.inception.recommendation.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.scheduling.RecommendationScheduler;
import de.tudarmstadt.ukp.inception.recommendation.service.RecommendationService;
Expand All @@ -52,20 +49,14 @@ public class TrainingTask
{
private final Logger log = LoggerFactory.getLogger(getClass());

private final Predictions predictions;

private @Autowired AnnotationSchemaService annoService;
private @Autowired DocumentService documentService;
private @Autowired RecommendationService recommendationService;
private @Autowired RecommendationScheduler recommendationScheduler;

public TrainingTask(User aUser, Project aProject, Predictions aPredictions)
public TrainingTask(User aUser, Project aProject)
{
super(aProject, aUser);

notNull(aPredictions);

predictions = aPredictions;
}

@Override
Expand Down Expand Up @@ -97,8 +88,7 @@ public void run()

log.info("[{}][{}]: Extracting training data...", user.getUsername(),
recommender.getName());
List<List<AnnotationObject>> trainingData = getTrainingData(classificationTool,
predictions);
List<List<AnnotationObject>> trainingData = getTrainingData(classificationTool);

if (trainingData == null || trainingData.isEmpty()) {
log.info("[{}][{}]: No training data.", user.getUsername(),
Expand All @@ -115,11 +105,10 @@ public void run()
}
}

recommendationScheduler.enqueue(new PredictionTask(user, getProject(), predictions));
recommendationScheduler.enqueue(new PredictionTask(user, getProject()));
}

private List<List<AnnotationObject>> getTrainingData(ClassificationTool<?> tool,
Predictions model)
private List<List<AnnotationObject>> getTrainingData(ClassificationTool<?> tool)
{
List<List<AnnotationObject>> result = new LinkedList<>();

Expand All @@ -128,7 +117,7 @@ private List<List<AnnotationObject>> getTrainingData(ClassificationTool<?> tool,
return result;
}

Project p = model.getProject();
Project p = getProject();
List<SourceDocument> docs = documentService.listSourceDocuments(p);

for (SourceDocument doc : docs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,5 +76,7 @@ void setActiveRecommenders(User aUser, AnnotationLayer layer,

Predictions getIncomingPredictions(User aUser, Project aProject);

void putIncomingPredictions(User aUser, Project aProject, Predictions aPredictions);

void switchPredictions(User aUser, Project aProject);
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,19 @@ public Predictions getIncomingPredictions(User aUser, Project aProject)
Predictions predictions;
synchronized (state) {
predictions = state.getIncomingPredictions(aProject);
if (predictions == null) {
predictions = new Predictions(aProject, aUser);
state.putIncomingPredictions(aProject, predictions);
}
}
return predictions;
}

@Override
public void putIncomingPredictions(User aUser, Project aProject, Predictions aPredictions)
{
RecommendationUserState state = getState(aUser.getUsername());
synchronized (state) {
state.putIncomingPredictions(aProject, aPredictions);
}
}

@Override
public void setActiveRecommenders(User aUser, AnnotationLayer aLayer,
List<Recommender> aRecommenders)
Expand Down Expand Up @@ -211,11 +216,7 @@ public void onDocumentOpen(DocumentOpenedEvent aEvent) throws Exception
private void triggerTrainingAndClassification(String aUser, Project aProject)
{
User user = userRepository.get(aUser);
Predictions model = getIncomingPredictions(user, aProject);
if (model == null) {
return;
}
scheduler.enqueueTask(user, aProject, model);
scheduler.enqueueTask(user, aProject);
}

@EventListener
Expand Down

0 comments on commit 8d3ecdd

Please sign in to comment.