In [None]:
# Install essential libraries for fine-tuning from the Granite Vision notebook
# !pip install git+https://github.com/huggingface/transformers.git
# !pip install -U trl datasets bitsandbytes peft accelerate trackio

# Install docling to handle document processing and DocTags conversion
# !pip install docling

# Optional: Install Flash Attention for better performance on compatible GPUs
# !pip install flash-attn --no-build-isolation

In [None]:
# from huggingface_hub import login
# from kaggle_secrets import UserSecretsClient

# user_secrets = UserSecretsClient()
# secret_value_0 = user_secrets.get_secret("HF_TOKEN")
# login(token=secret_value_0)

In [None]:
# Optional: Install Flash Attention for better performance on compatible GPUs
# !pip install -q flash-attn --no-build-isolation
try:
    import flash_attn
    print("FlashAttention is installed.")
    USE_FLASH_ATTENTION = True
except ImportError:
    print("FlashAttention is not installed.")
    USE_FLASH_ATTENTION = False

# Get the Dataset and convert to doctags

In [None]:
# # Run this command in a Colab cell
# !rm -rf ~/.cache/huggingface/datasets


In [None]:
import torch
from datasets import load_dataset
from docling_core.types.doc import DoclingDocument, BoundingBox, ProvenanceItem, PageItem, PictureItem, ImageRef, Size, DocItemLabel
from PIL import Image
import io
import json
# 1. Load the docling-dpbench dataset
dataset_id = "ds4sd/docling-dpbench"
# Let's use the 'default' configuration, train split
dataset = load_dataset(dataset_id, name="default", split="test")

# For a real scenario, you would use the full dataset and split it
# For this demonstration, we'll just use a small subset
if len(dataset) > 120:
    train_dataset_raw = dataset.select(range(100))
    test_dataset_raw = dataset.select(range(100, 120))
else:
    # Handle smaller datasets by splitting what's available
    train_test_split = dataset.train_test_split(test_size=0.1, seed=42)
    train_dataset_raw = train_test_split["train"]
    test_dataset_raw = train_test_split["test"]



In [None]:

# 2. Define the system message and user prompt
system_message = "A chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
user_prompt = "Convert this page to docling." # This is a supported instruction [7]

def convert_to_rgb(image):
    """Convert image to RGB format if not already in RGB."""
    if image.mode == "RGB":
        return image
    image_rgba = image.convert("RGBA")
    background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
    alpha_composite = Image.alpha_composite(background, image_rgba)
    return alpha_composite.convert("RGB")


def reduce_image_size(image, scale=0.5):
    """Reduce image size by a given scale."""
    original_width, original_height = image.size
    new_width = int(original_width * scale)
    new_height = int(original_height * scale)
    return image.resize((new_width, new_height))

# 3. Function to process a docling-dpbench sample
def process_dpbench_sample(sample):
    """
    Processes a sample from the dpbench dataset to extract the image
    and the target DocTags string.
    """
    try:
        # Load the ground truth image from bytes
        # The image data is in a list, we'll take the first one
        image_bytes_data = sample["GroundTruthPageImages"][0]['bytes']
        image = Image.open(io.BytesIO(image_bytes_data))

        # Load the ground truth DoclingDocument from its JSON string representation
        doc_json_str = sample["GroundTruthDocument"]
        doc_dict = json.loads(doc_json_str)
        doc = DoclingDocument(**doc_dict)

        # Export the document to the required DocTags format [8-10]
        # This is the target output for the model
        target_doctags = doc.export_to_doctags()

        return {
            "image": image,
            "target_text": target_doctags
        }
    except Exception as e:
        # If a sample is corrupted or fails processing, we skip it
        print(f"Skipping sample due to error: {e}")
        return None

# 4. Format the processed data into the required chat structure [11]
def format_data(processed_sample):
    """
    Formats the processed data (image and target text) into the chat
    template expected by the model and trainer.
    """
    image = processed_sample["image"]
    image = convert_to_rgb(image)
    image = reduce_image_size(image)
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": system_message}],
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {
                    "type": "text",
                    "text": user_prompt,
                },
            ],
        },
        {
            "role": "assistant",
            "content": [{"type": "text", "text": processed_sample["target_text"]}],
        },
    ]

# Process the raw datasets
processed_train = [process_dpbench_sample(s) for s in train_dataset_raw]
processed_test = [process_dpbench_sample(s) for s in test_dataset_raw]

# Filter out any samples that failed processing and format them
train_dataset = [format_data(p) for p in processed_train if p is not None]
test_dataset = [format_data(p) for p in processed_test if p is not None]

print(f"Successfully processed {len(train_dataset)} training samples.")

In [None]:
# print(train_dataset[12])

# Loading the model and the tokenizer and testing them

In [None]:
# from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
# model_id = "ibm-granite/granite-docling-258M"
# model = AutoModelForVision2Seq.from_pretrained(
#     model_id,
#     device_map='cuda',
#     torch_dtype=torch.float16,
# )
# processor = AutoProcessor.from_pretrained(model_id)

In [None]:
# import time
# import torch
# from statistics import mean

# def measure_generation_speed(model, processor, sample,
#                              device='cuda',
#                              max_new_tokens=128,
#                              warmup_runs=1,
#                              runs=1):
#     """
#     Returns: dict with keys:
#       - ttf_s: time-to-first-token (seconds, generate-only)
#       - avg_gen_time_s: average generate() call time (seconds)
#       - tokens_per_sec: tokens/sec computed as (actual_generated_tokens / avg_gen_time_s)
#       - tokens_generated: actual generated tokens observed (per sample)
#     """
#     # resolve device
#     if device is None:
#         try:
#             device = str(model.device)
#         except Exception:
#             device = "cuda" if torch.cuda.is_available() else "cpu"
#     device = torch.device(device)

#     # helper for accurate timing with CUDA
#     def _sync():
#         if device.type == "cuda":
#             torch.cuda.synchronize()

#     # prepare tokenized inputs (on CPU)
#     proc_out = processor.apply_chat_template(
#         [sample["messages"][1]],
#         add_generation_prompt=True,
#         tokenize=True,
#         return_dict=True,
#         return_tensors="pt",
#     )
#     input_len = proc_out["input_ids"].shape[-1]

#     # move tensors to device
#     batch_on_device = {k: (v.to(device) if isinstance(v, torch.Tensor) else v)
#                        for k, v in proc_out.items()}

#     # warmup
#     # model.to(device)
#     # model.eval()
#     with torch.no_grad():
#         for _ in range(warmup_runs):
#             _sync()
#             _ = model.generate(**batch_on_device, max_new_tokens=min(8, max_new_tokens))
#             _sync()

#     # 1) Time-to-first-token (generate-only): single call with max_new_tokens=1
#     _sync()
#     t0 = time.perf_counter()
#     with torch.no_grad():
#         out = model.generate(**batch_on_device, max_new_tokens=1)
#     _sync()
#     ttf = time.perf_counter() - t0

#     # compute how many new tokens were produced in that call (usually 1)
#     first_generated = out[0].shape[-1] - input_len

#     # 2) Tokens/sec (generation-only): run several full generate calls and average
#     gen_times = []
#     observed_generated = None
#     with torch.no_grad():
#         for _ in range(runs):
#             _sync()
#             tstart = time.perf_counter()
#             out = model.generate(**batch_on_device, max_new_tokens=max_new_tokens)
#             _sync()
#             elapsed = time.perf_counter() - tstart
#             gen_times.append(elapsed)
#             generated = out[0].shape[-1] - input_len
#             observed_generated = generated if observed_generated is None else observed_generated

#     avg_gen = mean(gen_times) if gen_times else float("nan")
#     tokens_per_sec = (observed_generated / avg_gen) if avg_gen and avg_gen > 0 else float("inf")

#     results = {
#         "ttf_s": float(ttf),
#         "ttf_generated_tokens": int(first_generated),
#         "avg_gen_time_s": float(avg_gen),
#         "tokens_generated": int(observed_generated),
#         "tokens_per_sec": float(tokens_per_sec),
#         "gen_times_list": [float(x) for x in gen_times],
#     }

#     # minimal print
#     print(f"TTF (generate-only): {results['ttf_s']:.4f}s (generated {results['ttf_generated_tokens']} token(s))")
#     print(f"Avg generate time: {results['avg_gen_time_s']:.4f}s  |  Tokens/sec (generation-only): {results['tokens_per_sec']:.2f}")

#     return results
def generate_text_from_sample(model, processor, sample, max_new_tokens=40, device="cuda"):
    inputs = processor.apply_chat_template(
        [sample["messages"][1]],
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

    outputs = model.generate(**inputs, max_new_tokens=max_new_tokens)
    return processor.decode(outputs[0][inputs["input_ids"].shape[-1]:])

In [None]:
# res = generate_text_from_sample(model, processor, sample=train_dataset[0], max_new_tokens=40)
# res

In [None]:
import gc
import time
def clear_memory():
    # Delete variables if they exist in the current global scope
    if "inputs" in globals():
        del globals()["inputs"]
    if "model" in globals():
        del globals()["model"]
    if "processor" in globals():
        del globals()["processor"]
    if "trainer" in globals():
        del globals()["trainer"]
    if "peft_model" in globals():
        del globals()["peft_model"]
    if "bnb_config" in globals():
        del globals()["bnb_config"]
    time.sleep(2)

    # Garbage collection and clearing CUDA memory
    gc.collect()
    time.sleep(2)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    time.sleep(2)
    gc.collect()
    time.sleep(2)

    print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")


# clear_memory()

# Setting up the training loop

In [None]:
# Test cell - run this first to check if basic setup works
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU count: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

In [None]:
from transformers import BitsAndBytesConfig


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_skip_modules=["vision_tower", "lm_head"],  # Skip problematic modules
    llm_int8_enable_fp32_cpu_offload=True,
)


from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
model_id = "ibm-granite/granite-docling-258M"

model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map='auto',
    torch_dtype=torch.float16,
    quantization_config=bnb_config
)
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["q_proj", "v_proj"],
    use_dora=True,
    init_lora_weights="gaussian",
)

# don't:
# Apply PEFT model adaptation
# model = get_peft_model(model, peft_config)
# model.add_adapter(peft_config)
# model.enable_adapters()

# Add adapters only once
model = get_peft_model(model, peft_config)

# Print trainable parameters
model.print_trainable_parameters()

In [10]:
from trl import SFTConfig
training_args = SFTConfig(
    output_dir="granite-final-finetunned",
    num_train_epochs=3,
    # max_steps=30,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=1,
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=1,
    save_strategy="steps",
    save_steps=20,
    save_total_limit=1,
    optim="adamw_torch_fused",
    # bf16=True,
    push_to_hub=False,
    report_to="none",
    remove_unused_columns=False,
    gradient_checkpointing=True,
    dataset_text_field="",
    dataset_kwargs={"skip_prepare_dataset": True},
)

In [None]:
# import trackio

# trackio.init(
#     project="granite-docling",
#     name="granite-docling-trl-sft",
#     config=training_args.to_dict(),
#     space_id=training_args.output_dir + "-trackio",
# )

In [18]:
# Look at your dataset:
example = next(iter(train_dataset))
texts = processor.apply_chat_template(example, tokenize=False)
texts[:400]

"<|start_of_role|>system<|end_of_role|>A chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|end_of_text|>\n<|start_of_role|>user<|end_of_role|><image>Convert this page to docling.<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|><doctag><page_header><loc_71><loc_24><loc_85><loc_36>314</page_header>\n"

The assistant string you are looking for is `<|start_of_role|>assistant<|end_of_role|>`, not `<|assistant|>` 

In [31]:
def collate_fn(examples):
    texts = [processor.apply_chat_template(example, tokenize=False) for example in examples]

    image_inputs = []
    for example in examples:
        image = example[1]["content"][0]["image"]
        if image.mode != "RGB":
            image = image.convert("RGB")
        image_inputs.append([image])

    batch = processor(text=texts, images=image_inputs, return_tensors="pt", padding=True)

    labels = batch["input_ids"].clone()
    # fix assistant text
    # assistant_tokens = processor.tokenizer("<|assistant|>", return_tensors="pt")["input_ids"][0]
    assistant_tokens = processor.tokenizer("<|start_of_role|>assistant<|end_of_role|>", return_tensors="pt")["input_ids"][0]
    eos_token = processor.tokenizer("<|end_of_text|>", return_tensors="pt")["input_ids"][0]

    for i in range(batch["input_ids"].shape[0]):
        apply_loss = False
        for j in range(batch["input_ids"].shape[1]):
            if not apply_loss:
                labels[i][j] = -100
            if (j >= len(assistant_tokens) + 1) and torch.all(
                batch["input_ids"][i][j + 1 - len(assistant_tokens) : j + 1] == assistant_tokens
            ):
                apply_loss = True
            if batch["input_ids"][i][j] == eos_token:
                apply_loss = False

    batch["labels"] = labels

    return batch

In [32]:
# Test your collator, check labels

train_dataset_iter = iter(train_dataset)
samples = [ next(train_dataset_iter) for i in range(8)]
batch = collate_fn(samples)

if torch.all(batch["labels"] == -100):
    print("Useless samples!")

In [33]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset= test_dataset,
    data_collator=collate_fn,
    # Don't do this, you added adapters already:
    # peft_config=peft_config,
    processing_class=processor.tokenizer,
)

In [34]:
trainer.train()

  return fn(*args, **kwargs)


Step,Training Loss
1,1.0813
2,0.9484
3,1.1411
4,0.9997
5,1.1438
6,1.0642


KeyboardInterrupt: 

In [None]:
trainer.save_model(training_args.output_dir)

# Testing the model

In [None]:
clear_memory()

In [None]:
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
model_id = "ibm-granite/granite-docling-258M"
model = AutoModelForVision2Seq.from_pretrained(
    model_id,
    device_map='auto',
    torch_dtype=torch.float16,
)
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
adapter_path = "/kaggle/working/granite-final-finetunned"
model.load_adapter(adapter_path)

In [None]:
train_dataset[0]["images"][0]

In [None]:
output = generate_text_from_sample(model, processor, train_dataset[0])
output