In [2]:
import os
import cv2
import torch
import numpy as np
import joblib
from collections import deque
from itertools import combinations
from torch import nn
from ultralytics import YOLO
import mediapipe as mp
from pytorchvideo.models.hub import slow_r50
import logging
from scipy.stats import mode

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class InteractionAwareModel(nn.Module):
    def __init__(self, input_size, num_classes):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.main(x)

class VideoAnomalyPredictor:
    def __init__(self, config):
        self.config = config
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.half_precision = torch.cuda.is_available()
        self.load_models()
        self.initialize_components()
        
        self.feature_config = {
            "feature_dims": {
                "spatial": 36,
                "temporal": 10,
                "tracked": 4
            },
            "interaction_weights": {
                "tracked": 2.5,
                "group": 3.0
            }
        }

    def load_models(self):
        """Load models with dtype consistency checks"""
        try:
            # Load preprocessing components
            self.scaler = joblib.load(self.config['scaler_save_path'])
            self.label_mapping = joblib.load(self.config['label_mapping_path'])
            self.reverse_mapping = {v: k for k, v in self.label_mapping.items()}

            # Initialize main model
            self.model = InteractionAwareModel(
                input_size=self.scaler.n_features_in_,
                num_classes=len(self.label_mapping)
            ).to(self.device)
            
            state_dict = torch.load(self.config['model_save_path'], map_location=self.device)
            self.model.load_state_dict(state_dict)
            self.model.eval()

            # Initialize detection models
            self.yolo = YOLO('yolov8x-seg.pt').to(self.device)
            self.mp_pose = mp.solutions.pose.Pose(
                static_image_mode=False,
                min_detection_confidence=0.5,
                model_complexity=2)

            # Initialize 3D CNN
            self.c3d_model = slow_r50(pretrained=True).eval().to(self.device)

            # Handle half-precision
            if self.half_precision:
                self.model.half()
                self.c3d_model.half()
                self.yolo.model.half()
                logger.info("Converted models to half-precision")
                
                # Convert scaler parameters to float16
                self.scaler.mean_ = self.scaler.mean_.astype(np.float16)
                self.scaler.var_ = self.scaler.var_.astype(np.float16)
                self.scaler.scale_ = self.scaler.scale_.astype(np.float16)

            logger.debug(f"Model dtype: {next(self.model.parameters()).dtype}")
            logger.debug(f"YOLO dtype: {next(self.yolo.model.parameters()).dtype}")

        except Exception as e:
            logger.error(f"Model loading failed: {str(e)}")
            raise

    def initialize_components(self):
        """Initialize buffers and trackers"""
        self.frame_buffer = deque(maxlen=32)
        self.confidence_history = deque(maxlen=16)
        self.tracking_history = deque(maxlen=5)

    def preprocess_frame(self, frame):
        """Process frame with GPU acceleration support"""
        try:
            # Use CUDA-accelerated resize if available
            if cv2.cuda.getCudaEnabledDeviceCount() > 0:
                gpu_frame = cv2.cuda_GpuMat()
                gpu_frame.upload(frame)
                resized = cv2.cuda.resize(gpu_frame, (224, 224)).download()
            else:
                resized = cv2.resize(frame, (224, 224))
                
            frame = cv2.cvtColor(resized, cv2.COLOR_BGR2RGB)
            tensor = torch.from_numpy(frame).permute(2, 0, 1)
            tensor = tensor.to(self.device).float() / 255.0
            
            if self.half_precision:
                tensor = tensor.half()
                
            return tensor
        except Exception as e:
            logger.error(f"Frame preprocessing failed: {str(e)}")
            return None

    def vector_angle(self, v1, v2):
        """Calculate angle between vectors with stability"""
        cos_theta = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2) + 1e-9)
        return np.arccos(np.clip(cos_theta, -1.0, 1.0))

    def extract_spatial_features(self, people):
        """Spatial features with kinematic validation"""
        features = []
        try:
            if len(people) >= 2:
                logger.debug(f"Processing spatial features for {len(people)} people")
                p1, p2 = people[:2]
                
                # Joint position features
                joint_pairs = [(15, 0), (27, 27), (11, 23), (12, 24)]
                for j1, j2 in joint_pairs:
                    kp1 = p1['keypoints'][j1][:3]
                    kp2 = p2['keypoints'][j2][:3]
                    features.append(np.linalg.norm(kp1 - kp2))
                
                # Velocity features
                if len(self.tracking_history) > 1:
                    prev_pos = self.tracking_history[-1]
                    curr_pos = [p['bbox_center'] for p in [p1, p2]]
                    dt = 1/30  # Assuming 30 FPS
                    velocities = [
                        np.linalg.norm(np.array(curr) - np.array(prev)) / dt
                        for curr, prev in zip(curr_pos, prev_pos)
                    ]
                    features.extend(velocities)
                
                self.tracking_history.append([p['bbox_center'] for p in [p1, p2]])
            
            # Pad to 36 features if needed
            features += [0.0] * (36 - len(features))
            return features[:36]
        except Exception as e:
            logger.error(f"Spatial feature error: {str(e)}")
            return [0.0] * 36

    def extract_temporal_features(self):
        """Temporal features with buffer validation"""
        try:
            if len(self.frame_buffer) < 16:
                logger.warning("Insufficient frames for temporal features")
                return np.zeros(10)
            
            # Process sequence through 3D CNN
            sequence = torch.stack(list(self.frame_buffer))  # [T, C, H, W]
            sequence = sequence.permute(1, 0, 2, 3)        # [C, T, H, W]
            sequence = sequence.unsqueeze(0)                # [1, C, T, H, W]
            
            if self.half_precision:
                sequence = sequence.half()
                
            features = self.c3d_model(sequence).squeeze().cpu().numpy()
            return features[:10]
        except Exception as e:
            logger.error(f"Temporal feature error: {str(e)}")
            return np.zeros(10)

    def process_interaction_features(self, spatial_features):
        """Interaction features with dynamic weighting"""
        try:
            # Dynamic weight adjustment based on feature confidence
            base_weights = np.array([2.5, 1.0, 1.0, 3.0])
            confidence = np.clip(np.mean(spatial_features[:3]), 0, 1)
            weights = base_weights * (1 + confidence)
            
            return [
                spatial_features[0] * weights[0],
                spatial_features[1] * weights[1],
                spatial_features[2] * weights[2],
                weights[3]  # Group weight
            ]
        except Exception as e:
            logger.error(f"Interaction feature error: {str(e)}")
            return [0.0] * 4

    def extract_features(self, frame_tensor):
        """Complete feature pipeline with validation"""
        try:
            people = []
            with torch.no_grad():
                results = self.yolo(frame_tensor[None])[0]
                if results.boxes is not None:
                    for idx in (i for i, cls in enumerate(results.boxes.cls) if int(cls) == 0):
                        box = results.boxes[idx]
                        x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
                        cropped = frame_tensor[:, y1:y2, x1:x2].cpu().numpy().transpose(1, 2, 0)
                        
                        results_pose = self.mp_pose.process((cropped * 255).astype(np.uint8))
                        kps = np.zeros((33, 4)) if not results_pose.pose_landmarks else np.array([
                            [lm.x, lm.y, lm.z, lm.visibility] 
                            for lm in results_pose.pose_landmarks.landmark
                        ])
                        people.append({
                            'keypoints': kps,
                            'bbox_center': [(x1+x2)//2, (y1+y2)//2],
                            'confidence': float(box.conf)
                        })

            # Feature extraction and validation
            spatial = self.extract_spatial_features(people)
            temporal = self.extract_temporal_features()
            interaction = self.process_interaction_features(spatial)
            
            combined = np.concatenate([
                spatial[:36],
                temporal[:10],
                interaction[:4]
            ])
            
            # Validate features
            if len(combined) != self.scaler.n_features_in_:
                logger.error(f"Feature dimension mismatch: {len(combined)} vs {self.scaler.n_features_in_}")
                return np.zeros(self.scaler.n_features_in_)
            
            if np.all(np.abs(combined) < 1e-6):
                logger.error("All features are zero!")
                return np.zeros(self.scaler.n_features_in_)
                
            if np.any(np.isnan(combined)):
                logger.error("NaN values detected in features!")
                return np.zeros(self.scaler.n_features_in_)
            
            return combined
            
        except Exception as e:
            logger.error(f"Feature extraction failed: {str(e)}")
            return np.zeros(self.scaler.n_features_in_)

    def predict_video(self, video_path):
        """Enhanced prediction pipeline with resource management"""
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            logger.error(f"Failed to open video: {video_path}")
            return {'class': 'unknown', 'confidence': 0.0}
        
        predictions = []
        class_confidence = {cls: [] for cls in self.label_mapping.keys()}
        
        try:
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break

                frame_tensor = self.preprocess_frame(frame)
                if frame_tensor is None:
                    continue
                
                self.frame_buffer.append(frame_tensor)
                
                try:
                    with torch.no_grad(), torch.amp.autocast(device_type='cuda', enabled=self.half_precision):
                        features = self.extract_features(frame_tensor)
                        
                        # Skip frame if features are invalid
                        if np.all(features == 0):
                            logger.warning("Skipping frame with invalid features")
                            continue
                            
                        # Normalization and prediction
                        features = self.scaler.transform([features])
                        tensor = torch.tensor(features, 
                                            dtype=torch.float16 if self.half_precision else torch.float32)
                        tensor = tensor.to(self.device)
                        
                        outputs = self.model(tensor)
                        probs = torch.nn.functional.softmax(outputs, dim=1).cpu().numpy()[0]
                        
                        # Track confidence for all classes
                        for i, prob in enumerate(probs):
                            class_name = self.reverse_mapping[i]
                            class_confidence[class_name].append(prob)
                            
                        # Adaptive confidence threshold
                        current_max = np.max(probs)
                        threshold = 0.4 if current_max > 0.7 else 0.25
                        
                        if current_max > threshold:
                            pred = np.argmax(probs)
                            predictions.append(pred)

                        # Visualization
                        display_frame = cv2.resize(frame, (1280, 720))
                        avg_confidence = np.mean(self.confidence_history) if self.confidence_history else 0.0
                        label = self.reverse_mapping.get(np.argmax(probs), "unknown")
                        cv2.putText(display_frame, f"{label} ({avg_confidence:.2f})", 
                                  (20, 60), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 3)
                        cv2.imshow('Prediction', display_frame)
                        if cv2.waitKey(1) & 0xFF == ord('q'):
                            break
                except Exception as e:
                    logger.error(f"Frame processing error: {str(e)}")
        finally:
            cap.release()
            cv2.destroyAllWindows()
            # Cleanup models
            self.yolo = None
            self.mp_pose.close()
            return self.aggregate_predictions(class_confidence)

    def aggregate_predictions(self, class_confidence):
        """Robust aggregation with confidence analysis"""
        try:
            # Calculate mean confidence per class
            mean_conf = {}
            for cls, confs in class_confidence.items():
                mean_conf[cls] = np.mean(confs) if confs else 0.0
                
            # Sort classes by confidence
            sorted_classes = sorted(mean_conf.items(), key=lambda x: x[1], reverse=True)
            
            # Create distribution dictionary
            distribution = {}
            for cls, confs in class_confidence.items():
                count = len(confs)
                avg = np.mean(confs) if count > 0 else 0.0
                distribution[cls] = {'count': count, 'average_confidence': avg}
            
            return {
                'predicted_class': sorted_classes[0][0],
                'confidence': float(sorted_classes[0][1]),
                'top_candidates': sorted_classes[:3],
                'class_distribution': distribution
            }
        except Exception as e:
            logger.error(f"Aggregation failed: {str(e)}")
            return {
                'predicted_class': 'unknown',
                'confidence': 0.0,
                'class_distribution': {}
            }

if __name__ == "__main__":
    config = {
        "model_save_path": r"D:\CLASS NOTES\EPICS\Model\final_model.pth",
        "scaler_save_path": r"D:\CLASS NOTES\EPICS\Model\final_scaler.joblib",
        "label_mapping_path": r"D:\CLASS NOTES\EPICS\Model\label_mapping.joblib"
    }
    
    try:
        predictor = VideoAnomalyPredictor(config)
        result = predictor.predict_video(r"D:\CLASS NOTES\EPICS\Model Dataset\Anomaly-Videos\Shooting002_x264.mp4")
        print(f"Final Prediction: {result}")
    except Exception as e:
        print(f"Critical error: {str(e)}")
        cv2.destroyAllWindows()

INFO:__main__:Converted models to half-precision
  self.scaler.var_ = self.scaler.var_.astype(np.float16)



0: 224x224 1 person, 3 cars, 1 truck, 20.0ms
Speed: 1.0ms preprocess, 20.0ms inference, 4.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 20.0ms
Speed: 0.0ms preprocess, 20.0ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 19.4ms
Speed: 0.0ms preprocess, 19.4ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 40.0ms
Speed: 0.0ms preprocess, 40.0ms inference, 5.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 18.0ms
Speed: 0.0ms preprocess, 18.0ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 19.0ms
Speed: 0.0ms preprocess, 19.0ms inference, 2.5ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 18.0ms
Speed: 0.0ms preprocess, 18.0ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 18.0ms
Speed: 0.0ms preprocess, 18.0ms inference, 4.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 17.5ms
Speed: 0.0ms preprocess, 17.5ms inference, 3.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 18.0ms
Speed: 0.0ms preprocess, 18.0ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 18.0ms
Speed: 0.0ms preprocess, 18.0ms inference, 3.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 18.0ms
Speed: 0.0ms preprocess, 18.0ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 18.0ms
Speed: 0.0ms preprocess, 18.0ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 19.0ms
Speed: 0.0ms preprocess, 19.0ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 24.0ms
Speed: 0.0ms preprocess, 24.0ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)





0: 224x224 1 person, 3 cars, 1 truck, 25.0ms
Speed: 0.0ms preprocess, 25.0ms inference, 3.0ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 1 person, 3 cars, 1 truck, 25.0ms
Speed: 0.0ms preprocess, 25.0ms inference, 3.0ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 1 person, 3 cars, 1 truck, 35.0ms
Speed: 1.0ms preprocess, 35.0ms inference, 5.0ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 1 person, 3 cars, 35.0ms
Speed: 0.0ms preprocess, 35.0ms inference, 5.0ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 1 person, 3 cars, 21.0ms
Speed: 0.0ms preprocess, 21.0ms inference, 3.0ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 1 person, 3 cars, 21.0ms
Speed: 0.0ms preprocess, 21.0ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 1 person, 3 cars, 19.0ms
Speed: 0.0ms preprocess, 19.0ms inference, 2.0ms postprocess per image at shape (1, 3, 224, 224)

0: 224x224 1 person, 3 cars, 17.5m