Skip to content

Commit

Permalink
wip ml-based suggestions ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
oserikov committed Jun 11, 2018
1 parent 058e51c commit 1e9d027
Show file tree
Hide file tree
Showing 6 changed files with 336 additions and 2 deletions.
7 changes: 7 additions & 0 deletions languagetool-core/pom.xml
Expand Up @@ -189,6 +189,13 @@
<scope>test</scope>
</dependency>

<!--xgboost dependency-->
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.72-SNAPSHOT</version>
</dependency>

</dependencies>

</project>
@@ -0,0 +1,17 @@
package org.languagetool.rules.ngrams;

import org.languagetool.Language;

import java.util.*;

public class GoogleTokenUtil {


public static List<String> getGoogleTokensForString(String sentence, boolean addStartToken, Language language) {
List<String> tokens = new LinkedList<>();
for (GoogleToken token : GoogleToken.getGoogleTokens(sentence, addStartToken, language.getWordTokenizer())) {
tokens.add(token.token);
}
return tokens;
}
}
Expand Up @@ -19,23 +19,29 @@

package org.languagetool.rules.spelling.morfologik;

import ml.dmlc.xgboost4j.java.*;
import org.apache.commons.lang3.tuple.Pair;
import org.jetbrains.annotations.Nullable;
import org.languagetool.AnalyzedSentence;
import org.languagetool.AnalyzedTokenReadings;
import org.languagetool.JLanguageTool;
import org.languagetool.Language;
import org.languagetool.languagemodel.LanguageModel;
import org.languagetool.rules.Categories;
import org.languagetool.rules.ITSIssueType;
import org.languagetool.rules.RuleMatch;
import org.languagetool.rules.spelling.SpellingCheckRule;
import org.languagetool.rules.ngrams.GoogleTokenUtil;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Paths;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public abstract class MorfologikSpellerRule extends SpellingCheckRule {

private static final Integer DEFAULT_CONTEXT_LENGTH = 2;
protected MorfologikMultiSpeller speller1;
protected MorfologikMultiSpeller speller2;
protected MorfologikMultiSpeller speller3;
Expand All @@ -45,6 +51,10 @@ public abstract class MorfologikSpellerRule extends SpellingCheckRule {
private boolean checkCompound = false;
private Pattern compoundRegex = Pattern.compile("-");

private static final String XGBOOST_MODEL_BASE_PATH = "org/languagetool/resource/speller_rule/models/";
private static final String DEFAULT_PATH_TO_NGRAMS = "/home/ec2-user/ngram"; //TODO
private static NGramUtil nGramUtil;
private static Booster booster;
/**
* Get the filename, e.g., <tt>/resource/pl/spelling.dict</tt>.
*/
Expand All @@ -59,6 +69,12 @@ public MorfologikSpellerRule(ResourceBundle messages, Language language) throws
this.conversionLocale = conversionLocale != null ? conversionLocale : Locale.getDefault();
init();
setLocQualityIssueType(ITSIssueType.Misspelling);
nGramUtil = new NGramUtil(language);
try (InputStream models_path = this.getClass().getClassLoader().getResourceAsStream(XGBOOST_MODEL_BASE_PATH + this.getId() + "/spc.model")) {
booster = XGBoost.loadModel(models_path);
} catch (XGBoostError xgBoostError) {
throw new RuntimeException("error when loading xgboost model for " + this.getId());
}
}

@Override
Expand Down Expand Up @@ -193,7 +209,7 @@ protected List<RuleMatch> getRuleMatches(String word, int startPos, AnalyzedSent
suggestions.addAll(getAdditionalSuggestions(suggestions, word));
if (!suggestions.isEmpty()) {
filterSuggestions(suggestions);
ruleMatch.setSuggestedReplacements(orderSuggestions(suggestions, word));
ruleMatch.setSuggestedReplacements(orderSuggestions(suggestions, word, sentence, startPos, word.length()));
}
ruleMatches.add(ruleMatch);
}
Expand All @@ -216,6 +232,96 @@ protected List<String> orderSuggestions(List<String> suggestions, String word) {
return suggestions;
}

protected List<String> orderSuggestions(List<String> suggestions, String word, AnalyzedSentence sentence, int startPos, int wordLength) {


List<Pair<String, Float>> suggestionsProbs = new LinkedList<>();
for (int i = 0; i < suggestions.size(); i++) {
String suggestion = suggestions.get(i);
String text = sentence.getText();
String correctedSentence = text.substring(0, startPos) + suggestion + sentence.getText().substring(startPos + wordLength);
float score = 0;
try {
score = processRow(text, correctedSentence, word, suggestion, startPos, DEFAULT_CONTEXT_LENGTH);
} catch (IOException e) {
e.printStackTrace();
}
suggestionsProbs.add(Pair.of(suggestion, score));

}
Comparator<Pair<String, Float>> comparing = Comparator.comparing(Pair::getValue);
suggestionsProbs.sort(comparing.reversed());
List<String> result = new LinkedList<>();

suggestionsProbs.iterator().forEachRemaining((Pair<String, Float> p) -> result.add(p.getKey()));
return result;
}


private static float processRow(String sentence, String correctedSentence, String covered, String replacement,
Integer suggestionPos, Integer contextLength) throws IOException {


Pair<String, String> context = Pair.of("", "");
int errorStartIdx;

int sentencesDifferenceCharIdx = Utils.firstDifferencePosition(sentence, correctedSentence);
if (sentencesDifferenceCharIdx != -1) {
errorStartIdx = Utils.startOfErrorString(sentence, covered, sentencesDifferenceCharIdx);
if (errorStartIdx != -1) {
context = Utils.extractContext(sentence, covered, errorStartIdx, contextLength);
}
}

String leftContextCovered = context.getKey();
String rightContextCovered = context.getValue();
// String covered = covered;
String correction = replacement;

String leftContextCorrection = leftContextCovered.isEmpty() ? "" : leftContextCovered.substring(0, leftContextCovered.length() - covered.length()) + correction;
String rightContextCorrection = rightContextCovered.isEmpty() ? "" : correction + rightContextCovered.substring(covered.length());

boolean firstLetterMatches = Utils.longestCommonPrefix(new String[]{correction, covered}).length() != 0;

Integer editDistance = Utils.editDisctance(covered, correction);

List<String> leftContextCoveredTokenized = nGramUtil.tokenizeString(leftContextCovered.isEmpty() ? covered : leftContextCovered);
double leftContextCoveredProba = nGramUtil.stringProbability(leftContextCoveredTokenized, 3);
List<String> rightContextCoveredTokenized = nGramUtil.tokenizeString(rightContextCovered.isEmpty() ? covered : rightContextCovered);
double rightContextCoveredProba = nGramUtil.stringProbability(rightContextCoveredTokenized, 3);

List<String> leftContextCorrectionTokenized = nGramUtil.tokenizeString(leftContextCorrection.isEmpty() ? correction : leftContextCorrection);
double leftContextCorrectionProba = nGramUtil.stringProbability(leftContextCorrectionTokenized, 3);
List<String> rightContextCorrectionTokenized = nGramUtil.tokenizeString(rightContextCorrection.isEmpty() ? correction : rightContextCorrection);
double rightContextCorrectionProba = nGramUtil.stringProbability(rightContextCorrectionTokenized, 3);

float left_context_covered_length = leftContextCoveredTokenized.size();
float left_context_covered_proba = (float) leftContextCoveredProba;
float right_context_covered_length = rightContextCoveredTokenized.size();
float right_context_covered_proba = (float) rightContextCoveredProba;
float left_context_correction_length = leftContextCorrectionTokenized.size();
float left_context_correction_proba = (float) leftContextCorrectionProba;
float right_context_correction_length = rightContextCorrectionTokenized.size();
float right_context_correction_proba = (float) rightContextCorrectionProba;
float first_letter_matches = firstLetterMatches ? 1f : 0f;
float edit_distance = editDistance;


float[] data = {left_context_covered_length, left_context_covered_proba,
right_context_covered_length, right_context_covered_proba,
left_context_correction_length, left_context_correction_proba,
right_context_correction_length, right_context_correction_proba, first_letter_matches, edit_distance};
float res = -1;
try {
res = booster.predict(new DMatrix(data, 1, data.length))[0][0];
} catch (XGBoostError xgBoostError) {
xgBoostError.printStackTrace();
}

return res;
}


/**
* @param checkCompound If true and the word is not in the dictionary
* it will be split (see {@link #setCompoundRegex(String)})
Expand Down Expand Up @@ -253,3 +359,199 @@ protected boolean isSurrogatePairCombination (String word) {
return false;
}
}

class Utils {

public static String leftContext(String originalSentence, int errorStartIdx, String errorString, int contextLength) {
String regex = repeat(contextLength, "\\w+\\W+") + errorString + "$";
String stringToSearch = originalSentence.substring(0, errorStartIdx + errorString.length());

return findFirstRegexMatch(regex, stringToSearch);
}

public static String rightContext(String originalSentence, int errorStartIdx, String errorString, int contextLength) {
String regex = "^" + errorString + repeat(contextLength, "\\W+\\w+");
String stringToSearch = originalSentence.substring(errorStartIdx);

return findFirstRegexMatch(regex, stringToSearch);
}

public static int firstDifferencePosition(String sentence1, String sentence2) {
int result = -1;

for (int i = 0; i < sentence1.length(); i++) {
if (i >= sentence2.length() || sentence1.charAt(i) != sentence2.charAt(i)) {
result = i;
break;
}
}

return result;
}

public static int startOfErrorString(String sentence, String errorString, int sentencesDifferenceCharIdx) {
int result = -1;

List<Integer> possibleIntersections = allIndexesOf(sentence.charAt(sentencesDifferenceCharIdx), errorString);
for (int i : possibleIntersections) {
if (sentencesDifferenceCharIdx - i < 0 || sentencesDifferenceCharIdx - i + errorString.length() > sentence.length())
continue;
String possibleErrorString = sentence.substring(sentencesDifferenceCharIdx - i,
sentencesDifferenceCharIdx - i + errorString.length());

if (possibleErrorString.equals(errorString)) {
result = sentencesDifferenceCharIdx - i;
break;
}
}

return result;
}

public static String getMaximalPossibleRightContext(String sentence, int errorStartIdx, String errorString,
int startingContextLength) {
String rightContext = "";
for (int contextLength = startingContextLength; contextLength > 0; contextLength--) {
rightContext = rightContext(sentence, errorStartIdx, errorString, contextLength);
if (!rightContext.isEmpty()) {
break;
}
}
return rightContext;
}

public static String getMaximalPossibleLeftContext(String sentence, int errorStartIdx, String errorString,
int startingContextLength) {
String leftContext = "";
for (int contextLength = startingContextLength; contextLength > 0; contextLength--) {
leftContext = leftContext(sentence, errorStartIdx, errorString, contextLength);
if (!leftContext.isEmpty()) {
break;
}
}
return leftContext;
}

public static Pair<String, String> extractContext(String sentence, String covered, int errorStartIdx, int contextLength) {
int errorEndIdx = errorStartIdx + covered.length();
String errorString = sentence.substring(errorStartIdx, errorEndIdx);

String leftContext = getMaximalPossibleLeftContext(sentence, errorStartIdx, errorString, contextLength);
String rightContext = getMaximalPossibleRightContext(sentence, errorStartIdx, errorString, contextLength);

return Pair.of(leftContext, rightContext);
}


public static String longestCommonPrefix(String[] strs) {
if (strs == null || strs.length == 0) {
return "";
}

if (strs.length == 1)
return strs[0];

int minLen = strs.length + 1;

for (String str : strs) {
if (minLen > str.length()) {
minLen = str.length();
}
}

for (int i = 0; i < minLen; i++) {
for (int j = 0; j < strs.length - 1; j++) {
String s1 = strs[j];
String s2 = strs[j + 1];
if (s1.charAt(i) != s2.charAt(i)) {
return s1.substring(0, i);
}
}
}

return strs[0].substring(0, minLen);
}

public static int editDisctance(String x, String y) {
int[][] dp = new int[x.length() + 1][y.length() + 1];

for (int i = 0; i <= x.length(); i++) {
for (int j = 0; j <= y.length(); j++) {
if (i == 0) {
dp[i][j] = j;
} else if (j == 0) {
dp[i][j] = i;
} else {
dp[i][j] = min(dp[i - 1][j - 1] + costOfSubstitution(x.charAt(i - 1), y.charAt(j - 1)),
dp[i - 1][j] + 1,
dp[i][j - 1] + 1);
}
}
}

return dp[x.length()][y.length()];
}


private static int costOfSubstitution(char a, char b) {
return a == b ? 0 : 1;
}

private static int min(int... numbers) {
return Arrays.stream(numbers)
.min().orElse(Integer.MAX_VALUE);
}

private static String findFirstRegexMatch(String regex, String stringToSearch) {
String result = "";

Pattern pattern = Pattern.compile(regex);
Matcher stringToSearchMatcher = pattern.matcher(stringToSearch);

if (stringToSearchMatcher.find()) {
result = stringToSearch.substring(stringToSearchMatcher.start(), stringToSearchMatcher.end());
}

return result;
}

private static String repeat(int count, String with) {
return new String(new char[count]).replace("\0", with);
}

private static List<Integer> allIndexesOf(char character, String string) {
List<Integer> indexes = new ArrayList<>();
for (int index = string.indexOf(character); index >= 0; index = string.indexOf(character, index + 1)) {
indexes.add(index);
}
return indexes;
}
}

class NGramUtil {

// private static final JLanguageTool lt = new JLanguageTool(new AmericanEnglish());
private static Language language;
private static LanguageModel languageModel;

public NGramUtil(Language language) {
try {
NGramUtil.language = language;
System.out.println("in ngram utils: " + System.getProperty("ngram.path"));
languageModel = language.getLanguageModel(Paths.get(System.getProperty("ngram.path")).toFile());
} catch (IOException e) {
throw new RuntimeException("NGram file not found");
}
}

public List<String> tokenizeString(String s) {
return GoogleTokenUtil.getGoogleTokensForString(s, false, language);
}

public Double stringProbability(List<String> sTokenized, int length) {
if (sTokenized.size() > length) {
sTokenized = sTokenized.subList(sTokenized.size() - length, sTokenized.size());
}
return sTokenized.isEmpty() ? null : languageModel.getPseudoProbability(sTokenized).getProb();
}
}
Binary file not shown.

0 comments on commit 1e9d027

Please sign in to comment.