In [3]:
import torch

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version (PyTorch compiled):", torch.version.cuda)
if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))
    print("GPU capability:", torch.cuda.get_device_capability(0))


PyTorch version: 2.6.0+cu124
CUDA available: True
CUDA version (PyTorch compiled): 12.4
GPU name: NVIDIA GeForce RTX 4060 Ti
GPU capability: (8, 9)


In [2]:
import torch
from transformers import (
    AutoProcessor,
    AutoModelForCausalLM,
    BlipProcessor,
    BlipForConditionalGeneration,
    AutoConfig
)
from datasets import load_dataset
import evaluate
import nltk
nltk.download("punkt")

# ==== 環境設定 ====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==== 模型載入 ====
print("Loading BLIP...")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

print("Loading Phi-4...")
# 手動禁用 flash_attention_2
phi_config = AutoConfig.from_pretrained("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True)
phi_config.attn_implementation = "eager"

phi_processor = AutoProcessor.from_pretrained("microsoft/Phi-4-multimodal-instruct", trust_remote_code=True)
phi_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Phi-4-multimodal-instruct",
    config=phi_config,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto",
    trust_remote_code=True,
    low_cpu_mem_usage=True
)

# ==== 評估指標 ====
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")
meteor = evaluate.load("meteor")

# ==== 載入資料集 ====
def load_data(dataset_name):
    print(f"Loading dataset: {dataset_name}")
    if dataset_name == "mscoco":
        dataset = load_dataset("nlphuji/mscoco_2014_5k_test_image_text_retrieval", split="test[:20]")
    elif dataset_name == "flickr30k":
        dataset = load_dataset("nlphuji/flickr30k", split="test[:20]")
    else:
        raise ValueError("Unsupported dataset")
    return dataset

# ==== Caption Generation ====
def caption_blip(image):
    inputs = blip_processor(images=image, return_tensors="pt").to(device)
    out = blip_model.generate(**inputs)
    return blip_processor.decode(out[0], skip_special_tokens=True)

def caption_phi4(image):
    prompt = "Describe this image."
    inputs = phi_processor(prompt, images=image, return_tensors="pt").to(device, torch.float16 if torch.cuda.is_available() else torch.float32)
    out = phi_model.generate(**inputs, max_new_tokens=50)
    caption = phi_processor.decode(out[0], skip_special_tokens=True).replace(prompt, "").strip()
    return caption

# ==== 評估 ====
def evaluate_results(preds, refs):
    refs_nested = [[r] for r in refs]  # BLEU 用
    return {
        "BLEU": bleu.compute(predictions=preds, references=refs_nested)["bleu"],
        "ROUGE-1": rouge.compute(predictions=preds, references=refs)["rouge1"],
        "ROUGE-2": rouge.compute(predictions=preds, references=refs)["rouge2"],
        "METEOR": meteor.compute(predictions=preds, references=refs)["meteor"]
    }

# ==== 主程式 ====
def run_captioning(dataset_name):
    dataset = load_data(dataset_name)
    refs = []
    blip_preds, phi_preds = [], []

    for item in dataset:
        image = item["image"]
        caption = item["caption"]
        refs.append(caption)

        try:
            blip_caption = caption_blip(image)
            phi_caption = caption_phi4(image)
        except Exception as e:
            blip_caption, phi_caption = "", ""
            print("Error on image:", e)

        blip_preds.append(blip_caption)
        phi_preds.append(phi_caption)

    print(f"\n== Evaluation for {dataset_name.upper()} ==")
    print("-- BLIP --")
    print(evaluate_results(blip_preds, refs))
    print("-- Phi-4 --")
    print(evaluate_results(phi_preds, refs))

# ==== 執行兩個資料集 ====
run_captioning("mscoco")
run_captioning("flickr30k")


[nltk_data] Downloading package punkt to /home/yoyo/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Loading BLIP...
Loading Phi-4...


ImportError: FlashAttention2 has been toggled on, but it cannot be used due to the following error: the package flash_attn seems to be not installed. Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2.