In [2]:
!pip install -q torch torchvision transformers

In [3]:
# Import necessary libraries
import torch
import torchvision.transforms as T
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

# Set device and confirm version
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Torch version: {torch.__version__}')

Torch version: 2.5.1+cu124


In [12]:
# Define transformation for image
def build_transform(input_size=224):
    transform = T.Compose([
        T.Resize((input_size, input_size)),
        T.ToTensor(),
        T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
    ])
    return transform

# Load and initialize CLIP model and processor
model_id = "openai/clip-vit-base-patch32"
model = CLIPModel.from_pretrained(model_id).to(device)
processor = CLIPProcessor.from_pretrained(model_id)

# Dynamic Preprocess function to segment image, similar to InternVL approach
def dynamic_preprocess(image, image_size=224, use_thumbnail=False):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height
    target_aspect_ratio = (1, 1)
    target_width, target_height = image_size * target_aspect_ratio[0], image_size * target_aspect_ratio[1]
    
    resized_img = image.resize((target_width, target_height))
    processed_images = [resized_img]  # Here, add only the resized image
    if use_thumbnail:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

# Function to load and preprocess the image
def load_image(image_file, input_size=224):
    image = Image.open(image_file).convert('RGB')
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True)
    return images

def analyze_image(image_file):
    images = load_image(image_file)
    prompt_texts = [
        "A graph with trends and patterns",
        "A chart showing analytical data",
        "A graph illustrating financial data",
        "A pie chart with different segments",
        "A bar chart with several columns",
        "A line chart indicating changes over time",

        "A line showing upward and downward movement in a line graph"
    ]
    
    for img in images:
        image_input = processor(images=img, return_tensors="pt").to(device)
        outputs = model(**image_input, **text_inputs)

        logits_per_image = outputs.logits_per_image  # Similarity score
        probs = logits_per_image.softmax(dim=1)  # Probabilities for each text description

        # Print out the text description with the highest probability for each processed image
        best_text_idx = probs.argmax().item()
        print(f"Best match for the image: {prompt_texts[best_text_idx]}")

# Example of usage
analyze_image('/mnt/code/test_l1.png')

Best match for the image: A line showing a line graph
Best match for the image: A line showing a line graph
