In [None]:
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import FlavaProcessor, FlavaForPreTraining
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import cosine_similarity, cosine_distances
from PIL import Image
import seaborn as sns

torch.manual_seed(0)
np.random.seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# Set seaborn style
sns.set_style("whitegrid")
# Update font sizes globally
plt.rcParams.update({
    'font.size': 15,
    'axes.titlesize': 17,
    'axes.labelsize': 14,
    'xtick.labelsize': 13,
    'ytick.labelsize': 13,
    'legend.fontsize': 13,
    'legend.title_fontsize': 14,
})

# -- Prediction function --
def flava_predict_label(image_path, texts, model, processor, device):
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')

    # Replicate the image to match the number of text inputs (required for FLAVA)
    images = [image] * len(texts)

    # Process the inputs
    inputs = processor(
        text=texts,
        images=images,
        return_tensors="pt",
        padding=True,
        max_length=77,
        return_codebook_pixels=True,
        return_image_mask=True,
        return_attention_mask=True
    ).to(device)

    # Get model outputs without gradient tracking for efficiency
    with torch.no_grad():
        output = model(**inputs)

    # Extract the image-text similarity scores
    logits_per_image = output.contrastive_logits_per_image
    # Convert logits to probabilities
    probs = logits_per_image.softmax(dim=1)[0].unsqueeze(0)

    return probs.cpu().numpy()

# -- User-configurable labels --

##### IMAGE 1 #####
#labels = [
#    "A photo of a cube right of a cone",
#    "A photo of a cube left of a cone",
#    "A photo of a cone right of a cube",
#    "A photo of a cube left of a sphere",
#    "A photo of a cylinder right of a cone",]
#correct_label = "A photo of a cube right of a cone"
#image_path = "/home/bboulbarss/large_dataset/relational/ood_val/cube_right_cone/cube right cone/CLEVR_rel_000025.png"


################################################################################################
## Image 1 in final_gradcam
## Relational
labels = [
    "A photo of a cylinder left of a cone",
    "A photo of a cylinder right of a cone",
    "A photo of a cone left of a cylinder",
    "A photo of a cube right of a cylinder",
    "A photo of a sphere right of a cone",]
correct_label = "A photo of a cylinder left of a cone"
image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_cone/CLEVR_rel_000020.png"
#
# Two object
#labels = [
#    "A photo of a purple cylinder",
#    "A photo of a green cylinder",
#    "A photo of a purple cone",
#    "A photo of a red sphere",
#    "A photo of a blue cube"
#]
#correct_label = "A photo of a purple cylinder"
#image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_cone/CLEVR_rel_000020.png"

################################################################################################
# Image 2 in final_gradcam
# Relational
#labels = [
#    "A photo of a cylinder left of a sphere",
#    "A photo of a cylinder right of a sphere",
#    "A photo of a sphere left of a cylinder",
#    "A photo of a cube right a cone",
#    "A photo of a sphere right of a cone",]
#correct_label = "A photo of a cylinder left of a sphere"
#image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_sphere/CLEVR_rel_000031.png"
#
# Two object
#labels = [
#    "A photo of a green sphere",
#    "A photo of a blue sphere",
#    "A photo of a green cylinder",
#    "A photo of a red cone",
#    "A photo of a purple cube",
#]
#correct_label = "A photo of a green sphere"
#image_path = "/home/bboulbarss/large_dataset/relational/train/cylinder_left_sphere/CLEVR_rel_000031.png"

################################################################################################



# Directory to save PCA plot
save_dir = '/home/bboulbarss/pca_plots/flava'
os.makedirs(save_dir, exist_ok=True)
plot_path = os.path.join(save_dir, 'text_pca_flava_plot.png')

# -- Load FLAVA model and processor --
model_name = 'facebook/flava-full'
model = FlavaForPreTraining.from_pretrained(model_name)
processor = FlavaProcessor.from_pretrained(model_name)
model.eval()  # Set model to evaluation mode

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# -- Load and preprocess the image --
try:
    image = Image.open(image_path).convert('RGB')
except Exception as e:
    raise ValueError(f"Failed to load or process image at {image_path}: {str(e)}")

# -- Create a list of images (same image for each label) --
images = [image] * len(labels)

# -- Prepare text and image inputs for embeddings --
inputs = processor(text=labels, images=images, return_tensors='pt', padding=True, max_length=77, return_codebook_pixels=True, return_image_mask=True, return_attention_mask=True)
inputs = {k: v.to(device) for k, v in inputs.items()}

# -- Compute embeddings --
with torch.no_grad():
    outputs = model(**inputs)
    # Extract the [CLS] token from text embeddings as sentence representation
    text_embeddings = outputs.text_embeddings[:, 0, :]

# -- Normalize embeddings for cosine-based analysis --
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
embeddings = text_embeddings.cpu().numpy()

# -- Compute predictions using flava_predict_label --
probs = flava_predict_label(image_path, labels, model, processor, device)
#print("Prediction Probabilities:")
#for label, prob in zip(labels, probs[0]):
#    print(f"{label}: {prob:.4f}")

# sim_matrix = cosine_similarity(embeddings)
# dist_matrix = cosine_distances(embeddings)
# print("\nCosine Similarity Matrix:")
# print(np.round(sim_matrix, 4))
# print("\nCosine Distance Matrix:")
# print(np.round(dist_matrix, 4))

# -- PCA projection to 2 dimensions --
pca = PCA(n_components=2, random_state=42)
embeddings_2d = pca.fit_transform(embeddings)

# -- Plot PCA result --

# Define colorblind palette for "Correct" and explicit red for "Others"
colorblind_colors = sns.color_palette("colorblind")
palette = {"original": colorblind_colors[2], "ft": (1, 0, 0)}  # Explicit red for Others

# -- Plot PCA result --
plt.figure(figsize=(12, 10))

# Separate points into "correct" and "others"
correct_indices = [i for i, label in enumerate(labels) if label ==  correct_label]
other_indices = [i for i, label in enumerate(labels) if label !=  correct_label]

# Plot "correct" points
if correct_indices:
    plt.scatter(
        embeddings_2d[correct_indices, 0], embeddings_2d[correct_indices, 1],
        c=[palette["original"]],  # Color for correct label
        s=200,  # Larger circle for correct label
        marker='o',  # Circle marker
        label="Correct Label"  # Legend entry
    )

# Plot "others" points
if other_indices:
    plt.scatter(
        embeddings_2d[other_indices, 0], embeddings_2d[other_indices, 1],
        c=[palette["ft"]],  # Pure red for other labels
        s=100,  # Smaller circle for others
        marker='o',  # Circle marker
        label="Wrong Label"  # Legend entry
    )

# Equalize axis scales to avoid distortion
plt.axis('equal')

# Compute offsets only once based on axis limits
x_min, x_max = plt.xlim()
y_min, y_max = plt.ylim()
x_offset = 0.01 * (x_max - x_min)
y_offset = 0.01 * (y_max - y_min)

for i, label in enumerate(labels):
    plt.annotate(
        label,
        (embeddings_2d[i, 0] + x_offset, embeddings_2d[i, 1] + y_offset),
        fontsize=13,
        alpha=0.8,
        ha='left',
        va='bottom',
        wrap=True
    )

# Add legend
plt.legend()

plt.title("PCA Visualization of FLAVA Text Embeddings")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True)
plt.tight_layout()

# -- Save and close plot --
plt.savefig(plot_path, bbox_inches='tight', dpi=500)
plt.close()

print(f"PCA plot saved to: {plot_path}")



# -- Create bar plot for probabilities --
# Sort labels and probabilities in descending order
probs = probs[0]  # assuming shape (1, N)
sorted_indices = np.argsort(probs)[::-1]
sorted_labels = [labels[i] for i in sorted_indices]
sorted_probs = [probs[i] * 100 for i in sorted_indices]  # convert to percentages

# Plot
plt.figure(figsize=(10, 7))
# Create color list: green for correct_label, red for others
colors = ['green' if label == correct_label else 'red' for label in sorted_labels]
bars = plt.bar(sorted_labels, sorted_probs, color=colors)

# Labels and title
plt.xlabel("Labels")
plt.ylabel("Probability (%)", rotation=0, labelpad=40)
plt.title("FLAVA Label Probabilities for Image")

# Ticks
plt.xticks(rotation=45, ha="right")
plt.yticks(np.linspace(0, 100, 6))

# Add legend
legend_handles = [plt.Rectangle((0,0),1,1, color='green'), plt.Rectangle((0,0),1,1, color='red')] if correct_label in sorted_labels else [plt.Rectangle((0,0),1,1, color='red')]
legend_labels = ['Correct Label', 'Wrong Labels'] if correct_label in sorted_labels else ['Other Labels']
plt.legend(handles=legend_handles, labels=legend_labels, loc='upper right')

# Layout and save
plt.tight_layout()
bar_plot_path = os.path.join(save_dir, 'flava_probabilities_bar_plot.png')
plt.savefig(bar_plot_path, bbox_inches='tight', dpi=500)
plt.close()

print(f"Bar plot saved to: {bar_plot_path}")


`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if you are doing inference on unmasked text...
`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if you are doing inference on unmasked text...


PCA plot saved to: /home/bboulbarss/pca_plots/flava/text_pca_flava_plot.png
Bar plot saved to: /home/bboulbarss/pca_plots/flava/flava_probabilities_bar_plot.png


# PCA plot

In [2]:
import os
from PIL import Image
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.decomposition import PCA
import numpy as np

# Define base directory and set directories
base_dir = "/home/bboulbarss/large_dataset/relational"
sets = {
    "train": os.path.join(base_dir, "train"),
    "val": os.path.join(base_dir, "ood_val"),
    "test": os.path.join(base_dir, "ood_test"),
}

# Define markers for each set
markers = {"train": "o", "val": "s", "test": "^"}

# Function to get image paths based on set and directory structure
def get_image_paths(set_name, set_dir):
    image_data = []
    # Get classes, ignoring dot files
    classes = [c for c in os.listdir(set_dir) if not c.startswith('.')]
    if set_name == "train":
        for cls in classes:
            cls_dir = os.path.join(set_dir, cls)
            # Get images, ignoring dot files
            images = [img for img in os.listdir(cls_dir) if not img.startswith('.')]
            for img in images:
                path = os.path.join(cls_dir, img)
                image_data.append((path, cls, set_name))
    else:  # val or test
        for cls in classes:
            cls_dir = os.path.join(set_dir, cls)
            # Get the intermediate directory (assume there's only one, ignoring dot files)
            intermediate_dirs = [d for d in os.listdir(cls_dir) if not d.startswith('.') and os.path.isdir(os.path.join(cls_dir, d))]
            if intermediate_dirs:  # Ensure there's at least one intermediate directory
                intermediate_dir = os.path.join(cls_dir, intermediate_dirs[0])
                # Get images, ignoring dot files
                images = [img for img in os.listdir(intermediate_dir) if not img.startswith('.')]
                for img in images:
                    path = os.path.join(intermediate_dir, img)
                    image_data.append((path, cls, set_name))
    return image_data

# Collect all image data
all_image_data = []
for set_name, set_dir in sets.items():
    image_data = get_image_paths(set_name, set_dir)
    all_image_data.extend(image_data)

# Extract image paths, classes, and sets
image_paths, classes, sets_list = zip(*all_image_data)

# Load images
images = [Image.open(path).convert('RGB') for path in image_paths]

# Set dummy texts
texts = [""] * len(images)

# Process inputs
inputs = processor(text=texts, images=images, return_tensors='pt', padding=True, max_length=77, return_codebook_pixels=True, return_image_mask=True, return_attention_mask=True)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Get model outputs
with torch.no_grad():
    outputs = model(**inputs)

# Extract image embeddings (assuming [CLS] token is at position 0)
image_embeddings = outputs.image_embeddings[:, 0, :]

# Normalize embeddings
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
embeddings = image_embeddings.cpu().numpy()

# Apply PCA
pca = PCA(n_components=2, random_state=42)
embeddings_2d = pca.fit_transform(embeddings)

# Get unique classes and assign colors
unique_classes = sorted(set(classes))
# Combine tab20, tab20b, and tab20c for up to 60 distinct colors
colors = (plt.cm.tab20(np.linspace(0, 1, 20))[:, :3].tolist() + 
          plt.cm.tab20b(np.linspace(0, 1, 20))[:, :3].tolist() + 
          plt.cm.tab20c(np.linspace(0, 1, 20))[:, :3].tolist())
class_colors = {cls: colors[i % len(colors)] for i, cls in enumerate(unique_classes)}

# Map each class to its set
class_to_set = {cls: set_name for _, cls, set_name in all_image_data}

# Plotting
plt.figure(figsize=(10, 8))
for cls in unique_classes:
    indices = [i for i, c in enumerate(classes) if c == cls]
    x = embeddings_2d[indices, 0]
    y = embeddings_2d[indices, 1]
    color = class_colors[cls]
    set_name = class_to_set[cls]
    marker = markers[set_name]
    plt.scatter(x, y, color=color, marker=marker, s=25)

# Add legend for sets
for set_name, marker in markers.items():
    plt.scatter([], [], color='gray', marker=marker, label=set_name)
plt.legend(title="Sets")
plt.title("PCA Visualization of FLAVA Image Embeddings\n(Colors represent classes)")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True)
plt.tight_layout()

# Save plot
image_pca_plot_path = os.path.join(save_dir, "image_pca_flava_plot_all.png")
plt.savefig(image_pca_plot_path, bbox_inches="tight", dpi=500)
plt.close()
print(f"Image PCA plot saved to: {image_pca_plot_path}")

`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if you are doing inference on unmasked text...


Image PCA plot saved to: /home/bboulbarss/pca_plots/flava/image_pca_flava_plot_all.png


# PCA plot, classes merged, legend

In [3]:
import os
from PIL import Image
import torch
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from sklearn.decomposition import PCA
import numpy as np

# Define base directory and set directories
base_dir = "/home/bboulbarss/large_dataset/relational"
sets = {
    "train": os.path.join(base_dir, "train"),
    "val": os.path.join(base_dir, "ood_val"),
    "test": os.path.join(base_dir, "ood_test"),
}

# Define markers for each set
markers = {"train": "o", "val": "s", "test": "^"}

# Function to compute canonical class name
def get_canonical_class(cls):
    parts = cls.split('_')
    if len(parts) != 3:
        raise ValueError(f"Invalid class name format: {cls}")
    shape1, relation, shape2 = parts
    if relation not in ['left', 'right']:
        raise ValueError(f"Invalid relation in class name: {cls}")
    if shape1 < shape2:
        return cls
    else:
        inverted_relation = 'right' if relation == 'left' else 'left'
        return shape2 + '_' + inverted_relation + '_' + shape1

# Function to get image paths based on set and directory structure
def get_image_paths(set_name, set_dir):
    image_data = []
    # Get classes, ignoring dot files
    classes = [c for c in os.listdir(set_dir) if not c.startswith('.')]
    if set_name == "train":
        for cls in classes:
            canonical_cls = get_canonical_class(cls)
            cls_dir = os.path.join(set_dir, cls)
            # Get images, ignoring dot files
            images = [img for img in os.listdir(cls_dir) if not img.startswith('.')]
            for img in images:
                path = os.path.join(cls_dir, img)
                image_data.append((path, canonical_cls, set_name))
    else:  # val or test
        for cls in classes:
            canonical_cls = get_canonical_class(cls)
            cls_dir = os.path.join(set_dir, cls)
            # Get the intermediate directory (assume there's only one, ignoring dot files)
            intermediate_dirs = [d for d in os.listdir(cls_dir) if not d.startswith('.') and os.path.isdir(os.path.join(cls_dir, d))]
            if intermediate_dirs:  # Ensure there's at least one intermediate directory
                intermediate_dir = os.path.join(cls_dir, intermediate_dirs[0])
                # Get images, ignoring dot files
                images = [img for img in os.listdir(intermediate_dir) if not img.startswith('.')]
                for img in images:
                    path = os.path.join(intermediate_dir, img)
                    image_data.append((path, canonical_cls, set_name))
    return image_data

# Collect all image data
all_image_data = []
for set_name, set_dir in sets.items():
    image_data = get_image_paths(set_name, set_dir)
    all_image_data.extend(image_data)

# Extract image paths, canonical classes, and sets
image_paths, classes, sets_list = zip(*all_image_data)

# Load images
images = [Image.open(path).convert('RGB') for path in image_paths]

# Set dummy texts
texts = [""] * len(images)

# Process inputs
inputs = processor(text=texts, images=images, return_tensors='pt', padding=True, max_length=77, return_codebook_pixels=True, return_image_mask=True, return_attention_mask=True)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Get model outputs
with torch.no_grad():
    outputs = model(**inputs)

# Extract image embeddings (assuming [CLS] token is at position 0)
image_embeddings = outputs.image_embeddings[:, 0, :]

# Normalize embeddings
image_embeddings = image_embeddings / image_embeddings.norm(dim=-1, keepdim=True)
embeddings = image_embeddings.cpu().numpy()

# Apply PCA
pca = PCA(n_components=2, random_state=42)
embeddings_2d = pca.fit_transform(embeddings)

# Get unique canonical classes and assign colors
unique_classes = sorted(set(classes))
# Combine tab20, tab20b, and tab20c for up to 60 distinct colors
colors = (plt.cm.tab20(np.linspace(0, 1, 20))[:, :3].tolist() + 
          plt.cm.tab20b(np.linspace(0, 1, 20))[:, :3].tolist() + 
          plt.cm.tab20c(np.linspace(0, 1, 20))[:, :3].tolist())
class_colors = {cls: colors[i % len(colors)] for i, cls in enumerate(unique_classes)}

# Map each canonical class to its set
class_to_set = {cls: set_name for _, cls, set_name in all_image_data}

# Plotting
plt.figure(figsize=(12, 8))  # Slightly larger figure to accommodate two legends
for cls in unique_classes:
    indices = [i for i, c in enumerate(classes) if c == cls]
    x = embeddings_2d[indices, 0]
    y = embeddings_2d[indices, 1]
    color = class_colors[cls]
    set_name = class_to_set[cls]
    marker = markers[set_name]
    plt.scatter(x, y, color=color, marker=marker, s=25)

# Add legend for sets
for set_name, marker in markers.items():
    plt.scatter([], [], color='gray', marker=marker, label=set_name)
set_legend = plt.legend(title="Sets", loc='upper left', bbox_to_anchor=(1.02, 1.0))

# Add legend for classes
class_handles = [plt.scatter([], [], color=class_colors[cls], marker='o', label=cls) for cls in unique_classes]
class_legend = plt.legend(handles=class_handles, title="Classes", loc='upper left', bbox_to_anchor=(1.02, 0.7))

# Add both legends to the plot
plt.gca().add_artist(set_legend)
plt.axis('equal')
plt.title("PCA Visualization of FLAVA Image Embeddings\n(Colors: Classes, Markers: Sets)")
plt.xlabel("Principal Component 1")
plt.ylabel("Principal Component 2")
plt.grid(True)
plt.tight_layout()

# Save plot
image_pca_plot_path = os.path.join(save_dir, "image_pca_flava_plot_all.png")
plt.savefig(image_pca_plot_path, bbox_inches="tight", dpi=500)
plt.close()
print(f"Image PCA plot saved to: {image_pca_plot_path}")

`input_ids_masked` isn't passed which means MLM loss won't be calculated correctlySetting it to `input_ids` so that model can work. Please pass it if this is unintentional. This is usually OKAY if you are doing inference on unmasked text...


OutOfMemoryError: CUDA out of memory. Tried to allocate 13.04 GiB. GPU 0 has a total capacity of 93.11 GiB of which 5.54 GiB is free. Process 2980266 has 1.96 GiB memory in use. Process 2996152 has 1.84 GiB memory in use. Including non-PyTorch memory, this process has 83.74 GiB memory in use. Of the allocated memory 67.00 GiB is allocated by PyTorch, and 16.07 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)