In [None]:
import os
from datetime import datetime

os.environ["TOKENIZERS_PARALLELISM"] = "true"
import torch
from datasets import load_dataset

import pandas as pd

from nervaluate import Evaluator
from transformers import TrainerCallback
from ray import train, tune
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.train.huggingface.transformers import prepare_trainer
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch

from gliner import GLiNER
from gliner.training import Trainer, TrainingArguments
from gliner.data_processing.collator import DataCollatorWithPadding
from gliner.data_processing import GLiNERDataset

from utils import formatting_prompts_func, convert_to_gliner_dataset, combine_entities

In [None]:
ds = load_dataset("jjzha/skillspan")

In [None]:
ds = load_dataset("jjzha/skillspan")
ds = ds.map(formatting_prompts_func)

In [None]:
data = convert_to_gliner_dataset(ds['train'])

In [None]:
train_data = data

In [None]:
ds['validation'] = ds['validation'].map(function=combine_entities, batched=False)
ds['test'] = ds['test'].map(function=combine_entities, batched=False)

In [None]:
class MyCallback(TrainerCallback):

    def on_log(self, args, state, control, model, tokenizer, **kwargs):
        labels = ["Skill", "Knowledge"]
        pred_gli = []

        sentences = ds['validation']['sentence'][:]
        batch_size = 64
        # Generate batches
        for i in range(0, len(sentences), batch_size):
            # Yield successive batches of size batch_size
            text = sentences[i:i + batch_size]
            entities = model.batch_predict_entities(text, labels, threshold=0.82)
            pred_gli.extend(entities)
        
        evaluator = Evaluator(ds['validation']['knowledge_and_skill'][:], pred_gli, tags=['Skill', 'Knowledge'])
        # Returns overall metrics and metrics for each tag
        results, results_per_tag, result_indices, result_indices_by_tag = evaluator.evaluate()
        f1 = {"f1": results['strict']['f1']}
        train.report(metrics=f1)


In [None]:
def train_gliner(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Data Setup
    model = GLiNER.from_pretrained("EmergentMethods/gliner_small_news-v2.1",model_max_length=512)
    train_dataset = GLiNERDataset(train_data, model.config, data_processor=model.data_processor)
    data_collator = DataCollatorWithPadding(model.config)

    model.to(device)

    max_steps = config['max_steps']
    logging_steps = 10
    
    training_args = TrainingArguments(
        output_dir=train.get_context().get_trial_dir(),
        do_eval = False,
        learning_rate=config['lr'],
        weight_decay=config['weight_decay'],
        others_lr= config['others_lr'],
        others_weight_decay=0.01,
        lr_scheduler_type=config['scheduler'], #cosine
        optim = config['optim'],
        warmup_ratio=0.1,
        max_grad_norm = config['max_grad_norm'],
        per_device_train_batch_size=config['batch_size'],
        per_device_eval_batch_size=8,
        max_steps = max_steps,
        logging_steps = logging_steps,
        save_strategy="steps",
        save_steps = logging_steps*2,
        save_total_limit = 2,
        dataloader_num_workers = 8,
        use_cpu = False,
        report_to="none",
        )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        tokenizer=model.data_processor.transformer_tokenizer,
        data_collator=data_collator,
        callbacks=[MyCallback()]
    )
    trainer = prepare_trainer(trainer)
    trainer.train()

In [None]:
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%m_%d_%H_%M")

trainable_with_resources = tune.with_resources(train_gliner,{"cpu":7,"gpu": 1})
max_concurrent_trials = torch.cuda.device_count()

## with search algorithm
algo = OptunaSearch()

# Hyperparameter search space
search_space = {
    "lr": tune.quniform(10e-5, 30e-5, 2e-5),
    "others_lr": tune.quniform(3e-5, 10e-5, 1e-5),
    "max_steps": tune.choice([400]),
    "weight_decay": tune.quniform(0.03, 0.07, 0.01),
    "scheduler": tune.choice(["linear"]), 
    'batch_size': tune.choice([32]),
    "max_grad_norm": tune.quniform(0.3, 0.8, 0.1),
    "optim": tune.choice(['adamw_torch_fused'])
}
wb_project = f"raytune_gliner_{formatted_datetime}"

tuner = tune.Tuner(
    trainable_with_resources,
    param_space=search_space,
    tune_config=tune.TuneConfig(
        num_samples=10,
        search_alg=algo,
        metric="f1",
        mode="max",
        max_concurrent_trials=max_concurrent_trials,
        scheduler=ASHAScheduler(grace_period=30)
    ),
    run_config=train.RunConfig(
        callbacks=[WandbLoggerCallback(project=wb_project)],
        storage_path="~/raytune/checkpoint"
    ),
)
results = tuner.fit()

In [None]:
dfs = {result.path: result.metrics_dataframe for result in results}
# Plot by epoch
ax = None  # This plots everything on the same plot
for d in dfs.values():
    ax = d.f1.plot(ax=ax, legend=False)

In [None]:
best_result = results.get_best_result("f1", mode="max")

In [None]:
best_result

In [None]:
best_result.path

In [None]:
# NOTE: Raytune doesn't keep track of the best iteration (checkpoint) with the current version and with the current evaluation setup, make sure to select the best model path and checkpoint based on the F1 chart from your reporting tool (e.g. WandB or Tensorboard)
md = GLiNER.from_pretrained(f'{best_result.path}/checkpoint-400', load_tokenizer=True, local_files_only=True)

In [None]:
sentences = ds['validation']['sentence'][:]
labels = ['Skill', 'Knowledge']
y = ds['validation']['knowledge_and_skill'][:]

pred_gli = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
md.to(device)
batch_size = 64
for i in range(0, len(sentences), batch_size):
    # Yield successive batches of size batch_size
    text = sentences[i:i + batch_size]
    entities = md.batch_predict_entities(text, labels, threshold=0.89)
    pred_gli.extend(entities)

In [None]:
def proc_threshold(threshold):
    pred_gli_threshold = []
    for i in pred_gli:
        if len(i) ==0:
            pred_gli_threshold.append(i)
        else:
            pred_gli_threshold.append([j for j in i if j['score']> threshold])
            
    evaluator = Evaluator(y, pred_gli_threshold, tags=['Skill', 'Knowledge'])
    results, results_per_tag, result_indices, result_indices_by_tag = evaluator.evaluate()
    f1 = results['strict']['f1']
    return {"threshold": threshold, 'f1': f1}

In [None]:
from joblib import Parallel, delayed
res = Parallel(n_jobs=-1)(delayed(proc_threshold)(i / 100) for i in range(5, 100, 1))
res

In [None]:
# Select the threshold For GLiNER based on the best performance on the dev set
threshold = pd.DataFrame(res).sort_values(by='f1', ascending=False).iloc[0]['threshold']

In [None]:
sentences = ds['validation']['sentence'][:]
labels = ['Skill', 'Knowledge']
y = ds['validation']['knowledge_and_skill'][:]

pred_gli = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
md.to(device)
batch_size = 64
for i in range(0, len(sentences), batch_size):
    # Yield successive batches of size batch_size
    text = sentences[i:i + batch_size]
    entities = md.batch_predict_entities(text, labels, threshold=threshold)
    pred_gli.extend(entities)
evaluator = Evaluator(y, pred_gli, tags=['Skill', 'Knowledge'])
results, results_per_tag, result_indices, result_indices_by_tag = evaluator.evaluate()
results['strict']

In [None]:
{entity: entity_metric['strict'] for entity, entity_metric in results_per_tag.items()}

In [None]:
sentences = ds['test']['sentence'][:]
labels = ['Skill', 'Knowledge']
y = ds['test']['knowledge_and_skill'][:]

pred_gli = []
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
md.to(device)
batch_size = 64
for i in range(0, len(sentences), batch_size):
    # Yield successive batches of size batch_size
    text = sentences[i:i + batch_size]
    entities = md.batch_predict_entities(text, labels, threshold=threshold)
    pred_gli.extend(entities)
evaluator = Evaluator(y, pred_gli, tags=['Skill', 'Knowledge'])
results, results_per_tag, result_indices, result_indices_by_tag = evaluator.evaluate()
results['strict']

In [None]:
{entity: entity_metric['strict'] for entity, entity_metric in results_per_tag.items()}