In [None]:
# -*- coding: utf-8 -*-
"""gazelle-gaze-video.ipynb

Video gaze analysis using Gazelle model with memory management and data export.
"""

# Install required packages
# !pip install -q torch torchvision timm scikit-learn matplotlib pandas opencv-python tqdm pillow numpy mediapipe psutil
# !pip install -q retina-face

import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from retinaface import RetinaFace
from tqdm.notebook import tqdm
import os
import gc
import psutil
import json
import pandas as pd
from datetime import datetime
import math

In [None]:
import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from retinaface import RetinaFace
from tqdm import tqdm
import os
import gc
import psutil
import json
import pandas as pd
from datetime import datetime

class VideoGazeAnalyzer:
    def __init__(self, use_cuda=True):
        self.device = 'cuda' if use_cuda and torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")

        # Load Gazelle model
        self.model, self.transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout')
        self.model.eval()
        self.model.to(self.device)

        # Colors for visualization
        self.colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']

    def clear_memory(self):
        """
        Clear GPU/MPS memory and force garbage collection.
        Call this between processing different videos to free up memory.
        """
        print("Clearing memory...")
        
        # Clear PyTorch CUDA cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            
            # Print GPU memory usage
            allocated = torch.cuda.memory_allocated() / 1024**3  # GB
            cached = torch.cuda.memory_reserved() / 1024**3  # GB
            print(f"GPU Memory - Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB")
        
        # Clear MPS cache if using Apple Silicon
        if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            torch.mps.empty_cache()
            print("MPS cache cleared")
        
        # Force garbage collection
        gc.collect()
        
        # Print system memory usage
        memory = psutil.virtual_memory()
        print(f"System Memory - Used: {memory.used / 1024**3:.2f} GB / {memory.total / 1024**3:.2f} GB ({memory.percent:.1f}%)")
        print("Memory cleanup complete!\n")
    
    def cleanup_completely(self):
        """
        Complete cleanup: move model to CPU and clear all memory.
        Use this when you're done with the analyzer or want to free maximum memory.
        """
        print("Performing complete cleanup...")
        
        # Move model to CPU to free GPU memory
        if hasattr(self, 'model'):
            self.model.cpu()
            print("Model moved to CPU")
        
        # Clear memory
        self.clear_memory()
        
        print("Complete cleanup finished!")
    
    def reinitialize_model(self, use_cuda=True):
        """
        Reinitialize the model (useful after complete cleanup).
        """
        print("Reinitializing model...")
        
        self.device = 'cuda' if use_cuda and torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Reload model if needed
        if not hasattr(self, 'model') or self.model is None:
            self.model, self.transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout')
        
        self.model.eval()
        self.model.to(self.device)
        
        print("Model reinitialized!")

    def process_frame_with_data(self, frame, frame_number, timestamp):
        """Process a single frame and return both visualization and data"""
        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(frame_rgb)
        width, height = image.size

        # Initialize frame data
        frame_data = {
            "frame_number": frame_number,
            "timestamp": timestamp,
            "video_dimensions": {"width": width, "height": height},
            "faces": []
        }

        # Detect faces
        resp = RetinaFace.detect_faces(frame_rgb)
        if not isinstance(resp, dict):
            return frame, frame_data  # Return original frame if no faces detected

        # Extract bounding boxes
        bboxes = [resp[key]['facial_area'] for key in resp.keys()]
        norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height])
                       for bbox in bboxes]]

        # Prepare input for Gazelle
        img_tensor = self.transform(image).unsqueeze(0).to(self.device)
        input_data = {
            "images": img_tensor,
            "bboxes": norm_bboxes
        }

        # Get model predictions
        with torch.no_grad():
            output = self.model(input_data)

        # Extract data for each face
        for i in range(len(bboxes)):
            bbox = bboxes[i]
            xmin, ymin, xmax, ymax = bbox
            
            face_data = {
                "face_id": i,
                "bbox_raw": {"xmin": int(xmin), "ymin": int(ymin), "xmax": int(xmax), "ymax": int(ymax)},
                "bbox_normalized": {"xmin": float(xmin/width), "ymin": float(ymin/height), 
                                  "xmax": float(xmax/width), "ymax": float(ymax/height)},
                "face_center": {
                    "x": float((xmin + xmax) / 2),
                    "y": float((ymin + ymax) / 2)
                },
                "face_size": {
                    "width": float(xmax - xmin),
                    "height": float(ymax - ymin),
                    "area": float((xmax - xmin) * (ymax - ymin))
                }
            }

            if output['inout'] is not None and i < len(output['inout'][0]):
                inout_score = float(output['inout'][0][i])
                face_data["inout_score"] = inout_score
                face_data["looking_in_frame"] = inout_score > 0.5

                if inout_score > 0.5:
                    heatmap = output['heatmap'][0][i]
                    heatmap_np = heatmap.detach().cpu().numpy()
                    max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)

                    # Calculate gaze target
                    gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
                    gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
                    
                    face_data["gaze_target"] = {
                        "x": float(gaze_target_x),
                        "y": float(gaze_target_y)
                    }
                    
                    # Calculate gaze vector and distance
                    gaze_vector_x = gaze_target_x - face_data["face_center"]["x"]
                    gaze_vector_y = gaze_target_y - face_data["face_center"]["y"]
                    gaze_distance = np.sqrt(gaze_vector_x**2 + gaze_vector_y**2)
                    gaze_angle = np.arctan2(gaze_vector_y, gaze_vector_x) * 180 / np.pi
                    
                    face_data["gaze_vector"] = {
                        "x": float(gaze_vector_x),
                        "y": float(gaze_vector_y),
                        "magnitude": float(gaze_distance),
                        "angle_degrees": float(gaze_angle)
                    }
                    
                    # Add heatmap statistics
                    face_data["heatmap_stats"] = {
                        "max_value": float(np.max(heatmap_np)),
                        "mean_value": float(np.mean(heatmap_np)),
                        "std_value": float(np.std(heatmap_np)),
                        "max_position": {"row": int(max_index[0]), "col": int(max_index[1])}
                    }
                else:
                    face_data["gaze_target"] = None
                    face_data["gaze_vector"] = None
                    face_data["heatmap_stats"] = None
            else:
                face_data["inout_score"] = None
                face_data["looking_in_frame"] = False
                face_data["gaze_target"] = None
                face_data["gaze_vector"] = None
                face_data["heatmap_stats"] = None

            frame_data["faces"].append(face_data)

        # Visualize results
        result_image = self.visualize_all(
            image,
            output['heatmap'][0],
            norm_bboxes[0],
            output['inout'][0] if output['inout'] is not None else None
        )

        # Convert back to BGR for OpenCV
        result_array = np.array(result_image)
        return cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR), frame_data

    def process_frame(self, frame):
        """Process a single frame and return the visualization (backward compatibility)"""
        processed_frame, _ = self.process_frame_with_data(frame, 0, 0.0)
        return processed_frame

    def visualize_all(self, pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
        """Visualize all detected faces and their gaze directions"""
        overlay_image = pil_image.convert("RGBA")
        draw = ImageDraw.Draw(overlay_image)
        width, height = pil_image.size

        for i in range(len(bboxes)):
            bbox = bboxes[i]
            xmin, ymin, xmax, ymax = bbox
            color = self.colors[i % len(self.colors)]

            # Draw face bounding box
            draw.rectangle(
                [xmin * width, ymin * height, xmax * width, ymax * height],
                outline=color,
                width=int(min(width, height) * 0.01)
            )

            if inout_scores is not None:
                inout_score = inout_scores[i]

                # Draw in-frame score
                text = f"in-frame: {inout_score:.2f}"
                text_y = ymax * height + int(height * 0.01)
                draw.text(
                    (xmin * width, text_y),
                    text,
                    fill=color,
                    font=None  # Using default font
                )

                # Draw gaze direction if looking in-frame
                if inout_score > inout_thresh:
                    heatmap = heatmaps[i]
                    heatmap_np = heatmap.detach().cpu().numpy()
                    max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)

                    # Calculate gaze target and face center
                    gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
                    gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
                    bbox_center_x = ((xmin + xmax) / 2) * width
                    bbox_center_y = ((ymin + ymax) / 2) * height

                    # Draw gaze target point and line
                    draw.ellipse(
                        [(gaze_target_x-5, gaze_target_y-5),
                         (gaze_target_x+5, gaze_target_y+5)],
                        fill=color,
                        width=int(0.005*min(width, height))
                    )
                    draw.line(
                        [(bbox_center_x, bbox_center_y),
                         (gaze_target_x, gaze_target_y)],
                        fill=color,
                        width=int(0.005*min(width, height))
                    )

        # Convert to RGB for OpenCV compatibility
        return overlay_image.convert('RGB')

    def save_frame_data(self, output_folder, video_metadata, all_frame_data):
        """Save per-frame data to JSON and CSV files"""
        os.makedirs(output_folder, exist_ok=True)
        
        # Save complete data as JSON
        complete_data = {
            "video_metadata": video_metadata,
            "frames": all_frame_data
        }
        
        json_path = os.path.join(output_folder, "gaze_data.json")
        with open(json_path, 'w') as f:
            json.dump(complete_data, f, indent=2)
        
        # Create CSV with flattened data for easy analysis
        csv_data = []
        for frame_data in all_frame_data:
            frame_number = frame_data["frame_number"]
            timestamp = frame_data["timestamp"]
            
            if not frame_data["faces"]:
                # No faces detected
                csv_data.append({
                    "frame_number": frame_number,
                    "timestamp": timestamp,
                    "face_id": None,
                    "face_detected": False,
                    "bbox_xmin": None, "bbox_ymin": None, "bbox_xmax": None, "bbox_ymax": None,
                    "face_center_x": None, "face_center_y": None,
                    "face_width": None, "face_height": None, "face_area": None,
                    "inout_score": None, "looking_in_frame": False,
                    "gaze_target_x": None, "gaze_target_y": None,
                    "gaze_vector_x": None, "gaze_vector_y": None,
                    "gaze_distance": None, "gaze_angle": None,
                    "heatmap_max": None, "heatmap_mean": None, "heatmap_std": None
                })
            else:
                for face in frame_data["faces"]:
                    row = {
                        "frame_number": frame_number,
                        "timestamp": timestamp,
                        "face_id": face["face_id"],
                        "face_detected": True,
                        "bbox_xmin": face["bbox_raw"]["xmin"],
                        "bbox_ymin": face["bbox_raw"]["ymin"],
                        "bbox_xmax": face["bbox_raw"]["xmax"],
                        "bbox_ymax": face["bbox_raw"]["ymax"],
                        "face_center_x": face["face_center"]["x"],
                        "face_center_y": face["face_center"]["y"],
                        "face_width": face["face_size"]["width"],
                        "face_height": face["face_size"]["height"],
                        "face_area": face["face_size"]["area"],
                        "inout_score": face["inout_score"],
                        "looking_in_frame": face["looking_in_frame"]
                    }
                    
                    if face["gaze_target"]:
                        row.update({
                            "gaze_target_x": face["gaze_target"]["x"],
                            "gaze_target_y": face["gaze_target"]["y"],
                            "gaze_vector_x": face["gaze_vector"]["x"],
                            "gaze_vector_y": face["gaze_vector"]["y"],
                            "gaze_distance": face["gaze_vector"]["magnitude"],
                            "gaze_angle": face["gaze_vector"]["angle_degrees"],
                            "heatmap_max": face["heatmap_stats"]["max_value"],
                            "heatmap_mean": face["heatmap_stats"]["mean_value"],
                            "heatmap_std": face["heatmap_stats"]["std_value"]
                        })
                    else:
                        row.update({
                            "gaze_target_x": None, "gaze_target_y": None,
                            "gaze_vector_x": None, "gaze_vector_y": None,
                            "gaze_distance": None, "gaze_angle": None,
                            "heatmap_max": None, "heatmap_mean": None, "heatmap_std": None
                        })
                    
                    csv_data.append(row)
        
        # Save CSV
        df = pd.DataFrame(csv_data)
        csv_path = os.path.join(output_folder, "gaze_data.csv")
        df.to_csv(csv_path, index=False)
        
        # Save summary statistics
        summary = self.generate_summary_stats(df, video_metadata)
        summary_path = os.path.join(output_folder, "summary_stats.json")
        with open(summary_path, 'w') as f:
            json.dump(summary, f, indent=2)
        
        print(f"Data saved to {output_folder}:")
        print(f"  - Complete data: gaze_data.json")
        print(f"  - Tabular data: gaze_data.csv ({len(df)} rows)")
        print(f"  - Summary: summary_stats.json")

    def generate_summary_stats(self, df, video_metadata):
        """Generate summary statistics for the video"""
        summary = {
            "video_info": video_metadata,
            "processing_timestamp": datetime.now().isoformat(),
            "total_frames": len(df),
            "frames_with_faces": len(df[df["face_detected"] == True]),
            "frames_without_faces": len(df[df["face_detected"] == False]),
            "total_faces_detected": df["face_detected"].sum(),
        }
        
        if summary["frames_with_faces"] > 0:
            face_df = df[df["face_detected"] == True]
            summary.update({
                "face_detection_rate": float(summary["frames_with_faces"] / summary["total_frames"]),
                "avg_faces_per_frame": float(face_df.groupby("frame_number").size().mean()),
                "frames_looking_in": len(face_df[face_df["looking_in_frame"] == True]),
                "in_frame_gaze_rate": float(len(face_df[face_df["looking_in_frame"] == True]) / len(face_df)),
                "avg_inout_score": float(face_df["inout_score"].mean()) if face_df["inout_score"].notna().any() else None,
                "face_size_stats": {
                    "avg_width": float(face_df["face_width"].mean()),
                    "avg_height": float(face_df["face_height"].mean()),
                    "avg_area": float(face_df["face_area"].mean())
                }
            })
            
            # Gaze statistics for frames where subjects are looking in-frame
            gaze_df = face_df[face_df["looking_in_frame"] == True]
            if len(gaze_df) > 0:
                summary["gaze_stats"] = {
                    "avg_gaze_distance": float(gaze_df["gaze_distance"].mean()),
                    "gaze_angle_stats": {
                        "mean": float(gaze_df["gaze_angle"].mean()),
                        "std": float(gaze_df["gaze_angle"].std())
                    },
                    "heatmap_stats": {
                        "avg_max_value": float(gaze_df["heatmap_max"].mean()),
                        "avg_mean_value": float(gaze_df["heatmap_mean"].mean())
                    }
                }
        
        return summary

    def process_video_with_data_export(self, input_path, output_video_path=None, 
                                     data_output_folder=None, start_time=0, duration=None):
        """Process a video file, save visualization video and export per-frame data"""
        
        # Setup output paths
        video_name = os.path.splitext(os.path.basename(input_path))[0]
        if output_video_path is None:
            output_video_path = f"{video_name}_gaze_analyzed.mp4"
        if data_output_folder is None:
            data_output_folder = f"{video_name}_gaze_data"
            
        print(f"Processing video: {input_path}")
        print(f"Output video: {output_video_path}")
        print(f"Output data folder: {data_output_folder}")
        
        # Open video file
        cap = cv2.VideoCapture(input_path)

        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Calculate start and end frames
        start_frame = int(start_time * fps)
        if duration:
            end_frame = start_frame + int(duration * fps)
        else:
            end_frame = total_frames

        # Video metadata
        video_metadata = {
            "input_file": input_path,
            "output_video": output_video_path,
            "output_data_folder": data_output_folder,
            "fps": fps,
            "frame_width": frame_width,
            "frame_height": frame_height,
            "total_frames": total_frames,
            "start_frame": start_frame,
            "end_frame": end_frame,
            "start_time": start_time,
            "duration": duration,
            "processing_device": self.device
        }

        # Set up video writer if output video is requested
        out = None
        if output_video_path:
            fourcc = cv2.VideoWriter_fourcc(*'mp4v')
            out = cv2.VideoWriter(
                output_video_path,
                fourcc,
                fps,
                (frame_width, frame_height)
            )

        # Seek to start frame
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

        # Storage for all frame data
        all_frame_data = []

        try:
            with tqdm(total=end_frame-start_frame, desc="Processing frames") as pbar:
                frame_count = start_frame
                while cap.isOpened() and frame_count < end_frame:
                    ret, frame = cap.read()
                    if not ret:
                        break

                    # Calculate timestamp
                    timestamp = frame_count / fps

                    # Process frame with data extraction
                    processed_frame, frame_data = self.process_frame_with_data(
                        frame, frame_count, timestamp
                    )
                    
                    # Save frame data
                    all_frame_data.append(frame_data)

                    # Write processed frame to video if requested
                    if out is not None:
                        out.write(processed_frame)

                    frame_count += 1
                    pbar.update(1)

        finally:
            # Clean up video resources
            cap.release()
            if out is not None:
                out.release()
            cv2.destroyAllWindows()
            
            # Save all data
            self.save_frame_data(data_output_folder, video_metadata, all_frame_data)
            
            # Clear memory after video processing
            self.clear_memory()
            
        print(f"\nProcessing complete!")
        print(f"Video output: {output_video_path}")
        print(f"Data output: {data_output_folder}")
        return data_output_folder, output_video_path

    def process_video(self, input_path, output_path, start_time=0, duration=None):
        """Process a video file and save the result (backward compatibility)"""
        # Open video file
        cap = cv2.VideoCapture(input_path)

        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Calculate start and end frames
        start_frame = int(start_time * fps)
        if duration:
            end_frame = start_frame + int(duration * fps)
        else:
            end_frame = total_frames

        # Set up video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(
            output_path,
            fourcc,
            fps,
            (frame_width, frame_height)
        )

        # Seek to start frame
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

        try:
            with tqdm(total=end_frame-start_frame) as pbar:
                frame_count = start_frame
                while cap.isOpened() and frame_count < end_frame:
                    ret, frame = cap.read()
                    if not ret:
                        break

                    # Process frame
                    processed_frame = self.process_frame(frame)
                    out.write(processed_frame)

                    frame_count += 1
                    pbar.update(1)

        finally:
            # Clean up
            cap.release()
            out.release()
            cv2.destroyAllWindows()
            
            # Clear memory after video processing
            self.clear_memory()

# Standalone Memory Management Functions
def clear_gpu_memory():
    """
    Clear GPU memory and force garbage collection.
    Call this between processing different videos.
    """
    print("Clearing GPU memory...")
    
    # Clear PyTorch cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
        # Print GPU memory usage
        allocated = torch.cuda.memory_allocated() / 1024**3  # GB
        cached = torch.cuda.memory_reserved() / 1024**3  # GB
        print(f"GPU Memory - Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB")
    
    # Clear MPS cache if using Apple Silicon
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        torch.mps.empty_cache()
        print("MPS cache cleared")
    
    # Force garbage collection
    gc.collect()
    
    # Print system memory usage
    memory = psutil.virtual_memory()
    print(f"System Memory - Used: {memory.used / 1024**3:.2f} GB / {memory.total / 1024**3:.2f} GB ({memory.percent:.1f}%)")
    print("Memory cleanup complete!\n")

def cleanup_model_and_memory(model=None, analyzer=None):
    """
    Complete cleanup: remove model from memory and clear GPU cache.
    Use this when switching to a completely different model or finishing work.
    
    Args:
        model: PyTorch model to move to CPU and delete
        analyzer: VideoGazeAnalyzer instance to cleanup
    """
    print("Performing complete cleanup...")
    
    if analyzer is not None:
        # Use the analyzer's cleanup method
        analyzer.cleanup_completely()
    elif model is not None:
        # Move model to CPU and delete
        model.cpu()
        del model
        clear_gpu_memory()
    else:
        # Just clear memory
        clear_gpu_memory()
    
    print("Complete cleanup finished!")

In [None]:
import gc
import psutil

# Standalone Memory Management Functions
# Use these if you're not using the VideoGazeAnalyzer class

def clear_gpu_memory():
    """
    Clear GPU memory and force garbage collection.
    Call this between processing different videos.
    """
    print("Clearing GPU memory...")
    
    # Clear PyTorch cache
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        
        # Print GPU memory usage
        allocated = torch.cuda.memory_allocated() / 1024**3  # GB
        cached = torch.cuda.memory_reserved() / 1024**3  # GB
        print(f"GPU Memory - Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB")
    
    # Clear MPS cache if using Apple Silicon
    if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        torch.mps.empty_cache()
        print("MPS cache cleared")
    
    # Force garbage collection
    gc.collect()
    
    # Print system memory usage
    memory = psutil.virtual_memory()
    print(f"System Memory - Used: {memory.used / 1024**3:.2f} GB / {memory.total / 1024**3:.2f} GB ({memory.percent:.1f}%)")
    print("Memory cleanup complete!\n")

def cleanup_model_and_memory(model=None, analyzer=None):
    """
    Complete cleanup: remove model from memory and clear GPU cache.
    Use this when switching to a completely different model or finishing work.
    
    Args:
        model: PyTorch model to move to CPU and delete
        analyzer: VideoGazeAnalyzer instance to cleanup
    """
    print("Performing complete cleanup...")
    
    if analyzer is not None:
        # Use the analyzer's cleanup method
        analyzer.cleanup_completely()
    elif model is not None:
        # Move model to CPU and delete
        model.cpu()
        del model
        clear_gpu_memory()
    else:
        # Just clear memory
        clear_gpu_memory()
    
    print("Complete cleanup finished!")

class VideoGazeAnalyzer:
    def __init__(self, use_cuda=True):
        self.device = 'cuda' if use_cuda and torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")

        # Load Gazelle model
        self.model, self.transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout')
        self.model.eval()
        self.model.to(self.device)

        # Colors for visualization
        self.colors = ['lime', 'tomato', 'cyan', 'fuchsia', 'yellow']

    def process_frame(self, frame):
        """Process a single frame and return the visualization"""
        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(frame_rgb)
        width, height = image.size

        # Detect faces
        resp = RetinaFace.detect_faces(frame_rgb)
        if not isinstance(resp, dict):
            return frame  # Return original frame if no faces detected

        # Extract bounding boxes
        bboxes = [resp[key]['facial_area'] for key in resp.keys()]
        norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height])
                       for bbox in bboxes]]

        # Prepare input for Gazelle
        img_tensor = self.transform(image).unsqueeze(0).to(self.device)
        input_data = {
            "images": img_tensor,
            "bboxes": norm_bboxes
        }

        # Get model predictions
        with torch.no_grad():
            output = self.model(input_data)

        # Visualize results
        result_image = self.visualize_all(
            image,
            output['heatmap'][0],
            norm_bboxes[0],
            output['inout'][0] if output['inout'] is not None else None
        )

        # Convert back to BGR for OpenCV
        result_array = np.array(result_image)
        return cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
    
    def process_frame_with_data(self, frame, frame_number=None):
        """
        Process a single frame and return the visualization along with gaze data
        
        Args:
            frame: OpenCV BGR image
            frame_number: Optional frame number for data tracking
            
        Returns:
            processed_frame: Processed OpenCV BGR image
            frame_data: Dictionary with gaze metrics for the frame
        """
        # Convert BGR to RGB
        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(frame_rgb)
        width, height = image.size
        
        # Initialize frame data
        frame_data = {
            "frame_number": frame_number,
            "timestamp": datetime.now().isoformat(),
            "faces": []
        }
        
        # Detect faces
        resp = RetinaFace.detect_faces(frame_rgb)
        if not isinstance(resp, dict):
            return frame, frame_data  # Return original frame if no faces detected
        
        # Debug: Print the first face landmarks structure if this is the first frame
        if frame_number == 0 or frame_number is None:
            if len(resp) > 0:
                first_face_key = list(resp.keys())[0]
                print(f"Landmark structure for first face: {resp[first_face_key]['landmarks'].keys()}")
                
        # Extract bounding boxes
        bboxes = [resp[key]['facial_area'] for key in resp.keys()]
        landmarks = [resp[key]['landmarks'] for key in resp.keys()]
        norm_bboxes = [[np.array(bbox) / np.array([width, height, width, height])
                       for bbox in bboxes]]
        
        # Prepare input for Gazelle
        img_tensor = self.transform(image).unsqueeze(0).to(self.device)
        input_data = {
            "images": img_tensor,
            "bboxes": norm_bboxes
        }
        
        # Get model predictions
        with torch.no_grad():
            output = self.model(input_data)
            
        # Collect data for each face
        for i in range(len(bboxes)):
            face_data = {}
            
            # Add bounding box info
            bbox = bboxes[i]
            face_data["bbox"] = {
                "xmin": int(bbox[0]),
                "ymin": int(bbox[1]),
                "xmax": int(bbox[2]),
                "ymax": int(bbox[3]),
                "width": int(bbox[2] - bbox[0]),
                "height": int(bbox[3] - bbox[1])
            }
            
            # Add face center
            face_data["face_center"] = {
                "x": int((bbox[0] + bbox[2]) / 2),
                "y": int((bbox[1] + bbox[3]) / 2)
            }
            
            # Add landmarks with fallback for different key naming conventions
            landmark = landmarks[i]
            try:
                face_data["landmarks"] = {
                    "left_eye": [int(landmark["left_eye"][0]), int(landmark["left_eye"][1])],
                    "right_eye": [int(landmark["right_eye"][0]), int(landmark["right_eye"][1])],
                    "nose": [int(landmark["nose"][0]), int(landmark["nose"][1])],
                    "mouth_left": [int(landmark["mouth_left"][0]), int(landmark["mouth_left"][1])],
                    "mouth_right": [int(landmark["mouth_right"][0]), int(landmark["mouth_right"][1])]
                }
            except KeyError as e:
                print(f"Warning: Landmark key error - {e}. Attempting fallback mapping...")
                # Create a mapping dictionary for possible key variations
                key_mapping = {
                    "left_eye": ["left_eye", "eye_left"],
                    "right_eye": ["right_eye", "eye_right"],
                    "nose": ["nose"],
                    "mouth_left": ["mouth_left", "left_mouth"],
                    "mouth_right": ["mouth_right", "right_mouth"]
                }
                
                # Try to map the keys
                landmarks_dict = {}
                for target_key, possible_keys in key_mapping.items():
                    for possible_key in possible_keys:
                        if possible_key in landmark:
                            landmarks_dict[target_key] = [int(landmark[possible_key][0]), int(landmark[possible_key][1])]
                            break
                    
                    # If we couldn't find a mapping for this key, set a placeholder
                    if target_key not in landmarks_dict:
                        print(f"Warning: Could not find a mapping for {target_key}")
                        # Use face center as fallback
                        landmarks_dict[target_key] = [
                            face_data["face_center"]["x"],
                            face_data["face_center"]["y"]
                        ]
                
                face_data["landmarks"] = landmarks_dict
            
            # Add in-out scores
            if output['inout'] is not None:
                inout_score = float(output['inout'][0][i].item())
                face_data["inout_score"] = inout_score
                
                # Add gaze target if looking in-frame
                if inout_score > 0.5:  # Using 0.5 as threshold
                    heatmap = output['heatmap'][0][i]
                    heatmap_np = heatmap.detach().cpu().numpy()
                    
                    # Get max heatmap location
                    max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)
                    gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
                    gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
                    
                    face_data["gaze_target"] = {
                        "x": int(gaze_target_x),
                        "y": int(gaze_target_y)
                    }
                    
                    # Calculate gaze vector
                    bbox_center_x = face_data["face_center"]["x"]
                    bbox_center_y = face_data["face_center"]["y"]
                    
                    # Vector components
                    vector_x = gaze_target_x - bbox_center_x
                    vector_y = gaze_target_y - bbox_center_y
                    
                    # Vector length (Euclidean distance)
                    distance = math.sqrt(vector_x**2 + vector_y**2)
                    
                    # Normalize vector
                    if distance > 0:
                        norm_vector_x = vector_x / distance
                        norm_vector_y = vector_y / distance
                    else:
                        norm_vector_x = 0
                        norm_vector_y = 0
                    
                    # Angle in degrees
                    angle_rad = math.atan2(vector_y, vector_x)
                    angle_deg = math.degrees(angle_rad)
                    
                    face_data["gaze_vector"] = {
                        "x": float(vector_x),
                        "y": float(vector_y),
                        "normalized_x": float(norm_vector_x),
                        "normalized_y": float(norm_vector_y),
                        "distance": float(distance),
                        "angle_degrees": float(angle_deg)
                    }
                    
                    # Add heatmap statistics
                    face_data["heatmap_stats"] = {
                        "max_value": float(np.max(heatmap_np)),
                        "mean_value": float(np.mean(heatmap_np)),
                        "std_value": float(np.std(heatmap_np))
                    }
            
            # Add face data to frame
            frame_data["faces"].append(face_data)
        
        # Visualize results
        result_image = self.visualize_all(
            image,
            output['heatmap'][0],
            norm_bboxes[0],
            output['inout'][0] if output['inout'] is not None else None
        )
        
        # Convert back to BGR for OpenCV
        result_array = np.array(result_image)
        processed_frame = cv2.cvtColor(result_array, cv2.COLOR_RGB2BGR)
        
        return processed_frame, frame_data

In [None]:
# Example usage with NEW DATA EXPORT functionality
analyzer = VideoGazeAnalyzer()

# Process a video file
input_video = "/content/rehab_gait_10seconds.mp4"  # Replace with your video path
output_video = "output_video.mp4"

# Memory Management Usage Examples

# Example 1: Using VideoGazeAnalyzer with memory management
print("=== Example 1: Video Processing with Memory Management ===")

# Initialize analyzer
analyzer = VideoGazeAnalyzer(use_cuda=True)

# Process first video
print("Processing first video...")
# analyzer.process_video('video1.mp4', 'output1.mp4')

# Clear memory between videos (recommended)
analyzer.clear_memory()

# Process second video
print("Processing second video...")
# analyzer.process_video('video2.mp4', 'output2.mp4')

# Complete cleanup when done (optional - frees maximum memory)
analyzer.cleanup_completely()

print("\n=== Example 2: Manual Memory Management ===")

# If you need to reinitialize after complete cleanup
analyzer.reinitialize_model(use_cuda=True)

# Or use standalone functions
clear_gpu_memory()

# Complete cleanup with standalone function
cleanup_model_and_memory(analyzer=analyzer)

print("\n=== NEW FEATURE: Data Export Example ===")

# NEW: Process video with complete data export
analyzer = VideoGazeAnalyzer(use_cuda=True)

# Example with custom output paths
# data_folder, video_path = analyzer.process_video_with_data_export(
#     input_path="your_video.mp4",
#     output_video_path="analyzed_video.mp4", 
#     data_output_folder="video_gaze_analysis",
#     start_time=0,
#     duration=10
# )

# Example with automatic naming (recommended)
# data_folder, video_path = analyzer.process_video_with_data_export(
#     input_path="your_video.mp4",
#     start_time=0,
#     duration=10
# )

print("\n=== Data Export Features ===")
print("NEW functionality includes:")
print("1. Per-frame gaze coordinates and eye positions")
print("2. Face bounding boxes and centers")
print("3. Gaze targets and vectors (direction, distance, angle)")
print("4. In-frame looking scores and confidence")
print("5. Heatmap statistics (max, mean, std)")
print("6. Complete JSON export with all data")
print("7. CSV export for easy analysis in Excel/Python")
print("8. Summary statistics (detection rates, gaze patterns)")
print("\nOutput structure:")
print("video_name_gaze_data/")
print("├── gaze_data.json      # Complete per-frame data")
print("├── gaze_data.csv       # Tabular data for analysis")
print("└── summary_stats.json  # Video-level statistics")

print("\n=== Memory Management Tips ===")
print("1. Call analyzer.clear_memory() between videos")
print("2. Call analyzer.cleanup_completely() when completely done")
print("3. Use analyzer.reinitialize_model() to restart after complete cleanup")
print("4. Monitor memory usage with the printed statistics")

def visualize_all(self, pil_image, heatmaps, bboxes, inout_scores, inout_thresh=0.5):
        """Visualize all detected faces and their gaze directions"""
        overlay_image = pil_image.convert("RGBA")
        draw = ImageDraw.Draw(overlay_image)
        width, height = pil_image.size

        for i in range(len(bboxes)):
            bbox = bboxes[i]
            xmin, ymin, xmax, ymax = bbox
            color = self.colors[i % len(self.colors)]

            # Draw face bounding box
            draw.rectangle(
                [xmin * width, ymin * height, xmax * width, ymax * height],
                outline=color,
                width=int(min(width, height) * 0.01)
            )

            if inout_scores is not None:
                inout_score = inout_scores[i]

                # Draw in-frame score
                text = f"in-frame: {inout_score:.2f}"
                text_y = ymax * height + int(height * 0.01)
                draw.text(
                    (xmin * width, text_y),
                    text,
                    fill=color,
                    font=None  # Using default font
                )

                # Draw gaze direction if looking in-frame
                if inout_score > inout_thresh:
                    heatmap = heatmaps[i]
                    heatmap_np = heatmap.detach().cpu().numpy()
                    max_index = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)

                    # Calculate gaze target and face center
                    gaze_target_x = max_index[1] / heatmap_np.shape[1] * width
                    gaze_target_y = max_index[0] / heatmap_np.shape[0] * height
                    bbox_center_x = ((xmin + xmax) / 2) * width
                    bbox_center_y = ((ymin + ymax) / 2) * height

                    # Draw gaze target point and line
                    draw.ellipse(
                        [(gaze_target_x-5, gaze_target_y-5),
                         (gaze_target_x+5, gaze_target_y+5)],
                        fill=color,
                        width=int(0.005*min(width, height))
                    )
                    draw.line(
                        [(bbox_center_x, bbox_center_y),
                         (gaze_target_x, gaze_target_y)],
                        fill=color,
                        width=int(0.005*min(width, height))
                    )

        # Convert to RGB for OpenCV compatibility
        return overlay_image.convert('RGB')

    def process_video(self, input_path, output_path, start_time=0, duration=None):
        """Process a video file and save the result"""
        # Open video file
        cap = cv2.VideoCapture(input_path)

        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Calculate start and end frames
        start_frame = int(start_time * fps)
        if duration:
            end_frame = start_frame + int(duration * fps)
        else:
            end_frame = total_frames

        # Set up video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(
            output_path,
            fourcc,
            fps,
            (frame_width, frame_height)
        )

        # Seek to start frame
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

        try:
            with tqdm(total=end_frame-start_frame) as pbar:
                frame_count = start_frame
                while cap.isOpened() and frame_count < end_frame:
                    ret, frame = cap.read()
                    if not ret:
                        break

                    # Process frame
                    processed_frame = self.process_frame(frame)
                    out.write(processed_frame)

                    frame_count += 1
                    pbar.update(1)

        finally:
            # Clean up
            cap.release()
            out.release()
            cv2.destroyAllWindows()
            
            # Clear memory after video processing
            self.clear_memory()
            
    def process_video_with_data_export(self, input_path, output_path, start_time=0, duration=None):
        """
        Process a video file, save the result, and export frame-by-frame gaze data
        
        Args:
            input_path: Path to input video file
            output_path: Path to output video file (will be created)
            start_time: Time in seconds to start processing from
            duration: Duration in seconds to process (None for full video)
            
        Returns:
            output_folder: Path to folder containing exported data
        """
        # Create output folder based on output video name
        output_base = os.path.splitext(output_path)[0]
        output_folder = f"{output_base}_data"
        os.makedirs(output_folder, exist_ok=True)
        
        # Open video file
        cap = cv2.VideoCapture(input_path)

        # Get video properties
        fps = int(cap.get(cv2.CAP_PROP_FPS))
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Calculate start and end frames
        start_frame = int(start_time * fps)
        if duration:
            end_frame = start_frame + int(duration * fps)
        else:
            end_frame = total_frames

        # Set up video writer
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(
            output_path,
            fourcc,
            fps,
            (frame_width, frame_height)
        )
        
        # Prepare for data collection
        all_frame_data = []

        # Seek to start frame
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

        try:
            with tqdm(total=end_frame-start_frame) as pbar:
                frame_count = start_frame
                while cap.isOpened() and frame_count < end_frame:
                    ret, frame = cap.read()
                    if not ret:
                        break

                    # Process frame and get data
                    processed_frame, frame_data = self.process_frame_with_data(frame, frame_number=frame_count)
                    
                    # Add timestamp in seconds
                    frame_data["timestamp_seconds"] = (frame_count - start_frame) / fps
                    
                    # Write processed frame to video
                    out.write(processed_frame)
                    
                    # Store frame data
                    all_frame_data.append(frame_data)

                    frame_count += 1
                    pbar.update(1)
                    
            # Save all collected data
            self.save_frame_data(all_frame_data, output_folder)
            
            # Generate summary statistics
            self.generate_summary_stats(all_frame_data, output_folder)
            
            print(f"Video processing complete. Data exported to {output_folder}")
            return output_folder

        finally:
            # Clean up
            cap.release()
            out.release()
            cv2.destroyAllWindows()
            
            # Clear memory after video processing
            self.clear_memory()
            
    def save_frame_data(self, all_frame_data, output_folder):
        """Save frame data to JSON and CSV files"""
        # Save raw JSON data
        json_path = os.path.join(output_folder, "gaze_data.json")
        with open(json_path, 'w') as f:
            json.dump(all_frame_data, f, indent=2)
            
        print(f"Saved detailed frame data to {json_path}")
        
        # Flatten data for CSV export
        flat_data = []
        
        for frame in all_frame_data:
            frame_number = frame.get("frame_number", None)
            timestamp_seconds = frame.get("timestamp_seconds", None)
            
            # If no faces in this frame, add a row with just the frame info
            if len(frame["faces"]) == 0:
                flat_data.append({
                    "frame_number": frame_number,
                    "timestamp_seconds": timestamp_seconds,
                    "has_face": False
                })
            
            # Otherwise add a row for each face
            for face_idx, face in enumerate(frame["faces"]):
                flat_row = {
                    "frame_number": frame_number,
                    "timestamp_seconds": timestamp_seconds,
                    "has_face": True,
                    "face_index": face_idx
                }
                
                # Add face bounding box
                if "bbox" in face:
                    for key, value in face["bbox"].items():
                        flat_row[f"bbox_{key}"] = value
                
                # Add face center
                if "face_center" in face:
                    for key, value in face["face_center"].items():
                        flat_row[f"face_center_{key}"] = value
                
                # Add landmarks (flatten the structure)
                if "landmarks" in face:
                    for landmark_name, coords in face["landmarks"].items():
                        flat_row[f"landmark_{landmark_name}_x"] = coords[0]
                        flat_row[f"landmark_{landmark_name}_y"] = coords[1]
                
                # Add inout score
                if "inout_score" in face:
                    flat_row["inout_score"] = face["inout_score"]
                
                # Add gaze target
                if "gaze_target" in face:
                    for key, value in face["gaze_target"].items():
                        flat_row[f"gaze_target_{key}"] = value
                
                # Add gaze vector
                if "gaze_vector" in face:
                    for key, value in face["gaze_vector"].items():
                        flat_row[f"gaze_vector_{key}"] = value
                
                # Add heatmap stats
                if "heatmap_stats" in face:
                    for key, value in face["heatmap_stats"].items():
                        flat_row[f"heatmap_{key}"] = value
                
                flat_data.append(flat_row)
        
        # Save as CSV
        csv_path = os.path.join(output_folder, "gaze_data.csv")
        df = pd.DataFrame(flat_data)
        df.to_csv(csv_path, index=False)
        
        print(f"Saved flattened frame data to {csv_path}")
    
    def generate_summary_stats(self, all_frame_data, output_folder):
        """Generate and save summary statistics from all frame data"""
        # Initialize counters and accumulators
        total_frames = len(all_frame_data)
        frames_with_faces = 0
        total_faces = 0
        faces_looking_in_frame = 0
        
        # For averaging
        all_inout_scores = []
        all_gaze_distances = []
        all_gaze_angles = []
        
        for frame in all_frame_data:
            if len(frame["faces"]) > 0:
                frames_with_faces += 1
                total_faces += len(frame["faces"])
                
                for face in frame["faces"]:
                    if "inout_score" in face:
                        all_inout_scores.append(face["inout_score"])
                        
                        if face["inout_score"] > 0.5:  # Using 0.5 as threshold
                            faces_looking_in_frame += 1
                            
                            if "gaze_vector" in face:
                                all_gaze_distances.append(face["gaze_vector"]["distance"])
                                all_gaze_angles.append(face["gaze_vector"]["angle_degrees"])
        
        # Calculate summary statistics
        summary = {
            "video_stats": {
                "total_frames": total_frames,
                "frames_with_faces": frames_with_faces,
                "face_detection_rate": frames_with_faces / total_frames if total_frames > 0 else 0
            },
            "face_stats": {
                "total_faces_detected": total_faces,
                "average_faces_per_frame": total_faces / total_frames if total_frames > 0 else 0
            },
            "gaze_stats": {
                "faces_looking_in_frame": faces_looking_in_frame,
                "in_frame_percentage": faces_looking_in_frame / total_faces if total_faces > 0 else 0,
                "average_inout_score": np.mean(all_inout_scores) if all_inout_scores else None,
                "average_gaze_distance": np.mean(all_gaze_distances) if all_gaze_distances else None,
                "average_gaze_angle": np.mean(all_gaze_angles) if all_gaze_angles else None
            }
        }
        
        # Save summary statistics
        summary_path = os.path.join(output_folder, "summary_stats.json")
        with open(summary_path, 'w') as f:
            json.dump(summary, f, indent=2)
            
        print(f"Saved summary statistics to {summary_path}")
        
        return summary


In [None]:
# Process 10 seconds starting from 5 seconds into the video
# DEMO: Process video with complete data export
analyzer = VideoGazeAnalyzer(use_cuda=True)

# Example 1: Basic usage with automatic naming
print("=== Processing video with data export ===")
data_folder, video_path = analyzer.process_video_with_data_export(
    input_path=input_video,
    start_time=5,  # Start 5 seconds in
    duration=10    # Process 10 seconds
)

print(f"\nResults saved to:")
print(f"Video: {video_path}")
print(f"Data: {data_folder}")

# Example 2: Custom output paths
print("\n=== Processing with custom output paths ===")
# custom_data_folder, custom_video_path = analyzer.process_video_with_data_export(
#     input_path=input_video,
#     output_video_path="custom_analyzed_video.mp4",
#     data_output_folder="custom_gaze_analysis",
#     start_time=0,
#     duration=5
# )

# Load and examine the exported data
print("\n=== Examining exported data ===")
import json
import pandas as pd

# Read the JSON data
json_file = os.path.join(data_folder, "gaze_data.json")
if os.path.exists(json_file):
    with open(json_file, 'r') as f:
        gaze_data = json.load(f)
    
    print(f"Video metadata: {gaze_data['video_metadata']['fps']} FPS, "
          f"{gaze_data['video_metadata']['frame_width']}x{gaze_data['video_metadata']['frame_height']}")
    print(f"Processed {len(gaze_data['frames'])} frames")
    
    # Show sample frame data
    if gaze_data['frames']:
        sample_frame = gaze_data['frames'][0]
        print(f"Sample frame data keys: {list(sample_frame.keys())}")
        if sample_frame['faces']:
            print(f"Sample face data keys: {list(sample_frame['faces'][0].keys())}")

# Read the CSV data for easy analysis
csv_file = os.path.join(data_folder, "gaze_data.csv")
if os.path.exists(csv_file):
    df = pd.read_csv(csv_file)
    print(f"\nCSV data shape: {df.shape}")
    print("CSV columns:", list(df.columns))
    
    # Basic statistics
    print(f"\nBasic statistics:")
    print(f"Frames with faces: {df['face_detected'].sum()}/{len(df)}")
    if df['face_detected'].sum() > 0:
        face_df = df[df['face_detected'] == True]
        print(f"Looking in-frame: {face_df['looking_in_frame'].sum()}/{len(face_df)}")
        if face_df['looking_in_frame'].sum() > 0:
            gaze_df = face_df[face_df['looking_in_frame'] == True]
            print(f"Average gaze distance: {gaze_df['gaze_distance'].mean():.2f} pixels")

# Read summary statistics
summary_file = os.path.join(data_folder, "summary_stats.json")
if os.path.exists(summary_file):
    with open(summary_file, 'r') as f:
        summary = json.load(f)
    print(f"\nSummary statistics:")
    print(f"Face detection rate: {summary.get('face_detection_rate', 0):.2%}")
    print(f"In-frame gaze rate: {summary.get('in_frame_gaze_rate', 0):.2%}")
    if 'gaze_stats' in summary:
        print(f"Average gaze distance: {summary['gaze_stats']['avg_gaze_distance']:.2f} pixels")

print("\n=== Data export complete! ===")
print("You can now analyze the exported data using:")
print("1. JSON file for complete programmatic access")
print("2. CSV file for analysis in Excel, Python pandas, R, etc.")
print("3. Summary file for quick overview statistics")

def clear_memory(self):
        """
        Clear GPU/MPS memory and force garbage collection.
        Call this between processing different videos to free up memory.
        """
        print("Clearing memory...")
        
        # Clear PyTorch CUDA cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            
            # Print GPU memory usage
            allocated = torch.cuda.memory_allocated() / 1024**3  # GB
            cached = torch.cuda.memory_reserved() / 1024**3  # GB
            print(f"GPU Memory - Allocated: {allocated:.2f} GB, Cached: {cached:.2f} GB")
        
        # Clear MPS cache if using Apple Silicon
        if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
            torch.mps.empty_cache()
            print("MPS cache cleared")
        
        # Force garbage collection
        gc.collect()
        
        # Print system memory usage
        memory = psutil.virtual_memory()
        print(f"System Memory - Used: {memory.used / 1024**3:.2f} GB / {memory.total / 1024**3:.2f} GB ({memory.percent:.1f}%)")
        print("Memory cleanup complete!\n")
    
    def cleanup_completely(self):
        """
        Complete cleanup: move model to CPU and clear all memory.
        Use this when you're done with the analyzer or want to free maximum memory.
        """
        print("Performing complete cleanup...")
        
        # Move model to CPU to free GPU memory
        if hasattr(self, 'model'):
            self.model.cpu()
            print("Model moved to CPU")
        
        # Clear memory
        self.clear_memory()
        
        print("Complete cleanup finished!")
    
    def reinitialize_model(self, use_cuda=True):
        """
        Reinitialize the model (useful after complete cleanup).
        """
        print("Reinitializing model...")
        
        self.device = 'cuda' if use_cuda and torch.cuda.is_available() else 'cpu'
        print(f"Using device: {self.device}")
        
        # Reload model if needed
        if not hasattr(self, 'model') or self.model is None:
            self.model, self.transform = torch.hub.load('fkryan/gazelle', 'gazelle_dinov2_vitl14_inout')
        
        self.model.eval()
        self.model.to(self.device)
        
        print("Model reinitialized!")

# Example usage
if __name__ == "__main__":
    # Initialize analyzer
    analyzer = VideoGazeAnalyzer(use_cuda=True)
    
    # Process a video and export data
    input_video = "/content/gait_speaktask.mp4"  # Replace with your video path
    output_video = "output_video.mp4"
    
    # Process video with data export
    output_folder = analyzer.process_video_with_data_export(
        input_video,
        output_video,
        start_time=0,
        duration=None  # Process entire video
    )
    
    print(f"Video processing complete! Data exported to {output_folder}")
    
    # Clean up when done
    analyzer.cleanup_completely()