### Installations

In [1]:
!pip install -q --upgrade typing-extensions
!pip install -q openai==0.28

In [2]:
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import openai
import time
import pandas as pd
import os
import torch
from datasets import load_dataset
import random
from PIL import Image
import requests
from peft import get_peft_model, LoraConfig, TaskType
import numpy as np
import traceback
from openai.error import InvalidRequestError

  from .autonotebook import tqdm as notebook_tqdm


### GPT Model

In [3]:
def _ms_since_epoch():
    return time.perf_counter_ns() // 1000000


def set_openai_parameters(engine, max_tokens):
    # openai API setup and parameters
    openai.api_key = "key"
    parameters = {
        "max_tokens": max_tokens,
        "top_p": 0,  # greedy
        "temperature": 0.5,
        "logprobs": 5, 
        "engine": engine,
    }
    time_of_last_api_call = _ms_since_epoch()

    return parameters, time_of_last_api_call


def wait_between_predictions(time_of_last_api_call, min_ms_between_api_calls):
    if (
        cur_time := _ms_since_epoch()
    ) <= time_of_last_api_call + min_ms_between_api_calls:
        ms_to_sleep = min_ms_between_api_calls - (cur_time - time_of_last_api_call)
        time.sleep(ms_to_sleep / 1000)
    time_of_last_api_call = _ms_since_epoch()


def predict_sample_openai_gpt(
    example,
    prompt,
    min_ms_between_api_calls: int = 5000,
    engine: str = "text-davinci-003",
    max_tokens: int = 100,
):
    parameters, time_of_last_api_call = set_openai_parameters(engine, max_tokens)
    parameters["prompt"] = prompt


    wait_between_predictions(time_of_last_api_call, min_ms_between_api_calls)

    response = openai.Completion.create(**parameters)

    if response is None:
        raise Exception("Response from OpenAI API is None.")

    # build output data
    prediction = dict()
    prediction["input"] = prompt
    prediction["prediction"] = response.choices[0].text.strip().strip(".")  # type:ignore

    # build output metadata
    metadata = example.copy()  # dict()
    metadata["logprobs"] = response.choices[0]["logprobs"]  # type:ignore
    # "finish_reason" is located in a slightly different location in opt
    if "opt" in engine:
        finish_reason = response.choices[0]["logprobs"][  # type:ignore
            "finish_reason"
        ]
    else:
        finish_reason = response.choices[0]["finish_reason"]  # type:ignore
    metadata["finish_reason"] = finish_reason
    if "opt" not in engine:
        # From the OpenAI API documentation it's not clear what "index" is, but let's keep it as well
        metadata["index"] = response.choices[0]["index"]  # type:ignore

    prediction["metadata"] = metadata

    return prediction

def predict_sample_openai_chatgpt(
    prompt,
    img_url,
    min_ms_between_api_calls: int = 10000,
    engine: str = "gpt-4o",
    max_tokens: int = 100,
):
    parameters, time_of_last_api_call = set_openai_parameters(engine, max_tokens)
    parameters["prompt"] = prompt

    wait_time = 10
    time.sleep(wait_time)
    try:
        response = openai.ChatCompletion.create(model=engine, messages=[{"role": "user", "content": [{"type": "text", "text":prompt},{
          "type": "image_url",
          "image_url": {"url": f"data:image/jpeg;base64,{img_url}"
}}]}], temperature=parameters['temperature'], top_p=parameters['top_p'])
    except openai.error.RateLimitError as e:
        wait_time = 10
        print(f"Rate limit reached. Waiting {wait_time} seconds.")
        time.sleep(wait_time)

        response = openai.ChatCompletion.create(model=engine, messages=[{"role": "user", "content": prompt}],
                                                temperature=parameters['temperature'], top_p=parameters['top_p'])

    if response is None:
        raise Exception("Response from OpenAI API is None.")

    # build output data
    prediction = dict()
    prediction["input"] = prompt
    prediction["prediction"] = response.choices[0].message['content']  # type:ignore

    return prediction

def gpt4_estimetion(url):
  # best prompt:
  prompt= f"""
Generate a caption for the provided image. If the image contains any nonsensical or uncommon elements, make sure to highlight them.
  """
  gpt4_prediction = predict_sample_openai_chatgpt(prompt,url)
  return gpt4_prediction['prediction']

### Blip Model

In [None]:
model = InstructBlipForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", torch_dtype=torch.float16)
print("finish from_pretrained model")

processor = InstructBlipProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b")
print("finish from_pretrained processor")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("device:", device)

In [None]:
# Define LoRA configuration
lora_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=2,                              
    lora_alpha=8,                    
    lora_dropout=0.2,                 
    target_modules=["q_proj", "v_proj"]
)

print("LoRA config created")

# Apply LoRA to the model
model = get_peft_model(model, lora_config)
print("LoRA applied to model")

model.to(device)
print(f"model moved to {device}")

In [None]:
def blip_estimetion(url):

  image_data = base64.b64decode(url)
  image = Image.open(BytesIO(image_data)).convert("RGB")
    
  prompt = f"Generate a caption for the provided image. If the image contains any nonsensical or uncommon elements, make sure to highlight them."

  # Process the image and text together
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)

  outputs = model.generate(
          **inputs,
          do_sample=False,
          num_beams=5,
          max_length=150,
          min_length=5,
          top_p=0, #0.9, # the probability of the answer
          repetition_penalty=1.5,
          length_penalty=0.6, # A value greater than 1.0 encourages longer sequences, while a value less than 1.0 encourages shorter sequences.
          temperature=0.5, #1.2,
  )
    
  generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()

  return generated_text

### choose model:

In [None]:
# model_estimetion = gpt4_estimetion
# model_name = 'gpt'
# or!!!
model_estimetion = blip_estimetion
model_name = 'blip'

### prepare WHOOPS! dataset

In [None]:
!pip install -q git-lfs
!git clone https://huggingface.co/spaces/nlphuji/whoops-explorer-analysis
!pip install -q datasets

wmtis = load_dataset("nlphuji/wmtis-identify")['test']

In [None]:
print(len(wmtis))
# Slice the dataset to exclude the last index
wmtis = wmtis.select(range(len(wmtis) - 1))
print(len(wmtis))

### Run captions generation task on strange and normal images

In [None]:
import base64
from io import BytesIO
from openai.error import InvalidRequestError

model_captions ={'normal_caption':[],'strange_caption':[]}
for record in wmtis:
  normal_image = record['normal_image']
  strange_image = record['strange_image']

  # normal image:
  buffered = BytesIO()
  normal_image.save(buffered, format="PNG")
  normal_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
  try:
    normal_caption = model_estimetion(normal_str)
    print(normal_caption)
    model_captions['normal_caption'].append(normal_caption)
  except InvalidRequestError as e:
    model_captions['normal_caption'].append(f'error: {normal_caption}')
    print(f"Failed to get caption: {e}")

  # strange image:
  buffered = BytesIO()
  strange_image.save(buffered, format="PNG")
  strange_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
  try:
    strange_caption = model_estimetion(strange_str)
    print(strange_caption)
    model_captions['strange_caption'].append(strange_caption)
  except InvalidRequestError as e:
    model_captions['strange_caption'].append(f'error: {strange_caption}')
    print(f"Failed to get caption: {e}")


In [None]:
# save the outputs to a csv files
generated_df = pd.DataFrame({
    'strange': model_captions['strange_caption'],
    'normal': model_captions['normal_caption']
})

generated_df.to_csv(f'{model_name}-generated_captions.csv', index=False)  # index=False to avoid writing row numbers