<a href="https://colab.research.google.com/github/mit1280/Document-AI/blob/main/Fine_tune_KOSMOS_2_for_multimodal_grounding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Inference with KOSMOS-2 for multimodal grounding and referral

In this notebook, we'll perform inference with Microsoft's new impressive multimodal large language model (LLM) called [KOSMOS-2](https://huggingface.co/docs/transformers/main/en/model_doc/kosmos-2).


https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L619

https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L2924

https://discuss.huggingface.co/t/how-is-the-data-shifted-by-one-token-during-causallm-fine-tuning/36386

https://github.com/huggingface/transformers/blob/b2748a6efd045dd771f8fd48e8b309cbc061c618/src/transformers/models/kosmos2/__init__.py

https://github.com/microsoft/unilm/blob/master/kosmos-2/fairseq/fairseq/logging/metrics.py

https://github.com/huggingface/transformers/blob/main/src/transformers/trainer_pt_utils.py#L482

## Set-up environment

Let's start by installing 🤗 Transformers. We install from main here since the model is brand new at the time of writing. We also install Accelerate and Bitsandbytes since those will provide [4-bit inference](https://huggingface.co/blog/4bit-transformers-bitsandbytes), greatly reducing the memory requirements to load the model (without those I wouldn't be able to load the model in Google Colab).

In [None]:
# install required libaries
!pip install -q -U transformers accelerate bitsandbytes seqeval evaluate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m29.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m290.1/290.1 kB[0m [31m30.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m102.2/102.2 MB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m41.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m14.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m20

In [None]:
from transformers import AutoProcessor, AutoModelForVision2Seq
import requests
from datasets import load_dataset
from datasets.features import ClassLabel
import re
from PIL import Image, ImageDraw, ImageFont
import math
import random
from transformers import Kosmos2Config, Kosmos2Model



> The image resolution is set to 1280×1280 and the patch size is 10×10. We divide the width and height of the image into 256 bins, with each bin consisting of 5×5 pixels. A total of 256×256 location tokens are added to the vocabulary.



In [None]:
## Config of Kosmos2 changed just to demonstrate fine tuning
# configuration For to test fine tuning code
configuration = Kosmos2Config(
      text_config = {"layers" : 4},
      vision_config = {"num_hidden_layers" : 4}
)
'''
# configuration for actual fine-tuning
configuration = Kosmos2Config()
'''

In [None]:
from transformers import Kosmos2ForConditionalGeneration
from transformers import Kosmos2Config, Kosmos2Model, AutoProcessor

model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224", device_map="auto",  config = configuration)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224", add_eos_token=True, device_map="auto")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Load model

Next, let's load the model along with its processor. We specify `load_in_4bit=True` to reduce the size of the weights to be able to load the model in Google Colab. This is all thanks to the magic of bitsandbytes' integration in the Transformers library (see [this blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes) for all info). We also specify to place the model on the GPU (with id=0, meaning the first GPU on our system).

In [None]:
# dataset_id ="pierreguillou/DocLayNet-small"
# This dataset is takes from DocLayNet dataset
## This finetuning was done for the layout detection in the any image. Task was to find table/ header/ footer... from the any given image.
dataset_id = "Mit1208/test_dataset"

dataset = load_dataset(dataset_id, trust_remote_code=True)

print(f"Train dataset size: {len(dataset['train'])}")
# print(f"Test dataset size: {len(dataset['test'])}")

Train dataset size: 4


In [None]:
# Remove data which has no text
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/kosmos2/processing_kosmos2.py#L154
dataset = dataset.filter(lambda example: len(example['texts']) > 0)

## Define variables

below part is to defind id2label and label2id, some of the code is for creating visualization of layouts on the image (you can ignore color part).

In [None]:
features = dataset["train"].features
column_names = dataset["train"].column_names
image_column_name = "image"
text_column_name = "texts"
boxes_column_name = "bboxes_block"
label_column_name = "categories"

# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
# unique labels.
def get_label_list(labels):
    unique_labels = set()
    for label in labels:
        unique_labels = unique_labels | set(label)
    label_list = list(unique_labels)
    label_list.sort()
    return label_list

if isinstance(features[label_column_name].feature, ClassLabel):
    label_list = features[label_column_name].feature.names
    # No need to convert the labels since they are already ints.
    id2label = {k: v for k,v in enumerate(label_list)}
    label2id = {v: k for k,v in enumerate(label_list)}
else:
    label_list = get_label_list(dataset["train"][label_column_name])
    id2label = {k: v for k,v in enumerate(label_list)}
    label2id = {v: k for k,v in enumerate(label_list)}
num_labels = len(label_list)

In [None]:
id2label

{0: 'Caption',
 1: 'Footnote',
 2: 'Formula',
 3: 'List-item',
 4: 'Page-footer',
 5: 'Page-header',
 6: 'Picture',
 7: 'Section-header',
 8: 'Table',
 9: 'Text',
 10: 'Title'}

In [None]:
# Define colors for all labels
get_colors = lambda n: list(map(lambda i: "#" + "%06x" % random.randint(0, 0xFFFFFF),range(n)))
colors = get_colors(len(label_list))
font = ImageFont.load_default()
label2color = {label: colors[idx] for idx, label in enumerate(label_list)}

In [None]:
# Normalize box diamentions to range 0 to 1000
def normalized_box(box, image_width=1025, image_height=1025):
    return [
        round(float(box[0] / image_width), 6),
        round(float(box[1] / image_height), 6),
        round(float(box[2] / image_width), 6),
        round(float(box[3] / image_height), 6),
    ]

def convert_box(bbox):
    x, y, w, h = tuple(bbox) # Box coordinates are in (left, top, width, height) format
    return [x, y, x+w, y+h] # we need to convert it into (x1, y1, x2, y2) which is (left, top, left+widght, top+height)

In [None]:
example = dataset["train"][0]
# This function remove duplicate entries from the dataset
def set_cat_box(example):
    list1_tuples = [tuple(inner_list) for inner_list in example['bboxes_block']]

    # Create unique pairs
    unique_pairs = set(zip(list1_tuples, example['categories']))

    # Separate the unique pairs back into lists
    result_list1, result_list2 = zip(*unique_pairs)
    return result_list1, result_list2

# set_boxs, set_categories = set_cat_box(example)

In [None]:
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

train_df = pd.DataFrame(dataset['train'])

In [None]:
train_df['type'] = 'train'
all_df = train_df

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'texts', 'bboxes_block', 'bboxes_line', 'categories', 'image', 'page_hash', 'original_filename', 'page_no', 'num_pages', 'original_width', 'original_height', 'coco_width', 'coco_height', 'collection', 'doc_category'],
        num_rows: 4
    })
})

In [None]:
## Create proper prompt which has grounding labels and it's location.
def pre_process_data(example_df):

    set_boxs, set_categories = set_cat_box(example_df)
    example_df['float_val'] = [tuple(normalized_box(convert_box(i))) for i in set_boxs]
    example_df['text'] = '<grounding> This image is type of ' + example_df['doc_category'] + '. It has multiple page layouts ' + ", ".join(["<phrase>" + id2label[i] +"</phrase>" for i in set_categories]) + 'in it.'

    # print(encoding)
    return example_df

In [None]:
all_df = all_df.progress_apply(pre_process_data, axis=1)

100%|██████████| 4/4 [00:00<00:00, 373.47it/s]


In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
from datasets import Dataset
## process prompt. Note: this will convert bounding box to required text and then convert it to number
inputs = processor(images = all_df['image'].to_list(), text = all_df['text'].to_list(), bboxes = all_df['float_val'].to_list(), padding=True, truncation= True, return_tensors="pt")
labels = inputs['input_ids'].clone()
labels[inputs['input_ids'] == 1] = -100
inputs['labels'] = labels

dataset = Dataset.from_dict(inputs)
train_test_split = dataset.train_test_split(test_size=0.3)

In [None]:
train_test_split

DatasetDict({
    train: Dataset({
        features: ['pixel_values', 'input_ids', 'attention_mask', 'image_embeds_position_mask', 'labels'],
        num_rows: 2
    })
    test: Dataset({
        features: ['pixel_values', 'input_ids', 'attention_mask', 'image_embeds_position_mask', 'labels'],
        num_rows: 2
    })
})

In [None]:
train_dataset = train_test_split['train']
test_dataset = train_test_split['test']

In [None]:
train_dataset.set_format("torch")
test_dataset.set_format("torch")

In [None]:
import torch

example = train_test_split['train'][0]
for k,v in example.items():
    print(k,v.shape)

pixel_values torch.Size([3, 224, 224])
input_ids torch.Size([221])
attention_mask torch.Size([221])
image_embeds_position_mask torch.Size([221])
labels torch.Size([221])


In [None]:
from huggingface_hub import notebook_login

In [None]:
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
                                    output_dir="./kosmos-finetuned-DocLayNet",# Output directory
                                    max_steps=1000,                              # Maximum number of training steps
                                    per_device_train_batch_size=1,               # Batch size for training
                                    per_device_eval_batch_size=1,                # Batch size for evaluation
                                    gradient_accumulation_steps=2,
                                    eval_accumulation_steps=2,
                                    learning_rate=1e-5,                          # Learning rate for the optimizer
                                    evaluation_strategy="steps",                 # Evaluate every "eval_steps" steps
                                    eval_steps=100,                              # Evaluate every 250 steps
                                    save_strategy="steps",                       # Save checkpoints every "save_steps" steps
                                    save_steps=100,                             # Save checkpoints every 1000 steps
                                    logging_dir='./logs',                        # Directory for storing logs
                                    logging_steps=250,                           # Log every "logging_steps" steps
                                    load_best_model_at_end=True,                 # Load the best model when finished training
                                    # metric_for_best_model="accuracy",            # Use accuracy as the metric to compare models
                                    # greater_is_better = True,                     # Indicate whether the metric is to be maximized or minimized
                                    warmup_ratio=0.1, # we warmup a bit
                                    # fp16=True, # we use mixed precision (less memory consumption)
                                    push_to_hub=True, # after training, we'd like to push our model to the hub
                                    push_to_hub_model_id=f"kosmos-finetuned-DocLayNet", # this is the name we'll use for our model on the hub
)



In [None]:
from transformers.data.data_collator import default_data_collator


# Initialize our Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=processor,
    data_collator=default_data_collator,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [None]:
trainer.train()

Step,Training Loss,Validation Loss
100,No log,1.61236
200,No log,1.566668
300,2.393000,1.660224
400,2.393000,1.736173
500,0.004200,1.764009
600,0.004200,1.800707
700,0.004200,1.779289
800,0.001600,1.784427
900,0.001600,1.802788
1000,0.001000,1.804111


SafetensorError: Error while serializing: IoError(Os { code: 28, kind: StorageFull, message: "No space left on device" })

## Only useful if patch size is different then 224

In [None]:
'''
# Initializing a Kosmos-2 kosmos-2-patch14-224 style configuration
configuration = Kosmos2Config(
      text_config = {"max_position_embeddings" : 2048*2, "attention_heads" : 32*4},
      vision_config = {"image_size" : 1280, "patch_size" : 256}
    )
# configuration = Kosmos2Config(latent_query_num = 64 * 4)
# model = Kosmos2ForConditionalGeneration.from_pretrained("microsoft/kosmos-2-patch14-224", config = configuration, ignore_mismatched_sizes=True)
# num_patches_per_side = 32*math.sqrt(total_tokens_increase_by)
# total_tokens_increase_by = 64
# # , num_patch_index_tokens = 1024 * total_tokens_increase_by
'''

In [1]:
# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L35C1-L75C38
# (with format modifications)
def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: int):
    # Compute the size of each cell in the grid
    cell_size = 1.0 / num_patches_per_side

    # Compute the x and y indices of the upper-left and lower-right corners of the bounding box
    ul_x = ul_idx % num_patches_per_side
    ul_y = ul_idx // num_patches_per_side

    lr_x = lr_idx % num_patches_per_side
    lr_y = lr_idx // num_patches_per_side

    # Compute the normalized coordinates of the bounding box
    if ul_idx == lr_idx:
        x1 = ul_x * cell_size
        y1 = ul_y * cell_size
        x2 = lr_x * cell_size + cell_size
        y2 = lr_y * cell_size + cell_size
    elif ul_x == lr_x or ul_y == lr_y:
        x1 = ul_x * cell_size
        y1 = ul_y * cell_size
        x2 = lr_x * cell_size + cell_size
        y2 = lr_y * cell_size + cell_size
    else:
        x1 = ul_x * cell_size + cell_size / 2
        y1 = ul_y * cell_size + cell_size / 2
        x2 = lr_x * cell_size + cell_size / 2
        y2 = lr_y * cell_size + cell_size / 2

    return x1, y1, x2, y2


# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L4-L33
# (with format modifications)
def extract_entities_with_patch_indices(text):
    # The regular expression pattern for matching the required formats
    pattern = r"(?:(<phrase>([^<]+)</phrase>))?<object>((?:<patch_index_\d+><patch_index_\d+></delimiter_of_multi_objects/>)*<patch_index_\d+><patch_index_\d+>)</object>"

    # Find all matches in the given string
    matches = re.finditer(pattern, text)

    # Initialize an empty list to store the valid patch_index combinations
    entities_with_patch_indices = []

    for match in matches:
        # span of a `phrase` that is between <phrase> and </phrase>
        span = match.span(2)
        phrase_tag, phrase, match_content = match.groups()
        if not phrase_tag:
            phrase = None
            # We take the starting position of `<object>`
            span = (match.span(0)[0], match.span(0)[0])

        # Split the match_content by the delimiter to get individual patch_index pairs
        patch_index_pairs = match_content.split("</delimiter_of_multi_objects/>")

        entity_bboxes = []
        for pair in patch_index_pairs:
            # Extract the xxxx and yyyy values from the patch_index pair
            x = re.search(r"<patch_index_(\d+)>", pair)
            y = re.search(r"<patch_index_(\d+)>", pair[1:])

            if x and y:
                if phrase:
                    entity_bboxes.append((int(x.group(1)), int(y.group(1))))
                else:
                    entity_bboxes.append((int(x.group(1)), int(y.group(1))))

        if phrase:
            entities_with_patch_indices.append((phrase, span, entity_bboxes))
        else:
            for bbox in entity_bboxes:
                # fake entity name
                entity = f"<patch_index_{bbox[0]}><patch_index_{bbox[1]}>"
                entities_with_patch_indices.append((entity, span, [bbox]))

    return entities_with_patch_indices


def adjust_entity_positions(entity, text):
    """Adjust the positions of the entities in `text` to be relative to the text with special fields removed."""
    entity_name, (start, end) = entity
    # computed the length of strings with special fields (tag tokens, patch index tokens, etc.) removed
    adjusted_start = len(re.sub("<.*?>", "", text[:start]))
    adjusted_end = len(re.sub("<.*?>", "", text[:end]))
    adjusted_entity = (entity_name, (adjusted_start, adjusted_end))
    return adjusted_entity


def _cleanup_spaces(text, entities):
    """Remove the spaces around the text and the entities in it."""
    new_text = text.strip()
    leading_spaces = len(text) - len(text.lstrip())

    new_entities = []
    for entity_name, (start, end), bboxes in entities:
        entity_name_leading_spaces = len(entity_name) - len(entity_name.lstrip())
        entity_name_trailing_spaces = len(entity_name) - len(entity_name.rstrip())

        start = start - leading_spaces + entity_name_leading_spaces
        end = end - leading_spaces - entity_name_trailing_spaces
        entity_name = entity_name.strip()

        new_entities.append((entity_name, (start, end), bboxes))

    return new_text, new_entities


# copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L77-L87
# (with format modifications)
def clean_text_and_extract_entities_with_bboxes(text, num_patches_per_side=32):
    # remove special fields (tag tokens, patch index tokens, etc.)
    processed_text = re.sub("<.*?>", "", text)

    entities_with_patch_indices = extract_entities_with_patch_indices(text)
    entities = []
    for item in entities_with_patch_indices:
        entity, bboxes = item[0:2], item[2]
        adjusted_entity = adjust_entity_positions(entity, text)
        bboxes_in_coords = [patch_index_to_coordinate(bbox[0], bbox[1], num_patches_per_side) for bbox in bboxes]

        entities.append(adjusted_entity + (bboxes_in_coords,))

    return _cleanup_spaces(processed_text, entities)