In [1]:
# !pip3.11 install huggingface_hub --break-system-packages
# !pip3.11 install git+https://github.com/huggingface/transformers accelerate --break-system-packages
# !pip3.11 install qwen-vl-utils --break-system-packages
# !pip3.11 install --pre torch==2.6.0 torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
# !pip3 install -U flash-attn --no-build-isolation --break-system-packages #- not used for macbooks

In [2]:
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

# Set the device to MPS if available, else fallback to CPU
device = "mps" if torch.backends.mps.is_available() else "cpu"
torch.mps.empty_cache()

# Load model (using available device(s))
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct", 
    torch_dtype="auto", 
    device_map="auto",
    # attn_implementation="flash_attention_2", # not avaialble for macbooks
)

# Move the model explicitly to the selected device
model.to(device)

# Load the processor for both text and vision modalities
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:12<00:00,  6.14s/it]


# image 2 text inference

In [3]:

# Define messages with an image and a text request
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

# Prepare the input text using the processor's chat template
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)

# Process the vision information (images and/or videos)
image_inputs, video_inputs = process_vision_info(messages)

# Prepare the full inputs for the model
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)

# Move the inputs to the selected device
inputs = inputs.to(device)

# Run inference (generate a response)
# generated_ids = model.generate(**inputs, max_new_tokens=128)

# # Trim the generated IDs to exclude the prompt part
# generated_ids_trimmed = [
#     out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
# ]

# Decode the generated tokens into text
# output_text = processor.batch_decode(
#     generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
# )

# print(output_text)



# embedding extractions

In [4]:
def extract_image_embeddings(image_path, question="", min_pixels=350000, max_pixels=500000):
    # Create messages format similar to previous example
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image_path,
                    "max_pixels": max_pixels,
                    "min_pixels": min_pixels,
                },
                {"type": "text", "text": question},
            ],
        }
    ]

    # Process the vision information (images and/or videos)
    image_inputs, video_inputs = process_vision_info(messages)

    # Prepare the inputs
    inputs = processor(
        text=[""],  # Empty text since we only need image embeddings
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    
    # Move inputs to the correct device
    inputs = inputs.to(device)
    pixel_values = inputs["pixel_values"].type(torch.bfloat16)
    
    # Extract the visual embeddings using the visual component
    with torch.no_grad():
        # Access the visual component directly
        vision_model = model.visual
        image_embeds = vision_model(pixel_values, grid_thw=inputs["image_grid_thw"])
    
    return inputs["image_grid_thw"], image_embeds

# Example usage
image_path = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
grid_thw, img_embeddings = extract_image_embeddings(image_path)
print(f"Grid THW: {grid_thw}")
# print(f"Embedding shape: {embeddings.last_hidden_state.shape}")

Grid THW: tensor([[ 1, 40, 60]], device='mps:0')


In [5]:
img_embeddings

tensor([[-4.9688, -2.0312,  1.2578,  ..., -2.0625,  0.9414,  0.1787],
        [ 1.2656, -3.2656,  1.5156,  ..., -2.3906,  1.5469, -0.2188],
        [ 1.2031,  0.0474,  1.0078,  ..., -1.4141,  1.6328,  0.2012],
        ...,
        [-1.4844, -1.2266, -0.5586,  ...,  1.3750, -0.2012, -0.9219],
        [ 0.0425, -0.9922, -1.8984,  ...,  1.3203,  0.1289, -1.8516],
        [ 0.3945, -4.0312,  0.1670,  ...,  0.5117,  1.3672, -1.5391]],
       device='mps:0', dtype=torch.bfloat16)

# text embedding extraction

In [6]:
# Get all the methods from model
def extract_text_embeddings(text):
    # Prepare the input text using the processor
    inputs = processor(
        text=[text],
        padding=True,
        return_tensors="pt"
    )
    
    # Move inputs to the correct device
    inputs = inputs.to(device)
    
    # Extract the text embeddings using the decoder model
    with torch.no_grad():
        # Access the decoder model directly
        decoder_model = model.model
        outputs = decoder_model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        )
        
    return outputs.last_hidden_state

# Example usage
text = """The image depicts a serene beach scene with a person and a dog. The person is sitting on the sandy beach, facing the ocean. They are wearing a plaid shirt and black pants, and they have long hair. The dog, which appears to be a Labrador Retriever, is sitting on the sand and is interacting with the person by placing its paw on their hand. The dog is wearing a harness with a colorful collar. The background shows the ocean with gentle waves and a clear sky, suggesting it might be early morning or late afternoon due to the soft lighting. The overall atmosphere of the image is peaceful and joyful"""

text_embeddings = extract_text_embeddings(text)
print(f"Text embedding shape: {text_embeddings.shape}")

Text embedding shape: torch.Size([1, 125, 2048])


In [7]:
text_embeddings

tensor([[[ 1.4922, -1.7578, -2.3906,  ..., -1.1328,  9.9375, -0.1572],
         [ 4.2188, -0.9023, -2.0938,  ..., -1.0938,  2.9219, -1.9062],
         [-0.1592, -1.4531, -2.2031,  ..., -0.0221,  6.3438, -0.4980],
         ...,
         [ 0.4727, -4.0938,  2.3594,  ...,  1.3516,  7.1875, -3.1406],
         [ 6.3125, -0.9336, -3.6562,  ..., -1.9219,  8.9375, -1.9297],
         [-0.0109, -4.0000,  1.8984,  ..., -0.2676,  4.0938, -2.2344]]],
       device='mps:0', dtype=torch.bfloat16)

# compare text and image embeddings

In [8]:
# torch multiply both embeddings
print(text_embeddings.shape)
print(img_embeddings.shape)

torch.Size([1, 125, 2048])
torch.Size([600, 2048])


In [9]:
import torch.nn.functional as F

# Assuming text_embeddings has shape [1, 125, 2048] and img_embeddings has shape [600, 2048]

# Aggregate text embeddings by mean pooling along the token dimension (dim=1)
pooled_text_embedding = text_embeddings.mean(dim=1)  # Resulting shape: [1, 2048]

# Aggregate image embeddings by mean pooling along the token dimension (dim=0)
pooled_image_embedding = img_embeddings.mean(dim=0, keepdim=True)  # Resulting shape: [1, 2048]

# Normalize both embeddings to unit vectors
norm_text_embedding = F.normalize(pooled_text_embedding, p=2, dim=-1)
norm_image_embedding = F.normalize(pooled_image_embedding, p=2, dim=-1)

# Compute cosine similarity between the text and image embeddings
cosine_similarity = (norm_text_embedding * norm_image_embedding).sum(dim=-1)
print("Cosine similarity:", cosine_similarity.item())


Cosine similarity: 0.07763671875


# image to image comparison

In [10]:
# Calculate img1_embeddings and img2_embeddings from ./sim# Calculate embeddings for both images
img1_path = "./similarity_samples/corsa1.png"
img2_path = "./similarity_samples/corsa2.webp"

# Get embeddings for first image
grid_thw1, img1_embeddings = extract_image_embeddings(img1_path)

# Get embeddings for second image
grid_thw2, img2_embeddings = extract_image_embeddings(img2_path)


In [11]:
import torch.nn.functional as F

pooled_img1_embedding = img1_embeddings.mean(dim=0, keepdim=True)
pooled_img2_embedding = img2_embeddings.mean(dim=0, keepdim=True)

norm_img1_embedding = F.normalize(pooled_img1_embedding, p=2, dim=-1)
norm_img2_embedding = F.normalize(pooled_img2_embedding, p=2, dim=-1)


cosine_similarity = (norm_img1_embedding * norm_img2_embedding).sum(dim=-1)
print("Cosine similarity:", cosine_similarity.item())

Cosine similarity: 0.9140625


# 3d Viz

In [12]:
# !pip3.11 install umap-learn plotly --break-system-packages
# !pip3.11 install nbformat --break-system-packages

In [16]:

import os
import numpy as np
from umap import UMAP
import plotly.graph_objects as go
import torch
import torch.nn.functional as F  # Ensure F is imported

# --- Configuration ---
# Directory containing image files
image_dir = "./similarity_samples"
# Only process files with these extensions
valid_extensions = ('.png', '.jpg', '.jpeg', '.webp')

# --- Get list of image files ---
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(valid_extensions)]
if not image_files:
    raise ValueError(f"No image files found in {image_dir} with extensions {valid_extensions}")

# --- Process each image ---
all_embeddings = []
file_names = []

for image_file in image_files:
    image_path = os.path.join(image_dir, image_file)
    try:
        # extract_image_embeddings is assumed to be defined in your environment
        grid_thw, embeddings = extract_image_embeddings(image_path)
        # Pool the token-level embeddings by taking the mean over the token dimension (no extra dim)
        pooled_embedding = embeddings.mean(dim=0)  # shape: [2048]
        normalized_embedding = F.normalize(pooled_embedding, p=2, dim=-1).cpu().float().numpy()
        all_embeddings.append(normalized_embedding)
        file_names.append(image_file)
        print(f"Processed {image_file}")
    except Exception as e:
        print(f"Error processing {image_file}: {e}")

if len(all_embeddings) == 0:
    raise ValueError("No embeddings were extracted. Check your extract_image_embeddings function and inputs.")


Processed tortoise1.jpeg
Processed opra1.jpg
Processed beyonce1.jpg
Processed llama_faces.webp
Processed sam2.jpeg
Processed elon_suit2.jpg
Processed German-Shepherd-dog-Alsatian.webp
Processed corsa2.webp
Processed opra2.jpg
Processed dude_white_tshirt2.jpeg
Processed corsa1.png
Processed sam_pre_surgery.webp
Processed elon_suit.jpg
Processed dude_w_jacket.jpg
Processed llama.webp
Processed dude_white_tshirt.jpg
Processed black_ferrari.webp
Processed mark_lizzard.jpg
Processed elon_bald.webp
Processed fordka1.jpg
Processed deude_green_tshirt2.jpg
Processed tortoise2.jpeg
Processed dude_green_tshirt.jpg
Processed mark1.jpg


In [17]:
all_embeddings = np.array(all_embeddings)
print("Embeddings array shape:", all_embeddings.shape)

# --- UMAP Projection in 3D ---
umap_3d = UMAP(n_components=3, random_state=42)
embeddings_3d = umap_3d.fit_transform(all_embeddings)
print("UMAP 3D projection shape:", embeddings_3d.shape)

# --- Create a 3D Scatter Plot using Plotly ---
fig = go.Figure(data=[go.Scatter3d(
    x=embeddings_3d[:, 0],
    y=embeddings_3d[:, 1],
    z=embeddings_3d[:, 2],
    mode='markers+text',
    text=file_names,
    hovertext=file_names,
    marker=dict(
        size=10,
        color=np.arange(len(file_names)),
        colorscale='Viridis',
        opacity=0.8
    )
)])

fig.update_layout(
    title='3D UMAP Projection of Image Embeddings',
    scene=dict(
        xaxis_title='UMAP Dimension 1',
        yaxis_title='UMAP Dimension 2',
        zaxis_title='UMAP Dimension 3'
    ),
    width=900,
    height=900,
    showlegend=False
)

# Show the plot (use renderer="browser" if needed)
fig.show()

Embeddings array shape: (24, 2048)
UMAP 3D projection shape: (24, 3)



n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.

