In [1]:
import sys
sys.path.append("/workspaces/chisel/")

# 👗 Example: Processing Fashion Brand NER (JSON Format) with Chisel

This example shows how to preprocess the explosion/ner-fashion-brands dataset into ChiselRecord objects for training transformer-based NER models using BILO labeling.

## 📥 Step 1: Load the Dataset

In [2]:
from datasets import load_dataset
ds = load_dataset("explosion/ner-fashion-brands")

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
ds['train'][0]['text']

"It's all preference for which looks better, personally I feel that the more natural the hair looks the better the style, which for me means going with a matte finish which leaves the hair looking as natural as possible while still holding it in place"

## 🧩 Step 2: Implement a JSON Span Parser
The dataset provides character-level spans in a spans field. We write a parser that extracts these into Chisel's EntitySpan format.

In [5]:
from typing import Tuple, List
from chisel.extraction.base.protocols import Parser
from chisel.extraction.models.models import EntitySpan

class JSONSpanParser(Parser):
    def parse(self, doc: dict) -> Tuple[str, List[EntitySpan]]:
        text = doc["text"]
        entities = [
            EntitySpan(
                text=text[e["start"]:e["end"]],
                start=e["start"],
                end=e["end"],
                label=e["label"]
            )
            for e in doc.get("spans", [])
        ]
        return text, entities

## 🔧 Step 3: Initialize Chisel Components

In [6]:
from chisel.extraction.tokenizers.hf_tokenizer import HFTokenizer
from chisel.extraction.span_aligners.token_span_aligner import TokenSpanAligner
from chisel.extraction.labelers.bilo_labeler import BILOLabeler
from chisel.extraction.labelers.label_encoder import SimpleLabelEncoder
from chisel.extraction.validators.validators import DefaultParseValidator, HFTokenAlignmentValidator
from chisel.extraction.formatters.torch_formatter import TorchDatasetFormatter
from chisel.extraction.models.models import ChiselRecord

## Component Setup

In [7]:
parser = JSONSpanParser()
tokenizer = HFTokenizer(model_name="bert-base-cased")
aligner = TokenSpanAligner()
labeler = BILOLabeler()

label_encoder = SimpleLabelEncoder(label_to_id={
    'O': 0,
    'B-FASHION_BRAND': 1,
    'I-FASHION_BRAND': 2,
    'L-FASHION_BRAND': 3,
    'U-FASHION_BRAND': 4,
})

parse_validators = [DefaultParseValidator(on_error="raise")]
label_validators = [HFTokenAlignmentValidator(tokenizer=tokenizer.tokenizer, on_error="raise")]
formatter = TorchDatasetFormatter()

## 🔄 Step 4: Run the Preprocessing Pipeline

In [8]:
processed_data = []

for idx, example in enumerate(ds["train"]):
    text, entities = parser.parse(example)

    # 🧪 Per-span validation — skip bad spans
    valid_spans = []
    for span in entities:
        try:
            for validator in parse_validators:
                validator.validate(text, span)
            valid_spans.append(span)
        except ValueError:
            continue 

    tokens = tokenizer.tokenize(text)
    token_entity_spans = aligner.align(entities, tokens)

    labels = labeler.label(tokens, token_entity_spans)
    encoded_labels = label_encoder.encode(labels)

    # 🧪 Per-span validation — skip bad spans
    valid_token_spans = []
    for span in token_entity_spans:
        try:
            for validator in label_validators:
                validator.validate(tokens, span)
            valid_token_spans.append(span)
        except ValueError:
            continue  # Optionally log or collect stats on dropped spans

    record = ChiselRecord(
        id=str(idx),
        chunk_id=0,
        text=tokenizer.tokenizer.decode([token.id for token in tokens]),
        tokens=tokens,
        input_ids=[token.id for token in tokens],
        attention_mask=[1] * len(tokens),
        entities=[tes.entity for tes in valid_token_spans],
        bio_labels=labels,
        labels=encoded_labels
    )
    processed_data.append(record)

data = formatter.format(processed_data)

Token indices sequence length is longer than the specified maximum sequence length for this model (525 > 512). Running this sequence through the model will result in indexing errors


### ✅ Output
You now have a torch dataset ready for training!

In [9]:
data[0]

{'input_ids': tensor([ 1135,   112,   188,  1155, 12629,  1111,  1134,  2736,  1618,   117,
          7572,   146,  1631,  1115,  1103,  1167,  2379,  1103,  1716,  2736,
          1103,  1618,  1103,  1947,   117,  1134,  1111,  1143,  2086,  1280,
          1114,   170, 22591,  1566,  3146,  1134,  2972,  1103,  1716,  1702,
          1112,  2379,  1112,  1936,  1229,  1253,  2355,  1122,  1107,  1282]),
 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1]),
 'labels': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0])}

In [10]:
for idx, mask, label in zip(
    data[1]["input_ids"][0:20], 
    data[1]["attention_mask"][0:20], 
    data[1]["labels"][0:20]
):
    print(f"({idx}): {mask}, {label}")

(23669): 1, 0
(2298): 1, 0
(1303): 1, 0
(117): 1, 0
(3983): 1, 0
(787): 1, 0
(189): 1, 0
(1899): 1, 0
(1330): 1, 0
(1141): 1, 0
(2589): 1, 0
(1111): 1, 0
(1103): 1, 0
(3813): 1, 0
(1107): 1, 0
(1103): 1, 0
(27103): 1, 1
(2101): 1, 3
(2984): 1, 0
(119): 1, 0
