In [None]:
import json
import os
import torch
import logging
import random
import warnings
import argparse
import cProfile
import pstats
import torch.multiprocessing as mp
import time
from typing import Optional, List

from argparse import Namespace
from glob import glob
from PIL import Image, ImageOps
from tqdm import tqdm
from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
from optimum.bettertransformer import BetterTransformer
from optimum.onnxruntime import ORTModelForCausalLM
from optimum.onnxruntime.configuration import AutoQuantizationConfig
from optimum.onnxruntime import ORTOptimizer

warnings.simplefilter("ignore", FutureWarning)
logging.basicConfig(
    level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

TOTAL_GPUS = 8
BOOTSTRAPS = 3

# TODO: Optimize prompt length?
PROMPT = """Identify the jersey number of the basketball player in the frame. If none, return None. Output only the digits:
<jersey_number>
[EOS]"""

# TODO:
# 1. Optimum -- NO SUPPORT
# 2. to_bettertransformer() -- NO SUPPORT
# 3. torch.backends.cuda.sdp_kernel
# 4. tensorRT (no quantization): https://huggingface.co/docs/optimum/onnxruntime/usage_guides/gpu#accelerated-inference-on-nvidia-gpus
# 5. autocast -- NO SPEEDUP vs. HALF()
# 6. 4/8bit quantization -- NO SUPPORT
# 7. # bootstraps 9 -> 3
# 8. florence large -> florence-base
# 9. export for onnx + onnx rt
# 10. torch dataloader?
# 11. paralellize data pre-processing w/ processor obj.

def load_model_and_tokenizer(device: int = 0, args=None):
    # compile_model = args.compile_model
    # precision = args.precision
    try:
        logger.info("Loading model and tokenizer...")
        model = (
            AutoModelForCausalLM.from_pretrained(
                # "microsoft/Florence-2-large-ft",
                "microsoft/Florence-2-base-ft",
                trust_remote_code=True,
                device_map="cuda",
            )
            .eval()
            .to(device)
        )

        # attempt to speed up inference by compiling model JIT
        # if compile_model == "True":
        #     logger.info("Compiling model...")
        #     model = torch.compile(model, mode="max-autotune")
        
        # optimizer = ORTOptimizer.from_pretrained(model)
        # quant_config = AutoQuantizationConfig.avx512_vnni()
        # optimizer.optimize_model("onnx_model.onnx", quantization_config=quant_config)

        model = torch.compile(model)
        processor = AutoProcessor.from_pretrained(
            # "microsoft/Florence-2-large-ft", 
            "microsoft/Florence-2-base-ft", 
            trust_remote_code=True,
        )
        return model, processor
    except Exception as e:
        logger.error(f"Failed to load model or tokenizer: {e}")
        raise

In [None]:
model, processor = load_model_and_tokenizer()

In [None]:
# use half precision
half = model.half()

In [None]:
def is_valid_jersey_number(text):
    # TODO: what about the number "00"?
    if text.isdigit():
        number = int(text)
        return 0 <= number <= 99
    return False

In [None]:
from transformers import BitsAndBytesConfig
from concurrent.futures import ThreadPoolExecutor, as_completed
from torch.cuda.amp import autocast


BOOTSTRAPS = 1


def preprocess_image(image, prompt, device):
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device, non_blocking=True)
    pixel_values = inputs["pixel_values"].to(device, non_blocking=True).half()
    return input_ids, pixel_values

def ocr(
    image_file_paths: List[str],
    model,
    processor,
    device: int = 0,
    args: Optional[dict] = None,
) -> Optional[List[str]]:
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.DEBUG)
    bootstraped_results = []

    def load_image(fp):
        try:
            image = Image.open(fp)
            image.load()
            return image
        except Exception as e:
            logger.error(f"Failed to load image {fp}: {e}")
            return None

    # Load images in parallel
    start = time.time()
    with ThreadPoolExecutor() as executor:
        images = list(executor.map(load_image, image_file_paths))
    images = [img for img in images if img is not None]
    end = time.time()
    logger.debug(f"Images loaded in: {end - start:.2f}s")

    if not images:
        logger.error("No valid images loaded.")
        return None

    # TODO: maybe resize images before using processor
    prompts = [PROMPT] * len(images)
    
    start = time.time()
    inputs = processor(text=prompts, images=images, return_tensors="pt")
    end = time.time()
    logger.debug(f"Preprocessing inputs took: {end - start:.2f}s")

    start = time.time()
    input_ids = inputs["input_ids"].to(device, non_blocking=True)
    pixel_values = inputs["pixel_values"].to(device, non_blocking=True).half()
    end = time.time()
    del inputs
    logger.debug(f"Copying + deleting inputs took: {end - start:.2f}s")

    with torch.no_grad():
        start = time.time()
        generated_ids = model.generate(
            input_ids=input_ids,
            pixel_values=pixel_values,
            max_new_tokens=5,
            do_sample=False,
            early_stopping=False,
            num_beams=BOOTSTRAPS,
            num_return_sequences=BOOTSTRAPS,
        )
        end = time.time()
        logger.debug(f"Generating ids took: {end - start:.2f}s")

    # decode the generated text
    start = time.time()
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)
    end = time.time()
    logger.debug(f"Batch decoding took: {end - start:.2f}s")

    # post-process the output
    start = time.time()
    for gt, image in zip(generated_text, images):
        parsed_answer = processor.post_process_generation(
            gt, task="<OCR>", image_size=(image.width, image.height)
        )
        bootstraped_results.append(parsed_answer)
    end = time.time()
    logger.debug(f"Post processing outputs took: {end - start:.2f}s")

    # except Exception as e:
    #     logger.error(f"OCR processing failed: {e}")
    #     return None

    return bootstraped_results if bootstraped_results else None

In [None]:
ex_img_6 = "/mnt/opr/levlevi/player-re-id/src/testing/constrastive_matching/clip_reid/data/data_reid/reid_challenge/gallery/00955.jpeg"
batch_size = 96

start = time.time()
results = ocr([ex_img_6] * batch_size, half, processor)
end = time.time()
print(f"Total inference time: {end-start}s")
print(f"Inference time per prompt: {(end-start) / batch_size}s")
print(results)

| Batch Size | Time/Image (Sec.) |
| :---: | :---: |
| 1 | 0.5399 | 
| 2 | 0.3558 |
| 4 | 0.2159 |
| **6** | **0.2114** |
| 8 | OOM |

--- 

| Batch Size | `.half()` | `autocast()` | Time/Image (Sec.) |
| :---: | :---: |  :---: | :---: |
| 4 | No | No | 0.2159 |
| 4 | Yes | No | 0.0949 |
| 8 | No | No | OOM |
| 8 | Yes | No | 0.0856 |
| 16 | Yes | No | 0.0808 |
| 32 | Yes | No | 0.0785 |
| 64 | Yes | No | 0.0783 |
| 64 | Yes | Yes | 0.1006 |
| **96** | **Yes** | No | **0.0778** |
| 128 | Yes | No | OOM |

---

| `torch.backends.cuda.matmul.allow_tf32` | `torch.backends.cudnn.benchmark` | Time/Image (Sec.) |
| :---: | :---: | :---: |
| No | No | 0.0902 |
| Yes | No | 0.0922 |
| No | Yes | 0.0962 |
| Yes | Yes | 0.0960 |

---

| Tokenizer | Time/Image (Sec.) |
| :---: | :---: |
| Default | **0.06187** | 
| `bart-base` | 0.06196 |

---

| Model Varient | Time/Image (Sec.) |
| :---: | :---: |
| `large-ft` | 0.06187 | 
| `base-ft` | **0.03936** |

In [None]:
torch.cuda.empty_cache()

1 / 0.03650

In [None]:
from transformers import BartForConditionalGeneration, BartTokenizer, BartTokenizerFast


# tokenizer = BartTokenizerFast.from_pretrained("microsoft/Florence-2-large-ft")
tok = BartTokenizerFast.from_pretrained("facebook/bart-base")
processor = AutoProcessor.from_pretrained(
    "microsoft/Florence-2-large-ft",
    trust_remote_code=True,
    # tokenizer=tok,
)