# NuExtract 2 Inference

In this notebook we will provide examples of how to use the NuExtract 2.0 models for inference.

First, let's load a model.

In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel

model_name = "numind/NuExtract-2-2B"
# model_name = "numind/NuExtract-2-4B"
# model_name = "numind/NuExtract-2-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name, 
                                             trust_remote_code=True, 
                                             torch_dtype=torch.bfloat16,
                                             attn_implementation="flash_attention_2"
                                            ).to("cuda")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
InternLM2ForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Next, we need to include some utility code to preprocess images to be compatible with InternVL-2.5 (base model that NuExtract 2.0 is built on). We will also include a function `nuextract_generate()` to handle our inference process.

In [2]:
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    # calculate the existing image aspect ratio
    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    # find the closest aspect ratio to the target
    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    # calculate the target width and height
    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    # resize the image
    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        # split the image
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file, input_size=448, max_num=12):
    image = Image.open(image_file).convert('RGB')
    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values

IMG_START_TOKEN='<img>'
IMG_END_TOKEN='</img>'
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'

def nuextract_generate(model, tokenizer, prompts, generation_config, pixel_values_list=None, num_patches_list=None):
    """
    Generate responses for a batch of NuExtract inputs.
    Support for multiple and varying numbers of images per prompt.
    
    Args:
        model: The vision-language model
        tokenizer: The tokenizer for the model
        pixel_values_list: List of tensor batches, one per prompt
                          Each batch has shape [num_images, channels, height, width] or None for text-only prompts
        prompts: List of text prompts
        generation_config: Configuration for text generation
        num_patches_list: List of lists, each containing patch counts for images in a prompt
        
    Returns:
        List of generated responses
    """
    img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
    model.img_context_token_id = img_context_token_id
    
    # Replace all image placeholders with appropriate tokens
    modified_prompts = []
    total_image_files = 0
    total_patches = 0
    image_containing_prompts = []
    for idx, prompt in enumerate(prompts):
        # check if this prompt has images
        has_images = (pixel_values_list and
                      idx < len(pixel_values_list) and 
                      pixel_values_list[idx] is not None and 
                      isinstance(pixel_values_list[idx], torch.Tensor) and
                      pixel_values_list[idx].shape[0] > 0)
        
        if has_images:
            # prompt with image placeholders
            image_containing_prompts.append(idx)
            modified_prompt = prompt
            
            patches = num_patches_list[idx] if (num_patches_list and idx < len(num_patches_list)) else []
            num_images = len(patches)
            total_image_files += num_images
            total_patches += sum(patches)
            
            # replace each <image> placeholder with image tokens
            for i, num_patches in enumerate(patches):
                image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * model.num_image_token * num_patches + IMG_END_TOKEN
                modified_prompt = modified_prompt.replace('<image>', image_tokens, 1)
        else:
            # text-only prompt
            modified_prompt = prompt
        
        modified_prompts.append(modified_prompt)
    
    # process all prompts in a single batch
    tokenizer.padding_side = 'left'
    model_inputs = tokenizer(modified_prompts, return_tensors='pt', padding=True)
    input_ids = model_inputs['input_ids'].to(model.device)
    attention_mask = model_inputs['attention_mask'].to(model.device)
    
    eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>\n".strip())
    generation_config['eos_token_id'] = eos_token_id
    
    # prepare pixel values
    flattened_pixel_values = None
    if image_containing_prompts:
        # collect and concatenate all image tensors
        all_pixel_values = []
        for idx in image_containing_prompts:
            all_pixel_values.append(pixel_values_list[idx])
        
        flattened_pixel_values = torch.cat(all_pixel_values, dim=0)
        print(f"Processing batch with {len(prompts)} prompts, {total_image_files} actual images, and {total_patches} total patches")
    else:
        print(f"Processing text-only batch with {len(prompts)} prompts")
    
    # generate outputs
    outputs = model.generate(
        pixel_values=flattened_pixel_values,  # will be None for text-only prompts
        input_ids=input_ids,
        attention_mask=attention_mask,
        **generation_config
    )
    
    # Decode responses
    responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    
    return responses

## Preparing Model Inputs

Before using the model, we also need to make sure our prompts are properly formatted to work with NuExtract. NuExtract expects all input information to come as a single user chat prompt, formatted as follows:

```python
f"""
# Template:
{template}
# Context:
{text}
"""
```
and if in-context examples are provided:
```python
f"""
# Template:
{template}
# Examples
## Input:
{input1}
## Output:
{output1}
## Input:
{input2}
## Output:
{output2}
# Context:
{text}
"""
```

If you are working with image inputs, simply use `"<image>"` placeholders for `text`, `output1`, etc. The code below will then inject tokens representing the actual image content in the location of these placeholders.

The following functions will make this formatting more convenient for us.

In [3]:
def construct_message(text, template, examples=None):
    """
    Construct the individual NuExtract message texts, prior to chat template formatting.
    """
    # add few-shot examples if needed
    if examples is not None and len(examples) > 0:
        icl = "# Examples:\n"
        for row in examples:
            icl += f"## Input:\n{row['input']}\n## Output:\n{row['output']}\n"
    else:
        icl = ""
        
    return f"""# Template:\n{template}\n{icl}# Context:\n{text}"""

def prepare_inputs(messages, image_paths, tokenizer, device='cuda', dtype=torch.bfloat16):
    """
    Prepares multi-modal input components (supports multiple images per prompt).
    
    Args:
        messages: List of input messages
        image_paths: List where each element is either None (for text-only) or a list of image paths
        tokenizer: The tokenizer to use for applying chat templates
        device: Device to place tensors on ('cuda', 'cpu', etc.)
        dtype: Data type for image tensors (default: torch.bfloat16)
    
    Returns:
        dict: Contains 'prompts', 'pixel_values_list', and 'num_patches_list' ready for the model
    """
    # Make sure image_paths list is at least as long as messages
    if len(image_paths) < len(messages):
        # Pad with None for text-only messages
        image_paths = image_paths + [None] * (len(messages) - len(image_paths))
    
    # Process images and collect patch information
    loaded_images = []
    num_patches_list = []
    for paths in image_paths:
        if paths and isinstance(paths, list) and len(paths) > 0:
            # Load each image in this prompt
            prompt_images = []
            prompt_patches = []
            
            for path in paths:
                # Load the image
                img = load_image(path).to(dtype=dtype, device=device)
                
                # Ensure img has correct shape [patches, C, H, W]
                if len(img.shape) == 3:  # [C, H, W] -> [1, C, H, W]
                    img = img.unsqueeze(0)
                    
                prompt_images.append(img)
                # Record the number of patches for this image
                prompt_patches.append(img.shape[0])
            
            loaded_images.append(prompt_images)
            num_patches_list.append(prompt_patches)
        else:
            # Text-only prompt
            loaded_images.append(None)
            num_patches_list.append([])
    
    # Create the concatenated pixel_values_list
    pixel_values_list = []
    for prompt_images in loaded_images:
        if prompt_images:
            # Concatenate all images for this prompt
            pixel_values_list.append(torch.cat(prompt_images, dim=0))
        else:
            # Text-only prompt
            pixel_values_list.append(None)
    
    # Format messages for the model
    if all(isinstance(m, str) for m in messages):
        # Simple string messages: convert to chat format
        batch_messages = [
            [{"role": "user", "content": message}] 
            for message in messages
        ]
    else:
        # Assume messages are already in the right format
        batch_messages = messages
    
    # Apply chat template
    prompts = tokenizer.apply_chat_template(
        batch_messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    return {
        'prompts': prompts,
        'pixel_values_list': pixel_values_list,
        'num_patches_list': num_patches_list
    }

## Inference
### Basic Example

Now we are ready to run the model!

Let's start with a basic text-only example, where we want to extract peoples' names from a short text.

In [4]:
template = """{"names": ["verbatim-string"]}"""
text = "John went to the restaurant with Mary. James went to the cinema."

# prepare the user message content
input_messages = [construct_message(text, template)]

# prepare actual input content
input_content = prepare_inputs(
    messages=input_messages,
    image_paths=[], # we give empty list of image paths because this is text-only
    tokenizer=tokenizer,
)

Our NuExtract message is now formatted in standard chat template formatting, which is what is given directly to the model.

In [5]:
print(input_content["prompts"][0])

<s><|im_start|>user
# Template:
{"names": ["verbatim-string"]}
# Context:
John went to the restaurant with Mary. James went to the cinema.<|im_end|>
<|im_start|>assistant



The other input items are empty in this case because they are only relevant when we're using images.

In [6]:
print(input_content.keys())
print(input_content["pixel_values_list"])
print(input_content["num_patches_list"])

dict_keys(['prompts', 'pixel_values_list', 'num_patches_list'])
[None]
[[]]


Now let's actually run the model.

In [7]:
# we choose greedy sampling here, which works well for most information extraction tasks
generation_config = {"do_sample": False, "num_beams": 1, "max_new_tokens": 2048}

# run the model
with torch.no_grad():
    result = nuextract_generate(
        model=model,
        tokenizer=tokenizer,
        prompts=input_content['prompts'], # input prompts containing template, etc.
        pixel_values_list=input_content['pixel_values_list'], # image information
        num_patches_list=input_content['num_patches_list'], # number of patches used per image
        generation_config=generation_config
    )
for y in result:
    print(y)

Processing text-only batch with 1 prompts
{"names": ["John", "Mary", "James"]}


### In-Context Examples

Sometimes the model might not perform as well as we want because our task is challenging or involves some degree of ambiguity. Alternatively, we may want the model to follow some specific formatting, or just give it a bit more help. In cases like this it can be valuable to provide "in-context examples" to help NuExtract better understand the task.

To do so, we can provide a list `examples` to `construct_message` which contain dictionaries of input/output pairs. In the example below, we show to the model that we want the extracted names to be in captial letters with +s on either side (for the sake of illustration).

In [8]:
template = """{"names": ["verbatim-string"], "female_names": ["verbatim-string"]}"""
text = "John went to the restaurant with Mary. James went to the cinema."
examples = [
    {
        "input": "Stephen is the manager at Susan's store.",
        "output": """{"names": ["+STEPHEN+", "+SUSAN+"], "female_names": ["+SUSAN+"]}"""
    }
]

input_messages = [construct_message(text, template, examples)]

input_content = prepare_inputs(
    messages=input_messages,
    image_paths=[],
    tokenizer=tokenizer,
)

We can see below that the in-context example has now been included in the model prompt, specifically between the template and context components.

In [9]:
print(input_content["prompts"][0])

<s><|im_start|>user
# Template:
{"names": ["verbatim-string"], "female_names": ["verbatim-string"]}
# Examples:
## Input:
Stephen is the manager at Susan's store.
## Output:
{"names": ["+STEPHEN+", "+SUSAN+"], "female_names": ["+SUSAN+"]}
# Context:
John went to the restaurant with Mary. James went to the cinema.<|im_end|>
<|im_start|>assistant



In [10]:
generation_config = {"do_sample": False, "num_beams": 1, "max_new_tokens": 2048}

with torch.no_grad():
    result = nuextract_generate(
        model=model,
        tokenizer=tokenizer,
        prompts=input_content['prompts'],
        pixel_values_list=input_content['pixel_values_list'],
        num_patches_list=input_content['num_patches_list'],
        generation_config=generation_config
    )
for y in result:
    print(y)

Processing text-only batch with 1 prompts
{"names": ["+JOHN+", "+MARY+", "+JAMES+"], "female_names": ["+MARY+"]}


To get even better performance, add multiple in-context examples to your input.

### Image Inputs

If we want to give image inputs to NuExtract, instead of text, we simply use the placeholder `"<image>"` anywhere we would normally have texts (main input and ICL examples). Then we provide a list of image paths to `prepare_inputs()`. This should be a list of lists, so that multiple images can be injected when using ICL examples. The order of the image paths should match order of appearance in the prompt (i.e. ICL inputs first, with the main input at the end).

In this example, we give an image of a receipt (`1.jpg`) and ask the model to extract the name of the store. We also provide an ICL example of a receipt from Walmart (`0.jpg`).

In [11]:
template = """{"store": "verbatim-string"}"""
text = "<image>"
examples = [
    {
        "input": "<image>",
        "output": """{"store": "WALMART"}"""
    }
]

input_messages = [construct_message(text, template, examples)]

images = [
    ["0.jpg", "1.jpg"]
]

input_content = prepare_inputs(
    messages=input_messages,
    image_paths=images,
    tokenizer=tokenizer,
)

Just like in the text-only case above, our in-context example has been included in the prompt before the main context.

In [12]:
print(input_content["prompts"][0])

<s><|im_start|>user
# Template:
{"store": "verbatim-string"}
# Examples:
## Input:
<image>
## Output:
{"store": "WALMART"}
# Context:
<image><|im_end|>
<|im_start|>assistant



Now that we are using images, the values of `pixel_values_list` and `num_patches_list` are actually meaningful. Below we can see that the first image (`0.jpg`) has been broken into 7 patches, and the second (`1.jpg`) into 3. This is comfirmed by looking at the size of the tensor in `pixel_values_list` -- 10 patches * 3 (RGB) * 448*448 (patch dimension).

In [13]:
print(input_content.keys())
print(input_content["pixel_values_list"][0].shape)
print(input_content["num_patches_list"])

dict_keys(['prompts', 'pixel_values_list', 'num_patches_list'])
torch.Size([10, 3, 448, 448])
[[7, 3]]


In [14]:
generation_config = {"do_sample": False, "num_beams": 1, "max_new_tokens": 2048}

with torch.no_grad():
    result = nuextract_generate(
        model=model,
        tokenizer=tokenizer,
        prompts=input_content['prompts'],
        pixel_values_list=input_content['pixel_values_list'],
        num_patches_list=input_content['num_patches_list'],
        generation_config=generation_config
    )
for y in result:
    print(y)

Processing batch with 1 prompts, 2 actual images, and 10 total patches
{"store": "TRADER JOE'S"}


### Multi-Modal Batches

Finally, we can run batched inference over a list of input examples, regardless of whether they contain text, images, and/or ICL examples.

In [15]:
inputs = [
    # image input with no ICL examples
    {
        "text": "<image>",
        "template": """{"store_name": "verbatim-string"}""",
        "examples": None,
    },
    # image input with 1 ICL example
    {
        "text": "<image>",
        "template": """{"store_name": "verbatim-string"}""",
        "examples": [
            {
                "input": "<image>",
                "output": """{"store_name": "Walmart"}""",
            }
        ],
    },
    # text input with no ICL examples
    {
        "text": "John went to the restaurant with Mary. James went to the cinema.",
        "template": """{"names": ["verbatim-string"]}""",
        "examples": None,
    },
    # text input with ICL example
    {
        "text": "John went to the restaurant with Mary. James went to the cinema.",
        "template": """{"names": ["verbatim-string"], "female_names": ["verbatim-string"]}""",
        "examples": [
            {
                "input": "Stephen is the manager at Susan's store.",
                "output": """{"names": ["STEPHEN", "SUSAN"], "female_names": ["SUSAN"]}"""
            }
        ],
    },
]

input_messages = [
    construct_message(
        x["text"], 
        x["template"], 
        x["examples"]
    ) for x in inputs
]

images = [
    ["0.jpg"],
    ["0.jpg", "1.jpg"],
    None,
    None
]

input_content = prepare_inputs(
    messages=input_messages,
    image_paths=images,
    tokenizer=tokenizer,
)

generation_config = {"do_sample": False, "num_beams": 1, "max_new_tokens": 2048}

with torch.no_grad():
    result = nuextract_generate(
        model=model,
        tokenizer=tokenizer,
        prompts=input_content['prompts'],
        pixel_values_list=input_content['pixel_values_list'],
        num_patches_list=input_content['num_patches_list'],
        generation_config=generation_config
    )
for y in result:
    print(y)

Processing batch with 4 prompts, 3 actual images, and 17 total patches
{"store_name": "WAL*MART"}
{"store_name": "Trader Joe's"}
{"names": ["John", "Mary", "James"]}
{"names": ["JOHN", "MARY", "JAMES"], "female_names": ["MARY"]}
