In [3]:
import os
import csv
import zipfile
from PIL import Image
from tqdm import tqdm
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration


# Initialize model
try:
    MODEL_NAME = "fancyfeast/llama-joycaption-alpha-two-hf-llava"
    processor = AutoProcessor.from_pretrained(MODEL_NAME)
    llava_model = LlavaForConditionalGeneration.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
except Exception as e:
    raise RuntimeError(f"Model loading failed: {str(e)}")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [1]:
!pip install transformers accelerate



In [4]:
import torch
print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU Memory Cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

GPU Memory Allocated: 15.80 GB
GPU Memory Cached: 15.86 GB


In [5]:
import gc
# ======================
# 2. CONFIGURATION
# ======================
DRIVE_DATA_ROOT = "monuments_6"
OUTPUT_CSV = "monuments_captions.csv"  # Colab temporary storage
ZIP_PATH = "monuments_captions.zip"  # Final Drive location

PROMPT = """Analyze this monument image. Respond ONLY in this exact pipe-separated format (don't be too specific):
Monument Type | Architecture Type | Material | Texture | Construction Period(century) | Key Features | Lighting
Example: Basilica | Baroque | Marble | Smooth | 17th century | Ornate facade, central arch, tower | Daylight, soft shadows"""

# ======================
# 3. IMAGE DISCOVERY
# ======================
def find_images(root_dir):
    """Recursively find all image files in directory"""
    image_exts = ('.png', '.jpg', '.jpeg', '.webp','.JPG')
    image_paths = []

    for dirpath, _, filenames in os.walk(root_dir):
        for f in filenames:
            if f.lower().endswith(image_exts):
                image_paths.append(os.path.join(dirpath, f))

    print(f"Found {len(image_paths)} images in {root_dir}")
    return image_paths

image_paths = find_images(DRIVE_DATA_ROOT)

# ======================
# 4. BATCH CAPTION GENERATION (FIXED)
# ======================
def generate_captions_batch(image_paths, batch_size=2):
    """Process multiple images in a batch with proper autocast handling"""
    try:
        # Load all images in batch
        images = [Image.open(path) for path in image_paths]

        # Prepare conversations
        convos = [
            [{"role": "system", "content": "You are a precise architectural image captioner."},
             {"role": "user", "content": PROMPT}]
            for _ in range(len(images))
        ]
        convo_strings = [processor.apply_chat_template(c, tokenize=False) for c in convos]

        # Process batch with fixed autocast
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):  # Explicit device_type
            inputs = processor(
                text=convo_strings,
                images=images,
                return_tensors="pt",
                padding=True
            ).to('cuda')

            generate_ids = llava_model.generate(
                **inputs,
                max_new_tokens=100,
                do_sample=False,
                top_p=0.9,
                num_beams=1,
                early_stopping=True
            )

        # Decode captions
        captions = []
        for ids in generate_ids:
            caption = processor.tokenizer.decode(
                ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            )
            last_line = next(
                (line.strip() for line in reversed(caption.split("[/INST]")[-1].splitlines()) if line.strip()),
                "CAPTION_ERROR"
            )
            captions.append(last_line)

        return captions

    except Exception as e:
        print(f"Batch failed: {str(e)}")
        return ["CAPTION_ERROR"] * len(image_paths)
# ======================
# 5. MAIN PROCESSING (WITH WORKERS)
# ======================
from torch.utils.data import Dataset, DataLoader
import gc

class MonumentDataset(Dataset):
    def __init__(self, image_paths):
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        try:
            image = Image.open(self.image_paths[idx])
            return image, self.image_paths[idx]
        except Exception as e:
            print(f"Error loading {self.image_paths[idx]}: {str(e)}")
            return None, None

# Initialize DataLoader with 2 workers
dataset = MonumentDataset(image_paths)
dataloader = DataLoader(
    dataset,
    batch_size=64,  # Match your GPU capacity  # Critical change
    pin_memory=True,
    collate_fn=lambda x: [item for item in x if item[0] is not None]  # Filter failed loads
)

# Initialize CSV
with open(OUTPUT_CSV, 'w', newline='', encoding='utf-8') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["image_path", "caption"])

    for batch_images_paths in tqdm(dataloader, desc="Generating captions"):
        batch_images, batch_paths = zip(*batch_images_paths)

        try:
            captions = generate_captions_batch(batch_paths, len(batch_paths))

            for path, caption in zip(batch_paths, captions):
                rel_path = os.path.relpath(path, DRIVE_DATA_ROOT)
                writer.writerow([rel_path, caption])

            if len(batch_paths) % 20 == 0:
                csvfile.flush()
                torch.cuda.empty_cache()
                gc.collect()

        except Exception as e:
            print(f"Batch failed: {str(e)}")
            for path in batch_paths:
                rel_path = os.path.relpath(path, DRIVE_DATA_ROOT)
                writer.writerow([rel_path, "CAPTION_ERROR"])
# ======================
# 6. COMPRESS & STORE
# ======================
print("\nCompressing results...")
with zipfile.ZipFile(ZIP_PATH, 'w', zipfile.ZIP_DEFLATED) as zipf:
    zipf.write(OUTPUT_CSV, arcname="monuments_captions.csv")

# Verify
print(f"\n✅ Done! Results saved to:")
print(f"- Temporary CSV: {OUTPUT_CSV}")
print(f"- Drive ZIP: {ZIP_PATH}")
print(f"Total captions generated: {len(image_paths)}")


Found 6442 images in monuments_6


Generating captions: 100%|██████████| 101/101 [28:46<00:00, 17.10s/it]


Compressing results...

✅ Done! Results saved to:
- Temporary CSV: monuments_captions.csv
- Drive ZIP: monuments_captions.zip
Total captions generated: 6442



