In [2]:
## Initialize device for PyTorch computations
import torch

RICO_PATH = "/Users/matheus/Downloads/mud-dataset"

device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    print ("MPS device not found, using CPU instead.")

In [None]:
from transformers import BlipProcessor, BlipForConditionalGeneration, Blip2Processor, Blip2ForConditionalGeneration
from PIL import Image

# --- For BLIP (Image Captioning) ---
model_name_blip = "Salesforce/blip-image-captioning-base"
processor_blip = BlipProcessor.from_pretrained(model_name_blip)
model_blip = BlipForConditionalGeneration.from_pretrained(model_name_blip).to(device)

# --- For BLIP-2 (Visual Question Answering or Captioning) ---
model_name_blip2 = "Salesforce/blip2-opt-2.7b"
processor_blip2 = Blip2Processor.from_pretrained(model_name_blip2)
model_blip2 = Blip2ForConditionalGeneration.from_pretrained(model_name_blip2, torch_dtype=torch.float16).to(device)


  from .autonotebook import tqdm as notebook_tqdm
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
`torch_dtype` is deprecated! Use `dtype` instead!
Loading checkpoint shards:  50%|█████     | 1/2 [00:12<00:12, 12.71s/it]

In [None]:
# Load an image
raw_image = Image.open(f"{RICO_PATH}/6703.png")

# Captioning using BLIP
inputs_blip = processor_blip(raw_image, return_tensors="pt").to(device)
out_blip = model_blip.generate(**inputs_blip)
print("BLIP Caption:", processor_blip.decode(out_blip[0], skip_special_tokens=True).strip())

# VQA using BLIP-2
question = "Question: What is this app about? Answer:"
inputs_blip2 = processor_blip2(raw_image, text=question, return_tensors="pt").to(device, torch.float16)
out_blip2 = model_blip2.generate(**inputs_blip2)
print("BLIP-2 Answer:", processor_blip2.decode(out_blip2[0], skip_special_tokens=True).strip())

question = "Question: What is the content displayed on the app? Answer:"
inputs_blip2 = processor_blip2(raw_image, text=question, return_tensors="pt").to(device, torch.float16)
out_blip2 = model_blip2.generate(**inputs_blip2)

print("BLIP-2 Answer:", processor_blip2.decode(out_blip2[0], skip_special_tokens=True).strip())

inputs_blip2_cap = processor_blip2(raw_image, return_tensors="pt").to(device, torch.float16)
out_blip2_cap = model_blip2.generate(**inputs_blip2_cap)

print("BLIP-2 Caption:", processor_blip2.decode(out_blip2_cap[0], skip_special_tokens=True).strip())

BLIP Caption: the app for the app is displayed on the screen
BLIP-2 Answer: Question: What is this app about? Answer: It is a video converter app for android
BLIP-2 Answer: Question: What is the content displayed on the app? Answer:
BLIP-2 Caption: vtrm converter for android screenshot


In [6]:
from PIL import Image
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import torch

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
    "Salesforce/blip2-opt-2.7b", device_map={"": 0}, dtype=torch.float16
)  # doctest: +IGNORE_RESULT

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

Loading checkpoint shards: 100%|██████████| 2/2 [00:50<00:00, 25.00s/it]


In [7]:
inputs = processor(images=image, return_tensors="pt").to(device, torch.float16)

generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

two cats laying on a couch


In [None]:
prompt = "Question: how many cats are there? Answer:"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, dtype=torch.float16)

generated_ids = model.generate(**inputs)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

Question: how many cats are there? Answer: two
