In [1]:
import os
import torch
import pandas as pd
from PIL import Image
import numpy as np
from transformers import AutoProcessor, AutoModelForCausalLM
from concurrent.futures import ThreadPoolExecutor

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f'Using device: {device}')

# Possible paths for train, val, and test directories
possible_paths = {
    "train": [
        r"../../data/enel645_2024f/garbage_data/CVPR_2024_dataset_Train",
        r"/work/TALC/enel645_2024f/garbage_data/CVPR_2024_dataset_Train"
    ],
    "val": [
        r"../../data/enel645_2024f/garbage_data/CVPR_2024_dataset_Val",
        r"/work/TALC/enel645_2024f/garbage_data/CVPR_2024_dataset_Val"
    ],
    "test": [
        r"../../data/enel645_2024f/garbage_data/CVPR_2024_dataset_Test",
        r"/work/TALC/enel645_2024f/garbage_data/CVPR_2024_dataset_Test"
    ]
}

# Function to automatically detect and return the correct directory path
def get_data_directory(data_type):
    for path in possible_paths[data_type]:
        if os.path.exists(path):
            return path
    raise FileNotFoundError(f"None of the paths for {data_type} directory exist!")

# Get the correct paths
train_dir = get_data_directory("train")
val_dir = get_data_directory("val")
test_dir = get_data_directory("test")

# List all images in a directory
def list_images_in_dir(directory, valid_extensions=(".png", ".jpg", ".jpeg")):
    image_paths = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(valid_extensions):
                image_paths.append(os.path.join(root, file))
    return image_paths

# Load the model and processor
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)

# Function to get captions for a batch of images
def get_batch_captions(image_paths):
    images = []
    for image_path in image_paths:
        image = Image.open(image_path).convert("RGB")
        image_np = np.array(image)
        images.append(image_np)

    prompt = "<CAPTION>"
    inputs = processor(text=[prompt]*len(images), images=images, return_tensors="pt", padding=True).to(device, torch_dtype)
    
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        num_beams=3,
        do_sample=False
    )
    generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
    
    return generated_texts


Using device: mps


Florence2LanguageForConditionalGeneration has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


In [2]:

# Function to process images from a directory and save results in chunks
def process_and_save_in_batches(directory, output_file, batch_size=4):
    image_paths = list_images_in_dir(directory)
    
    # Initialize the CSV if it doesn't exist
    if not os.path.exists(output_file):
        df = pd.DataFrame(columns=["image", "description"])
        df.to_csv(output_file, index=False)
    
    # Loop through images in batches
    for i in range(0, len(image_paths), batch_size):
        batch_paths = image_paths[i:i+batch_size]
        try:
            captions = get_batch_captions(batch_paths)
            print(f"Processed batch {i // batch_size + 1}/{(len(image_paths) + batch_size - 1) // batch_size}")
            
            # Prepare new entries to be saved
            new_entries = [{"image": path, "description": caption} for path, caption in zip(batch_paths, captions)]
            
            # Save batch to CSV
            new_df = pd.DataFrame(new_entries)
            new_df.to_csv(output_file, mode='a', header=False, index=False)
            
        except Exception as e:
            print(f"Error processing batch starting with {batch_paths[0]}: {e}")

# Define output files
train_output = 'train_image_descriptions.csv'
val_output = 'val_image_descriptions.csv'
test_output = 'test_image_descriptions.csv'

# Process each directory and save the results in different CSV files
process_and_save_in_batches(train_dir, train_output)
process_and_save_in_batches(val_dir, val_output)
process_and_save_in_batches(test_dir, test_output)


Error processing batch starting with ../../data/enel645_2024f/garbage_data/CVPR_2024_dataset_Train/Green/popsicle_1750.png: MPS backend out of memory (MPS allocated: 26.07 GB, other allocations: 1.13 GB, max allowed: 27.20 GB). Tried to allocate 162.00 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).


KeyboardInterrupt: 

In [9]:
import re
import string
import spacy

# Initialize necessary components
nlp = spacy.load("en_core_web_sm")

# List of irrelevant phrases
irrelevant_phrases = [
    "on the table", 
    "sitting on", 
    "in the background", 
    "on the floor", 
    "next to", 
    "on top of",
    "on top of a table",
    "top of a table"
]

def clean_generated_text(text):
    # Step 1: Remove special tokens
    text = text.replace('<PAD>', '').replace('<SEP>', '').replace('<CLS>', '').strip()
    
    # Step 2: Convert to lowercase
    text = text.lower()
    
    # Step 3: Remove punctuation
    text = text.translate(str.maketrans('', '', string.punctuation))
    
    # Step 4: Remove numbers
    text = re.sub(r'\d+', '', text)
    
    # Step 5: Remove extra spaces
    text = " ".join(text.split())
    
    # Step 6: Remove irrelevant phrases
    for phrase in irrelevant_phrases:
        if phrase in text:
            text = text.replace(phrase, "")
    
    # Step 7: Remove location entities (NER filtering)
    doc = nlp(text)
    print(doc)
    filtered_text = " ".join([token.text for token in doc if token.ent_type_ not in ['LOC', 'GPE']])
    
    # Step 8: Final cleanup of extra spaces after removal of phrases
    filtered_text = " ".join(filtered_text.split())
    
    return filtered_text.strip()

# Example usage
raw_caption = "A can of beer sitting on top of a table."
clean_caption = clean_generated_text(raw_caption)
print(f"Cleaned caption: '{clean_caption}'")


a can of beer  
Cleaned caption: 'a can of beer'


In [13]:
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print(f'Using device: {device}')

# Function to get captions for a batch of images
def get_batch_captions(image_paths):
    images = []
    for image_path in image_paths:
        image = Image.open(image_path).convert("RGB")
        image_np = np.array(image)
        images.append(image_np)

    prompt = "<CAPTION>"
    inputs = processor(text=[prompt]*len(images), images=images, return_tensors="pt", padding=True).to(device, torch_dtype)
    
    generated_ids = model.generate(
        input_ids=inputs["input_ids"],
        pixel_values=inputs["pixel_values"],
        max_new_tokens=1024,
        num_beams=3,
        do_sample=False
    )
    generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
    
    return generated_texts

# Possible paths for train, val, and test directories
possible_paths = {
    "train": [
        r"../../data/enel645_2024f/garbage_data/CVPR_2024_dataset_Train",
        r"/work/TALC/enel645_2024f/garbage_data/CVPR_2024_dataset_Train"
    ],
    "val": [
        r"../../data/enel645_2024f/garbage_data/CVPR_2024_dataset_Val",
        r"/work/TALC/enel645_2024f/garbage_data/CVPR_2024_dataset_Val"
    ],
    "test": [
        r"../../data/enel645_2024f/garbage_data/CVPR_2024_dataset_Test",
        r"/work/TALC/enel645_2024f/garbage_data/CVPR_2024_dataset_Test"
    ]
}

# Function to automatically detect and return the correct directory path
def get_data_directory(data_type):
    for path in possible_paths[data_type]:
        if os.path.exists(path):
            return path
    raise FileNotFoundError(f"None of the paths for {data_type} directory exist!")

# Get the correct paths
train_dir = get_data_directory("train")
val_dir = get_data_directory("val")
test_dir = get_data_directory("test")

# List all images in a directory
def list_images_in_dir(directory, valid_extensions=(".png", ".jpg", ".jpeg")):
    image_paths = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.lower().endswith(valid_extensions):
                image_paths.append(os.path.join(root, file))
    return image_paths

image_paths = list_images_in_dir(train_dir)


image_path = image_paths[:1]
print(image_path)

raw_caption = get_batch_captions(image_path)
print(raw_caption)
clean_caption = clean_generated_text(raw_caption)
print(f"Cleaned Caption: {clean_caption}")

Using device: mps
['../../data/enel645_2024f/garbage_data/CVPR_2024_dataset_Train/Green/popsicle_1750.png']


KeyboardInterrupt: 