# Project Phase 2: Video Dialog 

## Imports

In [1]:
import json
from pprint import pprint

#Open Search
from opensearchpy import OpenSearch

#Embeddings neighborhood
import torch

#Contextual embeddings and self-attention
import numpy as np

# Get the interactive Tools for Matplotlib
import matplotlib.pyplot as plt
plt.style.use('ggplot')

from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor, LlavaForConditionalGeneration, AutoProcessor
import matplotlib.cm as cm
import matplotlib
from matplotlib.colors import Normalize


from PIL import Image
import av
import glob

import os
import yt_dlp

from pathlib import Path
import math

## 2.2 Text-based Search

### Load the video captions

In [None]:
def load_captions_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    processed = {}
    for video_id, captions in data.items():
        processed[video_id] = {
            "segments": captions['segments'] if 'segments' in captions else captions,
        }
    return processed

# Load the data
val_data1 = load_captions_data('captions/val_1.json')
val_data2 = load_captions_data('captions/val_2.json')

# Combine dictionaries (preserving video_id as keys)
all_captions_data = {**val_data1, **val_data2}

pprint(f"Number of captions: {len(all_captions_data)}")
pprint(f"Example Captions: {all_captions_data}")

### Load the videos

In [None]:
with open('activity_net.v1-3.min.json', 'r') as json_data:
    data = json.load(json_data)

database = {}

for video_id in data['database']:
    database["v_" + video_id] = data['database'][video_id]

# Create the list with all data, sorted by the number of annotations
sorted_database = sorted(
    database.items(),
    key=lambda x: len(x[1]['annotations']),
    reverse=True
)

# Top 10 videos in number of annotations
top_videos = dict(sorted_database[:27])

pprint(top_videos)

In [None]:
matching_ids = set(database.keys()) & set(all_captions_data.keys())
print(f"Número de IDs correspondentes: {len(matching_ids)}")
print(f"IDs no top_videos: {list(top_videos.keys())[:5]}...")
print(f"IDs em all_captions_data: {list(all_captions_data.keys())[:5]}...")

### Compute the final captions dataset

In [None]:
final_dataset_captions = {}
#final_dataset_video = {}

# Check and store the captions' of the top 10 videos
for video_id in top_videos:
    try:
        if (all_captions_data[video_id] != None):
            final_dataset_captions[video_id] = all_captions_data[video_id]
            #final_dataset_video[video_id] = top_videos[video_id]
    except Exception as e:
        None

final_dataset_captions.pop("v_PJ72Yl0B1rY", None) # This video has no available URL
#final_dataset_video.pop("v_PJ72Yl0B1rY", None)

pprint(final_dataset_captions)
pprint(len(final_dataset_captions))

### Keyframe extraction

In [None]:
def download_video(video_url, output_path):
    ydl_opts = {
        'format': 'mp4',
        'outtmpl': output_path,
        'quiet': True
    }
    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        ydl.download([video_url])

In [None]:
def extract_segment_keyframes(video_path, output_dir, t):
    os.makedirs(output_dir, exist_ok=True)
    
    try:
        with av.open(video_path) as container:
            stream = container.streams.video[0]
            stream.codec_context.skip_frame = "NONKEY"
            time_base = stream.time_base  # Needed to convert pts to seconds

            for frame in container.decode(stream):
                timestamp_sec = frame.pts * time_base

                i = 0
                aux = math.inf
                right_ts = -1

                for s in t:
                    # Code to find the closest timestamp 
                    start = float(s[0])
                    end = float(s[1])

                    value = abs(float(timestamp_sec) - start) + abs(float(timestamp_sec) - end)
                    if value < aux and start <= float(timestamp_sec) <= end:
                        aux = value
                        right_ts = i
                    i += 1

                if t[right_ts][0] <= float(timestamp_sec) <= t[right_ts][1]:
                    # Save the frame as an image
                    out_path = os.path.join(
                        output_dir,
                        f"frame_{float(t[right_ts][0])}_{float(t[right_ts][1])}_{round(float(timestamp_sec), 4)}.jpg"
                    )
                    frame.to_image().save(out_path, quality=80)

    except Exception as e:
        print(f"Error in {video_path}: {e}")

# Base folders
video_dir = "videos"
output_base = "keyframes"
os.makedirs(output_base, exist_ok=True)

processed_count = 0
missing_count = 0

for video_id, metadata in final_dataset_captions.items():
    video_path = os.path.join(video_dir, f"{video_id}.mp4")
    output_dir = os.path.join(output_base, video_id)
    t = final_dataset_captions[video_id]['segments']['timestamps']

    if not os.path.exists(video_path):
        video_url = top_videos[video_id]['url']
        print(f"[Download] {video_id} → {video_url}")
        download_video(video_url, video_path)

    if os.path.exists(video_path):
        print(f"[Processing] Extracting keyframes from: {video_id}")
        extract_segment_keyframes(video_path, output_dir, t)
        processed_count += 1
    else:
        print(f"[Missing] Could not find video after download: {video_id}")
        missing_count += 1

print("\nKeyframe extraction completed.")
print(f"    Processed videos: {processed_count}")
print(f"    Missing videos: {missing_count}")
print(f"    Keyframes saved in: {output_base}/<video_id>/")

### OpenSearch connection settings

In [None]:
#Connections to the Open Search Server
host = 'api.novasearch.org'
port = 443

user = 'user09'
password = 'grupo09fct'
index_name = user

Test if OpenSearch is up and running

In [None]:
# Create the client with SSL/TLS enabled, but hostname verification disabled.
client = OpenSearch(
    hosts = [{'host': host, 'port': port}],
    http_compress = True, # enables gzip compression for request bodies
    http_auth = (user, password),
    use_ssl = True,
    url_prefix = 'opensearch_v2',
    verify_certs = False,
    ssl_assert_hostname = False,
    ssl_show_warn = False
)

if client.indices.exists(index_name):

    resp = client.indices.open(index = index_name)
    print(resp)

    print('\n----------------------------------------------------------------------------------- INDEX SETTINGS')
    settings = client.indices.get_settings(index = index_name)
    pprint(settings)

    print('\n----------------------------------------------------------------------------------- INDEX MAPPINGS')
    mappings = client.indices.get_mapping(index = index_name)
    pprint(mappings)

    print('\n----------------------------------------------------------------------------------- INDEX #DOCs')
    print(client.count(index = index_name))
else:
    print("Index does not exist.")

In [None]:
client.indices.delete(index=index_name, ignore=[400, 404])

### Create the index mappings

In [None]:
index_body = {
    "settings": {
        "index": {
            "knn": True
        }
    },
    "mappings": {
        "properties": {
            "video_id": {"type": "keyword"},
            "frame_timestamp": {"type": "float"},
            "caption": {"type": "text"},
            "caption_vector": {
                "type": "knn_vector",
                "dimension": 512,
                "method": {
                    "name": "hnsw",
                    "space_type": "innerproduct",
                    "engine": "faiss",
                    "parameters": {
                        "ef_construction": 256,
                        "m": 48
                    }
                }
            },
            "image_clip_vector": {
                "type": "knn_vector",
                "dimension": 512,
                "method": {
                    "name": "hnsw",
                    "space_type": "innerproduct",
                    "engine": "faiss",
                    "parameters": {
                        "ef_construction": 256,
                        "m": 48
                    }
                }
            }
        }
    }
}


if client.indices.exists(index=index_name):
    print("Index already existed. You may force the new mappings.")
else:        
    response = client.indices.create(index_name, body=index_body)
    print('\nCreating index:')
    print(response)

### Encode images and text using CLIP

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(model_id).to(device)
processor = CLIPProcessor.from_pretrained(model_id)

In [None]:
def encode_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = clip_model.get_image_features(**inputs)
        return outputs[0].cpu().numpy()

    
def encode_text(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = clip_model.get_text_features(**inputs)
        return outputs[0].cpu().numpy()

In [None]:
def index_clip_data(video_id, frame_timestamp, caption, image_path):
    caption_vec = encode_text(caption).tolist()
    image_vec = encode_image(image_path).tolist()
    
    doc = {
        "video_id": video_id,
        "frame_timestamp": frame_timestamp,
        "caption": caption,
        "caption_vector": caption_vec,
        "image_clip_vector": image_vec
    }
    
    client.index(index=index_name, body=doc)

### Index the images and captions

In [None]:
keyframes_root = Path("./keyframes")

for video_folder in keyframes_root.iterdir():
    video_id = video_folder.name

    for img in video_folder.glob("*.jpg"):
        filename_parts = img.stem.split("_")
        start_ts = float(filename_parts[1])
        end_ts = float(filename_parts[2])
        frame_ts = float(filename_parts[3])

        img_path = str(img)

        timestamp_array = final_dataset_captions[video_id]['segments']['timestamps']
        sentences_array = final_dataset_captions[video_id]['segments']['sentences']
        
        i = timestamp_array.index([start_ts, end_ts])

        sentence = sentences_array[i]

        index_clip_data(video_id, frame_ts, sentence, img_path)

        print(f"Indexed: {video_id} {img_path} {timestamp_array[i]} {sentences_array[i]}")

In [None]:
client.indices.refresh(index=index_name)

## Queries

In [None]:
keyframes_dir = "keyframes"

def find_closest_frame(video_id, timestamp):
    folder = os.path.join(keyframes_dir, video_id)
    pattern = os.path.join(folder, f"frame_*_{timestamp:.4f}.jpg")
    
    matches = glob.glob(pattern)
    if matches:
        return matches[0]  # take first match

    # If no exact match, fallback to closest
    pattern = os.path.join(folder, f"frame_*_*.jpg")
    frames = glob.glob(pattern)

    # Extract float timestamps and find closest
    best_match = None
    min_diff = float("inf")
    for f in frames:
        try:
            ts = float(f.split("_")[-1].replace(".jpg", ""))
            diff = abs(ts - timestamp)
            if diff < min_diff:
                best_match = f
                min_diff = diff
        except:
            continue

    return best_match

### Query Example (Text → Image)

In [None]:
query = "a man surfing"
query_embedding = encode_text(query).tolist()

search_query = {
    "size": 5,
    "_source": ["video_id", "frame_timestamp"],
    "query": {
        "knn": {
            "image_clip_vector": {
                "vector": query_embedding,
                "k": 5
            }
        }
    }
}
response = client.search(index=index_name, body=search_query)

for hit in response["hits"]["hits"]:
    video_id = hit["_source"]["video_id"]
    timestamp = hit["_source"]["frame_timestamp"]

    image_path = find_closest_frame(video_id, timestamp)
    print(f"Video: {video_id} — Time: {timestamp}s")
    print(f"Found image: {image_path}")

    if image_path and os.path.exists(image_path):
        img = Image.open(image_path)
        plt.imshow(img)
        plt.title(f"{video_id} @ {timestamp}s")
        plt.axis("off")
        plt.show()
    else:
        print("Image not found.")

### Query Example → Using text and image to make similarities

In [None]:
query = "a man surfing on a wave"
query_embedding = encode_text(query).tolist()

search_query = {
    "size": 5,
    "_source": ["video_id", "frame_timestamp"],
    "query": {
        "bool": {
            "should": [
                {
                    "knn": {
                        "image_clip_vector": {
                            "vector": query_embedding,
                            "k": 5
                        }
                    }
                },
                {
                    "knn": {
                        "caption_vector": {
                            "vector": query_embedding,
                            "k": 5
                        }
                    }
                }
            ],
            "minimum_should_match": 1
        }
    }
}


response = client.search(index=index_name, body=search_query)

for hit in response["hits"]["hits"]:
    video_id = hit["_source"]["video_id"]
    timestamp = hit["_source"]["frame_timestamp"]

    image_path = find_closest_frame(video_id, timestamp)
    print(f"Video: {video_id} — Time: {timestamp}s")
    print(f"Found image: {image_path}")

    if image_path and os.path.exists(image_path):
        img = Image.open(image_path)
        plt.imshow(img)
        plt.title(f"{video_id} @ {timestamp}s")
        plt.axis("off")
        plt.show()
    else:
        print("Image not found.")

### Query Example (Image → Image)

In [None]:
image_embedding = encode_image("./keyframes/v_2ji02dSx1nM/frame_18.71_68.33_62.0621.jpg")

search_query = {
    "size": 5,
    "query": {
        "knn": {
            "image_clip_vector": {
                "vector": image_embedding,
                "k": 5
            }
        }
    }
}

response = client.search(index=index_name, body=search_query)

for hit in response["hits"]["hits"]:
    video_id = hit["_source"]["video_id"]
    timestamp = hit["_source"]["frame_timestamp"]

    image_path = find_closest_frame(video_id, timestamp)
    print(f"Video: {video_id} — Time: {timestamp}s")
    print(f"Found image: {image_path}")

    if image_path and os.path.exists(image_path):
        img = Image.open(image_path)
        plt.imshow(img)
        plt.title(f"{video_id} @ {timestamp}s")
        plt.axis("off")
        plt.show()
    else:
        print("Image not found.")

### Query Example (Image + Text → Image)

In [None]:
def encode_combined_query(image_path, text_query, alpha=0.5):
    """
    alpha controls the weighting: 0.0 = only text, 1.0 = only image
    """
    image_vec = encode_image(image_path)
    text_vec = encode_text(text_query)
    
    combined_vec = alpha * image_vec + (1 - alpha) * text_vec
    return combined_vec / np.linalg.norm(combined_vec)

In [None]:
#img_emb = encode_image("./keyframes/v_2ji02dSx1nM/frame_18.71_68.33_62.0621.jpg")
#txt_emb = encode_text("Men yappin some shit")

#combined_emb = (img_emb / np.linalg.norm(img_emb) + txt_emb / np.linalg.norm(txt_emb)) / 2

combined_vec = encode_combined_query(
    "./keyframes/v_2ji02dSx1nM/frame_18.71_68.33_62.0621.jpg", 
    "Men yappin some shit", 
    alpha=0.5).tolist()

# Prepare the OpenSearch query
search_query = {
    "size": 5,
    "_source": ["video_id", "frame_timestamp"],
    "query": {
        "bool": {
            "should": [
                {
                    "knn": {
                        "image_clip_vector": {
                            "vector": combined_vec,
                            "k": 5
                        }
                    }
                },
                {
                    "knn": {
                        "caption_vector": {
                            "vector": combined_vec,
                            "k": 5
                        }
                    }
                }
            ],
            "minimum_should_match": 1
        }
    }
}

response = client.search(index=index_name, body=search_query)

# Display the matched frames
for hit in response["hits"]["hits"]:
    video_id = hit["_source"]["video_id"]
    timestamp = hit["_source"]["frame_timestamp"]

    image_path = find_closest_frame(video_id, timestamp)
    print(f"Video: {video_id} — Time: {timestamp}s")
    print(f"Found image: {image_path}")

    if image_path and os.path.exists(image_path):
        img = Image.open(image_path)
        plt.imshow(img)
        plt.title(f"{video_id} @ {timestamp}s")
        plt.axis("off")
        plt.show()
    else:
        print("Image not found.")

# Large Vision and Language Models

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model and processor
#model_id = "llava-hf/llava-1.5-7b-hf"
model_id = "llava-hf/llava-1.5-7b-hf"

processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float32,        
    low_cpu_mem_usage=True
).to(device)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
def ask_llava(image_path, question, max_tokens=64):
    image = Image.open(image_path).convert("RGB")
    prompt = f"<|user|>\n<image>\n{question}<|end|>\n<|assistant|>"
    
    # Prepare inputs for CPU and float32
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=max_tokens)

    response = processor.decode(output_ids[0], skip_special_tokens=True)
    return response


In [None]:
question = "What is the person doing in this frame?"
frame_path = "./keyframes/v_2ji02dSx1nM/frame_18.71_68.33_62.0621.jpg"  # output from OpenSearch + CLIP retrieval

response = ask_llava(frame_path, question)
print(response)

# Interpretability

In [None]:
question = "What is the person doing in this frame?"
frame_path = "./keyframes/v_2ji02dSx1nM/frame_18.71_68.33_62.0621.jpg"  # output from OpenSearch + CLIP retrieval

### Attention Layers Visualization

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", output_hidden_states=True).to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def visualize_token_patch_similarity(image: Image.Image, text: str):
    inputs = processor(text=[text], images=image, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        output = model(**inputs, output_hidden_states=True)

    # Image CLS embedding
    image_cls = output.vision_model_output.hidden_states[-1][0][0]
    image_embed = model.vision_model.post_layernorm(image_cls.unsqueeze(0))
    image_embed = model.visual_projection(image_embed)
    image_embed = torch.nn.functional.normalize(image_embed, dim=-1)

    # Text token embeddings
    text_hidden = output.text_model_output.hidden_states[-1][0]  # (tokens, 768)
    text_embed = model.text_projection(text_hidden)              # (tokens, 512)
    text_embed = torch.nn.functional.normalize(text_embed, dim=-1)

    # Filter out special tokens
    input_ids = inputs["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    valid_indices = [
        i for i, tok in enumerate(tokens)
        if tok not in tokenizer.all_special_tokens
    ]

    filtered_tokens = [tokens[i] for i in valid_indices]
    filtered_embed = text_embed[valid_indices]

    # Compute similarity
    similarity = torch.matmul(filtered_embed, image_embed.T).squeeze().detach().cpu().numpy()

    # Normalize for color mapping
    norm = Normalize(vmin=similarity.min(), vmax=similarity.max())
    colors = cm.plasma(norm(similarity))

    # Plotting
    fig, ax = plt.subplots(figsize=(10, len(filtered_tokens) * 0.5 + 1))
    bars = ax.barh(range(len(filtered_tokens)), similarity, color=colors)
    ax.set_yticks(range(len(filtered_tokens)))
    ax.set_yticklabels(filtered_tokens)
    ax.set_xlabel("Similarity to Image")
    ax.set_title("Token Relevance to Image")
    ax.invert_yaxis()

    # Colorbar
    sm = cm.ScalarMappable(cmap='plasma', norm=norm)
    sm.set_array([])
    fig.colorbar(sm, ax=ax, label="Cosine Similarity")

    plt.tight_layout()
    plt.show()

In [None]:
def visualize_token_similarity(image: Image.Image, text: str):
    # Preprocess inputs
    inputs = processor(text=[text], images=image, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        output = model(**inputs, output_hidden_states=True)

    # Image patch embeddings (remove CLS token)
    image_hidden = output.vision_model_output.hidden_states[-1][0][1:]  # (49, 768)
    image_embeds = model.vision_model.post_layernorm(image_hidden)
    image_embeds = model.visual_projection(image_embeds)                # (49, 512)
    image_embeds = torch.nn.functional.normalize(image_embeds, dim=-1)

    # Text token embeddings
    text_hidden = output.text_model_output.hidden_states[-1][0]         # (tokens, 768)
    text_embeds = model.text_projection(text_hidden)                    # (tokens, 512)
    text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)

    # Filter special tokens
    token_ids = inputs["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
    valid_indices = [
        i for i, tok in enumerate(tokens)
        if tok not in tokenizer.all_special_tokens
    ]
    tokens = [tokens[i] for i in valid_indices]
    text_embeds = text_embeds[valid_indices]

    # Compute similarity and reshape to 7×7 grid per token
    similarity = torch.matmul(image_embeds, text_embeds.T).detach().cpu().numpy()  # (49, valid_tokens)
    patch_grid = similarity.reshape(7, 7, -1)

    # Plot image and heatmaps
    num_tokens = len(tokens)
    fig, axs = plt.subplots(1, num_tokens + 1, figsize=(3.5 * (num_tokens + 1), 6))
    norm = Normalize(vmin=similarity.min(), vmax=similarity.max())
    cmap = matplotlib.colormaps["viridis"]

    # Original image
    axs[0].imshow(image)
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    # Token heatmaps with shared colorbar
    for i, token in enumerate(tokens):
        im = axs[i + 1].imshow(patch_grid[:, :, i], cmap=cmap, norm=norm)
        axs[i + 1].set_title(f"'{token}'")
        axs[i + 1].axis("off")

    # Add colorbar to the right
    cbar = fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=axs[-1], fraction=0.046, pad=0.04)
    cbar.set_label("Similarity (Cosine)")

    plt.tight_layout()
    plt.show()


In [None]:
image = Image.open(frame_path).convert("RGB")  # Use a real image
text = "A man talking"
visualize_token_patch_similarity(image, text)
visualize_token_similarity(image, text)

### Relevancy Map

In [None]:
def visualize_relevancy_map(image: Image.Image, text: str):
    # Preprocess and forward pass
    inputs = processor(text=[text], images=image, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        output = model(**inputs, output_hidden_states=True)

    # Embeddings
    image_hidden = output.vision_model_output.hidden_states[-1][0][1:]  # Remove CLS
    image_embeds = model.vision_model.post_layernorm(image_hidden)
    image_embeds = model.visual_projection(image_embeds)                # → (49, 512)
    image_embeds = torch.nn.functional.normalize(image_embeds, dim=-1)

    text_hidden = output.text_model_output.hidden_states[-1][0]         # (tokens, 512)
    text_embeds = model.text_projection(text_hidden)
    text_embeds = torch.nn.functional.normalize(text_embeds, dim=-1)

    # Mean text embedding (or use a specific token)
    sentence_embed = text_embeds.mean(dim=0)  # (512,)

    # Cosine similarity (relevance)
    relevance = torch.matmul(image_embeds, sentence_embed).detach().cpu().numpy()  # (49,)
    heatmap = relevance.reshape(7, 7)

    # Plot
    plt.figure(figsize=(6, 6))
    plt.imshow(heatmap, cmap="plasma")
    plt.colorbar(label="Relevance to text")
    plt.title(f"Relevancy Map: '{text}'")
    plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
visualize_relevancy_map(image, text)

### Causal Token Graph

In [None]:
def causal_token_graph(image: Image.Image, text: str):
    tokens = tokenizer.tokenize(text)
    full_text = " ".join(tokens)
    
    # Original similarity
    inputs_full = processor(text=[full_text], images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        full_outputs = model(**inputs_full)
        orig_score = full_outputs.logits_per_text[0, 0].item()

    effects = []
    for i in range(len(tokens)):
        reduced_tokens = tokens[:i] + tokens[i+1:]
        reduced_text = " ".join(reduced_tokens)
        inputs = processor(text=[reduced_text], images=image, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            new_score = outputs.logits_per_text[0, 0].item()
        delta = orig_score - new_score
        effects.append(delta)

    # Plot causal influence
    plt.figure(figsize=(10, 4))
    plt.bar(tokens, effects, color="tomato")
    plt.title("Causal Influence of Each Token on Image Similarity")
    plt.ylabel("Change in Similarity")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

In [None]:
causal_token_graph(image, text)