In [1]:
from datasets import load_dataset
import random
import numpy as np
import torch
from gliner import GLiNER
from gliner.training import Trainer, TrainingArguments
from gliner.data_processing.collator import DataCollator

In [2]:
def ner_tags_to_spans(samples, tag_to_id):
    """
    Converts NER tags in the dataset samples to spans (start, end, entity type).

    Args:
        samples (dict): A dictionary containing the tokens and NER tags.
        tag_to_id (dict): A dictionary mapping NER tags to IDs.

    Returns:
        dict: A dictionary containing tokenized text and corresponding NER spans.
    """
    ner_tags = samples["ner_tags"]
    id_to_tag = {v: k for k, v in tag_to_id.items()}
    spans = []
    start_pos = None
    entity_name = None

    for i, tag in enumerate(ner_tags):
        if tag_to_id[tag] == 0:  # 'O' tag
            if entity_name is not None:
                spans.append((start_pos, i - 1, entity_name))
                entity_name = None
                start_pos = None
        else:
            tag_name = tag
            if tag_name.startswith('B-'):
                if entity_name is not None:
                    spans.append((start_pos, i - 1, entity_name))
                entity_name = tag_name[2:]
                start_pos = i
            elif tag_name.startswith('I-'):
                continue

    # Handle the last entity if the sentence ends with an entity
    if entity_name is not None:
        spans.append((start_pos, len(samples["tokens"]) - 1, entity_name))

    return {"tokenized_text": samples["tokens"], "ner": spans}

In [3]:
DATASET_NAME = 'adsabs/WIESP2022-NER'

In [4]:
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [5]:
trainset = load_dataset(DATASET_NAME, split='train')
testset = load_dataset(DATASET_NAME, split='test')

In [13]:
unique_tags = set(tag for example in trainset["ner_tags"] for tag in example)
sorted_tags = sorted(unique_tags - {"O"})  # Exclude 'O' from sorted tags
list_tags = list(unique_tags)

clear = [tag[2:] for tag in list_tags]
labels_list = list(dict.fromkeys(clear))

tag_to_id = {"O": 0, **{tag: idx + 1 for idx, tag in enumerate(sorted_tags)}}
id_to_tag = {idx: tag for tag, idx in tag_to_id.items()}

In [7]:
train_dataset = [ner_tags_to_spans(i, tag_to_id) for i in trainset]
test_dataset = [ner_tags_to_spans(i, tag_to_id) for i in testset]

In [9]:
model = GLiNER.from_pretrained("urchade/gliner_small")
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
model = model.to(device)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]



In [48]:
data_collator = DataCollator(model.config, data_processor=model.data_processor, prepare_labels=True)

In [49]:
batch_size = 4
tokenizers = model.data_processor.transformer_tokenizer
tokenizers.model_max_length = 800
model.data_processor.config.max_len = 800

In [51]:
training_args = TrainingArguments(
    output_dir="E:/tmp/models",
    learning_rate=5e-6,
    weight_decay=0.01,
    others_lr=1e-5,
    others_weight_decay=0.01,
    lr_scheduler_type="linear",  # cosine
    warmup_ratio=0.1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    focal_loss_alpha=0.75,
    focal_loss_gamma=2,
    save_strategy="epoch",
    save_total_limit=10,
    num_train_epochs=10,
    logging_strategy='epoch',
    dataloader_num_workers=0,
    use_cpu=False,
    report_to="none",
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    tokenizer=tokenizers,
    data_collator=data_collator,
)

  trainer = Trainer(


In [37]:
trainer.train()

  trainer = Trainer(


In [10]:
model = GLiNER.from_pretrained("E:/tmp/models/checkpoint-2634")
model = model.to(device)

config.json not found in E:\tmp\models\checkpoint-2634


In [11]:
evaluation_results = model.evaluate(
    test_dataset, flat_ner=True, entity_types=labels_list, batch_size=4
)



In [12]:
print(evaluation_results)

('P: 74.09%\tR: 74.21%\tF1: 74.15%\n', 0.7415358671682448)
