# Adjust GLiNER model with contrastive learning (NCELoss)

In [None]:
!pip install gliner



In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
import json

import torch

from gliner.training import Trainer, TrainingArguments
from model_finetuning.model import ContrastiveGLiNER
from model_finetuning.data_collator import ContrastiveDataCollator

In [None]:
# Load the processed output from the JSON file which contains data ready for training
with open("../data/processed_output.json", "r", encoding="utf-8") as file:
    processed_output = json.load(file)

print(len(processed_output))
print(processed_output[0])



In [None]:
train_dataset = processed_output[:int(len(processed_output)*0.9)]
test_dataset = processed_output[int(len(processed_output)*0.9):]

In [None]:
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
gliner_model = ContrastiveGLiNER.from_pretrained("urchade/gliner_small")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]



In [None]:
gliner_model.config

GLiNERConfig {
  "class_token_index": 128002,
  "decoder_mode": null,
  "dropout": 0.4,
  "embed_ent_token": true,
  "encoder_config": {
    "_attn_implementation_autoset": true,
    "_name_or_path": "microsoft/deberta-v3-small",
    "add_cross_attention": false,
    "architectures": null,
    "attention_probs_dropout_prob": 0.1,
    "bad_words_ids": null,
    "begin_suppress_tokens": null,
    "bos_token_id": null,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": null,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "hidden_act": "gelu",
    "hidden_dropout_prob": 0.1,
    "hidden_size": 768,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "initializer_range": 0.02,

In [None]:
gliner_model.to(device)

ContrastiveGLiNER(
  (model): SpanModel(
    (token_rep_layer): Encoder(
      (bert_layer): Transformer(
        (model): DebertaV2Model(
          (embeddings): DebertaV2Embeddings(
            (word_embeddings): Embedding(128004, 768, padding_idx=0)
            (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder): DebertaV2Encoder(
            (layer): ModuleList(
              (0-5): 6 x DebertaV2Layer(
                (attention): DebertaV2Attention(
                  (self): DisentangledSelfAttention(
                    (query_proj): Linear(in_features=768, out_features=768, bias=True)
                    (key_proj): Linear(in_features=768, out_features=768, bias=True)
                    (value_proj): Linear(in_features=768, out_features=768, bias=True)
                    (pos_dropout): Dropout(p=0.1, inplace=False)
                    (dropout): Dropout(p=0.1, inplace=False)
   

In [None]:
data_collator = ContrastiveDataCollator(gliner_model.config, data_processor=gliner_model.data_processor, prepare_labels=True)

In [None]:
num_steps = 5 # 3380
batch_size = 8
data_size = len(train_dataset) # 2700
num_batches = data_size // batch_size
num_epochs = 5 # max(1, num_steps // num_batches)

In [None]:
num_epochs

5

In [None]:
# training_args = TrainingArguments(
#     output_dir="models",
#     learning_rate=5e-6,
#     weight_decay=0.01,
#     others_lr=1e-5,
#     others_weight_decay=0.01,
#     lr_scheduler_type="linear",  # or "cosine"
#     warmup_ratio=0.1,
#     per_device_train_batch_size=batch_size,
#     per_device_eval_batch_size=batch_size,
#     focal_loss_alpha=0.75,
#     focal_loss_gamma=2,
#     num_train_epochs=num_epochs,
#     dataloader_num_workers=0,
#     use_cpu=False,
#     report_to="none",
#     save_strategy="steps",
#     eval_strategy="steps",
#     eval_steps=500,     # choose as appropriate—must align with save_steps
#     save_steps=500,
#     load_best_model_at_end=True,
#     metric_for_best_model="eval_loss",
#     greater_is_better=False,
#     save_total_limit=3,  # optional: keep only latest + best
# )

# Define training arguments which may be adjusted as needed
training_args = TrainingArguments(
    output_dir="contrastive_gliner_model",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=num_epochs,
    learning_rate=5e-5,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=False,
    metric_for_best_model="eval_loss",  # optionally use custom metric (e.g. F1)
    report_to="none",
)

In [None]:
trainer = Trainer(
    model=gliner_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=gliner_model.data_processor.transformer_tokenizer,
    data_collator=data_collator,
)

  trainer = Trainer(


In [None]:
trainer.train()

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Epoch,Training Loss,Validation Loss
1,No log,59862.414062
2,No log,46809.511719
3,No log,80830.078125
4,No log,86430.46875
5,No log,60625.832031


TrainOutput(global_step=60, training_loss=181409.9, metrics={'train_runtime': 274.748, 'train_samples_per_second': 1.638, 'train_steps_per_second': 0.218, 'total_flos': 0.0, 'train_loss': 181409.9, 'epoch': 5.0})

In [None]:
trainer.save_model("finetuned_gliner_model_NCELoss")

In [None]:
best_ckpt = trainer.state.best_model_checkpoint
print(best_ckpt)

In [None]:
trained_model = ContrastiveGLiNER.from_pretrained(best_ckpt, load_tokenizer=True)

config.json not found in /content/contrastive_gliner_model/checkpoint-24


In [None]:
# Example inference
text = "Any part of your neck  muscles, bones, joints, tendons, ligaments, or nerves  can cause neck problems. Neck pain is very common. Pain may also come from your shoulder, jaw, head, or upper arms. Muscle strain or tension often causes neck pain. The problem is usually overuse, such as from sitting at a computer for too long. Sometimes you can strain your neck muscles from sleeping in an awkward position or overdoing it during exercise. Falls or accidents, including car accidents, are another common cause of neck pain. Whiplash, a soft tissue injury to the neck, is also called neck sprain or strain. Treatment depends on the cause, but may include applying ice, taking pain relievers, getting physical therapy or wearing a cervical collar. You rarely need surgery."

labels = ["medical_condition", "body_part", "cause", "symptom", "treatment"]

entities = trained_model.predict_entities(text, labels, threshold=0.5)

for entity in entities:
    print(entity["text"], "=>", entity["label"])

Any => symptom
part => symptom
of => symptom
your => symptom
neck => symptom
muscles => symptom
, => symptom
bones => symptom
, => symptom
joints => symptom
, => symptom
tendons => symptom
, => symptom
ligaments, or => symptom
nerves  can => symptom
cause => symptom
neck => symptom
problems => symptom
. => symptom
Neck => symptom
pain => symptom
is => symptom
very => symptom
common => symptom
. => symptom
Pain => symptom
may also come from your shoulder => symptom
, => symptom
jaw => symptom
, => symptom
head, or upper => symptom
arms => symptom
. => symptom
Muscle => symptom
strain => symptom
or => symptom
tension => symptom
often causes => symptom
neck pain. The problem is usually overuse, such as from => symptom
sitting at a computer for too long => symptom
. => symptom
Sometimes => symptom
you can strain your neck muscles from sleeping => symptom
in => symptom
an => symptom
awkward => symptom
position or overdoing it during => symptom
exercise => symptom
. => symptom
Falls => sympt