# Project Phase 2: Video Dialog 

## Imports

In [None]:
import json
from pprint import pprint

#Open Search
from opensearchpy import OpenSearch

#Embeddings neighborhood
import torch

#Contextual embeddings and self-attention
import numpy as np

# Get the interactive Tools for Matplotlib
import matplotlib.pyplot as plt
plt.style.use('ggplot')

from transformers import CLIPModel, CLIPTokenizer, CLIPProcessor, LlavaForConditionalGeneration, AutoProcessor
import matplotlib.cm as cm
import matplotlib
from matplotlib.colors import Normalize
import seaborn as sns


from PIL import Image
import av
import glob

import os
import yt_dlp

from pathlib import Path
import math

from torchvision import transforms
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

## 2.2 Text-based Search

### Load the video captions

In [None]:
def load_captions_data(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    
    processed = {}
    for video_id, captions in data.items():
        processed[video_id] = {
            "segments": captions['segments'] if 'segments' in captions else captions,
        }
    return processed

# Load the data
val_data1 = load_captions_data('captions/val_1.json')
val_data2 = load_captions_data('captions/val_2.json')

# Combine dictionaries (preserving video_id as keys)
all_captions_data = {**val_data1, **val_data2}

pprint(f"Number of captions: {len(all_captions_data)}")
pprint(f"Example Captions: {all_captions_data}")

### Load the videos

In [None]:
with open('activity_net.v1-3.min.json', 'r') as json_data:
    data = json.load(json_data)

database = {}

for video_id in data['database']:
    database["v_" + video_id] = data['database'][video_id]

# Create the list with all data, sorted by the number of annotations
sorted_database = sorted(
    database.items(),
    key=lambda x: len(x[1]['annotations']),
    reverse=True
)

# Top 10 videos in number of annotations
top_videos = dict(sorted_database[:27])

pprint(top_videos)

In [None]:
matching_ids = set(database.keys()) & set(all_captions_data.keys())
print(f"Número de IDs correspondentes: {len(matching_ids)}")
print(f"IDs no top_videos: {list(top_videos.keys())[:5]}...")
print(f"IDs em all_captions_data: {list(all_captions_data.keys())[:5]}...")

### Compute the final captions dataset

In [None]:
final_dataset_captions = {}
#final_dataset_video = {}

# Check and store the captions' of the top 10 videos
for video_id in top_videos:
    try:
        if (all_captions_data[video_id] != None):
            final_dataset_captions[video_id] = all_captions_data[video_id]
            #final_dataset_video[video_id] = top_videos[video_id]
    except Exception as e:
        None

final_dataset_captions.pop("v_PJ72Yl0B1rY", None) # This video has no available URL
#final_dataset_video.pop("v_PJ72Yl0B1rY", None)

pprint(final_dataset_captions)
pprint(len(final_dataset_captions))

### Keyframe extraction

In [None]:
def download_video(video_url, output_path):
    ydl_opts = {
        'format': 'mp4',
        'outtmpl': output_path,
        'quiet': True
    }
    with yt_dlp.YoutubeDL(ydl_opts) as ydl:
        ydl.download([video_url])

In [None]:
def extract_segment_keyframes(video_path, output_dir, t):
    os.makedirs(output_dir, exist_ok=True)
    
    try:
        with av.open(video_path) as container:
            stream = container.streams.video[0]
            stream.codec_context.skip_frame = "NONKEY"
            time_base = stream.time_base  # Needed to convert pts to seconds

            for frame in container.decode(stream):
                timestamp_sec = frame.pts * time_base

                i = 0
                aux = math.inf
                right_ts = -1

                for s in t:
                    # Code to find the closest timestamp 
                    start = float(s[0])
                    end = float(s[1])

                    value = abs(float(timestamp_sec) - start) + abs(float(timestamp_sec) - end)
                    if value < aux and start <= float(timestamp_sec) <= end:
                        aux = value
                        right_ts = i
                    i += 1

                if t[right_ts][0] <= float(timestamp_sec) <= t[right_ts][1]:
                    # Save the frame as an image
                    out_path = os.path.join(
                        output_dir,
                        f"frame_{float(t[right_ts][0])}_{float(t[right_ts][1])}_{round(float(timestamp_sec), 4)}.jpg"
                    )
                    frame.to_image().save(out_path, quality=80)

    except Exception as e:
        print(f"Error in {video_path}: {e}")

# Base folders
video_dir = "videos"
output_base = "keyframes"
os.makedirs(output_base, exist_ok=True)

processed_count = 0
missing_count = 0

for video_id, metadata in final_dataset_captions.items():
    video_path = os.path.join(video_dir, f"{video_id}.mp4")
    output_dir = os.path.join(output_base, video_id)
    t = final_dataset_captions[video_id]['segments']['timestamps']

    if not os.path.exists(video_path):
        video_url = top_videos[video_id]['url']
        print(f"[Download] {video_id} → {video_url}")
        download_video(video_url, video_path)

    if os.path.exists(video_path):
        print(f"[Processing] Extracting keyframes from: {video_id}")
        extract_segment_keyframes(video_path, output_dir, t)
        processed_count += 1
    else:
        print(f"[Missing] Could not find video after download: {video_id}")
        missing_count += 1

print("\nKeyframe extraction completed.")
print(f"    Processed videos: {processed_count}")
print(f"    Missing videos: {missing_count}")
print(f"    Keyframes saved in: {output_base}/<video_id>/")

### OpenSearch connection settings

In [None]:
#Connections to the Open Search Server
host = 'api.novasearch.org'
port = 443

user = 'user09'
password = 'grupo09fct'
index_name = user

Test if OpenSearch is up and running

In [None]:
# Create the client with SSL/TLS enabled, but hostname verification disabled.
client = OpenSearch(
    hosts = [{'host': host, 'port': port}],
    http_compress = True, # enables gzip compression for request bodies
    http_auth = (user, password),
    use_ssl = True,
    url_prefix = 'opensearch_v2',
    verify_certs = False,
    ssl_assert_hostname = False,
    ssl_show_warn = False
)

if client.indices.exists(index_name):

    resp = client.indices.open(index = index_name)
    print(resp)

    print('\n----------------------------------------------------------------------------------- INDEX SETTINGS')
    settings = client.indices.get_settings(index = index_name)
    pprint(settings)

    print('\n----------------------------------------------------------------------------------- INDEX MAPPINGS')
    mappings = client.indices.get_mapping(index = index_name)
    pprint(mappings)

    print('\n----------------------------------------------------------------------------------- INDEX #DOCs')
    print(client.count(index = index_name))
else:
    print("Index does not exist.")

In [None]:
client.indices.delete(index=index_name, ignore=[400, 404])

### Create the index mappings

In [None]:
index_body = {
    "settings": {
        "index": {
            "knn": True
        }
    },
    "mappings": {
        "properties": {
            "video_id": {"type": "keyword"},
            "frame_timestamp": {"type": "float"},
            "caption": {"type": "text"},
            "caption_vector": {
                "type": "knn_vector",
                "dimension": 512,
                "method": {
                    "name": "hnsw",
                    "space_type": "innerproduct",
                    "engine": "faiss",
                    "parameters": {
                        "ef_construction": 256,
                        "m": 48
                    }
                }
            },
            "image_clip_vector": {
                "type": "knn_vector",
                "dimension": 512,
                "method": {
                    "name": "hnsw",
                    "space_type": "innerproduct",
                    "engine": "faiss",
                    "parameters": {
                        "ef_construction": 256,
                        "m": 48
                    }
                }
            }
        }
    }
}


if client.indices.exists(index=index_name):
    print("Index already existed. You may force the new mappings.")
else:        
    response = client.indices.create(index_name, body=index_body)
    print('\nCreating index:')
    print(response)

### Encode images and text using CLIP

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_id = "openai/clip-vit-base-patch32"
clip_model = CLIPModel.from_pretrained(model_id).to(device)
processor = CLIPProcessor.from_pretrained(model_id)

In [None]:
def encode_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = clip_model.get_image_features(**inputs)
        return outputs[0].cpu().numpy()

    
def encode_text(text):
    inputs = processor(text=[text], return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = clip_model.get_text_features(**inputs)
        return outputs[0].cpu().numpy()

In [None]:
def index_clip_data(video_id, frame_timestamp, caption, image_path):
    caption_vec = encode_text(caption).tolist()
    image_vec = encode_image(image_path).tolist()
    
    doc = {
        "video_id": video_id,
        "frame_timestamp": frame_timestamp,
        "caption": caption,
        "caption_vector": caption_vec,
        "image_clip_vector": image_vec
    }
    
    client.index(index=index_name, body=doc)

### Index the images and captions

In [None]:
keyframes_root = Path("./keyframes")

for video_folder in keyframes_root.iterdir():
    video_id = video_folder.name

    for img in video_folder.glob("*.jpg"):
        filename_parts = img.stem.split("_")
        start_ts = float(filename_parts[1])
        end_ts = float(filename_parts[2])
        frame_ts = float(filename_parts[3])

        img_path = str(img)

        timestamp_array = final_dataset_captions[video_id]['segments']['timestamps']
        sentences_array = final_dataset_captions[video_id]['segments']['sentences']
        
        i = timestamp_array.index([start_ts, end_ts])

        sentence = sentences_array[i]

        index_clip_data(video_id, frame_ts, sentence, img_path)

        print(f"Indexed: {video_id} {img_path} {timestamp_array[i]} {sentences_array[i]}")

In [None]:
client.indices.refresh(index=index_name)

## Queries

In [None]:
keyframes_dir = "keyframes"

def find_closest_frame(video_id, timestamp):
    folder = os.path.join(keyframes_dir, video_id)
    pattern = os.path.join(folder, f"frame_*_{timestamp:.4f}.jpg")
    
    matches = glob.glob(pattern)
    if matches:
        return matches[0]  # take first match

    # If no exact match, fallback to closest
    pattern = os.path.join(folder, f"frame_*_*.jpg")
    frames = glob.glob(pattern)

    # Extract float timestamps and find closest
    best_match = None
    min_diff = float("inf")
    for f in frames:
        try:
            ts = float(f.split("_")[-1].replace(".jpg", ""))
            diff = abs(ts - timestamp)
            if diff < min_diff:
                best_match = f
                min_diff = diff
        except:
            continue

    return best_match

### Query Example (Text → Image)

In [None]:
query = "a man surfing"
query_embedding = encode_text(query).tolist()

search_query = {
    "size": 5,
    "_source": ["video_id", "frame_timestamp"],
    "query": {
        "knn": {
            "image_clip_vector": {
                "vector": query_embedding,
                "k": 5
            }
        }
    }
}
response = client.search(index=index_name, body=search_query)

for hit in response["hits"]["hits"]:
    video_id = hit["_source"]["video_id"]
    timestamp = hit["_source"]["frame_timestamp"]

    image_path = find_closest_frame(video_id, timestamp)
    print(f"Video: {video_id} — Time: {timestamp}s")
    print(f"Found image: {image_path}")

    if image_path and os.path.exists(image_path):
        img = Image.open(image_path)
        plt.imshow(img)
        plt.title(f"{video_id} @ {timestamp}s")
        plt.axis("off")
        plt.show()
    else:
        print("Image not found.")

### Query Example → Using text and image to make similarities

In [None]:
query = "a man surfing on a wave"
query_embedding = encode_text(query).tolist()

search_query = {
    "size": 5,
    "_source": ["video_id", "frame_timestamp"],
    "query": {
        "bool": {
            "should": [
                {
                    "knn": {
                        "image_clip_vector": {
                            "vector": query_embedding,
                            "k": 5
                        }
                    }
                },
                {
                    "knn": {
                        "caption_vector": {
                            "vector": query_embedding,
                            "k": 5
                        }
                    }
                }
            ],
            "minimum_should_match": 1
        }
    }
}


response = client.search(index=index_name, body=search_query)

for hit in response["hits"]["hits"]:
    video_id = hit["_source"]["video_id"]
    timestamp = hit["_source"]["frame_timestamp"]

    image_path = find_closest_frame(video_id, timestamp)
    print(f"Video: {video_id} — Time: {timestamp}s")
    print(f"Found image: {image_path}")

    if image_path and os.path.exists(image_path):
        img = Image.open(image_path)
        plt.imshow(img)
        plt.title(f"{video_id} @ {timestamp}s")
        plt.axis("off")
        plt.show()
    else:
        print("Image not found.")

### Query Example (Image → Image)

In [None]:
image_embedding = encode_image("./keyframes/v_2ji02dSx1nM/frame_18.71_68.33_62.0621.jpg")

search_query = {
    "size": 5,
    "query": {
        "knn": {
            "image_clip_vector": {
                "vector": image_embedding,
                "k": 5
            }
        }
    }
}

response = client.search(index=index_name, body=search_query)

for hit in response["hits"]["hits"]:
    video_id = hit["_source"]["video_id"]
    timestamp = hit["_source"]["frame_timestamp"]

    image_path = find_closest_frame(video_id, timestamp)
    print(f"Video: {video_id} — Time: {timestamp}s")
    print(f"Found image: {image_path}")

    if image_path and os.path.exists(image_path):
        img = Image.open(image_path)
        plt.imshow(img)
        plt.title(f"{video_id} @ {timestamp}s")
        plt.axis("off")
        plt.show()
    else:
        print("Image not found.")

### Query Example (Image + Text → Image)

In [None]:
def encode_combined_query(image_path, text_query, alpha=0.5):
    """
    alpha controls the weighting: 0.0 = only text, 1.0 = only image
    """
    image_vec = encode_image(image_path)
    text_vec = encode_text(text_query)
    
    combined_vec = alpha * image_vec + (1 - alpha) * text_vec
    return combined_vec / np.linalg.norm(combined_vec)

In [None]:
#img_emb = encode_image("./keyframes/v_2ji02dSx1nM/frame_18.71_68.33_62.0621.jpg")

#combined_emb = (img_emb / np.linalg.norm(img_emb) + txt_emb / np.linalg.norm(txt_emb)) / 2

combined_vec = encode_combined_query(
    "./keyframes/v_2ji02dSx1nM/frame_18.71_68.33_62.0621.jpg", 
    "Men talking", 
    alpha=0.5).tolist()

# Prepare the OpenSearch query
search_query = {
    "size": 5,
    "_source": ["video_id", "frame_timestamp"],
    "query": {
        "bool": {
            "should": [
                {
                    "knn": {
                        "image_clip_vector": {
                            "vector": combined_vec,
                            "k": 5
                        }
                    }
                },
                {
                    "knn": {
                        "caption_vector": {
                            "vector": combined_vec,
                            "k": 5
                        }
                    }
                }
            ],
            "minimum_should_match": 1
        }
    }
}

response = client.search(index=index_name, body=search_query)

# Display the matched frames
for hit in response["hits"]["hits"]:
    video_id = hit["_source"]["video_id"]
    timestamp = hit["_source"]["frame_timestamp"]

    image_path = find_closest_frame(video_id, timestamp)
    print(f"Video: {video_id} — Time: {timestamp}s")
    print(f"Found image: {image_path}")

    if image_path and os.path.exists(image_path):
        img = Image.open(image_path)
        plt.imshow(img)
        plt.title(f"{video_id} @ {timestamp}s")
        plt.axis("off")
        plt.show()
    else:
        print("Image not found.")

# Large Vision and Language Models

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the model and processor
model_id = "llava-hf/llava-1.5-7b-hf"

processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float32,        
    low_cpu_mem_usage=True
).to(device)

In [None]:
def ask_llava(image_path, question, max_tokens=64):
    image = Image.open(image_path).convert("RGB")
    prompt = f"<|user|>\n<image>\n{question}<|end|>\n<|assistant|>"
    
    # Prepare inputs for CPU and float32
    inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=max_tokens)

    response = processor.decode(output_ids[0], skip_special_tokens=True)
    return response


In [None]:
question = "What is the person doing in this frame?"
frame_path = "./keyframes/v_2ji02dSx1nM/frame_18.71_68.33_62.0621.jpg"  # output from OpenSearch + CLIP retrieval

response = ask_llava(frame_path, question)
print(response)

# CLIP Interpretability

### Language-Vision temporal similarity

In [None]:


outputs = model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
probs

In [None]:
def load_frames_from_dir(video_id, frame_dir="./keyframes"):
    frame_path = os.path.join(frame_dir, video_id)
    frame_files = sorted(glob.glob(os.path.join(frame_path, "frame_*.jpg")))
    # Extract timestamp from filename assuming format: frame_0.jpg, frame_2.jpg, etc.
    frames = []
    for f in frame_files:
        ts = int(os.path.basename(f).split("_")[1].split(".")[0])
        img = Image.open(f).convert("RGB")
        frames.append((ts, img))
    return frames

def compute_clip_similarity(image, caption, model, processor, device="cuda"):
    inputs = processor(text=[caption], images=image, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        img_feat = model.get_image_features(pixel_values=inputs["pixel_values"])
        txt_feat = model.get_text_features(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
    img_feat /= img_feat.norm(dim=-1, keepdim=True)
    txt_feat /= txt_feat.norm(dim=-1, keepdim=True)
    similarity = (img_feat @ txt_feat.T).item()
    return similarity

def plot_similarity_curves(video_id, captions, clip_model, processor, frame_dir="./keyframes"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    frames = load_frames_from_dir(video_id, frame_dir)
    times = [ts for ts, _ in frames]

    plt.figure(figsize=(10, 5))
    for i, caption in enumerate(captions):
        sims = []
        for ts, img in frames:
            sim = compute_clip_similarity(img, caption, clip_model, processor, device)
            sims.append(sim)
        plt.plot(times, sims, label=f"Caption {i}")
    plt.xlabel("Time (s)")
    plt.ylabel("CLIP Similarity")
    plt.title(f"CLIP Similarity Over Time - {video_id}")
    plt.legend()
    plt.grid(True)
    plt.show()

# Example usage
video_id = "v_94wjthSzsSQ"  # replace with your video ID
captions = final_dataset_captions[video_id]['segments']['sentences']
plot_similarity_curves(video_id, captions, clip_model, processor)

In [None]:
def generate_gradcam_for_caption(image, caption, model, processor, device="cuda"):
    inputs = processor(text=[caption], images=image, return_tensors="pt").to(device)

    # Encode features
    with torch.no_grad():
        img_feat = model.get_image_features(pixel_values=inputs["pixel_values"])
        txt_feat = model.get_text_features(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
    img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True)
    txt_feat = txt_feat / txt_feat.norm(dim=-1, keepdim=True)
    score = (img_feat @ txt_feat.T).squeeze(0)

    cam = GradCAM(model=model.vision_model, target_layers=[model.vision_model.encoder.layers[-1]], use_cuda=(device=="cuda"))
    grayscale_cam = cam(input_tensor=inputs["pixel_values"], targets=[ClassifierOutputTarget(score.item())])[0]

    img_np = transforms.ToTensor()(image).permute(1, 2, 0).numpy()
    img_np = img_np / img_np.max()
    cam_image = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
    return cam_image

def plot_attention_grid(video_id, captions, clip_model, processor, max_frames=10, frame_dir="./keyframes"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    frames = load_frames_from_dir(video_id, frame_dir)[:max_frames]
    
    fig, axes = plt.subplots(len(captions), len(frames), figsize=(2 * len(frames), 2.5 * len(captions)))
    for i, caption in enumerate(captions):
        for j, (ts, img) in enumerate(frames):
            heatmap_img = generate_gradcam_for_caption(img, caption, clip_model, processor, device)
            axes[i, j].imshow(heatmap_img)
            axes[i, j].axis('off')
            if i == 0:
                axes[i, j].set_title(f"{ts}s", fontsize=8)
            if j == 0:
                axes[i, j].set_ylabel(f"Caption {i}", fontsize=8)
    plt.suptitle(f"Grad-CAM Attention Maps for {video_id}", fontsize=14)
    plt.tight_layout()
    plt.show()

# Example usage
plot_attention_grid(video_id, captions, clip_model, processor)

### Language-Vision contrastive moments

In [None]:
def compute_contrastive_similarity_matrix(video_id, captions, clip_model, processor, frame_dir):
    keyframes = load_frames_from_dir(video_id, frame_dir)
    if len(keyframes) < len(captions):
        print("Not enough keyframes for contrastive matrix!")
        return

    M = len(captions)
    sim_matrix = np.zeros((M, M))
    for i in range(M):  # captions
        for j in range(M):  # keyframes
            sim_matrix[i, j] = compute_clip_similarity(keyframes[j][1], captions[i], clip_model, processor)

    plt.figure(figsize=(6, 5))
    im = plt.imshow(sim_matrix, cmap='viridis', interpolation='nearest')
    plt.title(f"CLIP Contrastive Similarity - {video_id}")
    plt.xlabel("Keyframe Index")
    plt.ylabel("Caption Index")
    plt.colorbar(im, label="Similarity")

    # Annotate each cell
    for i in range(M):
        for j in range(M):
            plt.text(j, i, f"{sim_matrix[i, j]:.2f}", ha='center', va='center', color='w', fontsize=8)

    plt.xticks(range(M))
    plt.yticks(range(M))
    plt.tight_layout()
    plt.show()

# Example usage
compute_contrastive_similarity_matrix(video_id, captions, clip_model, processor, "./keyframes/v_94wjthSzsSQ")