# **Using LLMs to Generate Synthetic Data for Fine-Tuning GLiNER**

In this notebook, we'll explore a simple way to generate synthetic data for fine-tuning GLiNER. I have used a similar approach to generate training data for [**PII extraction**](https://huggingface.co/urchade/gliner_multi_pii-v1). We will be using `Mistral-7B-Instruct-v0.2`, though I think there are better LLMs available online (like LLaMa-3 ... etc).

Additionally, the prompt used in this example is far from optimal, so you should adapt it to your specific use case or domain. This notebook serves only as an example for practitioners, as some people have requested one.

In this notebook, we generate **fully synthetic data**, including both text and entity annotations, but if you have quality data from your target domain, *you can alternatively have the LLM annotate your existing data*. 📊📝

Feel free to experiment and tailor the approach to better suit your needs! *Happy fine-tuning!* 🌟

In [2]:
# install vllm (https://github.com/vllm-project/vllm)

In [2]:
from vllm import LLM, SamplingParams

## Load large language model

In [3]:
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.2" # you can use a better model
NUM_GPUs = 4

In [4]:
llm = LLM(model=LLM_MODEL, tensor_parallel_size=NUM_GPUs, dtype="half")



2024-05-10 21:18:20,760	INFO worker.py:1724 -- Started a local Ray instance.


INFO 05-10 21:18:25 llm_engine.py:72] Initializing an LLM engine with config: model='/gpfsdswork/dataset/HuggingFace_Models/mistralai/Mistral-7B-Instruct-v0.2', tokenizer='/gpfsdswork/dataset/HuggingFace_Models/mistralai/Mistral-7B-Instruct-v0.2', tokenizer_mode=auto, revision=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=auto, tensor_parallel_size=4, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto, seed=0)
INFO 05-10 21:19:02 llm_engine.py:322] # GPU blocks: 42735, # CPU blocks: 8192
INFO 05-10 21:19:05 model_runner.py:632] Capturing the model for CUDA graphs. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI.
INFO 05-10 21:19:05 model_runner.py:636] CUDA graphs can take additional 1~3 GiB memory per GPU. If you are running out of memory, consider dec

In [5]:
# sampling parameters
sampling_params = SamplingParams(top_k=100, max_tokens=1000, top_p=0.8, stop="<end>")

## Prompting function

In [6]:
def create_json_prompt_for_synthetic_data(**kwargs):
    
    # Use dictionary comprehension to filter out 'n/a' values and to keep the code flexible
    attributes = {key: value for key, value in kwargs.items() if value != "n/a"}
    
    # Building the initial part of the prompt
    prompt = """
**Objective:**
Produce realistic text passages that include clearly identified named entities. Each entity should be meticulously labeled according to its type for straightforward extraction.

**Format Requirements:**
- The output should be formatted in JSON, containing the text and the corresponding entities list.
- Each entity in the text should be accurately marked and annotated in the 'entities' list.
- Meticulously follow all the listed attributes.

**Entity Annotation Details:**
- All entity types must be in lowercase. For example, use "type" not "TYPE".
- Entity types can be multiwords separate by space. For instance, use "entity type" rather than "entity_type".
- Entities spans can be nested within other entities.
- A single entity may be associated with multiple types. list them in the key "types".

**Output Schema:**

<start attribute_1="value1" attribute_2="value2" ...>
{
  "text": "{text content}",
  "entities": [
    {"entity": "entity name", "types": ["type 1", "type 2", ...]},
    ...
  ]
}
<end>

**Here are some real world examples**:"""

    # Create a string of attributes for the <start> tag, excluding any 'n/a' values
    attributes_string = " ".join([f'{key}="{value}"' for key, value in attributes.items()])

    # Adding the dynamically created attributes string to the prompt
    prompt += f"""
<start {attributes_string}>
"""

    return prompt

## Example of generation

In [7]:
import json

def generate(**kwargs):
    outputs = llm.generate([create_json_prompt_for_synthetic_data(**kwargs)], sampling_params)
    return json.loads(outputs[0].outputs[0].text)

In [8]:
generate(language="french", types_of_text="detailled job ads", sector="machine learning", country="france")

Processed prompts: 100%|██████████| 1/1 [00:06<00:00,  6.40s/it]


{'text': "Nous recherchons un Data Scientist expérimenté pour notre équipe de Paris. Votre mission consistera à concevoir et à mettre en œuvre des modèles statistiques et machine learning. Les candidats doivent posséder une solide expérience en Python et en R. Un diplôme universitaire dans le domaine des mathématiques ou de l'informatique est requis. Les meilleurs candidats auront également une bonne connaissance de TensorFlow et Scikit-Learn.",
 'entities': [{'entity': 'Nous', 'types': ['organization']},
  {'entity': 'notre équipe', 'types': ['organization']},
  {'entity': 'Paris', 'types': ['location']},
  {'entity': 'Data Scientist', 'types': ['jobtitle']},
  {'entity': 'votre mission', 'types': ['jobdescription']},
  {'entity': 'concevoir et mettre en œuvre', 'types': ['jobresponsibility']},
  {'entity': 'des modèles statistiques et machine learning',
   'types': ['jobresponsibility']},
  {'entity': 'Les candidats', 'types': ['person']},
  {'entity': 'doivent posséder', 'types': ['

## Functions

In [9]:
# post processing functions

import re

def tokenize_text(text):
    """Tokenize the input text into a list of tokens."""
    return re.findall(r'\w+(?:[-_]\w+)*|\S', text)

def extract_entities(data):
    all_examples = []

    for dt in data:

        # Attempt to extract entities; skip current record on failure
        try:
            tokens = tokenize_text(dt['text'])
            ents = [(k["entity"], k["types"]) for k in dt['entities']]
        except:
            continue

        spans = []
        for entity in ents:
            entity_tokens = tokenize_text(str(entity[0]))

            # Find the start and end indices of each entity in the tokenized text
            for i in range(len(tokens) - len(entity_tokens) + 1):
                if " ".join(tokens[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
                    for el in entity[1]:
                        spans.append((i, i + len(entity_tokens) - 1, el.lower().replace('_', ' ')))

        # Append the tokenized text and its corresponding named entity recognition data
        all_examples.append({"tokenized_text": tokens, "ner": spans})

    return all_examples

# generation functions
def generate_from_prompts(prompts, llm, sampling_params):
    outputs = llm.generate(prompts, sampling_params)

    all_outs = []
    
    for output in outputs:
        try:
            js = json.loads(output.outputs[0].text.strip())
        except:
            continue
            
        all_outs.append(js)

    return all_outs, extract_entities(all_outs)

## Use case: synthetic data for job ads

In [10]:
# I have used GPT-4 to generate these

# List of countries
countries = [
    "Madagascar", "Taiwan", "USA", "Germany", "France", "Spain", "Russia", "China", 
    "Japan", "Brazil", "India", "Egypt", "South Africa", "Australia", "Canada", 
    "Mexico", "Indonesia", "Nigeria", "Turkey", "United Kingdom", "Italy", "Poland", 
    "Argentina", "Netherlands", "Belgium", "Switzerland", "Sweden", "Norway", "Finland",
    "Denmark", "Portugal", "Greece", "Iran", "Thailand", "Philippines", "Vietnam", 
    "South Korea", "Saudi Arabia", "Israel", "UAE", "New Zealand", "Ireland", "Malaysia",
    "Singapore", "Hong Kong", "Czech Republic", "Hungary", "Romania", "Colombia", 
    "Peru", "Venezuela", "Chile", "Morocco", "Algeria", "Tunisia", "Nepal", "Pakistan", "Bangladesh", 
    "Kazakhstan", "Ukraine", "Austria", "Croatia", "Serbia", "Kenya", "Ghana", "Zimbabwe",
    "Cuba", "Panama", "Fiji", "Mongolia", "North Korea", "Myanmar", "Ethiopia", "Tanzania",
    "Algeria", "Libya", "Jordan", "Qatar", "Oman", "Kuwait", "Lebanon", "Bulgaria", "Slovakia",
    "Lithuania", "Latvia", "Estonia", "Cyprus", "Luxembourg", "Macao", "Bhutan", "Maldives",
    "Angola", "Cameroon", "Senegal", "Mali", "Zambia", "Uganda", "Namibia", "Botswana",
    "Mozambique", "Ivory Coast", "Burkina Faso", "Malawi", "Gabon", "Lesotho", "Gambia",
    "Guinea", "Cape Verde", "Rwanda", "Benin", "Burundi", "Somalia", "Eritrea", "Djibouti",
    "Togo", "Seychelles", "Chad", "Central African Republic", "Liberia", "Mauritania", "Sri Lanka",
    "Sierra Leone", "Equatorial Guinea", "Swaziland", "Congo (Kinshasa)", "Congo (Brazzaville)"
]

# job sectors
job_sectors = [
    # Finance Sector Specializations
    "Investment Banking",
    "Corporate Finance",
    "Asset Management",
    "Risk Management",
    "Quantitative Analysis",
    "Financial Planning",
    
    # Machine Learning and AI Specializations
    "Natural Language Processing",
    "Computer Vision",
    "Deep Learning",
    "Reinforcement Learning",
    "Predictive Analytics",
    "Algorithm Development",
    
    # Healthcare Sector Specializations
    "Medical Research",
    "Clinical Trials",
    "Health Informatics",
    "Biomedical Engineering",
    "Public Health Administration",
    "Pharmaceuticals",
    
    # Education Sector Specializations
    "Curriculum Development",
    "Educational Technology",
    "Special Education",
    "Higher Education Administration",
    "Educational Policy",
    "Language Instruction",
    
    # Manufacturing Sector Specializations
    "Process Engineering",
    "Quality Control",
    "Industrial Design",
    "Supply Chain Optimization",
    "Robotics Manufacturing",
    "Lean Manufacturing",
    
    # Energy Sector Specializations
    "Renewable Energy Systems",
    "Oil and Gas Exploration",
    "Energy Efficiency Consulting",
    "Nuclear Engineering",
    "Smart Grid Technology",
    "Energy Policy",
    
    # Environmental Sector Specializations
    "Wildlife Conservation",
    "Environmental Science",
    "Water Resource Management",
    "Sustainability Strategy",
    "Climate Change Analysis",
    "Environmental Law",
    
    # Media and Communications Specializations
    "Digital Marketing",
    "Journalism",
    "Public Relations",
    "Film Production",
    "Broadcasting",
    "Content Strategy",
    
    # Legal Sector Specializations
    "Corporate Law",
    "International Law",
    "Intellectual Property",
    "Environmental Law",
    "Civil Litigation",
    "Criminal Defense",
    
    # Retail Sector Specializations
    "E-commerce Strategy",
    "Store Management",
    "Merchandise Planning",
    "Customer Experience Management",
    "Retail Analytics",
    "Supply Chain Logistics"
]

### Generate prompts

In [11]:
# create prompts
NUM_SAMPLES = 100

import random

all_prompts = []

for i in range(NUM_SAMPLES):
    # sample
    job_sector = random.choice(job_sectors)
    country = random.choice(countries)
    
    prompt = create_json_prompt_for_synthetic_data(language="english", 
                                                   types_of_text="detailled job ads", 
                                                   sector=job_sector, 
                                                   country=country)
    all_prompts.append(prompt)

### Generate outputs

In [12]:
output, processed_output = generate_from_prompts(all_prompts, llm, sampling_params)

Processed prompts: 100%|██████████| 100/100 [00:17<00:00,  5.60it/s]


In [13]:
output[0]

{'text': 'Wanted: E-commerce Strategist in Lima, Peru. 5+ years of experience in digital marketing required. B2C or B2B projects preferred. Salary range: S/ 3000 to S/ 5000. Apply with resume and cover letter.',
 'entities': [{'entity': 'E-commerce Strategist',
   'types': ['person', 'jobtitle']},
  {'entity': 'Lima, Peru', 'types': ['location']},
  {'entity': '5+ years', 'types': ['quantity', 'duration']},
  {'entity': 'digital marketing', 'types': ['skill']},
  {'entity': 'B2C or B2B', 'types': ['business_model']},
  {'entity': 'Salary range', 'types': ['salary']},
  {'entity': 'S/ 3000 to S/ 5000', 'types': ['amount', 'currency']}]}

### Some statistics

In [26]:
lengths = []

for d in processed_output:
    lengths.append(len(d["tokenized_text"]))

print("Avg num tokens:", sum(lengths) / len(lengths))

Avg num tokens: 76.82291666666667


In [27]:
len_ner = []

for d in processed_output:
    len_ner.append(len(d["ner"]))
        
print("Avg num of entities:", sum(len_ner) / len(len_ner))

Avg num of entities: 11.875


In [28]:
unique_entities = []

for d in processed_output:
    for n in d["ner"]:
        unique_entities.append((str(n[2]).lower()))

print("Unique entity types:", len(unique_entities))

Unique entity types: 1140


In [21]:
# Top 10 entity types

from collections import Counter
Counter(unique_entities).most_common()[:10]

[('organization', 106),
 ('location', 86),
 ('job title', 83),
 ('person', 71),
 ('country', 41),
 ('technology', 40),
 ('field of study', 38),
 ('education', 29),
 ('degree', 24),
 ('quantity', 23)]

### Save for training

In [22]:
# Save to JSON
def save_data_to_file(data, filepath):
    """Saves the processed data to a JSON file."""
    with open(filepath, 'w') as f:
        json.dump(data, f)

In [23]:
output_file = "job_ads_data_gliner.json"

save_data_to_file(processed_output, output_file)