In [1]:
import io
import torch
import numpy as np
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
from PIL import Image
from datasets import load_dataset
from datasets import Image as HFImage
from anls_star import anls_score
from fastprogress.fastprogress import progress_bar
from diffusers.models.autoencoders import AutoencoderKL
from torchvision.transforms import ToPILImage, PILToTensor
from walloc import walloc
class Config: pass

In [2]:
def sd3_compress(sample):
    with torch.no_grad():
        img = sample['image'].resize((896,896),resample=Image.Resampling.LANCZOS).convert("RGB")
        x = PILToTensor()(img).to(torch.float)
        x = (x/255 - 0.5).unsqueeze(0).to(device)
        H, W = x.size(2), x.size(3)
        x_padded = walloc.pad(x,p=8)
        
        # X = codec.wavelet_analysis(x_padded,codec.J)
        # Y = codec.encoder(X)
        Y = codec.encode(x).latent_dist.mode().to(torch.float16).to("cpu")
        
        # X_hat = codec.decoder(Y)
        # x_hat = codec.wavelet_synthesis(X_hat,codec.J)
        x_hat = codec.decode(Y.to(device).to(torch.float)).sample

        # x_hat = codec.clamp(x_hat)
        x_hat = x_hat.clamp(-0.5,0.5)
        
        x_hat = walloc.crop(x_hat, (H,W))
        rec = ToPILImage()(x_hat[0] + 0.5)
        buff = io.BytesIO()
        rec.save(buff, format='WEBP', lossless=True)
        rec_webp_bytes = buff.getbuffer()   
    return {
        'image': rec_webp_bytes,
        'question': sample['question'],
        'questionId': sample['questionId'],
    }

In [3]:
ds = load_dataset("danjacobellis/docvqa",split='test')

Resolving data files:   0%|          | 0/57 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/57 [00:00<?, ?it/s]

In [4]:
device = "cuda"
codec = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder='vae')
codec.eval();
codec = codec.to(device)

In [5]:
sd_ds = ds.map(sd3_compress)

In [6]:
model_id = "google/paligemma-3b-ft-docvqa-896"
model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="cuda",
    revision="bfloat16",
)
processor = AutoProcessor.from_pretrained(model_id)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
def compute_score(item):
    prompt = item['question']
    image = item['image']
    model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
    input_len = model_inputs["input_ids"].shape[-1]
    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
        generation = generation[0][input_len:]
        pred = processor.decode(generation, skip_special_tokens=True)
    score, _ = anls_score(item['answer'], pred, return_gt=True)
    return score, pred

In [9]:
scores = []
preds = []
qid = []
for item in progress_bar(sd_ds):
    score, pred = compute_score(item)
    scores.append(score)
    preds.append(pred)
    qid.append(item['questionId'])

13324 MiB
5188 / 25:47

In [11]:
sd_ds = sd_ds.add_column('preds_12x', preds)

In [19]:
sd_ds.push_to_hub("danjacobellis/docvqa_stable_diffusion_3",split='test')

Uploading the dataset shards:   0%|          | 0/7 [00:00<?, ?it/s]

Map:   0%|          | 0/742 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Map:   0%|          | 0/741 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Map:   0%|          | 0/741 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Map:   0%|          | 0/741 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Map:   0%|          | 0/741 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Map:   0%|          | 0/741 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Map:   0%|          | 0/741 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/danjacobellis/docvqa_stable_diffusion_3/commit/a8623b7697c991a08e8bfab9f8b9ec360d403696', commit_message='Upload dataset', commit_description='', oid='a8623b7697c991a08e8bfab9f8b9ec360d403696', pr_url=None, pr_revision=None, pr_num=None)