# README

This notebook can be used to generate captions for an image dataset using the following models:
- BLIP2
- PaliGemma-3b
- CogVLM

# Dataset

In [1]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
from src.datasets.hugging_face_dataset import HuggingFaceDataset

dataset_name = "zzsi/afhq64_16k"
dataset = HuggingFaceDataset(dataset_name, 'val')
print(f"len(dataset): {len(dataset)}")

len(dataset): 1500


# BLIP 2

### Demos of BLIP2

In [None]:
import requests
from PIL import Image
import torch
from transformers import Blip2Processor, Blip2ForConditionalGeneration

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")


img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg' 
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
display(raw_image)

processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("Salesforce/blip2-opt-2.7b").to(device)

inputs = processor(raw_image, return_tensors="pt").to(device)

out = model.generate(**inputs)
print(processor.decode(out[0], skip_special_tokens=True).strip())

In [None]:
import matplotlib.pyplot as plt

fig, axs = plt.subplots(5, 1, figsize=(20, 10))

captions = []
for i in range(5):
    inputs = processor(dataset[i][0], return_tensors="pt").to("cuda")
    out = model.generate(**inputs)
    caption = processor.decode(out[0], skip_special_tokens=True).strip()
    captions.append(caption)


for i in range(5):
    axs[i].imshow(dataset[i][0])
    axs[i].axis('off')
    axs[i].set_title(captions[i])

plt.show()

## Generate captions for Datset with BLIP2

In [None]:
from tqdm import tqdm

blip_captions = {}
for i, data in enumerate(tqdm(dataset)):
    raw_image = dataset[i][0]
    inputs = processor(raw_image, return_tensors="pt").to("cuda")

    out = model.generate(**inputs)
    caption = processor.decode(out[0], skip_special_tokens=True).strip()
    blip_captions[i] = caption


In [None]:
# Save the captions
import json

with open("blip_captions.json", "w") as f:
    json.dump(blip_captions, f)

# PaliGemma-3b

In [None]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import torch

model_id = "google/paligemma-3b-mix-224"
device = "cuda:0"
dtype = torch.bfloat16

model = PaliGemmaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=dtype,
    device_map=device,
    revision="bfloat16",
).eval().to(device)
processor = AutoProcessor.from_pretrained(model_id)


In [None]:
image = dataset[0][0]
display(image)
prompt = "caption en <image>"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
input_len = model_inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]
    decoded = processor.decode(generation, skip_special_tokens=True)
    print(decoded)


In [None]:
captions = []
for i in range(5):
    image = dataset[i][0]
    prompt = "<image>caption en"
    model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
    input_len = model_inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
        generation = generation[0][input_len:]
        decoded = processor.decode(generation, skip_special_tokens=True)
        print(decoded)
        captions.append(decoded)

# plot images with captions
import matplotlib.pyplot as plt

fig, axs = plt.subplots(5, 1, figsize=(20, 10))

for i in range(5):
    axs[i].imshow(dataset[i][0])
    axs[i].axis('off')
    axs[i].set_title(captions[i])

In [None]:
from tqdm import tqdm

# save every image with caption to dict and the write to json
captions = {}
for i, data in enumerate(tqdm(dataset)):
    image = dataset[i][0]
    prompt = "<image>caption en"
    model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device)
    input_len = model_inputs["input_ids"].shape[-1]

    with torch.inference_mode():
        generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
        generation = generation[0][input_len:]
        decoded = processor.decode(generation, skip_special_tokens=True)
        captions[i] = decoded

In [None]:
# save to json
import json

with open("pali-gemma-3b-captions.json", "w") as f:
    json.dump(captions, f)
    

# CogVLM

In [12]:
import gc
import torch
import requests
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer

# print gpu usage
print(f"Initial GPU memory usage: {torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB")
gc.collect()
torch.cuda.empty_cache()
print(f"GPU memory usage after emptying the cache: {torch.cuda.memory_allocated() / 1024 ** 3:.2f} GB")

tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
    'THUDM/cogvlm-base-224-hf',
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to('cuda').eval()

image = Image.open(requests.get('https://github.com/THUDM/CogVLM/blob/main/examples/1.png?raw=true', stream=True).raw).convert('RGB')
inputs = model.build_conversation_input_ids(tokenizer, query='', images=[image])
inputs = {
    'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
    'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
    'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
    'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)]],
}
gen_kwargs = {"max_length": 2048, "do_sample": False}

with torch.no_grad():
    outputs = model.generate(**inputs, **gen_kwargs)
    outputs = outputs[:, inputs['input_ids'].shape[1]:]
    print(tokenizer.decode(outputs[0]))

Initial GPU memory usage: 23.28 GB
GPU memory usage after emptying the cache: 0.00 GB


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 86.00 MiB. GPU 0 has a total capacity of 23.68 GiB of which 74.75 MiB is free. Including non-PyTorch memory, this process has 23.53 GiB memory in use. Of the allocated memory 23.28 GiB is allocated by PyTorch, and 1.52 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [2]:
# push to huggingface
import json
from PIL import Image
from src.datasets.hugging_face_dataset import HuggingFaceDataset

if __name__ == "__main__":
    user_name = "reese-green"
    resolution = 64

    dataset_name = "zzsi/afhq64_16k"
    train_dataset = HuggingFaceDataset(dataset_name, 'train')
    val_dataset = HuggingFaceDataset(dataset_name, 'val')

    train_blip = "train_blip_captions.json"
    val_blip = "val_blip_captions.json"
    train_pali = "train_paligemma-3b-captions.json"
    val_pali = "val_paligemma-3b-captions.json"
    train_blip_list = []
    val_blip_list = []
    train_pali_list = []
    val_pali_list = []

    with open(train_blip, "r") as f:
        dict = json.load(f)
        train_blip_list = list(dict.values())
    with open(val_blip, "r") as f:
        dict = json.load(f)
        val_blip_list = list(dict.values())
    with open(train_pali, "r") as f:
        dict = json.load(f)
        train_pali_list = list(dict.values())
    with open(val_pali, "r") as f:
        dict = json.load(f)
        val_pali_list = list(dict.values())

    # take a torch dataset and push to huggingface
    # https://huggingface.co/docs/huggingface_hub/v0.25.0/en/tutorials/push_to_hub
    from datasets import Dataset, DatasetDict
    # convert torch dataset to huggingface dataset
    def gen_train_ds():
        for i, (img, label) in enumerate(train_dataset):
            yield {"image": img, "label": label, f"caption_blip2-opt-2.7b": train_blip_list[i], f"caption_paligemma-3b-mix-224": train_pali_list[i]}
    train_ds_hf = Dataset.from_generator(gen_train_ds)
    # Add both train and val datasets to the repo
    
    def gen_val_ds():
        for i, (img, label) in enumerate(val_dataset):
            yield {"image": img, "label": label, f"caption_blip2-opt-2.7b": val_blip_list[i], f"caption_paligemma-3b-mix-224": val_pali_list[i]}
    val_ds_hf = Dataset.from_generator(gen_val_ds)

    dataset_dict = DatasetDict({"train": train_ds_hf, "val": val_ds_hf})
    dataset_dict.push_to_hub(f"{user_name}/afhq{resolution}_16k", create_pr=False)

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/14630 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/147 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/1500 [00:00<?, ? examples/s]

Creating parquet from Arrow format:   0%|          | 0/15 [00:00<?, ?ba/s]