In [1]:
from transformers import AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
import requests

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

model_id = "./models-phi-35-vision"
device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map=device,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    _attn_implementation="flash_attention_2",
)
processor = AutoProcessor.from_pretrained(
    model_id, trust_remote_code=True, num_crops=16
)

model = torch.compile(model, mode="max-autotune")

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.20s/it]


In [3]:
def describe_image(image_url, model, processor, max_new_tokens=20, temperature=0.0):
    image = Image.open(requests.get(image_url, stream=True).raw)
    
    placeholder = "<|image_1|>\n"
    prompt_text = "Describe the image in concise, focusing on the main subjects, their actions, and the overall setting. Include information about colors, textures, and any notable objects or elements in the background. eliminate filler words, adverbs, and any unnecessary phrases, focusing solely on the core meaning and essential information."

    messages = [
        {"role": "user", "content": placeholder + prompt_text},
    ]
    prompt = processor.tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    with torch.inference_mode(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
      inputs = processor(prompt, image, return_tensors="pt").to(device, dtype=torch.bfloat16)

      generation_args = {
          "max_new_tokens": max_new_tokens,
          "temperature": temperature,
          "do_sample": False,
      }
      
      generate_ids = model.generate(**inputs,
          eos_token_id=processor.tokenizer.eos_token_id,
          **generation_args
      )
      
      # Decode and return the response
      generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
      response = processor.batch_decode(generate_ids,
          skip_special_tokens=True,
          clean_up_tokenization_spaces=False
      )[0]
      
    return response.strip()

In [4]:
%%time
url = "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg"
description = describe_image(url, model, processor, max_new_tokens=20)
print(description)

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.


The image features a blue background with a geometric pattern of overlapping hexagons. In the
CPU times: user 2.77 s, sys: 233 ms, total: 3 s
Wall time: 1.46 s
