In [3]:
!pip install datasets transformers

Collecting datasets
  Downloading datasets-2.14.5-py3-none-any.whl (519 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.6/519.6 kB[0m [31m7.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers
  Downloading transformers-4.33.3-py3-none-any.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m85.0 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m14.3 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━

In [4]:
import requests, random

import pandas as pd
import numpy as np

import torch

from transformers import Pix2StructForConditionalGeneration, AutoProcessor
from datasets import load_dataset
from tqdm.auto import tqdm
from PIL import Image

torch.manual_seed(420)
torch.random.manual_seed(420)
random.seed(420)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
from google.colab import drive
drive.mount("/content/gdrive")

Mounted at /content/gdrive


# Basic Example with a Single Input

### Load the Model and Input Processor

In [6]:
processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base", output_hidden_states=True).to(device)

Downloading (…)rocessor_config.json:   0%|          | 0.00/217 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/2.58k [00:00<?, ?B/s]

Downloading spiece.model:   0%|          | 0.00/851k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/3.27M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/5.01k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.13G [00:00<?, ?B/s]

### Load the Input Image and Question

In [7]:
image = Image.open(requests.get("https://www.ilankelman.org/stopsigns/australia.jpg", stream=True).raw)
question = "What color is the car?"

### Process the Input Image and Question

In [8]:
inputs = processor(text=question, images=image, return_tensors="pt", add_special_tokens=False).to(device)

### Generate an Answer

In [9]:
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)

What color is the car?


### Get Hidden States

In [10]:
with torch.no_grad():
    outputs = model(**inputs, output_attentions=True)

outputs.keys()

odict_keys(['logits', 'decoder_attentions', 'cross_attentions', 'encoder_last_hidden_state', 'encoder_attentions'])

# Running the Model on Multiple Images

### Load the Data

In [11]:
dataset = load_dataset("textvqa")

Downloading builder script:   0%|          | 0.00/5.02k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/13.3k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/21.6M [00:00<?, ?B/s]

Downloading data: 0.00B [00:00, ?B/s]

Downloading data: 0.00B [00:00, ?B/s]

Downloading data:   0%|          | 0.00/7.07G [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/970M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/5 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/34602 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5734 [00:00<?, ? examples/s]

### Example VQA With Image-Question Pair from Dataset

In [12]:
row_0 = dataset["test"][6]
row_0

{'image_id': 'de39ea9f2ac0f665',
 'question_id': 39608,
 'question': 'what website is this?',
 'question_tokens': ['what', 'website', 'is', 'this'],
 'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1024x819>,
 'image_width': 1024,
 'image_height': 819,
 'flickr_original_url': 'https://farm6.staticflickr.com/3052/2549481146_29850e11e8_o.jpg',
 'flickr_300k_url': 'https://farm6.staticflickr.com/3052/2549481146_29850e11e8_o.jpg',
 'answers': ['', '', '', '', '', '', '', '', '', ''],
 'image_classes': ['Computer monitor', 'Mobile phone'],
 'set_name': 'test'}

In [13]:
inputs = processor(
    text=f"Question: '{row_0['question']}'; Answer:",
    images=row_0["image"],
    return_tensors="pt",
    add_special_tokens=False
).to(device)

processor.batch_decode(model.generate(**inputs, max_new_tokens=256), skip_special_tokens=True)[0]

"Question: 'what website is this?'; Answer: 'yahoo! news: top stories'"

## Generate Results for 320 Random Image-Question Pairs

### Select Random Sample from the Test Set

In [14]:
sample_data = dataset["test"][torch.randint(0, len(dataset["test"]), (320,))]

IMAGES = sample_data["image"]
QUESTIONS = [f"Question: '{q}'; Answer:" for q in sample_data["question"]]
IMAGE_CLASSES = sample_data["image_classes"]

### Generate Ansewrs for Sample

In [15]:
sample_generated_answers = []

b_size = 8
for i in tqdm(range(0, 320, b_size)):
  sample_inputs = processor(
      text=QUESTIONS[i:i+b_size],
      images=IMAGES[i:i+b_size],
      return_tensors="pt",
      add_special_tokens=False,
      padding=True
  ).to(device)

  sample_generated_outputs = model.generate(**sample_inputs, max_new_tokens=256)
  sample_generated_answers += processor.batch_decode(sample_generated_outputs, skip_special_tokens=True)

  0%|          | 0/40 [00:00<?, ?it/s]

In [16]:
clean_answers = [a[a.index("'; Answer:") + 10:].strip() for a in sample_generated_answers]

In [31]:
hidden_states = []

b_size = 1
for i in tqdm(range(0, 320, b_size)):
  sample_inputs = processor(
      text=QUESTIONS[i:i+b_size],
      images=IMAGES[i:i+b_size],
      return_tensors="pt",
      add_special_tokens=False,
      padding=True
  ).to(device)

  with torch.no_grad():
    hidden_states.append(model(**sample_inputs)["encoder_last_hidden_state"])

  0%|          | 0/320 [00:00<?, ?it/s]

### Save the Data

In [None]:
# im_list = [np.asarray(im).tolist() for im in IMAGES]

In [35]:
pd.DataFrame({
    # "images": im_list,
    "questions": QUESTIONS,
    "answers": clean_answers,
    "image_classes": IMAGE_CLASSES,
    "hidden_states": torch.stack(hidden_states).squeeze(1).norm(dim=1).tolist()
}).to_pickle("./sample_data.pkl")
