In [None]:
import torch
import torchvision
import torchvision.transforms.v2 as T
import textwrap
import itertools
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
from tqdm import tqdm

torch.set_grad_enabled(False)

In [2]:
train_dataset = torchvision.datasets.ImageFolder("../data/train", T.Compose([T.Resize(16)]))
test_dataset = torchvision.datasets.ImageFolder("../data/test", T.Compose([T.Resize(16)]))

In [None]:
device = "mps"
torch_dtype = torch.float16

model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)

# model = AutoModelForCausalLM.from_pretrained("gokaygokay/Florence-2-SD3-Captioner", torch_dtype=torch_dtype, trust_remote_code=True).to(device).eval()
# processor = AutoProcessor.from_pretrained("gokaygokay/Florence-2-SD3-Captioner", trust_remote_code=True)

In [4]:
def caption_image(image, task="<DETAILED_CAPTION>"):
    task = [task] * len(image)
    inputs = processor(text=task, images=image, return_tensors="pt").to(device, torch_dtype)
    
    generated_ids = model.generate(
      input_ids=inputs["input_ids"],
      pixel_values=inputs["pixel_values"],
      max_new_tokens=1024,
      num_beams=3,
    )
    
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)

    # parsed_answer = processor.post_process_generation(generated_text, task=task, image_size=(image[0].width, image[0].height))
    answers = list(
        map(
            lambda x: processor.post_process_generation(x, task=task[0], image_size=(image[0].width, image[0].height)),
            generated_text
        )
    )
    
    return answers

In [None]:
test_vals = test_dataset.imgs
test_files = list(map(lambda x: x[0], test_vals))
test_classes = list(map(lambda x: x[1], test_vals))
print(test_files, test_classes)

In [None]:
captions = []
for images in itertools.batched(tqdm(test_files), 8):
    images = [Image.open(f) for f in images]
    
    caps = caption_image(images)
    captions.extend(caps)

In [14]:
captions = list(
    map(
        lambda x: x['<DETAILED_CAPTION>'].replace("<pad>", ""),
        captions
    )
)

In [None]:
from IPython.display import display

for cap, img in zip(captions, images):
    print("\n".join(textwrap.wrap(cap)))
    display(img)