In [None]:
import lightning.pytorch as pl
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import torchvision.transforms as T
import json

class GOT_OCR2_BBox_LitModule(pl.LightningModule):
    def __init__(self, model_name="stepfun-ai/GOT-OCR-2.0-hf", max_new_tokens=1024):
        super().__init__()
        self.model = AutoModelForImageTextToText.from_pretrained(model_name)
        self.processor = AutoProcessor.from_pretrained(model_name)
        self.max_new_tokens = max_new_tokens
        # 출력: bbox만, json array, pixel 좌표. 설명/코드블록/부가텍스트 금지!
        self.processor.tokenizer.add_tokens(["<image>"], special_tokens=True)
        self.model.resize_token_embeddings(len(self.processor.tokenizer))

        self.prompt = (
            "Detect all text regions in the <image> and return ONLY a valid JSON array of bounding boxes. "
            "Format: [{\"points\":[x1(int),y1(int),x2(int),y2(int)]}]. "
            "Use image pixel coordinates as integers. No explanation or code block."
        )


    def predict_step(self, batch, batch_idx):
        device = self.device

        pil_images = []

        # 배치의 이미지를 PIL로 변환
        for img in batch["images"]:
            if isinstance(img, torch.Tensor):
                img = img.detach().cpu()
                if img.dim() == 3 and img.shape[0] in (1, 3):
                    img = T.ToPILImage()(img)
            elif isinstance(img, str):
                img = Image.open(img).convert("RGB")
            pil_images.append(img)
        
        # Processor 호출 (이미지 + 프롬프트)
        inputs = self.processor(
            images=pil_images,
            text=[self.prompt] * len(pil_images),
            return_tensors="pt",
            padding=True,
            multi_page=True,
            format=True
        ).to(device)

        with torch.inference_mode():
            output_ids = self.model.generate(
                **inputs,
                do_sample=False,
                tokenizer=self.processor.tokenizer,
                stop_strings="<|im_end|>",
                max_new_tokens=self.max_new_tokens
            )

        # 결과 디코딩 후 JSON 파싱
        start_idx = inputs["input_ids"].shape[1]
        print(inputs["input_ids"])
        """         # 입력 길이 파악
        input_len = inputs["input_ids"].shape[1] if "input_ids" in inputs else 0
        results = []
        for o in output_ids:
            text_out = self.processor.decode(
                o[input_len:], skip_special_tokens=True
            )
            js = self.extract_json(text_out) or "[]"
            try:
                data = json.loads(js)
                # 단일 dict일 경우 리스트로 감쌈
                if isinstance(data, dict):
                    data = [data]
                # 보정: 숫자/길이 체크, x1<=x2, y1<=y2
                cleaned = []
                for item in data:
                    bbox = item.get("points")
                    if (
                        isinstance(bbox, list) and len(bbox) == 4 and
                        all(isinstance(v, (int, float)) for v in bbox)
                    ):
                        x1, y1, x2, y2 = bbox
                        x1, x2 = (x1, x2) if x1 <= x2 else (x2, x1)
                        y1, y2 = (y1, y2) if y1 <= y2 else (y2, y1)
                        cleaned.append({"points": [int(x1), int(y1), int(x2), int(y2)]})
                results.append(cleaned)
            except Exception:
                results.append([]) """

        """ print(self.processor.decode(output_ids[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True))
        print("*"*40)
        print(self.processor.tokenizer.decode(inputs["input_ids"][0]))
        print("*"*40)
        print(f"len={output_ids.shape}")
        print(output_ids) """
        pil_images = []

        preds = self.processor.batch_decode(output_ids, skip_special_tokens=True)
        return preds

In [None]:
import torch; print(torch.__version__)
import transformers; print(transformers.__version__)
import huggingface_hub; print(huggingface_hub.__version__)

In [1]:
import lightning.pytorch as pl
import torch
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
import torchvision.transforms as T
from PIL import Image
import json

class Qwen2VL_BBox_OCR_LitModule(pl.LightningModule):
    def __init__(self, model_name="Qwen/Qwen2-VL-2B-Instruct", max_new_tokens=1024):
        super().__init__()
        self.model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, torch_dtype="auto")
        self.processor = AutoProcessor.from_pretrained(model_name)
        self.max_new_tokens = max_new_tokens

        # 프롬프트: 박스만! JSON array, pixel 좌표, 불필요 텍스트 금지
        self.prompt = (
            "Detect all text regions in the image and return ONLY a valid JSON array of bounding boxes. "
            "Format: [{\"points\":[x1(int),y1(int),x2(int),y2(int)]}]. "
            "Use image pixel coordinates as integers. No explanation or code block."
        )

    def predict_step(self, batch, batch_idx):

        pil_images = []

        # 배치의 이미지를 PIL로 변환
        for img in batch["images"]:
            if isinstance(img, torch.Tensor):
                img = img.detach().cpu()
                if img.dim() == 3 and img.shape[0] in (1, 3):
                    img = T.ToPILImage()(img)
            elif isinstance(img, str):
                img = Image.open(img).convert("RGB")
            pil_images.append(img)
        
        # batch['image']에 PIL 이미지가 들어있다고 가정
        #image = batch['image'] if isinstance(batch, dict) else batch

        # Qwen-VL 프롬프트 템플릿 생성 (chat-based API 활용)
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {"type": "text", "text": self.prompt}
                ]
            }
        ]
        text_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)

        # Processor 호출
        inputs = self.processor(
            text=[text_prompt],
            images=pil_images,
            padding=True,
            return_tensors="pt"
        ).to(self.model.device)

        with torch.no_grad():
            output_ids = self.model.generate(**inputs, max_new_tokens=self.max_new_tokens)
            # 프롬프트 길이만큼 앞부분 잘라냄
            generated_ids = [
                output_id[input_id.size(0):]
                for input_id, output_id in zip(inputs.input_ids, output_ids)
            ]
            output_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)

            print(f"output_text={output_text}")
        # 결과: JSON array로 파싱
        try:
            print("inininini")
            bbox_array = json.loads(output_text[0])
            print(f"bbox_array={bbox_array}", flush=True)
        except Exception:
            print("exexe")
            bbox_array = []
        return bbox_array


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model_module = GOT_OCR2_BBox_LitModule()
print(isinstance(model_module, pl.LightningModule))  

In [2]:
import os
import sys
import lightning.pytorch as pl
import hydra
from omegaconf import OmegaConf

sys.path.append('/data/ephemeral/home/work/python/gx-ocr')
from ocr.lightning_modules import get_pl_modules_by_cfg  # noqa: E402
#from ocr.lightning_modules.got_ocr2 import GOT_OCR2_BBox_LitModule

CONFIG_DIR = os.environ.get('OP_CONFIG_DIR') or '../configs'


from ocr.models import get_model_by_cfg
from ocr.datasets import get_datasets_by_cfg
from ocr.lightning_modules.ocr_pl import OCRPLModule, OCRDataPLModule


def get_data_modules(config):
    dataset = get_datasets_by_cfg(config.datasets)
    data_modules = OCRDataPLModule(dataset=dataset, config=config)
    return data_modules




#@hydra.main(config_path=CONFIG_DIR, config_name='predict', version_base='1.2')
def predict(config):
    """
    Train a OCR model using the provided configuration.

    Args:
        `config` (dict): A dictionary containing configuration settings for predict.
    """
    pl.seed_everything(config.get("seed", 42), workers=True)

    model_module = Qwen2VL_BBox_OCR_LitModule()

    data_module = get_data_modules(config)

    trainer = pl.Trainer(logger=False)

    #ckpt_path = config.get("checkpoint_path")
    #assert ckpt_path, "checkpoint_path must be provided for prediction"

    trainer.predict(model_module,
                    data_module
                    )

if __name__ == "__main__":

    from hydra import initialize, compose

    with initialize(config_path="../configs"):
        cfg = compose(config_name="predict")
    print(cfg)


    print(cfg)
    predict(cfg)
    

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="../configs"):
Seed set to 42
`torch_dtype` is deprecated! Use `dtype` instead!


{'seed': 42, 'exp_name': 'ocr_training', 'checkpoint_path': None, 'minified_json': False, 'dataset_path': 'ocr.datasets', 'model_path': 'ocr.models', 'encoder_path': 'ocr.models.encoder', 'decoder_path': 'ocr.models.decoder', 'head_path': 'ocr.models.head', 'loss_path': 'ocr.models.loss', 'lightning_path': 'ocr.lightning_modules', 'log_dir': 'outputs/${exp_name}/logs', 'checkpoint_dir': 'outputs/${exp_name}/checkpoints', 'submission_dir': 'outputs/${exp_name}/submissions', 'dataset_base_path': '/data/ephemeral/home/work/python/gx-ocr/data/datasets/', 'datasets': {'train_dataset': {'_target_': '${dataset_path}.OCRDataset', 'image_path': '${dataset_base_path}images/train', 'annotation_path': '${dataset_base_path}jsons/train.json', 'transform': '${transforms.train_transform}'}, 'val_dataset': {'_target_': '${dataset_path}.OCRDataset', 'image_path': '${dataset_base_path}images/val', 'annotation_path': '${dataset_base_path}jsons/val.json', 'transform': '${transforms.val_transform}'}, 'test_

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 11.69it/s]
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_

Predicting DataLoader 0:   0%|          | 0/413 [00:00<?, ?it/s]output_text=['```json\n[\n  {\n    "points":[200,100,250,150]\n  },\n  {\n    "points":[250,100,300,150]\n  },\n  {\n    "points":[300,100,350,150]\n  },\n  {\n    "points":[350,100,400,150]\n  },\n  {\n    "points":[400,100,450,150]\n  },\n  {\n    "points":[450,100,500,150]\n  },\n  {\n    "points":[500,100,550,150]\n  },\n  {\n    "points":[550,100,600,150]\n  },\n  {\n    "points":[600,100,650,150]\n  },\n  {\n    "points":[650,100,700,150]\n  },\n  {\n    "points":[700,100,750,150]\n  },\n  {\n    "points":[750,100,800,150]\n  },\n  {\n    "points":[800,100,850,150]\n  },\n  {\n    "points":[850,100,900,150]\n  },\n  {\n    "points":[900,100,950,150]\n  },\n  {\n    "points":[950,100,1000,150]\n  }\n]\n```']
inininini
exexe
Predicting DataLoader 0:   0%|          | 1/413 [00:08<58:26,  0.12it/s]output_text=['```json\n[\n  {\n    "points":[200,100,250,200]\n  },\n  {\n    "points":[250,100,300,200]\n  },\n  {\n    "poi

/data/ephemeral/home/work/python/gx-ocr/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...
