# 1. PACKAGE INSTALLATION

In [None]:
# Install required packages
!pip install transformers 
!pip install einops 
!pip install torchvision 
!pip install torch
!pip install pillow 
!pip install accelerate 
!pip install ipywidgets

# 2. IMPORTS AND CONFIGURATIONS

In [None]:
# Import necessary libraries
import os
import json
import torch
from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
from PIL import Image
from collections import defaultdict

# Define paths and configurations
ROOT_FOLDER = 'images'
OUTPUT_FILE = 'image_descriptions.json'
SUPPORTED_FORMATS = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.svg')
IGNORE_PATTERNS = ('.ipynb_checkpoints', '-checkpoint') # Add patterns to ignore

# Create output file if it doesn't exist
if not os.path.exists(OUTPUT_FILE):
    with open(OUTPUT_FILE, 'w') as f:
        json.dump({}, f)
    print(f"Created empty {OUTPUT_FILE}")

# 3. MODEL INITIALIZATION

In [None]:
# Load the processor
processor = AutoProcessor.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)

# Load the model
model = AutoModelForCausalLM.from_pretrained(
    'allenai/Molmo-7B-D-0924',
    trust_remote_code=True,
    torch_dtype='auto',
    device_map='auto'
)

# 4. HELPER FUNCTIONS

In [None]:
def nested_dict():
    """Create a nested defaultdict for hierarchical storage."""
    return defaultdict(nested_dict)

def convert_defaultdict_to_dict(d):
    """Convert defaultdict to regular dict for JSON serialization."""
    if isinstance(d, defaultdict):
        d = {k: convert_defaultdict_to_dict(v) for k, v in d.items()}
    return d

def process_image(image_path):
    """Process a single image and return its description."""
    prompt = """Provide a comprehensive and precise description of this image that could be used for future retrieval. Structure your response in the following format:

    1. Image Type and Category:
        - Identify the primary type (diagram, chart, seal, form, table, map, etc.)
        - Note any specific subcategories or variations

    2. Identifier Information:
        - Document numbers, references, or codes visible
        - Any dates or version information shown
    - Page numbers or section markers

    3. Content Description:
        - Main subject matter or topic
        - Key terms and specific language used
        - Numbers, quantities, or measurements shown
        - Any proper nouns or specific terminology

    4. Visual Structure:
        - Overall layout and organization
        - Hierarchical relationships if present
        - Connections between elements (arrows, lines, groupings)
        - Color scheme and visual emphasis points

    5. Distinctive Features:
        - Unique or notable elements
        - Special symbols or markings
        - Unusual formatting or arrangements
        - Key differentiating characteristics

    Please write your description in clear, searchable language, including specific terms and identifiers that would be useful for finding this image later. Focus on accuracy and completeness rather than interpretation."""

    inputs = processor.process(
        images=[Image.open(image_path)],
        text=prompt
    )
    inputs = {k: v.to(model.device).unsqueeze(0) for k, v in inputs.items()}
    
    with torch.autocast(device_type="cpu", enabled=True, dtype=torch.bfloat16):
            output = model.generate_from_batch(
                inputs,
                GenerationConfig(max_new_tokens=2000, stop_strings="<|endoftext|>"),
                tokenizer=processor.tokenizer
            )
    
    generated_tokens = output[0,inputs['input_ids'].size(1):]
    return processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)

# 5. MAIN PROCESSING LOGIC

In [None]:
def main():
    # Initialize results dictionary
    results = nested_dict()
    
    # Load existing descriptions if any
    try:
        with open(OUTPUT_FILE, 'r') as f:
            existing_results = json.load(f)
            # Convert existing results to nested defaultdict
            for key, value in existing_results.items():
                if isinstance(value, dict):
                    results[key].update(value)
                else:
                    results[key] = value
        print(f"Loaded existing results from {OUTPUT_FILE}")
    except json.JSONDecodeError:
        print(f"Starting with empty results as {OUTPUT_FILE} is empty or invalid")

    # Keep track of all possible image paths
    all_image_paths = set()
    processed_images = set()

    # First pass: collect all image paths and already processed images
    for dirpath, dirnames, filenames in os.walk(ROOT_FOLDER):
        # Remove checkpoint directories
        dirnames[:] = [d for d in dirnames if not any(pattern in d for pattern in IGNORE_PATTERNS)]
        
        # Filter for valid image files
        image_files = [
            f for f in filenames 
            if f.lower().endswith(SUPPORTED_FORMATS) 
            and not any(pattern in f for pattern in IGNORE_PATTERNS)
        ]

        for filename in image_files:
            # Get relative path from root folder
            rel_path = os.path.relpath(dirpath, ROOT_FOLDER)
            
            # Store full path for processing
            full_path = os.path.join(dirpath, filename)
            all_image_paths.add(full_path)

            # Check if image is already in results
            current_dict = results
            if rel_path != '.':
                try:
                    for path_part in rel_path.split(os.sep):
                        current_dict = current_dict[path_part]
                    if filename in current_dict:
                        processed_images.add(full_path)
                except (KeyError, TypeError):
                    continue

    # Calculate images that need processing
    images_to_process = all_image_paths - processed_images
    
    # Print summary
    print(f"\nProcessing Summary:")
    print(f"Total images found: {len(all_image_paths)}")
    print(f"Already processed: {len(processed_images)}")
    print(f"Remaining to process: {len(images_to_process)}")
    
    # If no new images to process, exit
    if not images_to_process:
        print("\nNo new images to process. Exiting...")
        return

    # Ask for confirmation before proceeding
    proceed = input(f"\nProceed with processing {len(images_to_process)} images? (y/n): ")
    if proceed.lower() != 'y':
        print("Processing cancelled by user.")
        return

    # Second pass: process only new images
    count = 0
    total = len(images_to_process)
    
    for image_path in sorted(images_to_process):  # Sort for consistent ordering
        count += 1
        rel_path = os.path.relpath(os.path.dirname(image_path), ROOT_FOLDER)
        filename = os.path.basename(image_path)
        
        print(f"\nProcessing image {count}/{total}: {image_path}")
        
        # Navigate to correct position in results dictionary
        current_dict = results
        if rel_path != '.':
            for path_part in rel_path.split(os.sep):
                current_dict = current_dict[path_part]
        
        try:
            current_dict[filename] = process_image(image_path)
            print(f"✓ Successfully processed: {image_path}")
            
            # Save after each successful processing
            with open(OUTPUT_FILE, 'w') as f:
                json.dump(convert_defaultdict_to_dict(results), f, indent=4)
            print(f"✓ Progress saved to {OUTPUT_FILE}")
            
        except Exception as e:
            print(f"✕ Error processing {image_path}: {str(e)}")
            continue

    print(f"\nProcessing complete!")
    print(f"Total images processed in this run: {count}")
    print(f"Results saved to: {OUTPUT_FILE}")

# 6. EXECUTION

In [None]:
if __name__ == "__main__":
    main()