diff --git a/Classification/Core/src/main/java/org/tribuo/classification/ImmutableLabelInfo.java b/Classification/Core/src/main/java/org/tribuo/classification/ImmutableLabelInfo.java index 3d6b6b0f6..6d0dbbfba 100644 --- a/Classification/Core/src/main/java/org/tribuo/classification/ImmutableLabelInfo.java +++ b/Classification/Core/src/main/java/org/tribuo/classification/ImmutableLabelInfo.java @@ -29,9 +29,14 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; import java.util.logging.Level; import java.util.logging.Logger; @@ -61,15 +66,18 @@ private ImmutableLabelInfo(ImmutableLabelInfo info) { ImmutableLabelInfo(LabelInfo info) { super(info); - idLabelMap = new HashMap<>(); - labelIDMap = new HashMap<>(); + idLabelMap = new LinkedHashMap<>(); + labelIDMap = new LinkedHashMap<>(); int counter = 0; - for (Map.Entry e : labelCounts.entrySet()) { - idLabelMap.put(counter,e.getKey()); - labelIDMap.put(e.getKey(),counter); + SortedSet keys = new TreeSet<>(labelCounts.keySet()); + Set