In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoConfig, AutoModelForCausalLM
from sklearn.decomposition import PCA
import seaborn as sns

torch.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set seaborn style
sns.set_style("whitegrid")
# Update font sizes globally
plt.rcParams.update({
    'font.size': 15,
    'axes.titlesize': 17,
    'axes.labelsize': 14,
    'xtick.labelsize': 13,
    'ytick.labelsize': 13,
    'legend.fontsize': 13,
    'legend.title_fontsize': 14,
})

# --- configurable inputs ---

##### IMAGE 1 #####
#labels1 = [
#    "A photo of a cube right of a cone",
#    "A photo of a cube left of a cone",
#    "A photo of a cone right of a cube",
#    "A photo of a cube left of a sphere",
#    "A photo of a cylinder right of a cone",]
#correct_label1 = "A photo of a cube right of a cone"
#image_path1 = "/home/bboulbarss/large_dataset/relational/ood_val/cube_right_cone/cube right cone/CLEVR_rel_000025.png"


################################################################################################
## Image 1 in final_gradcam
## Relational
#labels = [
#    "A photo of a cylinder left of a cone",
#    "A photo of a cylinder right of a cone",
#    "A photo of a cone left of a cylinder",
#    "A photo of a cube right a cylinder",
#    "A photo of a sphere right of a cone",]
#correct_label = "A photo of a cylinder left of a cone"
#image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_cone/CLEVR_rel_000020.png"
#
# Two object
#labels = [
#    "A photo of a purple cylinder",
#    "A photo of a green cylinder",
#    "A photo of a purple cone",
#    "A photo of a red sphere",
#    "A photo of a blue cube"
#]
#correct_label = "A photo of a purple cylinder"
#image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_cone/CLEVR_rel_000020.png"

################################################################################################
# Image 2 in final_gradcam
# Relational
#labels = [
#    "A photo of a cylinder left of a sphere",
#    "A photo of a cylinder right of a sphere",
#    "A photo of a sphere left of a cylinder",
#    "A photo of a cube right a cone",
#    "A photo of a sphere right of a cone",]
#correct_label = "A photo of a cylinder left of a sphere"
#image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_sphere/CLEVR_rel_000031.png"

# Two object
#labels = [
#    "A photo of a green sphere",
#    "A photo of a blue sphere",
#    "A photo of a green cylinder",
#    "A photo of a red cone",
#    "A photo of a purple cube",
#]
#correct_label = "A photo of a green sphere"
#image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_sphere/CLEVR_rel_000031.png"

################################################################################################
# Image 3 in final_gradcam
# Relational
#labels = [
#    "A photo of a cone left of a cylinder",
#    "A photo of a cylinder left of a cone",
#    "A photo of a cone right of a cylinder",
#    "A photo of a cone right of a sphere",
#    "A photo of a cube left of a cylinder"
#]
#correct_label = "A photo of a cone left of a cylinder"
#image_path = "/home/bboulbarss/large_dataset/relational/ood_test/cone_left_cylinder/cone left cylinder/CLEVR_rel_000634.png"

# Two object
#labels = [
#    "A photo of a yellow cone",
#    "A photo of a blue cone",
#    "A photo of a yellow cylinder",
#    "A photo of a red sphere",
#    "A photo of a brown cube",
#]
#correct_label = "A photo of a yellow cone"
#image_path = "/home/bboulbarss/large_dataset/relational/ood_test/cone_left_cylinder/cone left cylinder/CLEVR_rel_000634.png"

################################################################################################
# Image 4 in final_gradcam
# Relational
#labels = [
#    "A photo of a cone left of a cylinder",
#    "A photo of a cone right of a cylinder",
#    "A photo of a cylinder right of a cone",
#    "A photo of a cube left of a sphere",
#    "A photo of a cone left of a sphere",
#]
#correct_label = "A photo of a cone right of a cylinder"
#image_path = "/home/bboulbarss/large_dataset/two_object/ood_test/yellow_cylinder_blue_cone/yellow cylinder/CLEVR_yellow_cylinder_blue_cone_015448.png"

# Two object
#labels=[
#    "A photo of a yellow cylinder",
#    "A photo of a yellow cube",
#    "A photo of a blue cylinder",
#    "A photo of a yellow cone",
#    "A photo of a purple cylinder"
#]
#correct_label = "A photo of a yellow cylinder"
#image_path = "/home/bboulbarss/large_dataset/two_object/ood_test/yellow_cylinder_blue_cone/yellow cylinder/CLEVR_yellow_cylinder_blue_cone_015448.png"

################################################################################################
# Image 5 in final_gradcam
# Relational
#labels = [
#    "A photo of a cone left of a cube",
#    "A photo of a cone right of a cube",
#    "A photo of a cube left of a cone",
#    "A photo of a sphere right of a cylinder",
#    "A photo of a cube left of a cylinder",
#]
#correct_label = "A photo of a cone left of a cube"
#image_path = "/home/bboulbarss/large_dataset/relational/ood_test/cone_left_cube/cone left cube/CLEVR_rel_000466.png"

# Two object
#labels=[
#    "A photo of a purple cube",
#    "A photo of a cyan cube",
#    "A photo of a purple cone",
#    "A photo of a gray sphere",
#    "A photo of a brown cylinder"
#]
#correct_label = "A photo of a purple cube"
#image_path = "/home/bboulbarss/large_dataset/relational/ood_test/cone_left_cube/cone left cube/CLEVR_rel_000466.png"

################################################################################################

# Image 5 in final_gradcam
# Relational
#labels = [
#    "A photo of a cylinder left of a cone",
#    "A photo of a cone right of a cylinder",
#    "A photo of a cylinder right of a cone",
#    "A photo of a cone left of a cylinder",
#    "A photo of a sphere left of a cylinder"
#]
#correct_label = "A photo of a cone left of a cylinder"
#image_path = "/home/bboulbarss/large_dataset/relational/ood_test/cone_left_cylinder/cone left cylinder/CLEVR_rel_000498.png"


save_dir = "/home/bboulbarss/gradcam_results/"
os.makedirs(save_dir, exist_ok=True)

# --- Load model and processor ---
config = AutoConfig.from_pretrained("AIDC-AI/Ovis2-8B", trust_remote_code=True)
config.llm_attn_implementation = "eager"
model = AutoModelForCausalLM.from_pretrained(
    "AIDC-AI/Ovis2-8B",
    config=config,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    attn_implementation="eager"
).cuda()
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()

# --- 1. Generate multimodal embeddings for PCA plot ---
embeddings = []
for label in labels:
    query = f"<image>\n{label}"
    images = [Image.open(image_path)]
    prompt, input_ids, pixel_values = model.preprocess_inputs(query, images, max_partition=9)
    attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id).to(model.device)
    input_ids = input_ids.unsqueeze(0).to(model.device)
    attention_mask = attention_mask.unsqueeze(0).to(model.device)
    pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]  # Fix: Correct dtype and device
    
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            labels=None,
            output_hidden_states=True,
            return_dict=True
        )
        last_hidden_state = outputs.hidden_states[-1]
        
        # Use merged attention mask if available, else create one
        if 'attention_mask' in outputs:
            merged_attention_mask = outputs['attention_mask']
        else:
            merged_attention_mask = torch.ones(last_hidden_state.shape[:2], device=last_hidden_state.device)
        
        merged_attention_mask = merged_attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
        
        summed = torch.sum(last_hidden_state * merged_attention_mask, dim=1)
        count = torch.clamp(merged_attention_mask.sum(dim=1), min=1e-9)
        embedding = summed / count
        embedding = embedding / embedding.norm(dim=-1, keepdim=True)
        embeddings.append(embedding.to(dtype=torch.float32).cpu().numpy())
embeddings = np.vstack(embeddings)

# --- Apply PCA ---
pca = PCA(n_components=2, random_state=42)
embeddings_2d = pca.fit_transform(embeddings)

# --- Plot PCA ---
# Define colorblind palette for "Correct" and explicit red for "Others"
colorblind_colors = sns.color_palette("colorblind")
palette = {"original": colorblind_colors[2], "ft": (1, 0, 0)}  # Explicit red for Others

# -- Plot PCA result --
plt.figure(figsize=(12, 10))

# Separate points into "correct" and "others"
correct_indices = [i for i, label in enumerate(labels) if label == correct_label]
other_indices = [i for i, label in enumerate(labels) if label != correct_label]

# Plot "correct" points
if correct_indices:
    plt.scatter(
        embeddings_2d[correct_indices, 0], embeddings_2d[correct_indices, 1],
        c=[palette["original"]],  # Color for correct label
        s=200,  # Larger circle for correct label
        marker='o',  # Circle marker
        label="Correct Label"  # Legend entry
    )

# Plot "others" points
if other_indices:
    plt.scatter(
        embeddings_2d[other_indices, 0], embeddings_2d[other_indices, 1],
        c=[palette["ft"]],  # Pure red for other labels
        s=100,  # Smaller circle for others
        marker='o',  # Circle marker
        label="Wrong Label"  # Legend entry
    )

# Equalize axis scales to avoid distortion
plt.axis('equal')

# Compute offsets only once based on axis limits
x_min, x_max = plt.xlim()
y_min, y_max = plt.ylim()
x_offset = 0.01 * (x_max - x_min)
y_offset = 0.01 * (y_max - y_min)

for i, label in enumerate(labels):
    plt.annotate(
        label,
        (embeddings_2d[i, 0] + x_offset, embeddings_2d[i, 1] + y_offset),
        fontsize=13,
        alpha=0.8,
        ha='left',
        va='bottom',
        wrap=True
    )

# Add legend
plt.legend()


plt.title("PCA Visualization of Ovis Text Embeddings")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(save_dir, "text_pca_ovis_plot.png"), bbox_inches="tight", dpi=500)
plt.close()

# --- 2. Generate output probabilities for bar chart ---
# Shuffle labels to match normal inference
labels = labels.copy()
#np.random.shuffle(labels)
num_labels = len(labels)

# Construct MCQ prompt
mcq_prompt = "Task: Identify the correct label for this image from the following choices:\n" + "\n".join(
    [f"{chr(65+i)}. {labels[i]}" for i in range(num_labels)]
) + "\nAnswer with the letter of the correct choice."
full_query = f"<image>\n{mcq_prompt}"

#print(labels)
#print(full_query)

# Preprocess inputs
images = [Image.open(image_path).convert("RGB")]
prompt, input_ids, pixel_values = model.preprocess_inputs(full_query, images, max_partition=9)
attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id).to(model.device)
input_ids = input_ids.unsqueeze(0).to(model.device)
attention_mask = attention_mask.unsqueeze(0).to(model.device)
if pixel_values is not None:
    pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]

# Debug input_ids
#print(f"Input IDs shape: {input_ids.shape}, device: {input_ids.device}")
#print(f"Input IDs: {input_ids[0].cpu().numpy()}")
#print(f"Attention mask shape: {attention_mask.shape}, device: {attention_mask.device}")

# Try partial decoding to avoid OverflowError
try:
    # Decode only the last few tokens to check the end of the prompt
    last_tokens = input_ids[0, -10:].cpu().numpy()  # Last 10 tokens
    decoded_last = [text_tokenizer.decode([int(token)], skip_special_tokens=False) for token in last_tokens]
    #print(f"Last 10 input tokens: {last_tokens}")
    #print(f"Decoded last 10 tokens: {decoded_last}")
except Exception as e:
    print(f"Error decoding input_ids: {e}")

# Define answer letters
answer_letters = ["A", "B", "C", "D", "E"]
answer_token_ids = []
for letter in answer_letters:
    encoded = text_tokenizer.encode(letter, add_special_tokens=False)
    if len(encoded) == 1:
        answer_token_ids.append(encoded[0])
    else:
        print(f"Warning: Letter '{letter}' encoded to multiple tokens: {encoded}")
        answer_token_ids.append(encoded[0])  # Fallback to first token
#print(f"Answer letters: {answer_letters}")
#print(f"Answer token IDs: {answer_token_ids}")
#print(f"Decoded answer tokens: {[text_tokenizer.decode([id], skip_special_tokens=True) for id in answer_token_ids]}")

# Forward pass to get logits
with torch.no_grad():
    outputs = model(
        input_ids=input_ids,
        pixel_values=pixel_values,
        attention_mask=attention_mask,
        labels=None
    )
    logits = outputs.logits[:, -1, :]  # [1, vocab_size], logits for next token
    #print(f"Logits shape: {logits.shape}")

# Compute probabilities
probs_all = torch.softmax(logits, dim=-1)[0]  # [vocab_size]
probs_all = probs_all.to(dtype=torch.float32)  # Convert to float32 to avoid bfloat16 issues
probs_for_answers = probs_all[answer_token_ids].to(dtype=torch.float32)  # Convert to float32
probs = (probs_for_answers / probs_for_answers.sum()).cpu().numpy()  # Normalize

# Print probabilities
print("Probabilities from logits:")
for letter, prob, label in zip(answer_letters, probs, labels):
    print(f"{letter} ({label}): {prob*100:.2f}%")

# Debug top logits
top_k = 5
top_token_ids = torch.topk(logits[0], k=top_k).indices.cpu().numpy()
top_probs = torch.topk(probs_all, k=top_k).values.cpu().numpy()  # Now works with float32
#print(f"Top {top_k} tokens by logit:")
#for tid, prob in zip(top_token_ids, top_probs):
    #decoded = text_tokenizer.decode([tid], skip_special_tokens=True)
    #print(f"Token ID={tid}, Decoded='{decoded}', Probability={prob*100:.2f}%")

# Plotting
sorted_indices = np.argsort(probs)[::-1]
sorted_labels = [labels[i] for i in sorted_indices]
sorted_probs = [probs[i] * 100 for i in sorted_indices]
plt.figure(figsize=(10, 7))
colors = ['green' if label == correct_label else 'red' for label in sorted_labels]
plt.bar(sorted_labels, sorted_probs, color=colors)
plt.title(f"Ovis Label Probabilities for Image")
plt.xlabel("Answer Choices")
plt.ylabel("Probability (%)", rotation=0, labelpad=40)
plt.xticks(rotation=45, ha="right")
plt.yticks(np.linspace(0, 100, 6))
legend_handles = [plt.Rectangle((0,0),1,1, color='green'), plt.Rectangle((0,0),1,1, color='red')] if correct_label in sorted_labels else [plt.Rectangle((0,0),1,1, color='red')]
legend_labels = ['Correct Label', 'Other Labels'] if correct_label in sorted_labels else ['Other Labels']
plt.legend(handles=legend_handles, labels=legend_labels, loc='upper right')
plt.tight_layout()
os.makedirs(save_dir, exist_ok=True)
plt.savefig(os.path.join(save_dir, "ovis_probabilities_bar_plot.png"), bbox_inches="tight", dpi=500)
plt.close()
print('plots saved!')

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Probabilities from logits:
A (A photo of a cylinder left of a cone): 0.00%
B (A photo of a cone right of a cylinder): 0.00%
C (A photo of a cylinder right of a cone): 98.20%
D (A photo of a cone left of a cylinder): 1.80%
E (A photo of a sphere left of a cylinder): 0.00%
plots saved!


# PCA plot

In [2]:
import os
from PIL import Image
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.decomposition import PCA
import numpy as np

# Define base directory and set directories
base_dir = "/home/bboulbarss/large_dataset/relational"
sets = {
    "train": os.path.join(base_dir, "train"),
    "val": os.path.join(base_dir, "ood_val"),
    "test": os.path.join(base_dir, "ood_test"),
}

# Define markers for each set
markers = {"train": "o", "val": "s", "test": "^"}

# Function to get image paths based on set and directory structure
def get_image_paths(set_name, set_dir):
    image_data = []
    # Get classes, ignoring dot files
    classes = [c for c in os.listdir(set_dir) if not c.startswith('.')]
    if set_name == "train":
        for cls in classes:
            cls_dir = os.path.join(set_dir, cls)
            # Get images, ignoring dot files
            images = [img for img in os.listdir(cls_dir) if not img.startswith('.')]
            for img in images:
                path = os.path.join(cls_dir, img)
                image_data.append((path, cls, set_name))
    else:  # val or test
        for cls in classes:
            cls_dir = os.path.join(set_dir, cls)
            # Get the intermediate directory (assume there's only one, ignoring dot files)
            intermediate_dirs = [d for d in os.listdir(cls_dir) if not d.startswith('.') and os.path.isdir(os.path.join(cls_dir, d))]
            if intermediate_dirs:  # Ensure there's at least one intermediate directory
                intermediate_dir = os.path.join(cls_dir, intermediate_dirs[0])
                # Get images, ignoring dot files
                images = [img for img in os.listdir(intermediate_dir) if not img.startswith('.')]
                for img in images:
                    path = os.path.join(intermediate_dir, img)
                    image_data.append((path, cls, set_name))
    return image_data

# Collect all image data
all_image_data = []
for set_name, set_dir in sets.items():
    image_data = get_image_paths(set_name, set_dir)
    all_image_data.extend(image_data)

# Extract image paths, classes, and sets
image_paths, classes, sets_list = zip(*all_image_data)

# Initialize a list to store embeddings for each image
image_embeddings = []

# Process each image with a minimal text prompt
for img_path in image_paths:
    query = "<image>"  # Minimal text prompt to focus on visual content
    images = [Image.open(img_path)]
    prompt, input_ids, pixel_values = model.preprocess_inputs(query, images, max_partition=9)
    attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id).to(model.device)
    input_ids = input_ids.unsqueeze(0).to(model.device)
    attention_mask = attention_mask.unsqueeze(0).to(model.device)
    pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]
    
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            labels=None,
            output_hidden_states=True,
            return_dict=True
        )
        last_hidden_state = outputs.hidden_states[-1]
        
        # Use merged attention mask if available, else create one
        if 'attention_mask' in outputs:
            merged_attention_mask = outputs['attention_mask']
        else:
            merged_attention_mask = torch.ones(last_hidden_state.shape[:2], device=last_hidden_state.device)
        
        merged_attention_mask = merged_attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
        
        # Compute the mean embedding
        summed = torch.sum(last_hidden_state * merged_attention_mask, dim=1)
        count = torch.clamp(merged_attention_mask.sum(dim=1), min=1e-9)
        embedding = summed / count
        embedding = embedding / embedding.norm(dim=-1, keepdim=True)
        image_embeddings.append(embedding.to(dtype=torch.float32).cpu().numpy())

# Stack the embeddings into a single array
image_embeddings = np.vstack(image_embeddings)

# Apply PCA to reduce embeddings to 2D
pca = PCA(n_components=2, random_state=42)
embeddings_2d = pca.fit_transform(image_embeddings)

# Get unique classes and assign colors
unique_classes = sorted(set(classes))
# Combine tab20, tab20b, and tab20c for up to 60 distinct colors
colors = (plt.cm.tab20(np.linspace(0, 1, 20))[:, :3].tolist() + 
          plt.cm.tab20b(np.linspace(0, 1, 20))[:, :3].tolist() + 
          plt.cm.tab20c(np.linspace(0, 1, 20))[:, :3].tolist())
class_colors = {cls: colors[i % len(colors)] for i, cls in enumerate(unique_classes)}

# Map each class to its set
class_to_set = {cls: set_name for _, cls, set_name in all_image_data}

# Plotting
plt.figure(figsize=(10, 8))
for cls in unique_classes:
    indices = [i for i, c in enumerate(classes) if c == cls]
    x = embeddings_2d[indices, 0]
    y = embeddings_2d[indices, 1]
    color = class_colors[cls]
    set_name = class_to_set[cls]
    marker = markers[set_name]
    plt.scatter(x, y, color=color, marker=marker, s=50)

# Add legend for sets
for set_name, marker in markers.items():
    plt.scatter([], [], color='gray', marker=marker, label=set_name)
plt.legend(title="Sets")
plt.axis('equal')
plt.title("PCA Visualization of Ovis Image Embeddings\n(Colors represent classes)")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True)
plt.tight_layout()

# Save plot
image_pca_plot_path = os.path.join(save_dir, "image_pca_ovis_plot_all.png")
plt.savefig(image_pca_plot_path, bbox_inches="tight", dpi=500)
plt.close()
print(f"Image PCA plot saved to: {image_pca_plot_path}")

Image PCA plot saved to: /home/bboulbarss/pca_plots/ovis/image_pca_ovis_plot_all.png


# PCA plot, classes merged.

In [3]:
import os
from PIL import Image
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.decomposition import PCA
import numpy as np

# Define base directory and set directories
base_dir = "/home/bboulbarss/large_dataset/relational"
sets = {
    "train": os.path.join(base_dir, "train"),
    "val": os.path.join(base_dir, "ood_val"),
    "test": os.path.join(base_dir, "ood_test"),
}

# Define markers for each set
markers = {"train": "o", "val": "s", "test": "^"}

# Function to compute canonical class name
def get_canonical_class(cls):
    parts = cls.split('_')
    if len(parts) != 3:
        raise ValueError(f"Invalid class name format: {cls}")
    shape1, relation, shape2 = parts
    if relation not in ['left', 'right']:
        raise ValueError(f"Invalid relation in class name: {cls}")
    if shape1 < shape2:
        return cls
    else:
        inverted_relation = 'right' if relation == 'left' else 'left'
        return shape2 + '_' + inverted_relation + '_' + shape1

# Function to get image paths based on set and directory structure
def get_image_paths(set_name, set_dir):
    image_data = []
    # Get classes, ignoring dot files
    classes = [c for c in os.listdir(set_dir) if not c.startswith('.')]
    if set_name == "train":
        for cls in classes:
            canonical_cls = get_canonical_class(cls)
            cls_dir = os.path.join(set_dir, cls)
            # Get images, ignoring dot files
            images = [img for img in os.listdir(cls_dir) if not img.startswith('.')]
            for img in images:
                path = os.path.join(cls_dir, img)
                image_data.append((path, canonical_cls, set_name))
    else:  # val or test
        for cls in classes:
            canonical_cls = get_canonical_class(cls)
            cls_dir = os.path.join(set_dir, cls)
            # Get the intermediate directory (assume there's only one, ignoring dot files)
            intermediate_dirs = [d for d in os.listdir(cls_dir) if not d.startswith('.') and os.path.isdir(os.path.join(cls_dir, d))]
            if intermediate_dirs:  # Ensure there's at least one intermediate directory
                intermediate_dir = os.path.join(cls_dir, intermediate_dirs[0])
                # Get images, ignoring dot files
                images = [img for img in os.listdir(intermediate_dir) if not img.startswith('.')]
                for img in images:
                    path = os.path.join(intermediate_dir, img)
                    image_data.append((path, canonical_cls, set_name))
    return image_data

# Collect all image data
all_image_data = []
for set_name, set_dir in sets.items():
    image_data = get_image_paths(set_name, set_dir)
    all_image_data.extend(image_data)

# Extract image paths, canonical classes, and sets
image_paths, classes, sets_list = zip(*all_image_data)
# Initialize a list to store embeddings for each image
image_embeddings = []

# Process each image with a minimal text prompt
for img_path in image_paths:
    query = "<image>"  # Minimal text prompt to focus on visual content
    images = [Image.open(img_path)]
    prompt, input_ids, pixel_values = model.preprocess_inputs(query, images, max_partition=9)
    attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id).to(model.device)
    input_ids = input_ids.unsqueeze(0).to(model.device)
    attention_mask = attention_mask.unsqueeze(0).to(model.device)
    pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]
    
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            labels=None,
            output_hidden_states=True,
            return_dict=True
        )
        last_hidden_state = outputs.hidden_states[-1]
        
        # Use merged attention mask if available, else create one
        if 'attention_mask' in outputs:
            merged_attention_mask = outputs['attention_mask']
        else:
            merged_attention_mask = torch.ones(last_hidden_state.shape[:2], device=last_hidden_state.device)
        
        merged_attention_mask = merged_attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
        
        # Compute the mean embedding
        summed = torch.sum(last_hidden_state * merged_attention_mask, dim=1)
        count = torch.clamp(merged_attention_mask.sum(dim=1), min=1e-9)
        embedding = summed / count
        embedding = embedding / embedding.norm(dim=-1, keepdim=True)
        image_embeddings.append(embedding.to(dtype=torch.float32).cpu().numpy())

# Stack the embeddings into a single array
image_embeddings = np.vstack(image_embeddings)

# Apply PCA to reduce embeddings to 2D
pca = PCA(n_components=2, random_state=42)
embeddings_2d = pca.fit_transform(image_embeddings)

# Get unique canonical classes and assign colors
unique_classes = sorted(set(classes))
# Combine tab20, tab20b, and tab20c for up to 60 distinct colors
colors = (plt.cm.tab20(np.linspace(0, 1, 20))[:, :3].tolist() + 
          plt.cm.tab20b(np.linspace(0, 1, 20))[:, :3].tolist() + 
          plt.cm.tab20c(np.linspace(0, 1, 20))[:, :3].tolist())
class_colors = {cls: colors[i % len(colors)] for i, cls in enumerate(unique_classes)}

# Map each canonical class to its set
class_to_set = {cls: set_name for _, cls, set_name in all_image_data}

# Plotting
plt.figure(figsize=(10, 8))
for cls in unique_classes:
    indices = [i for i, c in enumerate(classes) if c == cls]
    x = embeddings_2d[indices, 0]
    y = embeddings_2d[indices, 1]
    color = class_colors[cls]
    set_name = class_to_set[cls]
    marker = markers[set_name]
    plt.scatter(x, y, color=color, marker=marker, s=50)

# Add legend for sets
for set_name, marker in markers.items():
    plt.scatter([], [], color='gray', marker=marker, label=set_name)
plt.legend(title="Sets")
plt.axis('equal')
plt.title("PCA Visualization of Ovis Image Embeddings\n(Colors represent canonical classes)")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True)
plt.tight_layout()

image_pca_plot_path = os.path.join(save_dir, "image_pca_ovis_plot.png")
plt.savefig(image_pca_plot_path, bbox_inches="tight", dpi=500)
plt.close()
print(f"Image PCA plot saved to: {image_pca_plot_path}")

Image PCA plot saved to: /home/bboulbarss/pca_plots/ovis/image_pca_ovis_plot.png


# PCA plot, classes merged, legend

In [4]:
import os
from PIL import Image
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.decomposition import PCA
import numpy as np

# Define base directory and set directories
base_dir = "/home/bboulbarss/large_dataset/relational"
sets = {
    "train": os.path.join(base_dir, "train"),
    "val": os.path.join(base_dir, "ood_val"),
    "test": os.path.join(base_dir, "ood_test"),
}

# Define markers for each set
markers = {"train": "o", "val": "s", "test": "^"}

# Function to compute canonical class name
def get_canonical_class(cls):
    parts = cls.split('_')
    if len(parts) != 3:
        raise ValueError(f"Invalid class name format: {cls}")
    shape1, relation, shape2 = parts
    if relation not in ['left', 'right']:
        raise ValueError(f"Invalid relation in class name: {cls}")
    if shape1 < shape2:
        return cls
    else:
        inverted_relation = 'right' if relation == 'left' else 'left'
        return shape2 + '_' + inverted_relation + '_' + shape1

# Function to get image paths based on set and directory structure
def get_image_paths(set_name, set_dir):
    image_data = []
    # Get classes, ignoring dot files
    classes = [c for c in os.listdir(set_dir) if not c.startswith('.')]
    if set_name == "train":
        for cls in classes:
            canonical_cls = get_canonical_class(cls)
            cls_dir = os.path.join(set_dir, cls)
            # Get images, ignoring dot files
            images = [img for img in os.listdir(cls_dir) if not img.startswith('.')]
            for img in images:
                path = os.path.join(cls_dir, img)
                image_data.append((path, canonical_cls, set_name))
    else:  # val or test
        for cls in classes:
            canonical_cls = get_canonical_class(cls)
            cls_dir = os.path.join(set_dir, cls)
            # Get the intermediate directory (assume there's only one, ignoring dot files)
            intermediate_dirs = [d for d in os.listdir(cls_dir) if not d.startswith('.') and os.path.isdir(os.path.join(cls_dir, d))]
            if intermediate_dirs:  # Ensure there's at least one intermediate directory
                intermediate_dir = os.path.join(cls_dir, intermediate_dirs[0])
                # Get images, ignoring dot files
                images = [img for img in os.listdir(intermediate_dir) if not img.startswith('.')]
                for img in images:
                    path = os.path.join(intermediate_dir, img)
                    image_data.append((path, canonical_cls, set_name))
    return image_data

# Collect all image data
all_image_data = []
for set_name, set_dir in sets.items():
    image_data = get_image_paths(set_name, set_dir)
    all_image_data.extend(image_data)

# Extract image paths, canonical classes, and sets
image_paths, classes, sets_list = zip(*all_image_data)

# Initialize a list to store embeddings for each image
image_embeddings = []

# Process each image with a minimal text prompt
for img_path in image_paths:
    query = "<image>"  # Minimal text prompt to focus on visual content
    images = [Image.open(img_path)]
    prompt, input_ids, pixel_values = model.preprocess_inputs(query, images, max_partition=9)
    attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id).to(model.device)
    input_ids = input_ids.unsqueeze(0).to(model.device)
    attention_mask = attention_mask.unsqueeze(0).to(model.device)
    pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]
    
    with torch.no_grad():
        outputs = model(
            input_ids=input_ids,
            pixel_values=pixel_values,
            attention_mask=attention_mask,
            labels=None,
            output_hidden_states=True,
            return_dict=True
        )
        last_hidden_state = outputs.hidden_states[-1]
        
        # Use merged attention mask if available, else create one
        if 'attention_mask' in outputs:
            merged_attention_mask = outputs['attention_mask']
        else:
            merged_attention_mask = torch.ones(last_hidden_state.shape[:2], device=last_hidden_state.device)
        
        merged_attention_mask = merged_attention_mask.unsqueeze(-1).to(last_hidden_state.dtype)
        
        # Compute the mean embedding
        summed = torch.sum(last_hidden_state * merged_attention_mask, dim=1)
        count = torch.clamp(merged_attention_mask.sum(dim=1), min=1e-9)
        embedding = summed / count
        embedding = embedding / embedding.norm(dim=-1, keepdim=True)
        image_embeddings.append(embedding.to(dtype=torch.float32).cpu().numpy())

# Stack the embeddings into a single array
image_embeddings = np.vstack(image_embeddings)

# Apply PCA to reduce embeddings to 2D
pca = PCA(n_components=2, random_state=42)
embeddings_2d = pca.fit_transform(image_embeddings)

# Get unique canonical classes and assign colors
unique_classes = sorted(set(classes))
# Combine tab20, tab20b, and tab20c for up to 60 distinct colors
colors = (plt.cm.tab20(np.linspace(0, 1, 20))[:, :3].tolist() + 
          plt.cm.tab20b(np.linspace(0, 1, 20))[:, :3].tolist() + 
          plt.cm.tab20c(np.linspace(0, 1, 20))[:, :3].tolist())
class_colors = {cls: colors[i % len(colors)] for i, cls in enumerate(unique_classes)}

# Map each canonical class to its set
class_to_set = {cls: set_name for _, cls, set_name in all_image_data}

# Plotting
plt.figure(figsize=(12, 8))  # Slightly larger figure to accommodate two legends
for cls in unique_classes:
    indices = [i for i, c in enumerate(classes) if c == cls]
    x = embeddings_2d[indices, 0]
    y = embeddings_2d[indices, 1]
    color = class_colors[cls]
    set_name = class_to_set[cls]
    marker = markers[set_name]
    plt.scatter(x, y, color=color, marker=marker, s=50)

# Add legend for sets
for set_name, marker in markers.items():
    plt.scatter([], [], color='gray', marker=marker, label=set_name)
set_legend = plt.legend(title="Sets", loc='upper left', bbox_to_anchor=(1.02, 1.0))

# Add legend for classes
class_handles = [plt.scatter([], [], color=class_colors[cls], marker='o', label=cls) for cls in unique_classes]
class_legend = plt.legend(handles=class_handles, title="Classes", loc='upper left', bbox_to_anchor=(1.02, 0.7))

# Add both legends to the plot
plt.gca().add_artist(set_legend)
plt.axis('equal')
plt.title("PCA Visualization of Ovis Image Embeddings\n(Colors: Classes, Markers: Sets)")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True)
plt.tight_layout()

# Save plot
image_pca_plot_path = os.path.join(save_dir, "image_pca_ovis_plot_all.png")
plt.savefig(image_pca_plot_path, bbox_inches="tight", dpi=500)
plt.close()
print(f"Image PCA plot saved to: {image_pca_plot_path}")

Image PCA plot saved to: /home/bboulbarss/pca_plots/ovis/image_pca_ovis_plot_all.png


In [3]:
# OLD TEXT PCA BAR PLOT

#num_labels = len(labels)
#mcq_prompt = "Task: Identify the correct label for this image from the following choices:\n" + "\n".join(
#    [f"{chr(65+i)}. {label}" for i in range(num_labels)]
#) + "\nAnswer with the letter of the correct choice."
#full_query = f"<image>\n{mcq_prompt}"
#images = [Image.open(image_path)]
#prompt, input_ids, pixel_values = model.preprocess_inputs(full_query, images, max_partition=9)
#attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id).to(model.device)
#input_ids = input_ids.unsqueeze(0).to(model.device)
#attention_mask = attention_mask.unsqueeze(0).to(model.device)
#if pixel_values is not None:
#    pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]  # Fix: Correct dtype and device
#with torch.no_grad():
#    outputs = model(
#        input_ids=input_ids,
#        pixel_values=pixel_values,
#        attention_mask=attention_mask,
#        labels=None
#    )
#    answer_logits = outputs.logits[:, -1, :]  # Logits for the next token
#    answer_token_ids = [text_tokenizer.convert_tokens_to_ids(letter) for letter in answer_letters]
#    logits_for_answers = answer_logits[0, answer_token_ids]
#    probs = torch.softmax(logits_for_answers.to(dtype=torch.float32), dim=0).cpu().numpy()
#    print(f"Probabilities: {dict(zip(answer_letters, probs*100))}")
#
## Get token IDs for answer letters dynamically
#answer_letters = [chr(65 + i) for i in range(num_labels)]
#answer_token_ids = [text_tokenizer.convert_tokens_to_ids(letter) for letter in answer_letters]
#logits_for_answers = logits[0, answer_token_ids]
#probs = torch.softmax(logits_for_answers.to(dtype=torch.float32), dim=0).cpu().numpy()

## --- Plot bar chart ---
## Sort labels and probs by descending probability
#sorted_indices = np.argsort(probs)[::-1]
#sorted_labels = [labels[i] for i in sorted_indices]
#sorted_probs = [probs[i] * 100 for i in sorted_indices]  # convert to percentages
#
## Plotting
#plt.figure(figsize=(10, 7))
## Create color list: green for correct_label, red for others
#colors = ['green' if label == correct_label else 'red' for label in sorted_labels]
#bars = plt.bar(sorted_labels, sorted_probs, color=colors)
#
## Titles and labels
#plt.title("Ovis Label Probabilities for Image")
#plt.xlabel("Answer Choices")
#plt.ylabel("Probability (%)", rotation=0, labelpad=40)
#
## Tick formatting
#plt.xticks(rotation=45, ha="right")
#plt.yticks(np.linspace(0, 100, 6))
#
## Add legend
#legend_handles = [plt.Rectangle((0,0),1,1, color='green'), plt.Rectangle((0,0),1,1, color='red')] if correct_label in sorted_labels else [plt.Rectangle((0,0),1,1, color='red')]
#legend_labels = ['Correct Label', 'Other Labels'] if correct_label in sorted_labels else ['Other Labels']
#plt.legend(handles=legend_handles, labels=legend_labels, loc='upper right')
#
## Layout and save
#plt.tight_layout()
#plt.savefig(os.path.join(save_dir, "ovis_probabilities_bar_plot.png"), bbox_inches="tight", dpi=500)
#plt.close()