In [1]:
# !pip install -q transnetv2_pytorch
# !pip install -Uq transformers==4.46.3 tokenizers==0.20.3
# !pip install -q open_clip_torch
# !pip install -q addict

In [2]:
!rm -rf **

In [3]:
#!/usr/bin/env python3
"""
Unified Video Processing Library
Combines visual (slides) and audio (transcription) processing into context units
"""

import os
import gc
import json
import warnings
import subprocess
import shutil
from dataclasses import dataclass, asdict
from typing import List, Tuple, Dict, Optional, Any

import numpy as np
import cv2
from PIL import Image
from pydantic import BaseModel

import torch
import torch.nn.functional as F

from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter1d
from scipy import signal

from pydub import AudioSegment
from pydub.utils import make_chunks
from pydub import effects
from pydub.silence import detect_silence

from tqdm.auto import tqdm

warnings.filterwarnings("ignore")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ============================================================================
# Data Structures
# ============================================================================


class ContextUnitData(BaseModel):
    """Final output data structure for context units"""

    text: str
    start_time: float
    end_time: float
    visual_text: str = ""  # OCR text from slides
    audio_text: str = ""  # Transcribed speech


@dataclass
class Config:
    """Unified configuration for the pipeline"""

    # Paths
    output_dir: str = "output"
    frames_dir: str = "output/frames"

    # Phase 1: Frame Extraction
    target_fps: float = 1.0
    small_frame_size: Tuple[int, int] = (224, 224)
    full_frame_size: Tuple[int, int] = (960, 540)

    # Phase 1: Shot Detection (TransNetV2)
    shot_confidence_threshold: float = 0.5

    # Phase 1: Slide Detection (CLIP ViT-L/14)
    clip_model: str = "ViT-L-14"
    clip_batch_size: int = 16
    dynamic_threshold_lambda: float = 0.7
    smoothing_sigma: float = 2.0
    min_segment_duration: float = 4.0

    # Phase 2: Keyframe Selection
    keyframes_per_segment: int = 2
    quality_weight: float = 0.6
    diversity_penalty: float = 0.4

    # Phase 3: OCR (DeepSeek-OCR)
    ocr_model: str = "deepseek-ai/DeepSeek-OCR"
    ocr_base_size: int = 1024
    ocr_image_size: int = 640
    ocr_crop_mode: bool = True

    # Audio Processing
    whisper_model: str = "vinai/PhoWhisper-medium"
    audio_chunk_length_ms: int = 30000
    noise_reduction: bool = False
    reduction_strength: float = 0.5

    # Memory Management
    save_intermediate: bool = True


@dataclass
class FrameInfo:
    frame_index: int
    timestamp: float
    small_path: str
    full_path: str


@dataclass
class ShotInfo:
    shot_id: int
    start_frame_idx: int
    end_frame_idx: int
    confidence: float


@dataclass
class SegmentInfo:
    segment_id: int
    shot_id: int
    start_frame_idx: int
    end_frame_idx: int
    keyframe_indices: List[int]
    time_range: Tuple[float, float]


@dataclass
class TranscriptSegment:
    start_time: float
    end_time: float
    text: str


# ============================================================================
# Utilities
# ============================================================================


def clear_gpu_memory():
    """Clear GPU memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()


def compute_sharpness(image: np.ndarray) -> float:
    """Compute image sharpness using Laplacian variance"""
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray = image
    return cv2.Laplacian(gray, cv2.CV_64F).var()


def compute_entropy(image: np.ndarray) -> float:
    """Compute Shannon entropy of image"""
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray = image

    hist = cv2.calcHist([gray], [0], None, [256], [0, 256])
    hist = hist.flatten() / hist.sum()
    hist = hist[hist > 0]
    entropy = -np.sum(hist * np.log2(hist))
    return entropy


def convert_to_serializable(obj):
    """Convert numpy/torch types to JSON-serializable Python types"""
    if isinstance(obj, (np.integer, np.int64)):
        return int(obj)
    elif isinstance(obj, (np.floating, np.float64)):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    elif isinstance(obj, dict):
        return {k: convert_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple)):
        return [convert_to_serializable(item) for item in obj]
    return obj


def save_checkpoint(data: Any, filepath: str):
    """Save intermediate checkpoint"""
    os.makedirs(os.path.dirname(filepath), exist_ok=True)

    if isinstance(data, (list, dict)):
        data = convert_to_serializable(data)
        with open(filepath, "w", encoding="utf-8") as f:
            json.dump(data, f, indent=2, ensure_ascii=False)
    else:
        torch.save(data, filepath)


def load_checkpoint(filepath: str) -> Any:
    """Load checkpoint if exists"""
    if not os.path.exists(filepath):
        return None

    if filepath.endswith(".json"):
        with open(filepath, "r", encoding="utf-8") as f:
            return json.load(f)
    else:
        return torch.load(filepath, map_location="cpu")


# ============================================================================
# Audio Extraction & Preprocessing
# ============================================================================


def extract_audio_from_video(video_path: str, output_folder: str = "audio") -> str:
    """Extract audio from video file using ffmpeg"""
    try:
        os.makedirs(output_folder, exist_ok=True)

        if not os.path.exists(video_path):
            raise FileNotFoundError(f"Video file not found: {video_path}")

        base_name = os.path.splitext(os.path.basename(video_path))[0]
        audio_path = os.path.join(output_folder, f"{base_name}.wav")

        if os.path.exists(audio_path):
            return audio_path

        ffmpeg_cmd = [
            "ffmpeg",
            "-i",
            video_path,
            "-vn",
            "-acodec",
            "pcm_s16le",
            "-ar",
            "16000",
            "-ac",
            "1",
            "-y",
            audio_path,
        ]

        subprocess.run(
            ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True
        )

        return audio_path

    except Exception as e:
        raise RuntimeError(f"Audio extraction failed: {str(e)}")


def reduce_noise_spectral_subtraction(audio_segment, noise_factor=2.0):
    """Safe spectral subtraction that prevents removing entire speech"""
    try:
        samples = np.array(audio_segment.get_array_of_samples())
        sample_rate = audio_segment.frame_rate

        if audio_segment.channels == 2:
            samples = samples.reshape((-1, 2))
            left = _spectral_subtract_channel_safe(
                samples[:, 0], sample_rate, noise_factor
            )
            right = _spectral_subtract_channel_safe(
                samples[:, 1], sample_rate, noise_factor
            )
            cleaned = np.column_stack((left, right)).flatten()
        else:
            cleaned = _spectral_subtract_channel_safe(
                samples, sample_rate, noise_factor
            )

        cleaned = np.clip(cleaned, -32768, 32767).astype(np.int16)

        cleaned_audio = audio_segment._spawn(cleaned.tobytes())
        return cleaned_audio

    except Exception as e:
        return audio_segment


def _spectral_subtract_channel_safe(samples, sr, noise_factor):
    """Safe spectral subtraction: never zeroes out full speech"""
    fft = np.fft.rfft(samples)
    mag = np.abs(fft)
    phase = np.angle(fft)

    noise_est = np.percentile(mag, 10)
    clean_mag = mag - noise_factor * noise_est
    clean_mag = np.maximum(clean_mag, mag * 0.20)

    clean_fft = clean_mag * np.exp(1j * phase)
    cleaned = np.fft.irfft(clean_fft)

    return cleaned


def apply_high_pass_filter(audio_segment, cutoff_freq=80):
    """Apply high-pass filter to remove low-frequency noise"""
    try:
        samples = np.array(audio_segment.get_array_of_samples())
        sample_rate = audio_segment.frame_rate

        if audio_segment.channels == 2:
            samples = samples.reshape((-1, 2))
            left = _apply_highpass_channel(samples[:, 0], sample_rate, cutoff_freq)
            right = _apply_highpass_channel(samples[:, 1], sample_rate, cutoff_freq)
            filtered = np.column_stack((left, right)).flatten()
        else:
            filtered = _apply_highpass_channel(samples, sample_rate, cutoff_freq)

        filtered = np.clip(filtered, -32768, 32767).astype(np.int16)

        filtered_audio = audio_segment._spawn(filtered.tobytes())
        return filtered_audio

    except Exception as e:
        return audio_segment


def _apply_highpass_channel(samples, sample_rate, cutoff_freq):
    """High-pass for 1 channel"""
    try:
        samples_float = samples.astype(np.float32)
        nyquist = sample_rate / 2.0
        normal_cutoff = cutoff_freq / nyquist
        b, a = signal.butter(5, normal_cutoff, btype="high", analog=False)
        filtered = signal.filtfilt(b, a, samples_float)
        return filtered
    except Exception as e:
        return samples


def preprocess_audio_with_noise_reduction(
    audio_path, noise_reduction=True, reduction_strength=0.5
):
    """Preprocess audio with noise reduction"""
    try:
        audio = AudioSegment.from_file(audio_path)

        if noise_reduction:
            audio = apply_high_pass_filter(audio, cutoff_freq=80)
            audio = reduce_noise_spectral_subtraction(
                audio, noise_factor=reduction_strength * 3.0
            )

        if audio.dBFS < -30.0:
            boost_amount = min(-20.0 - audio.dBFS, 15.0)
            audio = audio.apply_gain(boost_amount)

        audio = audio.set_frame_rate(16000).set_channels(1)

        return audio

    except Exception as e:
        return AudioSegment.from_file(audio_path)


def split_audio(audio_path: str, chunk_length_ms: int = 30000) -> List[str]:
    """Split audio into chunks"""
    try:
        audio = AudioSegment.from_file(audio_path)
        chunks = make_chunks(audio, chunk_length_ms)

        chunk_paths = []
        for i, chunk in enumerate(chunks):
            chunk_path = f"temp_chunk_{i}.wav"
            chunk.export(
                chunk_path, format="wav", parameters=["-ac", "1", "-ar", "16000"]
            )
            chunk_paths.append(chunk_path)

        return chunk_paths

    except Exception as e:
        return [audio_path]


# ============================================================================
# Phase 1: Frame Extraction
# ============================================================================


class VideoFrameExtractor:
    """Extract and preprocess frames from video"""

    def __init__(self, config: Config):
        self.config = config
        os.makedirs(config.frames_dir, exist_ok=True)

    def extract_frames(self, video_path: str) -> List[FrameInfo]:
        """Extract frames at target FPS"""
        checkpoint_path = os.path.join(self.config.output_dir, "frames_info.json")
        if self.config.save_intermediate and os.path.exists(checkpoint_path):
            frames_data = load_checkpoint(checkpoint_path)
            frames_info = [FrameInfo(**f) for f in frames_data]
            return frames_info

        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            raise ValueError(f"Cannot open video: {video_path}")

        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_interval = int(fps / self.config.target_fps)

        frames_info = []
        frame_count = 0

        pbar = tqdm(total=total_frames, desc="Extracting frames")

        while True:
            ret, frame = cap.read()
            if not ret:
                break

            if frame_count % frame_interval == 0:
                frame_index = len(frames_info)
                timestamp = frame_count / fps

                small_frame = cv2.resize(frame, self.config.small_frame_size)
                small_path = os.path.join(
                    self.config.frames_dir, f"frame_{frame_index:05d}_small.jpg"
                )
                cv2.imwrite(small_path, small_frame)

                full_frame = cv2.resize(frame, self.config.full_frame_size)
                full_path = os.path.join(
                    self.config.frames_dir, f"frame_{frame_index:05d}_full.jpg"
                )
                cv2.imwrite(full_path, full_frame)

                frames_info.append(
                    FrameInfo(
                        frame_index=frame_index,
                        timestamp=timestamp,
                        small_path=small_path,
                        full_path=full_path,
                    )
                )

            frame_count += 1
            pbar.update(1)

        pbar.close()
        cap.release()

        if self.config.save_intermediate:
            frames_data = [asdict(f) for f in frames_info]
            save_checkpoint(frames_data, checkpoint_path)

        return frames_info


# ============================================================================
# Phase 1b: Shot Boundary Detection
# ============================================================================


class ShotDetector:
    """Detect shot boundaries using TransNetV2"""

    def __init__(self, config: Config):
        self.config = config
        self.model = None

    def load_model(self):
        """Load TransNetV2 model"""
        from transnetv2_pytorch import TransNetV2

        self.model = TransNetV2()
        if torch.cuda.is_available():
            self.model = self.model.to(device)
        self.model.eval()

    def detect_shots(
        self, video_path: str, frames_info: List[FrameInfo]
    ) -> List[ShotInfo]:
        """Detect shot boundaries from frames"""
        checkpoint_path = os.path.join(self.config.output_dir, "shots_info.json")
        if self.config.save_intermediate and os.path.exists(checkpoint_path):
            shots_data = load_checkpoint(checkpoint_path)
            shots = [ShotInfo(**s) for s in shots_data]
            return shots

        if self.model is None:
            self.load_model()

        _, single_predictions, _ = self.model.predict_video(video_path)
        predictions = single_predictions.cpu().numpy()

        cap = cv2.VideoCapture(video_path)
        original_fps = cap.get(cv2.CAP_PROP_FPS)
        cap.release()

        frame_interval = int(original_fps / self.config.target_fps)

        shot_boundaries = []
        for i in range(len(predictions)):
            if predictions[i] > self.config.shot_confidence_threshold:
                sampled_idx = i // frame_interval
                if (
                    sampled_idx < len(frames_info)
                    and sampled_idx not in shot_boundaries
                ):
                    shot_boundaries.append(sampled_idx)

        boundaries = [0] + shot_boundaries + [len(frames_info) - 1]
        shots = []

        for i in range(len(boundaries) - 1):
            start_idx = boundaries[i]
            end_idx = boundaries[i + 1]

            if i < len(shot_boundaries):
                original_idx = shot_boundaries[i] * frame_interval
                if original_idx < len(predictions):
                    confidence = float(predictions[original_idx])
                else:
                    confidence = 0.0
            else:
                confidence = 0.0

            shots.append(
                ShotInfo(
                    shot_id=i,
                    start_frame_idx=start_idx,
                    end_frame_idx=end_idx,
                    confidence=confidence,
                )
            )

        if self.config.save_intermediate:
            shots_data = [asdict(s) for s in shots]
            save_checkpoint(shots_data, checkpoint_path)

        return shots

    def free_model(self):
        """Free TransNetV2 model from memory"""
        if self.model is not None:
            del self.model
            self.model = None
            clear_gpu_memory()


# ============================================================================
# Phase 1c: CLIP Slide Detection
# ============================================================================


class CLIPSlideDetector:
    """Detect slide boundaries within shots using CLIP ViT-L/14"""

    def __init__(self, config: Config):
        self.config = config
        self.model = None
        self.preprocess = None

    def load_model(self):
        """Load CLIP ViT-L/14 model"""
        import open_clip

        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            "ViT-L-14", pretrained="openai"
        )
        self.model = self.model.to(device)
        self.model.eval()

    def extract_clip_features(self, frames_info: List[FrameInfo]) -> torch.Tensor:
        """Extract CLIP features for all frames"""
        checkpoint_path = os.path.join(self.config.output_dir, "clip_features.pt")
        if self.config.save_intermediate and os.path.exists(checkpoint_path):
            features = load_checkpoint(checkpoint_path)
            return features

        if self.model is None:
            self.load_model()

        features_list = []

        for i in tqdm(
            range(0, len(frames_info), self.config.clip_batch_size),
            desc="CLIP encoding",
        ):
            batch_frames = frames_info[i : i + self.config.clip_batch_size]

            images = []
            for frame_info in batch_frames:
                img = Image.open(frame_info.small_path).convert("RGB")
                images.append(self.preprocess(img))

            images_tensor = torch.stack(images).to(device)

            with torch.no_grad():
                features = self.model.encode_image(images_tensor)
                features = F.normalize(features, dim=-1)

            features_list.append(features.cpu())

        all_features = torch.cat(features_list, dim=0)

        if self.config.save_intermediate:
            save_checkpoint(all_features, checkpoint_path)

        return all_features

    def detect_slide_boundaries_in_shot(
        self, features: torch.Tensor, shot_start: int, shot_end: int
    ) -> List[int]:
        """Detect slide boundaries within a shot using dynamic thresholding"""
        shot_features = features[shot_start : shot_end + 1]

        if len(shot_features) < 2:
            return []

        similarities = []
        for i in range(len(shot_features) - 1):
            sim = F.cosine_similarity(
                shot_features[i : i + 1], shot_features[i + 1 : i + 2], dim=1
            )
            similarities.append(sim.item())

        similarities = np.array(similarities)
        dissimilarity = 1 - similarities

        dissimilarity_smooth = gaussian_filter1d(
            dissimilarity, sigma=self.config.smoothing_sigma
        )

        mean_dissim = dissimilarity_smooth.mean()
        std_dissim = dissimilarity_smooth.std()
        threshold = mean_dissim + self.config.dynamic_threshold_lambda * std_dissim

        peaks, _ = find_peaks(dissimilarity_smooth, height=threshold)
        boundaries = [shot_start + p for p in peaks]

        return boundaries

    def merge_short_segments(
        self, segments: List[SegmentInfo], frames_info: List[FrameInfo]
    ) -> List[SegmentInfo]:
        """Merge segments shorter than min_duration"""
        if len(segments) <= 1:
            return segments

        merged = []
        i = 0

        while i < len(segments):
            current = segments[i]
            duration = current.time_range[1] - current.time_range[0]

            if duration >= self.config.min_segment_duration:
                merged.append(current)
                i += 1
            else:
                if i + 1 < len(segments):
                    next_seg = segments[i + 1]
                    merged_seg = SegmentInfo(
                        segment_id=current.segment_id,
                        shot_id=current.shot_id,
                        start_frame_idx=current.start_frame_idx,
                        end_frame_idx=next_seg.end_frame_idx,
                        keyframe_indices=[],
                        time_range=(current.time_range[0], next_seg.time_range[1]),
                    )
                    merged.append(merged_seg)
                    i += 2
                elif len(merged) > 0:
                    prev_seg = merged.pop()
                    merged_seg = SegmentInfo(
                        segment_id=prev_seg.segment_id,
                        shot_id=prev_seg.shot_id,
                        start_frame_idx=prev_seg.start_frame_idx,
                        end_frame_idx=current.end_frame_idx,
                        keyframe_indices=[],
                        time_range=(prev_seg.time_range[0], current.time_range[1]),
                    )
                    merged.append(merged_seg)
                    i += 1
                else:
                    merged.append(current)
                    i += 1

        for idx, seg in enumerate(merged):
            seg.segment_id = idx

        return merged

    def select_keyframes_greedy(
        self,
        frames_info: List[FrameInfo],
        features: torch.Tensor,
        start_idx: int,
        end_idx: int,
    ) -> List[int]:
        """Select keyframes using greedy quality + diversity algorithm"""
        segment_frames = frames_info[start_idx : end_idx + 1]
        segment_features = features[start_idx : end_idx + 1]

        if len(segment_frames) == 0:
            return []
        if len(segment_frames) == 1:
            return [start_idx]

        quality_scores = []
        for frame_info in segment_frames:
            img = cv2.imread(frame_info.full_path)
            sharpness = compute_sharpness(img)
            entropy = compute_entropy(img)
            quality = 0.6 * sharpness + 0.4 * entropy
            quality_scores.append(quality)

        quality_scores = np.array(quality_scores)

        if quality_scores.std() > 0:
            quality_scores = (
                quality_scores - quality_scores.mean()
            ) / quality_scores.std()
            quality_scores = (quality_scores - quality_scores.min()) / (
                quality_scores.max() - quality_scores.min()
            )
        else:
            quality_scores = np.ones_like(quality_scores)

        selected_indices = []
        k = min(self.config.keyframes_per_segment, len(segment_frames))

        first_idx = int(np.argmax(quality_scores))
        selected_indices.append(first_idx)

        for _ in range(k - 1):
            best_score = -float("inf")
            best_idx = -1

            for idx in range(len(segment_frames)):
                if idx in selected_indices:
                    continue

                max_similarity = 0.0
                for sel_idx in selected_indices:
                    sim = F.cosine_similarity(
                        segment_features[idx : idx + 1],
                        segment_features[sel_idx : sel_idx + 1],
                        dim=1,
                    ).item()
                    max_similarity = max(max_similarity, sim)

                score = (
                    self.config.quality_weight * quality_scores[idx]
                    - self.config.diversity_penalty * max_similarity
                )

                if score > best_score:
                    best_score = score
                    best_idx = idx

            if best_idx >= 0:
                selected_indices.append(best_idx)

        keyframe_indices = sorted([start_idx + idx for idx in selected_indices])

        return keyframe_indices

    def detect_segments(
        self, frames_info: List[FrameInfo], shots: List[ShotInfo]
    ) -> List[SegmentInfo]:
        """Main pipeline for slide segment detection within shots"""
        features = self.extract_clip_features(frames_info)

        segments = []

        for shot in tqdm(shots, desc="Processing shots"):
            boundaries = self.detect_slide_boundaries_in_shot(
                features, shot.start_frame_idx, shot.end_frame_idx
            )

            shot_boundaries = [shot.start_frame_idx] + boundaries + [shot.end_frame_idx]

            for i in range(len(shot_boundaries) - 1):
                start_idx = shot_boundaries[i]
                end_idx = shot_boundaries[i + 1]

                segments.append(
                    SegmentInfo(
                        segment_id=len(segments),
                        shot_id=shot.shot_id,
                        start_frame_idx=start_idx,
                        end_frame_idx=end_idx,
                        keyframe_indices=[],
                        time_range=(
                            frames_info[start_idx].timestamp,
                            frames_info[end_idx].timestamp,
                        ),
                    )
                )

        segments = self.merge_short_segments(segments, frames_info)

        for segment in tqdm(segments, desc="Keyframe selection"):
            keyframes = self.select_keyframes_greedy(
                frames_info, features, segment.start_frame_idx, segment.end_frame_idx
            )
            segment.keyframe_indices = keyframes

        if self.config.save_intermediate:
            segments_data = [asdict(s) for s in segments]
            save_checkpoint(
                segments_data,
                os.path.join(self.config.output_dir, "segments_info.json"),
            )

        return segments

    def free_model(self):
        """Free CLIP model from memory"""
        if self.model is not None:
            del self.model
            del self.preprocess
            self.model = None
            self.preprocess = None
            clear_gpu_memory()


# ============================================================================
# Phase 3: OCR (DeepSeek-OCR)
# ============================================================================


class DeepSeekProcessor:
    """Extract text using DeepSeek-OCR"""

    def __init__(self, config: Config):
        self.config = config
        self.model = None
        self.tokenizer = None

    def load_model(self):
        """Load DeepSeek-OCR model"""
        from transformers import AutoModel, AutoTokenizer

        self.tokenizer = AutoTokenizer.from_pretrained(
            self.config.ocr_model, trust_remote_code=True
        )
        self.model = AutoModel.from_pretrained(
            self.config.ocr_model, trust_remote_code=True, use_safetensors=True
        )
        self.model = self.model.eval().cuda().to(torch.bfloat16)

    def process_image_unified(self, image_path: str) -> str:
        """Extract OCR text from image"""
        if self.model is None:
            self.load_model()

        try:
            prompt = "<image>\nDescribe this image in detail."

            self.model.infer(
                self.tokenizer,
                prompt=prompt,
                image_file=image_path,
                output_path=self.config.output_dir,
                base_size=self.config.ocr_base_size,
                image_size=self.config.ocr_image_size,
                crop_mode=self.config.ocr_crop_mode,
                save_results=True,
                test_compress=True,
            )

            result_file = os.path.join(self.config.output_dir, "result.mmd")
            if os.path.exists(result_file):
                with open(result_file, "r", encoding="utf-8") as f:
                    result_text = f.read().strip()
            else:
                result_text = ""

            return result_text

        except Exception as e:
            return ""

    def process_segments_unified(
        self, frames_info: List[FrameInfo], segments: List[SegmentInfo]
    ) -> Dict[int, str]:
        """Process OCR - extract text from first keyframe of each segment"""
        checkpoint = os.path.join(self.config.output_dir, "segment_results.json")

        if self.config.save_intermediate and os.path.exists(checkpoint):
            segment_results = load_checkpoint(checkpoint)
            segment_results = {int(k): v for k, v in segment_results.items()}
            return segment_results

        if self.model is None:
            self.load_model()

        segment_results = {}

        for segment in tqdm(segments, desc="OCR processing"):
            if segment.keyframe_indices:
                frame_info = frames_info[segment.keyframe_indices[0]]
                text = self.process_image_unified(frame_info.full_path)
                segment_results[segment.segment_id] = text
            else:
                segment_results[segment.segment_id] = ""

        if self.config.save_intermediate:
            save_checkpoint(segment_results, checkpoint)

        return segment_results

    def free_model(self):
        """Free DeepSeek-OCR model from memory"""
        if self.model is not None:
            del self.model
            self.model = None
        if self.tokenizer is not None:
            del self.tokenizer
            self.tokenizer = None
        clear_gpu_memory()


# ============================================================================
# Audio Transcription
# ============================================================================


class AudioTranscriber:
    """Transcribe audio using Whisper"""

    def __init__(self, config: Config):
        self.config = config
        self.transcriber = None

    def load_model(self):
        """Load Whisper model"""
        from transformers import pipeline

        self.transcriber = pipeline(
            "automatic-speech-recognition",
            model=self.config.whisper_model,
            device=device.type,
            return_timestamps=True,
            framework="pt",
            torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
            model_kwargs={"use_cache": True},
        )

    def transcribe_audio(self, audio_path: str) -> List[TranscriptSegment]:
        """Transcribe audio with timestamps"""
        if self.transcriber is None:
            self.load_model()

        print(f"[DEBUG] Starting audio transcription from: {audio_path}")

        try:
            if self.config.noise_reduction:
                print("[DEBUG] Applying noise reduction...")
                audio = preprocess_audio_with_noise_reduction(
                    audio_path,
                    noise_reduction=True,
                    reduction_strength=self.config.reduction_strength,
                )
                processed_audio_path = "temp_processed_audio.wav"
                audio.export(processed_audio_path, format="wav")
                audio_path = processed_audio_path

            chunk_paths = split_audio(audio_path, self.config.audio_chunk_length_ms)
            print(f"[DEBUG] Split audio into {len(chunk_paths)} chunks")

            all_segments = []
            current_offset = 0.0

            for idx, chunk_path in enumerate(
                tqdm(chunk_paths, desc="Transcribing audio")
            ):
                try:
                    print(f"[DEBUG] Processing chunk {idx+1}/{len(chunk_paths)}...")
                    result = self.transcriber(
                        chunk_path,
                        return_timestamps=True,
                        generate_kwargs={
                            "language": "vietnamese",
                            "task": "transcribe",
                        },
                    )

                    print(f"[DEBUG] Chunk {idx+1} result keys: {result.keys()}")

                    if "chunks" in result:
                        print(
                            f"[DEBUG] Found {len(result['chunks'])} text chunks in audio chunk {idx+1}"
                        )
                        for chunk in result["chunks"]:
                            timestamp = chunk.get("timestamp", (0, 0))
                            start_time = current_offset + (
                                timestamp[0] if timestamp[0] is not None else 0
                            )
                            end_time = current_offset + (
                                timestamp[1] if timestamp[1] is not None else 0
                            )
                            text = chunk.get("text", "").strip()

                            if text:
                                print(
                                    f"[DEBUG] Audio segment: [{start_time:.1f}s - {end_time:.1f}s] {text[:50]}..."
                                )
                                all_segments.append(
                                    TranscriptSegment(
                                        start_time=start_time,
                                        end_time=end_time,
                                        text=text,
                                    )
                                )
                    else:
                        print(f"[DEBUG] No 'chunks' key in result for chunk {idx+1}")

                    current_offset += self.config.audio_chunk_length_ms / 1000.0

                except Exception as e:
                    print(f"[DEBUG] Error processing chunk {idx+1}: {str(e)}")
                    current_offset += self.config.audio_chunk_length_ms / 1000.0
                    continue

            # Cleanup temp files
            for chunk_path in chunk_paths:
                if os.path.exists(chunk_path) and chunk_path.startswith("temp_"):
                    os.remove(chunk_path)

            if self.config.noise_reduction and os.path.exists(
                "temp_processed_audio.wav"
            ):
                os.remove("temp_processed_audio.wav")

            print(f"[DEBUG] Total audio segments extracted: {len(all_segments)}")
            return all_segments

        except Exception as e:
            print(f"[DEBUG] Audio transcription failed: {str(e)}")
            return []

    def free_model(self):
        """Free Whisper model from memory"""
        if self.transcriber is not None:
            del self.transcriber
            self.transcriber = None
            clear_gpu_memory()


# ============================================================================
# Merging Logic
# ============================================================================


def merge_visual_audio(
    segments: List[SegmentInfo],
    visual_texts: Dict[int, str],
    audio_segments: List[TranscriptSegment],
) -> List[ContextUnitData]:
    """
    Merge visual and audio data into context units.
    Visual segments define the structure, audio is merged based on time overlap.
    """
    print(
        f"\n[DEBUG] Merging {len(segments)} visual segments with {len(audio_segments)} audio segments"
    )

    context_units = []

    for segment in segments:
        v_start, v_end = segment.time_range
        visual_text = visual_texts.get(segment.segment_id, "")

        # Find overlapping audio segments
        overlapping_audio = []
        for audio_seg in audio_segments:
            # Check if audio segment overlaps with visual segment
            if audio_seg.start_time < v_end and audio_seg.end_time > v_start:
                overlapping_audio.append(audio_seg.text)

        # Concatenate texts
        audio_text = " ".join(overlapping_audio)

        print(f"[DEBUG] Segment {segment.segment_id} [{v_start:.1f}s - {v_end:.1f}s]:")
        print(f"  - Visual text length: {len(visual_text)}")
        print(f"  - Audio segments matched: {len(overlapping_audio)}")
        print(f"  - Audio text length: {len(audio_text)}")

        # Build combined text with clear separators
        combined_parts = []
        
        if visual_text:
            combined_parts.append(f"[VISUAL]\n{visual_text}")
        
        if audio_text:
            combined_parts.append(f"[AUDIO]\n{audio_text}")
        
        combined_text = "\n\n".join(combined_parts) if combined_parts else ""

        context_units.append(
            ContextUnitData(
                text=combined_text,
                start_time=v_start,
                end_time=v_end,
                visual_text=visual_text,
                audio_text=audio_text,
            )
        )

    print(f"[DEBUG] Created {len(context_units)} context units")
    return context_units


# ============================================================================
# Main Processing Function
# ============================================================================


def process_video(
    video_path: str, config: Optional[Config] = None
) -> List[ContextUnitData]:
    """
    Main function to process video and return context units.

    Args:
        video_path: Path to the video file
        config: Optional configuration (uses defaults if not provided)

    Returns:
        List of ContextUnitData with merged visual and audio information
    """
    if config is None:
        config = Config()

    os.makedirs(config.output_dir, exist_ok=True)
    os.makedirs(config.frames_dir, exist_ok=True)

    # Extract audio from video
    audio_path = extract_audio_from_video(
        video_path, os.path.join(config.output_dir, "audio")
    )

    # Phase 1: Frame Extraction
    extractor = VideoFrameExtractor(config)
    frames_info = extractor.extract_frames(video_path)

    # Phase 1b: Shot Detection
    shot_detector = ShotDetector(config)
    shots = shot_detector.detect_shots(video_path, frames_info)
    shot_detector.free_model()

    # Phase 1c: Slide Detection & Keyframe Selection
    slide_detector = CLIPSlideDetector(config)
    segments = slide_detector.detect_segments(frames_info, shots)
    slide_detector.free_model()

    # Phase 3: OCR Processing
    deepseek_processor = DeepSeekProcessor(config)
    visual_texts = deepseek_processor.process_segments_unified(frames_info, segments)
    deepseek_processor.free_model()

    # Phase 4: Audio Transcription
    print("\n" + "=" * 60)
    print("Phase 4: Audio Transcription")
    print("=" * 60)
    audio_transcriber = AudioTranscriber(config)
    audio_segments = audio_transcriber.transcribe_audio(audio_path)
    print(f"[INFO] Extracted {len(audio_segments)} audio segments")
    audio_transcriber.free_model()

    # Phase 5: Merge Visual and Audio
    print("\n" + "=" * 60)
    print("Phase 5: Merging Visual and Audio")
    print("=" * 60)
    context_units = merge_visual_audio(segments, visual_texts, audio_segments)

    return context_units


# ============================================================================
# Export Functions
# ============================================================================


def save_context_units(context_units: List[ContextUnitData], output_path: str):
    """Save context units to JSON file"""
    data = [unit.dict() for unit in context_units]
    with open(output_path, "w", encoding="utf-8") as f:
        json.dump(data, f, indent=2, ensure_ascii=False)


def load_context_units(input_path: str) -> List[ContextUnitData]:
    """Load context units from JSON file"""
    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)
    return [ContextUnitData(**unit) for unit in data]



In [None]:
%%time
# ============================================================================
# Constants
# ============================================================================

# Configure these constants before running
VIDEO_PATH = "/kaggle/input/cs431videos/6.3. CS431 - Chuong 6 Part 3 Bieu dien tu bang Vector O57P9YHZOE0.mp4"
OUTPUT_DIR = "output"


# ============================================================================
# Main Execution
# ============================================================================


if __name__ == "__main__":
    # Create config
    config = Config(
        output_dir=OUTPUT_DIR,
        frames_dir=f"{OUTPUT_DIR}/frames",
        noise_reduction=True,
        reduction_strength=0.5,
        save_intermediate=True
    )
    
    print(f"Processing video: {VIDEO_PATH}")
    print(f"Output directory: {OUTPUT_DIR}")
    print("="*60)
    
    # Process video
    context_units = process_video(VIDEO_PATH, config)
    
    # Save results
    output_path = os.path.join(OUTPUT_DIR, "context_units.json")
    save_context_units(context_units, output_path)
    
    print("="*60)
    print(f"Processing complete!")
    print(f"Total context units: {len(context_units)}")
    print(f"Results saved to: {output_path}")
    
    # Statistics
    units_with_visual = sum(1 for u in context_units if u.visual_text)
    units_with_audio = sum(1 for u in context_units if u.audio_text)
    units_with_both = sum(1 for u in context_units if u.visual_text and u.audio_text)
    
    print(f"\nStatistics:")
    print(f"  - Units with visual text: {units_with_visual}")
    print(f"  - Units with audio text: {units_with_audio}")
    print(f"  - Units with both: {units_with_both}")
    
    print("\nSample units:")
    for i, unit in enumerate(context_units[:3]):
        print(f"\nUnit {i+1} [{unit.start_time:.1f}s - {unit.end_time:.1f}s]:")
        print(f"  Visual: {unit.visual_text}")
        print(f"  Audio: {unit.audio_text}")


Processing video: /kaggle/input/cs431videos/6.3. CS431 - Chuong 6 Part 3 Bieu dien tu bang Vector O57P9YHZOE0.mp4
Output directory: output


Extracting frames:   0%|          | 0/34403 [00:00<?, ?it/s]

Extracting frames from /kaggle/input/cs431videos/6.3. CS431 - Chuong 6 Part 3 Bieu dien tu bang Vector O57P9YHZOE0.mp4


Processing frames: 100%|██████████| 34403/34403 [00:21<00:00, 1573.06frame/s]


CLIP encoding:   0%|          | 0/36 [00:00<?, ?it/s]

Processing shots:   0%|          | 0/3 [00:00<?, ?it/s]

Keyframe selection:   0%|          | 0/8 [00:00<?, ?it/s]

E0000 00:00:1763308528.078240    8519 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763308528.084893    8519 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

You are using a model of type deepseek_vl_v2 to instantiate a model of type DeepseekOCR. This is not supported for all configurations of models and can yield errors.
Some weights of DeepseekOCRForCausalLM were not initialized from the model checkpoint at deepseek-ai/DeepSeek-OCR and are newly initialized: ['model.vision_model.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


OCR processing:   0%|          | 0/8 [00:00<?, ?it/s]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. Calling `get_max_cache()` will raise error from v4.48
The attention layers in this model are transitioning from computing the RoPE embeddings internally through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed `position_embeddings` (Tuple of tensors, containing cos and si

BASE:  torch.Size([1, 256, 1280])
PATCHES:  torch.Size([2, 100, 1280])
The image is a digital presentation slide with a white background and a combination of text and graphics. At the top, there is a blue banner with the text "NỘI DUNG" in white capital letters. Below the banner, there are three numbered points in black text:

1. "XỬ LÝ NGÔN NGỮ TỰ NHIÊN"
2. "HỌC SÂU TRONG XỬ LÝ NGÔN NGỮ TỰ NHIÊN"
3. "BIỂU DIỄN TỪ VỚI VECTOR"

In the center of the slide, there is a graphic of a blue banner with the text "TS. Nguyễn Vinh Tiệp" in white, followed by "Giảng viên Khoa Khoa học Máy tính" in smaller font size. The banner also features a small graphic of a globe with a blue and white color scheme.

On the right side of the slide, there is a photograph of a man wearing a light-colored shirt with a dark tie. He is standing in front of a plain background and appears to be speaking or presenting. The man has short black hair and is looking slightly to his left with a neutral expression.

The slid


image: 0it [00:00, ?it/s][A

other: 0it [00:00, ?it/s][A
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


BASE:  torch.Size([1, 256, 1280])
PATCHES:  torch.Size([2, 100, 1280])
The image displays a presentation slide with a blue and white color scheme. At the top, there is a logo consisting of a stylized letter 'B' with a circular element, followed by the text "Biểu diễn từ bằng vector" in bold, dark blue font. Below this title, there are two bullet points in a lighter blue font. The first bullet point states "Biểu diễn từ bằng vector rất quan trọng khi áp dụng vào các mô hình máy học," which translates to "Vector representation is very important when applying it to various machine learning models." The second bullet point reads "Các kỹ thuật thường được sử dụng: One-hot vector, Bag-of-words hay BOW," meaning "Common techniques used include One-hot vector, Bag-of-words, and BOW."

On the right side of the slide, there is a photograph of a man wearing a light blue shirt and a dark blue tie. He has short black hair, is smiling, and appears to be speaking or presenting. The man is standing in


image: 0it [00:00, ?it/s][A

other: 0it [00:00, ?it/s][A
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


BASE:  torch.Size([1, 256, 1280])
PATCHES:  torch.Size([2, 100, 1280])
The image displays a presentation slide with a blue and white color scheme. At the top, there is a logo consisting of a blue circle with a white swirl design inside it. Below the logo, the slide is titled "Biểu diễn từ bằng vector" in bold, black font.

The slide contains bullet points in Vietnamese, which translate to "1. Biểu diễn từ bằng vector rất quan trọng khi áp dụng vào các mô hình máy học" and "2. Các kỹ thuật thường được sử dụng: One-hot vector, Bag-of-words hay BOW. Biểu diễn bằng ngữ cảnh."

On the right side of the slide, there is a photograph of a man wearing a light blue shirt with a collar. He has short black hair, is looking slightly to his left with a neutral expression, and is wearing a dark tie. The man is standing in front of a blurred background that does not provide any additional context.

At the bottom of the slide, there is a footer in blue text that reads "Thực hiện bởi Trường Đại học Công


image: 0it [00:00, ?it/s][A

other: 0it [00:00, ?it/s][A
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


BASE:  torch.Size([1, 256, 1280])
PATCHES:  torch.Size([2, 100, 1280])
The image displays a presentation slide with a blue and white color scheme. At the top, there is a logo consisting of a blue circle with a white swirl design inside it. Below the logo, the slide is titled "Biểu diễn từ bằng vector" in bold, black font.

The slide contains bullet points in Vietnamese, which translate to "1. Biểu diễn từ bằng vector rất quan trọng khi áp dụng vào các mô hình máy học" and "2. Các kỹ thuật thường được sử dụng: One-hot vector, Bag-of-words hay BOW. Biểu diễn bằng ngữ cảnh."

On the right side of the slide, there is a photograph of a man wearing a light blue shirt with a collar. He has short black hair, is smiling, and looking directly at the camera. The man appears to be standing in an indoor setting with a blurred background that does not provide any additional context.

The bottom of the slide includes a footer in blue text that reads "Thực hiện bởi Trường Đại học Công nghệ Thông tin, 


image: 0it [00:00, ?it/s][A

other: 0it [00:00, ?it/s][A
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


BASE:  torch.Size([1, 256, 1280])
PATCHES:  torch.Size([2, 100, 1280])
The image displays a presentation slide with a title "Cách 1: Biểu diễn với One-hot vector" which translates to "Method 1: Representing with One-hot vector". The slide is divided into two main sections.

On the left side, there is a bullet-pointed list with the following text: "Trước đây, từ được xem là một phần tử trong một trường hợp đặc biệt." This translates to "Previously, we have seen that a vector is a part of a vector space. In this special case."

Below this list, there is a mathematical expression: "Một số 1, bởi 1, bởi 0 → y nghĩa: vị trí của từ trong trường hợp đặc biệt." This translates to "One vector, by 1, by 0 → the meaning of the vector is: the position of the vector in the special case."

On the right side, there is a diagram with a red arrow pointing from a vector labeled "One-hot vector" to a mathematical expression "Một số 1, bởi 1, bởi 0 → vị trí của từ trong trường hợp đặc biệt." This translat