In [None]:
# the main loop in the fine-tuning process
import os
from datasets import load_dataset
from omegaconf import DictConfig, OmegaConf
from ft_src.sft_trainer import CustomSFTTrainer
from ft_src.model import Model
import torch

cfg = OmegaConf.load("cfg/qwen2_5-vl_train.yaml")
# 1. load the model
vlm_model = Model(cfg['model'])

# 2. load the data
dataset = load_dataset(cfg.dataset.dataset_id, split='train')

In [None]:
from ft_src.sft_dataset import generate_description
data_sample = dataset[50]
original_hddl = generate_description(data_sample, vlm_model.model, vlm_model.processor)

# load the lora model
vlm_model.model.load_adapter(cfg.trainer.output_dir)
ft_hddl = generate_description(data_sample, vlm_model.model, vlm_model.processor)

In [None]:
import pandas as pd
from IPython.display import display, HTML

def compare_generations(base_gen, ft_gen):
    # Ensure strings are displayed with line breaks
    base_gen = base_gen.encode().decode('unicode_escape')
    ft_gen = ft_gen.encode().decode('unicode_escape')

    # Create a DataFrame
    df = pd.DataFrame({
        'Base Generation': [base_gen],
        'Fine-tuned Generation': [ft_gen]
    })

    # Style the DataFrame for multiline rendering and fixed-width formatting
    styled_df = df.style.set_table_styles([
        {
            'selector': 'td',
            'props': [
                ('text-align', 'left'),
                ('white-space', 'pre-wrap'),
                ('font-family', '"Courier New", monospace'),
                ('border', '1px solid black'),
                ('padding', '10px'),
                ('vertical-align', 'top'),
                ('width', '500px'),  # adjust width as needed
                ('overflow-wrap', 'break-word')
            ]
        },
        {
            'selector': 'th',
            'props': [
                ('text-align', 'left'),
                ('font-family', '"Courier New", monospace'),
                ('border', '1px solid black'),
                ('padding', '10px')
            ]
        }
    ])

    # Display in notebook or IPython
    display(HTML(styled_df.to_html()))

compare_generations(original_hddl[0], ft_hddl[0])
