In [None]:
import os
import re
import json
import csv
import numpy as np
from collections import deque
from itertools import combinations
import cv2
import torch
from tqdm import tqdm
from ultralytics import YOLO
import mediapipe as mp
from pytorchvideo.models.hub import c2d_r50

# -------------------- Annotation Processing -------------------- #
class AnnotationParser:
    def __init__(self, annotation_path):
        self.annotations = self._parse_annotations(annotation_path)
        
    def _parse_annotations(self, path):
        """Parse annotations in Temporal_Anomaly_Annotation.txt format"""
        annotations = {}
        with open(path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) < 6:
                    continue
                
                video_name = parts[0]
                label = self._extract_label(video_name)
                segments = []
                
                # Process all frame pairs
                for i in range(2, len(parts), 2):
                    start = int(parts[i])
                    end = int(parts[i+1])
                    if start == -1 or end == -1:
                        continue
                    segments.append({'start_frame': start, 'end_frame': end})
                
                annotations[video_name] = {
                    'label': label,
                    'segments': segments,
                    'event_type': parts[1]
                }
        return annotations
    
    def _extract_label(self, name):
        """Extract label from video name up to first digit"""
        match = re.match(r'^[^\d]*', name)
        return match.group(0).rstrip('_') if match else 'unknown'

# -------------------- Interaction Tracking -------------------- #
class InteractionTracker:
    def __init__(self, window_size=30):
        self.history = deque(maxlen=window_size)
        
    def update(self, current_features):
        """Track temporal interaction patterns"""
        """Track temporal interaction patterns with numpy conversion"""
        # Convert features to numpy array
        current_arr = np.array(current_features, dtype=np.float32)
        features = {
            'current': current_arr,
            'velocity': np.zeros_like(current_arr),
            'acceleration': np.zeros_like(current_arr),
            'duration': 0
        }
        
        if len(self.history) >= 1:
            features['velocity'] = current_arr - self.history[-1]
            
            if len(self.history) >= 2:
                prev_velocity = self.history[-1] - self.history[-2]
                features['acceleration'] = features['velocity'] - prev_velocity
                
        features['duration'] = len(self.history)
        self.history.append(current_arr)
        return features
    
    def group_analysis(self, people):
        """Analyze group dynamics between 3+ people"""
        if len(people) < 3:
            return {
                'social_force': -1.0,
                'dominance': -1,
                'formation': -1.0
            }
            
        positions = np.array([p['bbox_center'] for p in people])
        centroid = np.mean(positions, axis=0)
        
        return {
            'social_force': self._calc_social_force(positions),
            'dominance': self._calc_dominance(positions, centroid),
            'formation': self._detect_formation(positions)
        }
    
    def _calc_social_force(self, positions):
        forces = []
        for i in range(len(positions)):
            others = np.delete(positions, i, axis=0)
            diffs = others - positions[i]
            dists = np.linalg.norm(diffs, axis=1)
            forces.append(np.sum(1 / (dists + 1e-9)))
        return np.mean(forces)
    
    def _calc_dominance(self, positions, centroid):
        return np.linalg.norm(positions - centroid, axis=1).argmin()
    
    def _detect_formation(self, positions):
        angles = []
        for trio in combinations(positions, 3):
            v1 = trio[1] - trio[0]
            v2 = trio[2] - trio[0]
            angles.append(np.arccos(np.dot(v1, v2) / (np.linalg.norm(v1)*np.linalg.norm(v2) + 1e-9)))
        return np.mean(angles)

# -------------------- Main Processing Pipeline -------------------- #
class EnhancedInteractionPipeline:
    def __init__(self, config):
        self.config = config
        self.yolo = YOLO('yolov8x-seg.pt')
        self.mp_pose = mp.solutions.pose
        self.c3d_model = c2d_r50(pretrained=True).eval()
        self.tracker = InteractionTracker()
        self._setup_directories()
        self.frame_buffer = deque(maxlen=16)  # Temporal window for C3D
        
    def _setup_directories(self):
        os.makedirs(self.config['frame_dir'], exist_ok=True)
        os.makedirs(self.config['feature_dir'], exist_ok=True)
        os.makedirs(self.config['metadata_dir'], exist_ok=True)
        
    def process_videos(self, annotation_path):
        parser = AnnotationParser(annotation_path)
        for video_name, data in parser.annotations.items():
            video_path = os.path.join(self.config['video_dir'], video_name)
            if not os.path.exists(video_path):
                continue
                
            self._process_video(video_path, data)
    
    def _process_video(self, video_path, video_data):
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        cap.release()
        self.frame_buffer.clear()  # Reset buffer for new video
        
        video_name = os.path.splitext(os.path.basename(video_path))[0]
        
        for seg_idx, segment in enumerate(video_data['segments']):
            # Handle normal videos
            if segment['start_frame'] == -1:
                start_frame = 0
                end_frame = total_frames - 1
            else:
                start_frame = segment['start_frame']
                end_frame = segment['end_frame']
            
            frames = self._extract_frames(video_path, start_frame, end_frame)
            features = []
            
            for frame_idx, frame in enumerate(tqdm(frames)):
                
                # Detect people
                detections = self._detect_people(frame)
                
                # Process keypoints
                keypoints = []
                for person in detections:
                    person['keypoints'] = self._get_pose_keypoints(
                        person['cropped_frame'])
                    keypoints.append(person['keypoints'])
                
                # Calculate interactions
                interactions = self._calculate_interactions(detections, keypoints)
                
                # Extract temporal features
                temporal_features = self._get_temporal_features(frame)
                
                # Track interactions
                tracked_features = self.tracker.update(interactions)
                group_features = self.tracker.group_analysis(detections)
                
                # Save data
                self._save_frame(frame, video_name, seg_idx, frame_idx)
                features.append({
                    'spatial': interactions,
                    'temporal': temporal_features,
                    'tracked': tracked_features,
                    'group': group_features
                })
            
            # Save features and metadata
            self._save_features(features, video_name, seg_idx, video_data['label'])
            self._save_metadata(video_name, seg_idx, {
                'resolution': frame.shape[:2],
                'fps': fps,
                'num_people': [len(f['spatial']) for f in features],
                'event_type': video_data['event_type']
            })
    
    def _extract_frames(self, video_path, start_frame, end_frame):
        cap = cv2.VideoCapture(video_path)
        frames = []
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
        
        while cap.isOpened() and len(frames) <= (end_frame - start_frame):
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        
        cap.release()
        return frames
    
    def _detect_people(self, frame):
        results = self.yolo(frame)[0]
        people = []
        
        if results.boxes is None or len(results.boxes) == 0:
            return people
        
        person_indices = [
            i for i, cls in enumerate(results.boxes.cls)
            if int(cls) == 0  # Person class
        ]
        
        for idx in person_indices:
            box = results.boxes[idx]
            x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
            
            mask = None
            if results.masks and idx < len(results.masks):
                mask = results.masks[idx].data[0].cpu().numpy()
            
            cropped = frame[y1:y2, x1:x2]
            
            people.append({
                'bbox': [x1, y1, x2, y2],
                'bbox_center': [(x1+x2)//2, (y1+y2)//2],
                'cropped_frame': cropped,
                'mask': mask,
                'confidence': float(box.conf)
            })
        
        return people
    
    def _get_pose_keypoints(self, cropped_frame):
        """Get pose landmarks using MediaPipe"""
        with self.mp_pose.Pose(
            static_image_mode=True,
            min_detection_confidence=0.5,
            model_complexity=2) as pose:
            
            results = pose.process(cv2.cvtColor(cropped_frame, cv2.COLOR_BGR2RGB))
            if not results.pose_landmarks:
                return np.zeros((33, 4))
            
            return np.array([[lm.x, lm.y, lm.z, lm.visibility]
                for lm in results.pose_landmarks.landmark])
    
    def _calculate_interactions(self, detections, keypoints):
        interactions = []
        for i, j in combinations(range(len(detections)), 2):
            kp1 = keypoints[i]
            kp2 = keypoints[j]
            
            interaction = {
                'pair': (i, j),
                'distance': {
                    'wrist_nose': np.linalg.norm(kp1[15][:3] - kp2[0][:3]),
                    'ankles': np.linalg.norm(kp1[27][:3] - kp2[27][:3])
                },
                'angle': {
                    'shoulder_hip': self._vector_angle(
                        kp1[11][:3] - kp1[23][:3],
                        kp2[12][:3] - kp2[24][:3])
                },
                'visibility': (kp1[:, 3].mean() + kp2[:, 3].mean()) / 2
            }
            interactions.append(interaction)
        return interactions
    
    def _vector_angle(self, v1, v2):
        return np.arccos(
            np.dot(v1, v2) / 
            (np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-9)
        )
    
    def _get_temporal_features(self, frame):
        """Process frame sequences for C3D with proper dimensions"""
        # Convert frame to tensor and add to buffer
        frame_tensor = torch.from_numpy(frame).permute(2, 0, 1)  # [C, H, W]
        self.frame_buffer.append(frame_tensor)
        
        # Wait until we have enough frames
        if len(self.frame_buffer) < 16:
            return np.zeros(2048)  # Return zeros until buffer fills
            
        # Create proper 5D tensor [1, C, T, H, W]
        sequence = torch.stack(list(self.frame_buffer), dim=1)  # [C, T, H, W]
        sequence = sequence.unsqueeze(0)  # [1, C, T, H, W]
        
        # Handle channel mismatch
        if sequence.shape[1] != 3:
            sequence = sequence[:, :3]  # Take first 3 channels if needed
            
        # Verify final dimensions
        if sequence.shape[1:] != (3, 16, 112, 112):
            sequence = F.interpolate(sequence, size=(112, 112))
            
        return self.c3d_model(sequence).squeeze().detach().numpy()
        # print("Frame shape before processing:", frame.shape)  # Debug
        # # Convert to 5D tensor (batch, channels, time, height, width)
        # if len(frame.shape) == 2 or frame.shape[-1] == 1:  
        #     print("Warning: Frame is grayscale. Converting to RGB...")
        #     frame = np.stack([frame] * 3, axis=-1)  # Convert single-channel grayscale to 3-channel RGB
        # frame_tensor = torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).unsqueeze(0)
        # print("Frame tensor shape after processing:", frame_tensor.shape)  # Debug
        # # assert frame_tensor.shape[1] == 3, "Frame should have 3 channels!"
        # return self.c3d_model(frame_tensor).detach().numpy()

    
    def _save_frame(self, frame, video_name, segment_idx, frame_idx):
        frame_dir = os.path.join(
            self.config['frame_dir'],
            f"{video_name}_seg{segment_idx}"
        )
        os.makedirs(frame_dir, exist_ok=True)
        cv2.imwrite(
            os.path.join(frame_dir, f"frame_{frame_idx:04d}.jpg"),
            cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
        )
    
    def _save_features(self, features, video_name, segment_idx, label):
        # CSV Format
        csv_path = os.path.join(
            self.config['feature_dir'],
            f"{video_name}_seg{segment_idx}_features.csv"
        )
        with open(csv_path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(['frame', 'label', 'spatial', 'temporal', 'tracked', 'group'])
            for idx, feat in enumerate(features):
                writer.writerow([
                    idx,
                    label,
                    json.dumps(feat['spatial']),
                    json.dumps(feat['temporal']),
                    json.dumps(feat['tracked']),
                    json.dumps(feat['group'])
                ])
        
        # NPZ Format
        npz_path = os.path.join(
            self.config['feature_dir'],
            f"{video_name}_seg{segment_idx}_features.npz"
        )
        np.savez_compressed(
            npz_path,
            spatial=np.array([f['spatial'] for f in features], dtype=object),
            temporal=np.array([f['temporal'] for f in features]),
            tracked=np.array([f['tracked'] for f in features], dtype=object),
            group=np.array([f['group'] for f in features], dtype=object),
            label=label
        )
    
    def _save_metadata(self, video_name, segment_idx, metadata):
        meta_path = os.path.join(
            self.config['metadata_dir'],
            f"{video_name}_seg{segment_idx}_meta.json"
        )
        with open(meta_path, 'w') as f:
            json.dump(metadata, f)

# -------------------- Execution -------------------- #
if __name__ == "__main__":
    config = {
        'video_dir': r'D:\CLASS NOTES\EPICS\Model Dataset\Anomaly-Videos',
        'frame_dir': r'D:\CLASS NOTES\EPICS\Model Dataset\Extracted_Frames',
        'feature_dir': r'D:\CLASS NOTES\EPICS\Model Dataset\Extracted_feautures',
        'metadata_dir': './processed/metadata',
        'annotation_path': r'D:\CLASS NOTES\EPICS\Model Dataset\Temporal_Anomaly_Annotation_edited.txt'
    }
    
    pipeline = EnhancedInteractionPipeline(config)
    pipeline.process_videos(config['annotation_path'])