In [None]:
# Multimodal Video and Audio Analysis Pipeline
This Jupyter Notebook demonstrates a complete pipeline for analyzing video frames and corresponding audio segments using Gemma-3n

# Setup and Model Loading
This cell imports all necessary libraries and loads the pre-trained Gemma model and its processor from Hugging Face. It also handles authentication using a Hugging Face token, which is assumed to be stored in Colab secrets.

In [None]:
# Import necessary libraries
import os
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
from google.colab import userdata  # For accessing Colab secrets

# Define the model path
MODEL_PATH = "google/gemma-3n-E2B-it"

# Load the Hugging Face token from Colab secrets
HF_TOKEN = userdata.get('HF_TOKEN')

# Set the Hugging Face token as an environment variable
os.environ['HF_TOKEN'] = HF_TOKEN

# Load the processor and model
processor = AutoProcessor.from_pretrained(MODEL_PATH, token=HF_TOKEN)
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    token=HF_TOKEN,
).eval().to("cuda")  # Load the model to GPU for inference

# Define System Configuration
Define key constants for the analysis pipeline, such as the target frames per second (FPS) for video sampling, the maximum number of frames to process, and the minimum duration for audio segments.

In [None]:
# Define system configuration constants
TARGET_FPS = 3  # Target frames per second for video sampling
MAX_FRAMES = 30  # Maximum number of frames to process
MIN_AUDIO_DURATION_S = 1.0  # Minimum duration for audio segments in seconds

# Helper Function: Extract Video Frames
This cell contains the `extract_frames_from_video` function. It uses the `av` library to open a video file, decode it, and save individual frames at the specified `TARGET_FPS` rate up to `MAX_FRAMES`.

In [None]:
# Helper Function: Extract Video Frames
import tempfile
import pathlib
import av

def extract_frames_from_video(video_path, target_fps=TARGET_FPS, max_frames=MAX_FRAMES):
    """
    Extract frames from a video file at the specified FPS rate up to a maximum number of frames.
    
    Args:
        video_path (str): Path to the video file.
        target_fps (int): Target frames per second for sampling.
        max_frames (int): Maximum number of frames to extract.
    
    Returns:
        tuple: A list of tuples containing frame file paths and their corresponding timestamps, and the video duration.
    """
    # Create a temporary directory to store extracted frames
    temp_dir = tempfile.mkdtemp(prefix="frames_")
    container = av.open(video_path)  # Open the video file using the av library
    video_stream = container.streams.video[0]  # Access the video stream
    
    # Calculate time-related parameters
    time_base = video_stream.time_base
    duration = float(video_stream.duration * time_base)
    interval = 1.0 / target_fps
    
    # Determine the total number of frames to extract
    total_frames = int(duration * target_fps)
    if max_frames is not None:
        total_frames = min(total_frames, max_frames)
        
    target_times = [i * interval for i in range(total_frames)]
    target_index = 0
    frame_paths = []
    
    # Decode video frames and save them at the specified intervals
    for frame in container.decode(video=0):
        if frame.pts is None:
            continue
            
        timestamp = float(frame.pts * time_base)
        
        if target_index < len(target_times) and abs(timestamp - target_times[target_index]) < (interval / 2):
            frame_path = pathlib.Path(temp_dir) / f"frame_{target_index:04d}.jpg"
            frame_img = frame.to_image()
            frame_img.save(frame_path)  # Save the frame as an image file
            frame_paths.append((str(frame_path), target_index * interval))
            target_index += 1
            
            if target_index >= max_frames:
                break
                
    container.close()  # Close the video container
    return frame_paths, duration

# Helper Function: Extract Audio Segments
This cell defines the `extract_audio_segment` function. It uses the `pydub` library to load an audio file and extract a specific segment corresponding to a video frame's timestamp.

In [None]:
# Helper Function: Extract Audio Segments
from pydub import AudioSegment

def extract_audio_segment(audio_path, start_time, duration):
    """
    Extract a segment of audio from the given file.

    Args:
        audio_path (str): Path to the audio file.
        start_time (float): Start time of the segment in seconds.
        duration (float): Duration of the segment in seconds.

    Returns:
        str: Path to the temporary file containing the extracted audio segment.
    """
    # Load the audio file using pydub
    audio = AudioSegment.from_file(audio_path)
    
    # Extract the specified segment
    segment = audio[start_time * 1000:(start_time + duration) * 1000]
    
    # Save the segment to a temporary file
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
    segment.export(temp_file.name, format="wav")
    
    return temp_file.name

# Core Processing: Analyze Image and Audio
The `process_inputs` function takes a single image and an audio file path. It prepares the inputs in the format required by the multimodal model, generates a textual description, and returns the result.

In [None]:
# Core Processing: Analyze Image and Audio
def process_inputs(image, audio):
    """
    Process a single image and audio segment using the Gemma 3n model.

    Args:
        image (PIL.Image.Image): The image to be analyzed.
        audio (str): Path to the audio file to be analyzed.

    Returns:
        str: The textual description generated by the model.
    """
    # Prepare the input message in the required format
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "audio", "audio": audio},
            ]
        }
    ]

    # Apply the processor to prepare inputs for the model
    input_ids = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    )
    input_len = input_ids["input_ids"].shape[-1]

    # Move inputs to the model's device and set the appropriate data type
    input_ids = input_ids.to(model.device, dtype=model.dtype)

    # Perform inference using the model
    with torch.inference_mode():
        outputs = model.generate(
            **input_ids,
            max_new_tokens=256,
            disable_compile=True
        )

    # Decode the generated output into text
    text = processor.batch_decode(
        outputs[:, input_len:],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )
    return text[0]

# Main Pipeline: Process the Entire Video
The `process_video` function orchestrates the entire workflow. It calls the frame and audio extraction functions in a loop, passes each pair to the `process_inputs` function for analysis, and aggregates the results into a single report.

In [None]:
from PIL import Image
import os

def process_video(video_path, audio_path):
    """
    Process an entire video by extracting frames and corresponding audio segments,
    analyzing them using the Gemma 3n model, and aggregating the results.

    Args:
        video_path (str): Path to the video file.
        audio_path (str): Path to the audio file.

    Returns:
        str: Aggregated analysis results for the video.
    """
    try:
        # Extract frames from the video
        frame_paths, duration = extract_frames_from_video(video_path)
        
        if not frame_paths:
            return "Error: No frames extracted from video."
        
        all_results = []
        
        # Process each frame with its corresponding audio segment
        for frame_path, timestamp in frame_paths:
            # Calculate segment duration, ensuring it's not too short
            inter_frame_duration = duration / len(frame_paths)
            segment_duration = max(inter_frame_duration, MIN_AUDIO_DURATION_S)
            
            # Extract audio for this segment
            audio_segment = extract_audio_segment(audio_path, timestamp, segment_duration)
            
            # Load the frame as a PIL Image
            image = Image.open(frame_path)
            
            # Process the frame and audio segment
            result = process_inputs(image, audio_segment)
            
            # Format the result with a timestamp
            time_str = f"[{timestamp:.1f}s - {timestamp + segment_duration:.1f}s]"
            all_results.append(f"{time_str}: {result}")
            
            # Clean up the temporary audio segment file
            os.unlink(audio_segment)
        
        # Clean up the temporary frame files
        for path, _ in frame_paths:
            if os.path.exists(path):
                os.unlink(path)
        
        # Return the aggregated results
        return "\n\n".join(all_results)
    
    except Exception as e:
        return f"Error processing video: {str(e)}"

# Launch the Interactive Web Interface
This cell uses `gradio` to create a simple web UI. The interface allows a user to upload a video and an audio file, which are then processed by the `process_video` function. The final analysis is displayed in a textbox.

In [None]:
# Launch the Gradio interface for interactive video and audio analysis
import gradio as gr

# Define the Gradio interface
iface = gr.Interface(
    fn=process_video,  # Function to process the video and audio
    inputs=[
        gr.Video(label="Upload Video"),  # Input for video file
        gr.Audio(label="Upload Audio", type="filepath")  # Input for audio file
    ],
    outputs=gr.Textbox(label="Analysis Results"),  # Output textbox for results
    title="Video Stream Analysis with Audio",  # Title of the interface
    description="Upload a video file and its audio. The system processes the video frames and analyzes them with the Gemma 3n model."  # Description of the interface
)

# Launch the interface
iface.launch()