Skip to content

Commit

Permalink
[ML] expand allowed NER labels to be any I-O-B tagged labels (#87091)
Browse files Browse the repository at this point in the history
Named entity recognition (NER) is a special form of token classification. The specific kind of labelling we support is Inside-Outside-Beginning (IOB) tagging. These labels indicate if they are the inside of a token (with a `I-` or `I_`), the beginning (`B-` or `B_`) or outside (`O`). 

Each valid token classification label starts with the require prefix or `O`. 

Before this commit, we restricted the labels to a specific set:

```
O(Entity.NONE),      // Outside a named entity
B_MISC(Entity.MISC), // Beginning of a miscellaneous entity right after another miscellaneous entity
I_MISC(Entity.MISC), // Miscellaneous entity
B_PER(Entity.PER),   // Beginning of a person's name right after another person's name
I_PER(Entity.PER),   // Person's name
B_ORG(Entity.ORG),   // Beginning of an organization right after another organization
I_ORG(Entity.ORG),   // Organisation
B_LOC(Entity.LOC),   // Beginning of a location right after another location
I_LOC(Entity.LOC);   // Location
```

But now, any entity is allowed, as long as the naming of the labels adhere to IOB tagging rules.
  • Loading branch information
benwtrent committed May 25, 2022
1 parent fd442d3 commit 90d93a9
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 198 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/87091.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 87091
summary: Expand allowed NER labels to be any I-O-B tagged labels
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,20 @@
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;

public class NerConfig implements NlpConfig {

public static boolean validIOBTag(String label) {
return label.toUpperCase(Locale.ROOT).startsWith("I-")
|| label.toUpperCase(Locale.ROOT).startsWith("B-")
|| label.toUpperCase(Locale.ROOT).startsWith("I_")
|| label.toUpperCase(Locale.ROOT).startsWith("B_")
|| label.toUpperCase(Locale.ROOT).startsWith("O");
}

public static final String NAME = "ner";

public static NerConfig fromXContentStrict(XContentParser parser) {
Expand Down Expand Up @@ -80,6 +89,22 @@ public NerConfig(
.orElse(new VocabularyConfig(InferenceIndexConstants.nativeDefinitionStore()));
this.tokenization = tokenization == null ? Tokenization.createDefault() : tokenization;
this.classificationLabels = classificationLabels == null ? Collections.emptyList() : classificationLabels;
if (this.classificationLabels.isEmpty() == false) {
List<String> badLabels = this.classificationLabels.stream().filter(l -> validIOBTag(l) == false).toList();
if (badLabels.isEmpty() == false) {
throw ExceptionsHelper.badRequestException(
"[{}] only allows IOB tokenization tagging for classification labels; provided {}",
NAME,
badLabels
);
}
if (this.classificationLabels.stream().noneMatch(l -> l.toUpperCase(Locale.ROOT).equals("O"))) {
throw ExceptionsHelper.badRequestException(
"[{}] only allows IOB tokenization tagging for classification labels; missing outside label [O]",
NAME
);
}
}
this.resultsField = resultsField;
if (this.tokenization.span != -1) {
throw ExceptionsHelper.badRequestException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
import org.elasticsearch.xpack.core.ml.inference.InferenceConfigItemTestCase;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;

public class NerConfigTests extends InferenceConfigItemTestCase<NerConfig> {

Expand Down Expand Up @@ -48,6 +52,12 @@ protected NerConfig mutateInstanceForVersion(NerConfig instance, Version version
}

public static NerConfig createRandom() {
Set<String> randomClassificationLabels = new HashSet<>(
Stream.generate(() -> randomFrom("O", "B_PER", "I_PER", "B_ORG", "I_ORG", "B_LOC", "I_LOC", "B_CUSTOM", "I_CUSTOM"))
.limit(10)
.toList()
);
randomClassificationLabels.add("O");
return new NerConfig(
randomBoolean() ? null : VocabularyConfigTests.createRandom(),
randomBoolean()
Expand All @@ -57,7 +67,7 @@ public static NerConfig createRandom() {
MPNetTokenizationTests.createRandom(),
RobertaTokenizationTests.createRandom()
),
randomBoolean() ? null : randomList(5, () -> randomAlphaOfLength(10)),
randomBoolean() ? null : new ArrayList<>(randomClassificationLabels),
randomBoolean() ? null : randomAlphaOfLength(5)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.core.ml.inference.results.InferenceResults;
import org.elasticsearch.xpack.core.ml.inference.results.NerResults;
Expand All @@ -21,64 +19,58 @@
import org.elasticsearch.xpack.ml.inference.nlp.tokenizers.TokenizationResult;
import org.elasticsearch.xpack.ml.inference.pytorch.results.PyTorchInferenceResult;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.Set;

import static org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfig.DEFAULT_RESULTS_FIELD;

public class NerProcessor extends NlpTask.Processor {

public enum Entity implements Writeable {
NONE,
MISC,
PER,
ORG,
LOC;

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(this);
}

@Override
public String toString() {
return name().toUpperCase(Locale.ROOT);
record IobTag(String tag, String entity) {
static IobTag fromTag(String tag) {
String entity = tag.toUpperCase(Locale.ROOT);
if (entity.startsWith("B-") || entity.startsWith("I-") || entity.startsWith("B_") || entity.startsWith("I_")) {
entity = entity.substring(2);
return new IobTag(tag, entity);
} else if (entity.equals("O")) {
return new IobTag(tag, entity);
} else {
throw new IllegalArgumentException("classification label [" + tag + "] is not an entity I-O-B tag.");
}
}
}

// Inside-Outside-Beginning (IOB) tag
enum IobTag {
O(Entity.NONE), // Outside a named entity
B_MISC(Entity.MISC), // Beginning of a miscellaneous entity right after another miscellaneous entity
I_MISC(Entity.MISC), // Miscellaneous entity
B_PER(Entity.PER), // Beginning of a person's name right after another person's name
I_PER(Entity.PER), // Person's name
B_ORG(Entity.ORG), // Beginning of an organisation right after another organisation
I_ORG(Entity.ORG), // Organisation
B_LOC(Entity.LOC), // Beginning of a location right after another location
I_LOC(Entity.LOC); // Location

private final Entity entity;

IobTag(Entity entity) {
this.entity = entity;
boolean isBeginning() {
return tag.startsWith("b") || tag.startsWith("B");
}

Entity getEntity() {
return entity;
boolean isNone() {
return tag.equals("o") || tag.equals("O");
}

boolean isBeginning() {
return name().toLowerCase(Locale.ROOT).startsWith("b");
@Override
public String toString() {
return tag;
}
}

static final IobTag[] DEFAULT_IOB_TAGS = new IobTag[] {
IobTag.fromTag("O"), // Outside a named entity
IobTag.fromTag("B_MISC"), // Beginning of a miscellaneous entity right after another miscellaneous entity
IobTag.fromTag("I_MISC"), // Miscellaneous entity
IobTag.fromTag("B_PER"), // Beginning of a person's name right after another person's name
IobTag.fromTag("I_PER"), // Person's name
IobTag.fromTag("B_ORG"), // Beginning of an organisation right after another organisation
IobTag.fromTag("I_ORG"), // Organisation
IobTag.fromTag("B_LOC"), // Beginning of a location right after another location
IobTag.fromTag("I_LOC") // Location
};

private final NlpTask.RequestBuilder requestBuilder;
private final IobTag[] iobMap;
private final String resultsField;
Expand All @@ -102,10 +94,10 @@ private void validate(List<String> classificationLabels) {
}

ValidationException ve = new ValidationException();
EnumSet<IobTag> tags = EnumSet.noneOf(IobTag.class);
Set<IobTag> tags = new HashSet<>();
for (String label : classificationLabels) {
try {
IobTag iobTag = IobTag.valueOf(label);
IobTag iobTag = IobTag.fromTag(label);
if (tags.contains(iobTag)) {
ve.addValidationError("the classification label [" + label + "] is duplicated in the list " + classificationLabels);
}
Expand All @@ -114,23 +106,20 @@ private void validate(List<String> classificationLabels) {
ve.addValidationError("classification label [" + label + "] is not an entity I-O-B tag.");
}
}

if (ve.validationErrors().isEmpty() == false) {
ve.addValidationError("Valid entity I-O-B tags are " + Arrays.toString(IobTag.values()));
throw ve;
}
}

static IobTag[] buildIobMap(List<String> classificationLabels) {
if (classificationLabels == null || classificationLabels.isEmpty()) {
return IobTag.values();
return DEFAULT_IOB_TAGS;
}

IobTag[] map = new IobTag[classificationLabels.size()];
for (int i = 0; i < classificationLabels.size(); i++) {
map[i] = IobTag.valueOf(classificationLabels.get(i));
map[i] = IobTag.fromTag(classificationLabels.get(i));
}

return map;
}

Expand Down Expand Up @@ -281,15 +270,15 @@ static List<NerResults.EntityGroup> groupTaggedTokens(List<TaggedToken> tokens,
int startTokenIndex = 0;
while (startTokenIndex < tokens.size()) {
TaggedToken token = tokens.get(startTokenIndex);
if (token.tag.getEntity() == Entity.NONE) {
if (token.tag.isNone()) {
startTokenIndex++;
continue;
}
int endTokenIndex = startTokenIndex + 1;
double scoreSum = token.score;
while (endTokenIndex < tokens.size()) {
TaggedToken endToken = tokens.get(endTokenIndex);
if (endToken.tag.isBeginning() || endToken.tag.getEntity() != token.tag.getEntity()) {
if (endToken.tag.isBeginning() || endToken.tag.entity().equals(token.tag.entity()) == false) {
break;
}
scoreSum += endToken.score;
Expand All @@ -300,13 +289,7 @@ static List<NerResults.EntityGroup> groupTaggedTokens(List<TaggedToken> tokens,
int endPos = tokens.get(endTokenIndex - 1).token.endOffset();
String entity = inputSeq.substring(startPos, endPos);
entities.add(
new NerResults.EntityGroup(
entity,
token.tag.getEntity().toString(),
scoreSum / (endTokenIndex - startTokenIndex),
startPos,
endPos
)
new NerResults.EntityGroup(entity, token.tag.entity(), scoreSum / (endTokenIndex - startTokenIndex), startPos, endPos)
);
startTokenIndex = endTokenIndex;
}
Expand All @@ -317,14 +300,7 @@ static List<NerResults.EntityGroup> groupTaggedTokens(List<TaggedToken> tokens,
record TaggedToken(DelimitedToken token, IobTag tag, double score) {
@Override
public String toString() {
return new StringBuilder("{").append("token:")
.append(token)
.append(", ")
.append(tag)
.append(", ")
.append(score)
.append("}")
.toString();
return "{" + "token:" + token + ", " + tag + ", " + score + "}";
}
}
}
Expand Down

0 comments on commit 90d93a9

Please sign in to comment.