## Windows omni_env

## 00 Install & Import Libraries

In [None]:
import torch
import os
import requests
import re

from transformers import (
    Qwen2_5_VLForConditionalGeneration,
    AutoTokenizer,
    AutoProcessor,
    BitsAndBytesConfig,
    TextStreamer
)
from qwen_vl_utils import process_vision_info
from PIL import Image, ImageDraw, ImageFont
from io import BytesIO
from bs4 import BeautifulSoup, Tag
from IPython.display import display

## 01 Import Model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

bnb_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type = 'nf4',
    bnb_4bit_compute_dtype = torch.float16,
    bnb_4bit_use_double_quant = True,
)

In [None]:
model_path = './00_Model/Qwen2.5-VL-3B-Instruct'

model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_path,
    quantization_config = bnb_config,
    device_map = 'auto',
).to(device) #''
processor = AutoProcessor.from_pretrained(model_path)

## 02 Define Inference Function

In [None]:
# Function to draw bounding boxes and text on images based on HTML content
def draw_bbox(image_path, resized_width, resized_height, full_predict):
    if image_path.startswith('http'):
        response = requests.get(image_path)
        image = Image.open(BytesIO(response.content))
    else:
        image = Image.open(image_path)
    original_width = image.width
    original_height = image.height
    
    # Parse the provided HTML content
    soup = BeautifulSoup(full_predict, 'html.parser')
    # Extract all elements that have a 'data-bbox' attribute
    elements_with_bbox = soup.find_all(attrs = {'data-bbox' : True})

    filtered_elements = []
    for el in elements_with_bbox:
        if el.name == 'ol':
            continue  # Skip <ol> tags
        elif el.name == 'li' and el.parent.name == 'ol':
            filtered_elements.append(el)  # Include <li> tags within <ol>
        else:
            filtered_elements.append(el)  # Include all other elements

    font = ImageFont.truetype('./00_Dataset/NotoSansCJK-Regular.ttc', 20)
    draw = ImageDraw.Draw(image)
    
    # Draw bounding boxes and text for each element
    for element in filtered_elements:
        bbox_str = element['data-bbox']
        text = element.get_text(strip=True)
        x1, y1, x2, y2 = map(int, bbox_str.split())
        
        # Calculate scaling factors
        scale_x = resized_width / original_width
        scale_y = resized_height / original_height
        
        # Scale coordinates accordingly
        x1_resized = int(x1 / scale_x)
        y1_resized = int(y1 / scale_y)
        x2_resized = int(x2 / scale_x)
        y2_resized = int(y2 / scale_y)
        
        if x1_resized > x2_resized:
            x1_resized, x2_resized = x2_resized, x1_resized
        if y1_resized > y2_resized:
            y1_resized, y2_resized = y2_resized, y1_resized
            
        # Draw bounding box
        draw.rectangle([x1_resized, y1_resized, x2_resized, y2_resized], outline = 'red', width = 2)
        # Draw associated text
        draw.text((x1_resized, y2_resized), text, fill = 'black', font = font)

    # Display the image
    display(image)

# Function to clean and format HTML content
def clean_and_format_html(full_predict):
    soup = BeautifulSoup(full_predict, 'html.parser')
    
    # Regular expression pattern to match 'color' styles in style attributes
    color_pattern = re.compile(r'\bcolor:[^;]+;?')

    # Find all tags with style attributes and remove 'color' styles
    for tag in soup.find_all(style=True):
        original_style = tag.get('style', '')
        new_style = color_pattern.sub('', original_style)
        if not new_style.strip():
            del tag['style']
        else:
            new_style = new_style.rstrip(';')
            tag['style'] = new_style
            
    # Remove 'data-bbox' and 'data-polygon' attributes from all tags
    for attr in ["data-bbox", "data-polygon"]:
        for tag in soup.find_all(attrs={attr: True}):
            del tag[attr]

    classes_to_update = ['formula.machine_printed', 'formula.handwritten']
    # Update specific class names in div tags
    for tag in soup.find_all(class_=True):
        if isinstance(tag, Tag) and 'class' in tag.attrs:
            new_classes = [cls if cls not in classes_to_update else 'formula' for cls in tag.get('class', [])]
            tag['class'] = list(dict.fromkeys(new_classes))  # Deduplicate and update class names

    # Clear contents of divs with specific class names and rename their classes
    for div in soup.find_all('div', class_='image caption'):
        div.clear()
        div['class'] = ['image']

    classes_to_clean = ['music sheet', 'chemical formula', 'chart']
    # Clear contents and remove 'format' attributes of tags with specific class names
    for class_name in classes_to_clean:
        for tag in soup.find_all(class_=class_name):
            if isinstance(tag, Tag):
                tag.clear()
                if 'format' in tag.attrs:
                    del tag['format']

    # Manually build the output string
    output = []
    for child in soup.body.children:
        if isinstance(child, Tag):
            output.append(str(child))
            output.append('\n')  # Add newline after each top-level element
        elif isinstance(child, str) and not child.strip():
            continue  # Ignore whitespace text nodes
    complete_html = f"""```html\n<html><body>\n{" ".join(output)}</body></html>\n```"""
    return complete_html

In [None]:
def inference(
    prompt,
    image_path,
    system_prompt = 'You are a helpful assistant',
    max_new_tokens = 32000,
    min_pixels = 512 * 28 * 28,
    max_pixels = 2048 * 28 * 28
):
    messages = [
        {
            'role' : 'user',
            'content' : [
                {
                    'type' : 'image',
                    'image' : image_path,
                    'min_pixels' : min_pixels,
                    'max_pixels' : max_pixels,
                },
                {'type' : 'text', 'text' : prompt},
            ],
        }
    ]

    # Preparation for inference
    text = processor.apply_chat_template(
        messages, tokenize = False, add_generation_prompt = True
    )
    print('input:\n', text)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(
        text = [text],
        images = image_inputs,
        videos = video_inputs,
        padding = True,
        return_tensors = 'pt',
    )
    inputs = inputs.to('cuda')

    streamer = TextStreamer(processor.tokenizer, skip_special_tokens = True, skip_prompt = True)

    # Inference: Generation of the output
    generated_ids = model.generate(**inputs, max_new_tokens = max_new_tokens, streamer = streamer)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens = True, clean_up_tokenization_spaces = False
    )

    print('output:\n', output_text[0])

    input_height = inputs['image_grid_thw'][0][1] * 14
    input_width = inputs['image_grid_thw'][0][2] * 14

    return output_text[0], input_height, input_width

## 03 Run Inference

### 03.01 Document Parsing in QwenVL HTML Format

In [None]:
%%time

image_path = './00_Dataset/docparsing_example6.png'
image = Image.open(image_path)
system_prompt = 'You are an AI specialized in recognizing and extracting text from images. Your mission is to analyze the image document and generate the result in QwenVL Document Parser HTML format using specified tags while maintaining user privacy and data integrity.'
prompt = 'QwenVL HTML'

output, input_height, input_width = inference(prompt, image_path)
print(input_height, input_width)
draw_bbox(image_path, input_width, input_height, output)
ordinary_html = clean_and_format_html(output)
print(ordinary_html)

### 03.02 Generate Ordinary HTML with Qwen2.5-VL

In [None]:
image_path = './00_Dataset/docparsing_example5.png'
image = Image.open(image_path)
prompt = '图片解析成html'

output, input_height, input_width = inference(prompt, image_path)