# NuExtract 2.0 Supervised Fine-tuning (SFT)

This notebook will show a basic example of how to perform supervised fine-tuning (SFT) on top of the base NuExtract 2.0 models, with your own data.

## Prepare Model
First, load the model you want to fine-tune, along with the processor.

In [1]:
import torch
from transformers import AutoProcessor, AutoModelForVision2Seq
from qwen_vl_utils import process_vision_info

model_name = "numind/NuExtract-2.0-2B"
# model_name = "numind/NuExtract-2.0-8B"

model = AutoModelForVision2Seq.from_pretrained(model_name, 
                                               trust_remote_code=True, 
                                               torch_dtype=torch.bfloat16,
                                               attn_implementation="flash_attention_2",
                                               device_map="auto",
                                               use_cache=False, # for training
                                              )

processor = AutoProcessor.from_pretrained(model_name, 
                                          trust_remote_code=True, 
                                          padding_side='right', # make sure to set padding to right for training
                                          use_fast=True,
                                         )
processor.eos_token = processor.tokenizer.eos_token
processor.eos_token_id = processor.tokenizer.eos_token_id

## Prepare Data

The `construct_messages()` function below will help us to format the messages to be fed into the model.

In [2]:
def construct_messages(document, template, label=None, examples=None, image_placeholder="<|vision_start|><|image_pad|><|vision_end|>"):
    """
    Construct the individual NuExtract message texts, prior to chat template formatting.
    """
    images = []
    
    # add few-shot examples if needed
    icl = ""
    if examples is not None and len(examples) > 0:
        icl = "# Examples:\n"
        for row in examples:
            example_input = row['input']
            
            if not isinstance(row['input'], str):
                example_input = image_placeholder
                images.append(row['input'])
                
            icl += f"## Input:\n{example_input}\n## Output:\n{row['output']}\n"
        
    # if input document is an image, set text to an image placeholder
    text = document
    if not isinstance(document, str):
        text = image_placeholder
        images.append(document)
    text = f"""# Template:\n{template}\n{icl}# Context:\n{text}"""
    
    messages = [
        {
            "role": "user",
            "content": [{"type": "text", "text": text}] + images,
        }
    ]
    if label is not None:
        messages.append({
            "role": "assistant",
            "content": [{"type": "text", "text": label}],
            
        })
    return messages

For illustration purposes, we will use a small dataset of manually created examples. You should prepare your own data in a similar way before fine-tuning your own model.

In the custom data below, we will only provide examples that return strings in full lowercase characters (unless ICL examples suggest otherwise). If we fine-tune on this, we would ideally alter the model to favour returning strings in lowercase by default.

*Note: training on a very small dataset like this is for illustration purposes only and will almost always result in a poorly performing model in real use-cases.*

In [3]:
inputs = [
    # image input with no ICL examples
    {
        "document": {"type": "image", "image": "file://data/0.jpg"},
        "template": """{"store_name": "verbatim-string"}""",
        "label": """{"store_name": "walmart"}""", # lowercase result
    },
    # image input with 1 ICL example
    {
        "document": {"type": "image", "image": "file://data/1.jpg"},
        "template": """{"store_name": "verbatim-string"}""",
        "examples": [
            {
                "input": {"type": "image", "image": "file://data/0.jpg"},
                "output": """{"store_name": "Walmart"}""",
            }
        ],
        "label": """{"store_name": "Trader Joe's"}""",
    },
    # text input with no ICL examples
    {
        "document": "John went to the restaurant with Mary. James went to the cinema.",
        "template": """{"names": ["verbatim-string"]}""",
        "label": """{"names": ["john", "mary", "james"]}""", # lowercase result
    },
    # text input with ICL example
    {
        "document": "John went to the restaurant with Mary. James went to the cinema.",
        "template": """{"names": ["verbatim-string"]}""",
        "examples": [
            {
                "input": "Stephen is the manager at Susan's store.",
                "output": """{"names": ["STEPHEN", "SUSAN"]}"""
            }
        ],
        "label": """{"names": ["JOHN", "MARY", "JAMES"]}""",
    },
] * 2 # double examples to have dataset of size 8

messages = [
    construct_messages(
        x["document"], 
        x["template"], 
        x["label"],
        x["examples"] if "examples" in x else None
    ) for x in inputs
]

In [4]:
messages[0]

[{'role': 'user',
  'content': [{'type': 'text',
    'text': '# Template:\n{"store_name": "verbatim-string"}\n# Context:\n<|vision_start|><|image_pad|><|vision_end|>'},
   {'type': 'image', 'image': 'file://data/0.jpg'}]},
 {'role': 'assistant',
  'content': [{'type': 'text', 'text': '{"store_name": "walmart"}'}]}]

Let's also add a couple of validation examples that we can use to confirm our model is generalizing to unseen data.

In [5]:
val_inputs = [
    {
        "document": "Jack went to the hill with Jill. Rupert went to the diner.",
        "template": """{"names": ["verbatim-string"]}""",
        "label": """{"names": ["jack", "jill", "rupert"]}""", # lowercase result
    },
    {
        "document": "My dog Clifford likes to play fetch with Emily and Peter.",
        "template": """{"names": ["verbatim-string"]}""",
        "label": """{"names": ["clifford", "emily", "peter"]}""", # lowercase result
    },
]

val_messages = [
    construct_messages(
        x["document"], 
        x["template"], 
        x["label"],
        x["examples"] if "examples" in x else None
    ) for x in val_inputs
]

The data is now structued in message format that the processor will be able to reformat via the chat template before tokenization. We will do that on the fly during training via the collate function below.

In [6]:
def collate_fn(examples):
    # process input/prompt part of conversations
    user_texts = [processor.apply_chat_template(example[:1], tokenize=False) for example in examples]
    
    # process full conversations (user + assistant)
    full_texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]
    
    # process images
    images = process_vision_info(examples)[0]
    
    # tokenize sequences
    user_batch = processor(text=user_texts, images=images, return_tensors="pt", padding=True)
    full_batch = processor(text=full_texts, images=images, return_tensors="pt", padding=True)
    
    # mask padding tokens
    labels = full_batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    # mask user message tokens for each example in the batch
    for i in range(len(examples)):
        # length of prompt message (accounting for possible padding)
        user_len = user_batch["attention_mask"][i].sum().item()
        
        # mask prompt part of label
        labels[i, :user_len - 1] = -100
    
    full_batch["labels"] = labels
    return full_batch

In [7]:
collate_fn(messages[:1])['pixel_values'].shape

torch.Size([2688, 1176])

## Fine-Tune the Model

We will use the `SFTTrainer` from the `trl` library, which abstracts a lot of the complexities of training for us. For your own use-case you should adjust various hyper-parameters like learning rate, epochs, etc. according to your problem.

In [8]:
from trl import SFTConfig, SFTTrainer

# Configure training arguments
training_args = SFTConfig(
    output_dir="test_finetune",  # Directory to save the model
    num_train_epochs=5,  # Number of training epochs
    per_device_train_batch_size=1,  # Batch size for training
    per_device_eval_batch_size=1,  # Batch size for evaluation
    gradient_accumulation_steps=4,  # Steps to accumulate gradients
    learning_rate=1e-5,  # Learning rate for training
    lr_scheduler_type="constant",  # Type of learning rate scheduler
    logging_steps=1,  # Steps interval for logging
    eval_steps=2,  # Steps interval for evaluation
    eval_strategy="steps",  # Strategy for evaluation
    # save_strategy="steps",  # Strategy for saving the model
    # save_steps=20,  # Steps interval for saving
    bf16=True,  # Use bfloat16 precision
    max_grad_norm=0.3,  # Maximum norm for gradient clipping
    warmup_ratio=0.03,  # Ratio of total steps for warmup
    report_to="none",  # Reporting tool for tracking metrics
    gradient_checkpointing=True,  # Enable gradient checkpointing for memory efficiency
    gradient_checkpointing_kwargs={"use_reentrant": False},  # Options for gradient checkpointing
    # max_seq_length=1024  # Maximum sequence length for input
)

# allow for proper loading of images during collation
training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=messages,
    eval_dataset=val_messages,
    processing_class=processor.tokenizer,
)

In [9]:
trainer.train()

Step,Training Loss,Validation Loss
2,0.1704,0.269275
4,0.0183,0.042456
6,0.0018,0.012193
8,0.0003,0.011337
10,0.0003,0.011667


TrainOutput(global_step=10, training_loss=0.057332569236314156, metrics={'train_runtime': 96.9988, 'train_samples_per_second': 0.412, 'train_steps_per_second': 0.103, 'total_flos': 261136381470720.0, 'train_loss': 0.057332569236314156})

In [None]:
trainer.save_model(training_args.output_dir)

## Test Generation

Now, let's run actual generation of outputs for our validation examples to see if what the model has learned.

In [None]:
# reload processor with left padding (for generation)
processor = AutoProcessor.from_pretrained(model_name, 
                                          trust_remote_code=True, 
                                          padding_side='left',
                                          use_fast=True)

# reconstruct validation messages without labels
test_messages = [
    construct_messages(
        x["document"], 
        x["template"], 
    ) for x in val_inputs
]

texts = processor.tokenizer.apply_chat_template(
    test_messages,
    tokenize=False,
    add_generation_prompt=True,
)

image_inputs = process_vision_info(messages[2][:1])[0]
inputs = processor(
    text=texts,
    images=image_inputs,
    padding=True,
    return_tensors="pt",
).to("cuda")

# we choose greedy sampling here, which works well for most information extraction tasks
generation_config = {"do_sample": True, "temperature": 1.0, "max_new_tokens": 2048}

# Inference: Generation of the output
generated_ids = model.generate(
    **inputs,
    **generation_config
)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_texts = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)

As can be seen below, the model is now generating extractions with lowercase strings by default.

In [None]:
for i in range(len(texts)):
    print(f"=== Prompt ===\n{texts[i]}")
    print(f"=== Output ===\n{output_texts[i]}\n")