Skip to content

Commit

Permalink
Add emission probabilities to t9 code
Browse files Browse the repository at this point in the history
  • Loading branch information
bxt committed May 12, 2015
1 parent 47f0622 commit 91c7411
Showing 1 changed file with 40 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
import java.util.Scanner;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;


import org.jooq.lambda.Seq;
import org.jooq.lambda.tuple.Tuple2;

public class T9 {

private final static String ARROW = " --> ";
private final static int PRECISION = 100000;

private Map<String, Long> letterCounts = new HashMap<String, Long>();
private Map<String, Integer> digaramCounts = new HashMap<String, Integer>();
private Map<String, Long> digaramCounts = new HashMap<String, Long>();
private Map<String, Integer> characterKey = new HashMap<String, Integer>();
{
Stream.of("a","b","c" ).forEach(c -> characterKey.put(c, 2));
Expand All @@ -45,27 +48,40 @@ public static void main(String[] args) {

public T9(String sampleTextFilename) {
InputStream stream = getClass().getResourceAsStream(sampleTextFilename);
Collection<String> tokens = null;
try(Scanner sc = new Scanner(stream)) {
sc.useDelimiter("[0-9]*");
Seq.seq(sc)
.duplicate()
.map1(s -> {
Tuple2<Seq<String>, Seq<String>> t = s.duplicate();
letterCounts = t.v1().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
return t.v2();
})
.map1(s -> Seq.concat(Seq.of(" "),s))
.map((l,r) -> Seq.zip(l, r, String::concat))
.forEach(digram -> digaramCounts.merge(digram, 1, Integer::sum));
tokens = Seq.seq(sc).collect(Collectors.toList());
}
System.out.println(digaramCounts);
normalizeCounts();
System.out.println(digaramCounts);
letterCounts = tokens.stream().collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
digaramCounts = Seq.seq(tokens)
.duplicate()
.map1(s -> Seq.concat(Seq.of(" "),s))
.map((l,r) -> Seq.zip(l, r, String::concat))
.collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
normalizeTransitions();
normalizeEmissions();
}

private void normalizeCounts() {
private void normalizeTransitions() {
digaramCounts.forEach((digram, count) -> {
digaramCounts.put(digram, (int)Math.round(100*Math.log((float)count/letterCounts.get(digram.substring(0, 1)))));
digaramCounts.put(digram, log((double)count/letterCounts.get(digram.substring(0, 1))));
});
}

private void normalizeEmissions() {
Seq.concat(Seq.of(0),IntStream.range(2, 10).boxed())
.map(this::lettersFor)
.map(s -> s.collect(Collectors.toList()))
.forEach(letters -> {
long total = letters
.stream()
.map(letterCounts::get)
.reduce(Long::sum)
.get();
letters.stream().forEach((letter) -> {
letterCounts.compute(letter, (x, count) -> log((double)count/total));
});
});
}

Expand Down Expand Up @@ -97,7 +113,7 @@ public String word(int[] keypresses) {
private Node maxTransition(Collection<Node> prev, String letter) {
return Seq
.zip(prev.stream(), prev.stream().map(
n -> n.getProbability() + digaramCounts.getOrDefault(n.getLetter() + letter, 0)))
n -> n.getProbability() + digaramCounts.getOrDefault(n.getLetter() + letter, 0L) + letterCounts.get(letter)))
.maxBy(Tuple2::v2).get()
.map((predecessor, probability) -> new Node(letter, predecessor, probability));
}
Expand All @@ -110,12 +126,16 @@ private Stream<String> lettersFor(int key) {
.map(Entry::getKey);
}

private static long log(double probability) {
return Math.round(PRECISION*Math.log(probability));
}

private static class Node {
private String letter;
private Node predecessor;
private int probability;
private long probability;

public Node(String letter, Node predecessor, int probability) {
public Node(String letter, Node predecessor, long probability) {
this.letter = letter;
this.predecessor = predecessor;
this.probability = probability;
Expand All @@ -129,7 +149,7 @@ public Node getPredecessor() {
return predecessor;
}

public int getProbability() {
public long getProbability() {
return probability;
}
}
Expand Down

0 comments on commit 91c7411

Please sign in to comment.