## 1. TRY 1 

In [None]:
!pip install -q face-recognition==1.3.0 albumentations==1.3.0 decord==0.6.0 timm==0.6.5 opencv-python

In [None]:
import os
import torch
from models.pred_func import load_genconvit, pred_vid, face_rec, preprocess_frame, is_video, extract_frames
from models.config import load_config
from typing import List, Dict, Union
import csv
import numpy as np
from PIL import Image
from tqdm import tqdm
import cv2
import multiprocessing
from pathlib import Path

config = load_config()

def save_results_to_csv(
    results: List[Dict[str, Union[str, int, List[float]]]],
    filepath: str,
    for_submit=False,
) -> None:
    """Save results to CSV with sorted filenames"""
    sorted_results = sorted(results, key=lambda x: x['name'])
    
    with open(filepath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        row = w.writerow
        if for_submit:
            row(("filename", "label"))
            for r in sorted_results:
                row((r['name'], r['pred']))
        else:
            row(("name", "pred", "pred_proba"))
            for r in sorted_results:
                row((r["name"], r["pred"], r["pred_proba"]))


def get_files_paths(main_path: str, exs: List[str]) -> List[str]:
    """Collect all files with specified extensions"""
    exs = {("." + ext.lstrip(".")).lower() for ext in exs}
    results = []
    
    for root, _, files in os.walk(main_path):
        for fname in files:
            abs_path = os.path.abspath(os.path.join(root, fname))
            _, ext = os.path.splitext(fname)
            if ext.lower() in exs:
                results.append(abs_path)
    
    return results

def calculate_laplacian_variance(frames):
    """Calculate Laplacian variance (sharpness/detail measure)"""
    if len(frames) == 0:
        return 0.0
    
    laplacian_vars = []
    for i in range(len(frames)):
        frame = frames[i]
        if len(frame.shape) == 3 and frame.shape[2] == 3:
            gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        else:
            gray = frame
        
        laplacian = cv2.Laplacian(gray, cv2.CV_64F)
        variance = laplacian.var()
        laplacian_vars.append(variance)
    
    return np.mean(laplacian_vars)

def check_frequency_artifacts(image):
    """Check for deepfake artifacts in frequency domain"""
    if len(image.shape) == 3:
        image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    
    f_transform = np.fft.fft2(image)
    f_shift = np.fft.fftshift(f_transform)
    magnitude_spectrum = np.abs(f_shift)
    
    h, w = magnitude_spectrum.shape
    center_h, center_w = h // 2, w // 2
    
    high_freq_region = magnitude_spectrum.copy()
    mask_radius = min(h, w) // 4
    y, x = np.ogrid[:h, :w]
    mask = (x - center_w)**2 + (y - center_h)**2 <= mask_radius**2
    high_freq_region[mask] = 0
    
    total_energy = np.sum(magnitude_spectrum)
    high_freq_energy = np.sum(high_freq_region)
    high_freq_ratio = high_freq_energy / (total_energy + 1e-10)
    
    return high_freq_ratio


def check_face_boundary_artifacts(image):
    """Check for artifacts around face boundaries"""
    h, w = image.shape[:2]
    border_width = int(min(h, w) * 0.1)
    
    top_border = image[:border_width, :]
    bottom_border = image[-border_width:, :]
    left_border = image[:, :border_width]
    right_border = image[:, -border_width:]
    
    border_regions = [top_border, bottom_border, left_border, right_border]
    border_stds = [np.std(region) for region in border_regions]
    avg_border_std = np.mean(border_stds)
    
    center_region = image[border_width:-border_width, border_width:-border_width]
    center_std = np.std(center_region)
    
    boundary_anomaly = abs(avg_border_std - center_std) / (center_std + 1e-10)
    return boundary_anomaly


def check_color_consistency(image):
    """Check color consistency across RGB channels"""
    r_mean = np.mean(image[:,:,0])
    g_mean = np.mean(image[:,:,1])
    b_mean = np.mean(image[:,:,2])
    
    channel_imbalance = np.std([r_mean, g_mean, b_mean]) / (np.mean([r_mean, g_mean, b_mean]) + 1e-10)
    return channel_imbalance


def perform_secondary_checks(image, laplacian_var):
    """Perform secondary verification for sharp images"""
    freq_artifact = check_frequency_artifacts(image)
    boundary_artifact = check_face_boundary_artifacts(image)
    color_inconsistency = check_color_consistency(image)
    
    freq_score = 0.0
    if freq_artifact < 0.01 or freq_artifact > 0.12:
        freq_score = min(abs(freq_artifact - 0.05) / 0.05, 1.0)
    
    boundary_score = min(boundary_artifact / 1.0, 1.0)
    color_score = min(color_inconsistency / 0.5, 1.0)
    
    suspicion_score = (
        0.5 * freq_score +
        0.2 * boundary_score +
        0.3 * color_score
    )
    
    return suspicion_score

# ===== CPU WORKER FUNCTION (runs in parallel) =====
def process_single_file(args):
    """Worker function: CPU tasks (face detection, Laplacian)"""
    file_path, num_frames, consecutive = args
    filename = os.path.basename(file_path)
    
    try:
        if is_video(file_path):
            # Extract raw frames
            raw_frames = extract_frames(file_path, num_frames, consecutive=consecutive)
            # Detect faces
            face_crops, count = face_rec(raw_frames)
            
            if count > 0:
                # Calculate Laplacian on face regions
                laplacian_var = calculate_laplacian_variance(face_crops[:count])
                return (filename, face_crops[:count], laplacian_var, True, None)  # is_video=True
            else:
                return (filename, None, 0.0, True, None)
        else:
            # Load image
            im = Image.open(file_path).convert('RGB')
            arr = np.asarray(im)
            
            # Detect face
            face, count = face_rec([arr])
            
            if count > 0:
                laplacian_var = calculate_laplacian_variance(face[:count])
                raw_image = face[0]
                return (filename, face[:count], laplacian_var, False, raw_image)  # is_video=False
            else:
                return (filename, None, 0.0, False, None)
    
    except Exception as e:
        return (filename, None, 0.0, None, str(e))


def extract_features(model, df, fp16_mode=False):
    """Extract features and calculate statistics"""
    if df.shape[0] == 0:
        return None
    
    with torch.no_grad():
        if fp16_mode:
            df = df.half()
        features = model.extract_features(df)
        features_np = features.cpu().numpy()
        
        feat_stats = {
            'mean': features_np.mean(axis=0),
            'std': features_np.std(axis=0).mean(),
            'num_frames': features_np.shape[0]
        }
        
        if features_np.shape[0] > 1:
            frame_diffs = np.diff(features_np, axis=0)
            feat_stats['temporal_variance'] = np.mean(np.var(frame_diffs, axis=0))
            
            per_frame_std = np.std(features_np, axis=1).mean()
            feat_stats['per_frame_std'] = per_frame_std
            
            frame_distances = np.linalg.norm(features_np - features_np.mean(axis=0), axis=1)
            threshold = np.median(frame_distances) + 2 * np.std(frame_distances)
            outlier_ratio = np.sum(frame_distances > threshold) / features_np.shape[0]
            feat_stats['outlier_ratio'] = outlier_ratio
        else:
            feat_stats['temporal_variance'] = 0.0
            feat_stats['per_frame_std'] = np.std(features_np)
            feat_stats['outlier_ratio'] = 0.0
        
        return feat_stats


# ===== CONFIGURATION =====
DATA_PATH = "data"  # Change this to your test data path
NUM_FRAMES = 25
FP16_MODE = False
CONSECUTIVE_FRAMES = True
NUM_WORKERS = None  # None = auto-detect (CPU count - 1)

# Model setup
net = 'genconvit'
ed_weight = 'genconvit_ed_inference'
vae_weight = 'genconvit_vae_inference'

# Load model
print("Loading model...")
model = load_genconvit(config, net, ed_weight, vae_weight, FP16_MODE)
print(f"✓ Loaded {net} network")

# Collect files
all_files = get_files_paths(DATA_PATH, exs=['png', 'jpg', 'jpeg', 'mp4'])
print(f"✓ Found {len(all_files)} files")

# Setup multiprocessing
num_workers = NUM_WORKERS if NUM_WORKERS else min(max(1, multiprocessing.cpu_count() - 1), 8)
print(f"✓ Using {num_workers} worker processes")
print(f"✓ Frame extraction: {'consecutive' if CONSECUTIVE_FRAMES else 'evenly spaced'}\n")

# Prepare worker arguments
worker_args = [(file_path, NUM_FRAMES, CONSECUTIVE_FRAMES) for file_path in all_files]


# ===== MAIN PROCESSING LOOP =====
results = []

print("Start Evaluating...")
with multiprocessing.Pool(processes=num_workers) as pool:
    with tqdm(total=len(all_files), desc="Processing files") as pbar:
        for result in pool.imap_unordered(process_single_file, worker_args):
            filename, face_crops, laplacian_var, is_video_file, raw_image = result
            
            # Check for errors
            if isinstance(raw_image, str):  # Error message
                print(f"Error processing {filename}: {raw_image}")
                results.append({
                    "name": filename,
                    "pred": 0,
                    "pred_proba": [0.5, 0.5]
                })
                pbar.update(1)
                continue
            
            # No face detected
            if face_crops is None:
                results.append({
                    "name": filename,
                    "pred": 0,
                    "pred_proba": [0.5, 0.5]
                })
                pbar.update(1)
                continue
            
            # ===== GPU INFERENCE =====
            df = preprocess_frame(face_crops)
            if FP16_MODE and df.shape[0] > 0:
                df = df.half()
            
            real_proba, fake_proba = (
                pred_vid(df, model)
                if df.shape[0] > 0
                else (0.0, 0.5)
            )
            
            feat_stats = extract_features(model, df, FP16_MODE)
            if feat_stats:
                feat_stats['laplacian_var'] = laplacian_var
            

            
            # ===== MULTI-LAYER DETECTION =====
            temporal_var = feat_stats['temporal_variance'] if feat_stats else 0.0
            force_fake = False
            
            HIGH_LAPLACIAN = 180
            HIGH_TEMPORAL_VAR = 0.02
            EXTREME_TEMPORAL_VAR = 0.04
            
            if is_video_file:
                # VEO-style: sharp + jerky
                if laplacian_var > HIGH_LAPLACIAN and temporal_var > HIGH_TEMPORAL_VAR:
                    veo_boost = 1.0 + (temporal_var * 15)
                    fake_proba = min(1.0, fake_proba * veo_boost)
                    if temporal_var > EXTREME_TEMPORAL_VAR:
                        force_fake = True
                
                # SORA-style: smooth + blurry
                elif laplacian_var < HIGH_LAPLACIAN and temporal_var < HIGH_TEMPORAL_VAR:
                    if laplacian_var < 50:
                        fake_proba = min(1.0, fake_proba * 1.3)
                        if laplacian_var < 30:
                            force_fake = True
                
                # Low-quality: blurry + jerky
                elif laplacian_var < HIGH_LAPLACIAN and temporal_var > HIGH_TEMPORAL_VAR:
                    fake_proba = min(1.0, fake_proba * 1.5)
                    force_fake = True
                
                # Spatial anomaly
                if feat_stats and feat_stats.get('outlier_ratio', 0) > 0.3:
                    force_fake = True
            
            else:
                # Image detection
                SHARPNESS_THRESHOLD = 180.0
                SUSPICION_THRESHOLD = 0.60
                
                # if laplacian_var > 200:
                #     fake_proba = min(1.0, fake_proba * 1.3)
                #     if laplacian_var > 300:
                #         force_fake = True
                # elif laplacian_var < 40:
                #     fake_proba = min(1.0, fake_proba * 1.2)
                #     if laplacian_var < 20:
                #         force_fake = True
                
                # if feat_stats and feat_stats.get('per_frame_std', 0) > 8.0:
                #     force_fake = True
                
                # Secondary check
                if laplacian_var > SHARPNESS_THRESHOLD and fake_proba < real_proba and raw_image is not None:
                    suspicion_score = perform_secondary_checks(raw_image, laplacian_var)
                    if suspicion_score > SUSPICION_THRESHOLD:
                        force_fake = True
            
            y = 1 if (force_fake or fake_proba >= real_proba) else 0
            
            results.append({
                "name": filename,
                "pred": y,
                "pred_proba": [real_proba, fake_proba]
            })
            
            pbar.update(1)

print("\n✓ Processing complete!")

# Save results for submission
save_results_to_csv(results, "submission.csv", for_submit=True)
print(f"✓ Saved submission to submission.csv")

# Print statistics
num_real = sum(1 for r in results if r['pred'] == 0)
num_fake = sum(1 for r in results if r['pred'] == 1)
print(f"\nPrediction Summary:")
print(f"  Real: {num_real} ({num_real/len(results)*100:.1f}%)")
print(f"  Fake: {num_fake} ({num_fake/len(results)*100:.1f}%)")

In [1]:
import aifactory.score as aif

aif.submit(model_name="GenConViT-v1-pad24",
           key="012c6af2-e485-4892-b1aa-273486827f7a"
           )

file : task
jupyter notebook
제출 완료


## 2. TRY 2 : 11/12 17:43 -> 0.53? 

In [None]:
import os
import torch
from models.pred_func import load_genconvit, pred_vid, face_rec, preprocess_frame, is_video, extract_frames
from models.config import load_config
from typing import List, Dict, Union, Tuple
import csv
import numpy as np
from PIL import Image
from tqdm import tqdm
import cv2
import multiprocessing
from pathlib import Path
import mediapipe as mp

config = load_config()

# =====================================================================
# FACIAL COMPONENT GUIDANCE (FCG) IMPLEMENTATION
# Based on DFD-FCG CVPR'25 paper
# =====================================================================

class FacialComponentAnalyzer:
    """
    Analyzes key facial components (eyes, nose, lips, skin) to detect deepfakes.
    Inspired by DFD-FCG's approach of focusing on critical facial regions.
    """
    def __init__(self):
        self.mp_face_mesh = mp.solutions.face_mesh
        self.face_mesh = self.mp_face_mesh.FaceMesh(
            static_image_mode=True,
            max_num_faces=1,
            refine_landmarks=True,
            min_detection_confidence=0.5
        )
        
        # Facial component landmark indices (MediaPipe)
        self.component_indices = {
            'left_eye': list(range(33, 42)) + list(range(133, 145)),
            'right_eye': list(range(263, 272)) + list(range(362, 374)),
            'nose': list(range(1, 9)) + [168, 197, 195, 5],
            'lips': list(range(61, 68)) + list(range(291, 298)) + list(range(375, 380)) + list(range(146, 151)),
            'skin': [10, 338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288,
                     397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136,
                     172, 58, 132, 93, 234, 127, 162, 21, 54, 103, 67, 109]
        }
    
    def extract_component_regions(self, image: np.ndarray) -> Dict[str, np.ndarray]:
        """Extract ROIs for each facial component"""
        h, w = image.shape[:2]
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if len(image.shape) == 3 else image
        
        results = self.face_mesh.process(image_rgb)
        component_regions = {}
        
        if results.multi_face_landmarks:
            landmarks = results.multi_face_landmarks[0]
            
            for component_name, indices in self.component_indices.items():
                # Get bounding box for component
                x_coords = [landmarks.landmark[i].x * w for i in indices]
                y_coords = [landmarks.landmark[i].y * h for i in indices]
                
                x_min, x_max = int(min(x_coords)), int(max(x_coords))
                y_min, y_max = int(min(y_coords)), int(max(y_coords))
                
                # Add padding
                padding = 5
                x_min = max(0, x_min - padding)
                y_min = max(0, y_min - padding)
                x_max = min(w, x_max + padding)
                y_max = min(h, y_max + padding)
                
                component_regions[component_name] = image[y_min:y_max, x_min:x_max]
        
        return component_regions
    
    def analyze_component_consistency(self, frames: List[np.ndarray]) -> Dict[str, float]:
        """
        Analyze temporal consistency of facial components across frames.
        Key insight from FCG: Eyes are most critical for generalization.
        """
        if len(frames) < 2:
            return {'eyes_score': 0.0, 'nose_score': 0.0, 'lips_score': 0.0, 'skin_score': 0.0}
        
        component_scores = {
            'eyes_score': [],
            'nose_score': [],
            'lips_score': [],
            'skin_score': []
        }
        
        for i in range(min(len(frames) - 1, 10)):  # Sample up to 10 frame pairs
            regions1 = self.extract_component_regions(frames[i])
            regions2 = self.extract_component_regions(frames[i + 1])
            
            for component in ['left_eye', 'right_eye', 'nose', 'lips', 'skin']:
                if component not in regions1 or component not in regions2:
                    continue
                
                # Calculate structural similarity
                region1 = cv2.resize(regions1[component], (64, 64)) if regions1[component].size > 0 else None
                region2 = cv2.resize(regions2[component], (64, 64)) if regions2[component].size > 0 else None
                
                if region1 is not None and region2 is not None:
                    # Compute difference
                    diff = np.mean(np.abs(region1.astype(float) - region2.astype(float)))
                    
                    # Map to score key
                    if 'eye' in component:
                        component_scores['eyes_score'].append(diff)
                    elif component == 'nose':
                        component_scores['nose_score'].append(diff)
                    elif component == 'lips':
                        component_scores['lips_score'].append(diff)
                    elif component == 'skin':
                        component_scores['skin_score'].append(diff)
        
        # Average scores
        return {
            k: np.mean(v) if v else 0.0 
            for k, v in component_scores.items()
        }
    
    def detect_component_artifacts(self, image: np.ndarray) -> Dict[str, float]:
        """Detect artifacts in individual components (for images)"""
        component_regions = self.extract_component_regions(image)
        artifact_scores = {}
        
        for component_name, region in component_regions.items():
            if region.size == 0:
                artifact_scores[component_name] = 0.0
                continue
            
            # Check for artifacts using multiple methods
            
            # 1. Frequency domain analysis
            if len(region.shape) == 3:
                gray = cv2.cvtColor(region, cv2.COLOR_RGB2GRAY)
            else:
                gray = region
            
            f_transform = np.fft.fft2(gray)
            f_shift = np.fft.fftshift(f_transform)
            magnitude = np.abs(f_shift)
            
            # High frequency ratio
            h, w = magnitude.shape
            center_h, center_w = h // 2, w // 2
            mask_radius = min(h, w) // 4
            y, x = np.ogrid[:h, :w]
            mask = (x - center_w)**2 + (y - center_h)**2 <= mask_radius**2
            
            high_freq = magnitude.copy()
            high_freq[mask] = 0
            
            freq_ratio = np.sum(high_freq) / (np.sum(magnitude) + 1e-10)
            
            # 2. Texture consistency
            laplacian = cv2.Laplacian(gray, cv2.CV_64F)
            texture_var = laplacian.var()
            
            # Combine metrics
            artifact_scores[component_name] = {
                'freq_anomaly': freq_ratio,
                'texture_var': texture_var
            }
        
        return artifact_scores


# =====================================================================
# ENHANCED TEMPORAL ANALYSIS
# =====================================================================

def analyze_optical_flow(frames: List[np.ndarray]) -> Dict[str, float]:
    """
    Optical flow analysis for motion consistency.
    Detects unnatural motion patterns in deepfakes.
    """
    if len(frames) < 2:
        return {'flow_mean': 0.0, 'flow_variance': 0.0, 'flow_inconsistency': 0.0}
    
    flow_magnitudes = []
    flow_inconsistencies = []
    
    for i in range(min(len(frames) - 1, 15)):  # Analyze up to 15 frame pairs
        frame1 = cv2.cvtColor(frames[i], cv2.COLOR_RGB2GRAY)
        frame2 = cv2.cvtColor(frames[i + 1], cv2.COLOR_RGB2GRAY)
        
        # Farneback optical flow
        flow = cv2.calcOpticalFlowFarneback(
            frame1, frame2, None,
            pyr_scale=0.5, levels=3, winsize=15,
            iterations=3, poly_n=5, poly_sigma=1.2, flags=0
        )
        
        magnitude = np.sqrt(flow[..., 0]**2 + flow[..., 1]**2)
        flow_magnitudes.append(np.mean(magnitude))
        
        # Detect sudden changes (inconsistency)
        if len(flow_magnitudes) > 1:
            flow_diff = abs(flow_magnitudes[-1] - flow_magnitudes[-2])
            flow_inconsistencies.append(flow_diff)
    
    mean_flow = np.mean(flow_magnitudes)
    flow_var = np.var(flow_magnitudes)
    
    # Count spikes (sudden motion changes - typical of AI video)
    if flow_inconsistencies:
        threshold = np.mean(flow_inconsistencies) + 1.5 * np.std(flow_inconsistencies)
        num_spikes = sum(1 for x in flow_inconsistencies if x > threshold)
    else:
        num_spikes = 0
    
    return {
        'flow_mean': mean_flow,
        'flow_variance': flow_var,
        'flow_inconsistency': np.mean(flow_inconsistencies) if flow_inconsistencies else 0.0,
        'flow_spikes': num_spikes
    }


def analyze_temporal_frequency(features_sequence: np.ndarray) -> Dict[str, float]:
    """
    FFT analysis of feature sequences to detect unnatural temporal patterns.
    AI-generated videos often have specific frequency signatures.
    """
    if features_sequence.shape[0] < 8:
        return {'freq_ratio': 1.0, 'freq_suspicious': False}
    
    freq_ratios = []
    
    # Analyze frequency content across feature dimensions
    for dim in range(min(features_sequence.shape[1], 100)):  # Sample dimensions
        signal = features_sequence[:, dim]
        
        # FFT
        fft_vals = np.fft.fft(signal)
        power_spectrum = np.abs(fft_vals) ** 2
        
        # Low vs high frequency energy
        mid_point = len(power_spectrum) // 2
        low_freq_energy = np.sum(power_spectrum[:mid_point//4])
        high_freq_energy = np.sum(power_spectrum[mid_point//4:mid_point])
        
        ratio = high_freq_energy / (low_freq_energy + 1e-10)
        freq_ratios.append(ratio)
    
    avg_ratio = np.mean(freq_ratios)
    
    # AI videos typically have abnormal frequency patterns
    is_suspicious = (avg_ratio > 2.5) or (avg_ratio < 0.2)
    
    return {
        'freq_ratio': avg_ratio,
        'freq_suspicious': is_suspicious,
        'freq_variance': np.var(freq_ratios)
    }


def calculate_laplacian_variance(frames):
    """Calculate Laplacian variance (sharpness/detail measure)"""
    if len(frames) == 0:
        return 0.0
    
    laplacian_vars = []
    for frame in frames:
        if len(frame.shape) == 3 and frame.shape[2] == 3:
            gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        else:
            gray = frame
        
        laplacian = cv2.Laplacian(gray, cv2.CV_64F)
        variance = laplacian.var()
        laplacian_vars.append(variance)
    
    return np.mean(laplacian_vars)


# =====================================================================
# FEATURE EXTRACTION WITH ENHANCED STATISTICS
# =====================================================================

def extract_features(model, df, fp16_mode=False):
    """Extract features and calculate comprehensive statistics"""
    if df.shape[0] == 0:
        return None
    
    with torch.no_grad():
        if fp16_mode:
            df = df.half()
        features = model.extract_features(df)
        features_np = features.detach().cpu().numpy()
        
        feat_stats = {
            'mean': features_np.mean(axis=0),
            'std': features_np.std(axis=0).mean(),
            'num_frames': features_np.shape[0]
        }
        
        if features_np.shape[0] > 1:
            # Temporal variance
            frame_diffs = np.diff(features_np, axis=0)
            feat_stats['temporal_variance'] = np.mean(np.var(frame_diffs, axis=0))
            
            # Per-frame consistency
            per_frame_std = np.std(features_np, axis=1).mean()
            feat_stats['per_frame_std'] = per_frame_std
            
            # Outlier detection
            frame_distances = np.linalg.norm(features_np - features_np.mean(axis=0), axis=1)
            threshold = np.median(frame_distances) + 2 * np.std(frame_distances)
            outlier_ratio = np.sum(frame_distances > threshold) / features_np.shape[0]
            feat_stats['outlier_ratio'] = outlier_ratio
            
            # Feature correlation across time
            correlations = []
            for i in range(features_np.shape[0] - 1):
                corr = np.corrcoef(features_np[i], features_np[i+1])[0, 1]
                if not np.isnan(corr):
                    correlations.append(corr)
            
            if correlations:
                feat_stats['mean_correlation'] = np.mean(correlations)
                feat_stats['correlation_variance'] = np.var(correlations)
                
                # Detect sudden correlation drops
                if len(correlations) > 1:
                    drops = [abs(correlations[i] - correlations[i+1]) 
                            for i in range(len(correlations)-1)]
                    feat_stats['max_correlation_drop'] = max(drops) if drops else 0.0
        else:
            feat_stats['temporal_variance'] = 0.0
            feat_stats['per_frame_std'] = np.std(features_np)
            feat_stats['outlier_ratio'] = 0.0
            feat_stats['mean_correlation'] = 1.0
            feat_stats['correlation_variance'] = 0.0
            feat_stats['max_correlation_drop'] = 0.0
        
        return feat_stats


# =====================================================================
# CSV UTILITIES
# =====================================================================

def save_results_to_csv(
    results: List[Dict[str, Union[str, int, List[float]]]],
    filepath: str,
    for_submit=False,
) -> None:
    """Save results to CSV with sorted filenames"""
    sorted_results = sorted(results, key=lambda x: x['name'])
    
    with open(filepath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        row = w.writerow
        if for_submit:
            row(("filename", "label"))
            for r in sorted_results:
                row((r['name'], r['pred']))
        else:
            row(("name", "pred", "pred_proba"))
            for r in sorted_results:
                row((r["name"], r["pred"], r["pred_proba"]))


def get_files_paths(main_path: str, exs: List[str]) -> List[str]:
    """Collect all files with specified extensions"""
    exs = {("." + ext.lstrip(".")).lower() for ext in exs}
    results = []
    
    for root, _, files in os.walk(main_path):
        for fname in files:
            abs_path = os.path.abspath(os.path.join(root, fname))
            _, ext = os.path.splitext(fname)
            if ext.lower() in exs:
                results.append(abs_path)
    
    return results


# =====================================================================
# CPU WORKER FUNCTION
# =====================================================================

def process_single_file(args):
    """Worker function: CPU tasks (face detection, Laplacian, optical flow)"""
    file_path, num_frames, consecutive = args
    filename = os.path.basename(file_path)
    
    try:
        if is_video(file_path):
            # Extract raw frames
            raw_frames = extract_frames(file_path, num_frames, consecutive=consecutive)
            
            # Detect faces
            face_crops, count = face_rec(raw_frames)
            
            if count > 0:
                # Calculate metrics
                laplacian_var = calculate_laplacian_variance(face_crops[:count])
                
                # Optical flow analysis
                flow_stats = analyze_optical_flow(face_crops[:count])
                
                return (filename, face_crops[:count], raw_frames[:count], laplacian_var, 
                       flow_stats, True, None)  # is_video=True
            else:
                return (filename, None, None, 0.0, {}, True, None)
        else:
            # Load image
            im = Image.open(file_path).convert('RGB')
            arr = np.asarray(im)
            
            # Detect face
            face, count = face_rec([arr])
            
            if count > 0:
                laplacian_var = calculate_laplacian_variance(face[:count])
                raw_image = face[0]
                return (filename, face[:count], [arr], laplacian_var, {}, False, raw_image)
            else:
                return (filename, None, None, 0.0, {}, False, None)
    
    except Exception as e:
        return (filename, None, None, 0.0, {}, None, str(e))


# =====================================================================
# MAIN CONFIGURATION AND EXECUTION
# =====================================================================

DATA_PATH = "data"  # Change this to your test data path
NUM_FRAMES = 25
FP16_MODE = False
CONSECUTIVE_FRAMES = True
NUM_WORKERS = None  # None = auto-detect

# Model setup
net = 'genconvit'
ed_weight = 'genconvit_ed_inference'
vae_weight = 'genconvit_vae_inference'

# Load model
print("Loading model...")
model = load_genconvit(config, net, ed_weight, vae_weight, FP16_MODE)
print(f"✓ Loaded {net} network")

# Initialize Facial Component Analyzer
print("Initializing Facial Component Analyzer...")
fcg_analyzer = FacialComponentAnalyzer()
print("✓ FCG Analyzer ready")

# Collect files
all_files = get_files_paths(DATA_PATH, exs=['png', 'jpg', 'jpeg', 'mp4'])
print(f"✓ Found {len(all_files)} files")

# Setup multiprocessing
num_workers = NUM_WORKERS if NUM_WORKERS else min(max(1, multiprocessing.cpu_count() - 1), 8)
print(f"✓ Using {num_workers} worker processes")
print(f"✓ Frame extraction: {'consecutive' if CONSECUTIVE_FRAMES else 'evenly spaced'}\n")

# Prepare worker arguments
worker_args = [(file_path, NUM_FRAMES, CONSECUTIVE_FRAMES) for file_path in all_files]

# =====================================================================
# MAIN PROCESSING LOOP
# =====================================================================

results = []

print("Start Evaluating with FCG Strategy...")
with multiprocessing.Pool(processes=num_workers) as pool:
    with tqdm(total=len(all_files), desc="Processing files") as pbar:
        for result in pool.imap_unordered(process_single_file, worker_args):
            filename, face_crops, raw_frames, laplacian_var, flow_stats, is_video_file, raw_image = result
            
            # Check for errors
            if isinstance(raw_image, str):  # Error message
                print(f"Error processing {filename}: {raw_image}")
                results.append({
                    "name": filename,
                    "pred": 0,
                    "pred_proba": [0.5, 0.5]
                })
                pbar.update(1)
                continue
            
            # No face detected
            if face_crops is None:
                results.append({
                    "name": filename,
                    "pred": 0,
                    "pred_proba": [0.5, 0.5]
                })
                pbar.update(1)
                continue
            
            # ===== GPU INFERENCE =====
            df = preprocess_frame(face_crops)
            if FP16_MODE and df.shape[0] > 0:
                df = df.half()
            
            real_proba, fake_proba = (
                pred_vid(df, model)
                if df.shape[0] > 0
                else (0.5, 0.5)
            )
            
            # Extract features with enhanced statistics
            feat_stats = extract_features(model, df, FP16_MODE)
            if feat_stats:
                feat_stats['laplacian_var'] = laplacian_var
            
            # ===== FACIAL COMPONENT GUIDANCE (FCG) ANALYSIS =====
            fcg_scores = {}
            
            if is_video_file and raw_frames is not None and len(raw_frames) > 1:
                # Temporal component consistency (FCG for videos)
                fcg_scores = fcg_analyzer.analyze_component_consistency(raw_frames)
                
                # Frequency analysis
                if feat_stats and df.shape[0] > 0:
                    with torch.no_grad():
                        features_tensor = model.extract_features(df)
                        features_np = features_tensor.detach().cpu().numpy()
                    freq_stats = analyze_temporal_frequency(features_np)
                    fcg_scores.update(freq_stats)
            
            elif not is_video_file and raw_image is not None:
                # Spatial component artifacts (FCG for images)
                component_artifacts = fcg_analyzer.detect_component_artifacts(raw_image)
                
                # Aggregate component scores
                for comp_name, metrics in component_artifacts.items():
                    if isinstance(metrics, dict):
                        for metric_name, value in metrics.items():
                            fcg_scores[f'{comp_name}_{metric_name}'] = value
            
            # ===== MULTI-LAYER DETECTION WITH FCG =====
            temporal_var = feat_stats.get('temporal_variance', 0.0) if feat_stats else 0.0
            force_fake = False
            
            # Thresholds
            HIGH_LAPLACIAN = 180
            HIGH_TEMPORAL_VAR = 0.02
            EXTREME_TEMPORAL_VAR = 0.04
            
            if is_video_file:
                # Get optical flow metrics
                flow_inconsistency = flow_stats.get('flow_inconsistency', 0.0)
                flow_spikes = flow_stats.get('flow_spikes', 0)
                
                # FCG: Eyes are most critical (per paper findings)
                eyes_inconsistency = fcg_scores.get('eyes_score', 0.0)
                
                # 1. VEO-style: sharp + jerky motion
                if laplacian_var > HIGH_LAPLACIAN and temporal_var > HIGH_TEMPORAL_VAR:
                    veo_boost = 1.0 + (temporal_var * 15)
                    fake_proba = min(1.0, fake_proba * veo_boost)
                    
                    if temporal_var > EXTREME_TEMPORAL_VAR or flow_spikes > 3:
                        force_fake = True
                
                # 2. SORA-style: smooth + blurry
                elif laplacian_var < HIGH_LAPLACIAN and temporal_var < HIGH_TEMPORAL_VAR:
                    if laplacian_var < 50:
                        fake_proba = min(1.0, fake_proba * 1.3)
                        if laplacian_var < 30:
                            force_fake = True
                
                # 3. Low-quality: blurry + jerky
                elif laplacian_var < HIGH_LAPLACIAN and temporal_var > HIGH_TEMPORAL_VAR:
                    fake_proba = min(1.0, fake_proba * 1.5)
                    force_fake = True
                
                # 4. FCG-based detection: Eye inconsistency (most reliable per paper)
                if eyes_inconsistency > 15.0:  # High eye region changes
                    fake_proba = min(1.0, fake_proba * 1.4)
                    if eyes_inconsistency > 25.0:
                        force_fake = True
                
                # 5. Frequency domain anomaly
                if fcg_scores.get('freq_suspicious', False):
                    fake_proba = min(1.0, fake_proba * 1.3)
                
                # 6. Optical flow spikes
                if flow_inconsistency > 5.0 or flow_spikes > 5:
                    fake_proba = min(1.0, fake_proba * 1.2)
                
                # 7. Spatial anomaly
                if feat_stats and feat_stats.get('outlier_ratio', 0) > 0.3:
                    force_fake = True
                
                # 8. Feature correlation drops (sudden inconsistency)
                if feat_stats and feat_stats.get('max_correlation_drop', 0) > 0.5:
                    fake_proba = min(1.0, fake_proba * 1.2)
            
            else:
                # ===== IMAGE DETECTION WITH FCG =====
                SHARPNESS_THRESHOLD = 180.0
                
                # Analyze component-specific artifacts
                high_freq_components = 0
                for comp_name in ['left_eye', 'right_eye', 'nose', 'lips']:
                    freq_key = f'{comp_name}_freq_anomaly'
                    if freq_key in fcg_scores:
                        freq_val = fcg_scores[freq_key]
                        # Abnormal frequency patterns
                        if freq_val > 0.15 or freq_val < 0.005:
                            high_freq_components += 1
                
                # If multiple components show artifacts
                if high_freq_components >= 2:
                    fake_proba = min(1.0, fake_proba * 1.5)
                    if high_freq_components >= 3:
                        force_fake = True
                
                # Eye-specific check (most critical per FCG paper)
                left_eye_freq = fcg_scores.get('left_eye_freq_anomaly', 0)
                right_eye_freq = fcg_scores.get('right_eye_freq_anomaly', 0)
                
                if (left_eye_freq > 0.12 or left_eye_freq < 0.01) and \
                   (right_eye_freq > 0.12 or right_eye_freq < 0.01):
                    fake_proba = min(1.0, fake_proba * 1.4)
                
                # Extreme sharpness anomalies
                if laplacian_var > 250:
                    fake_proba = min(1.0, fake_proba * 1.3)
                    if laplacian_var > 350:
                        force_fake = True
                elif laplacian_var < 30:
                    fake_proba = min(1.0, fake_proba * 1.2)
                    if laplacian_var < 15:
                        force_fake = True
            
            # Final prediction
            y = 1 if (force_fake or fake_proba >= real_proba) else 0
            
            results.append({
                "name": filename,
                "pred": y,
                "pred_proba": [real_proba, fake_proba]
            })
            
            pbar.update(1)

print("\n✓ Processing complete!")

# Save results
save_results_to_csv(results, "submission.csv", for_submit=True)
print(f"✓ Saved submission to submission.csv")

# Print statistics
num_real = sum(1 for r in results if r['pred'] == 0)
num_fake = sum(1 for r in results if r['pred'] == 1)
print(f"\nPrediction Summary:")
print(f"  Real: {num_real} ({num_real/len(results)*100:.1f}%)")
print(f"  Fake: {num_fake} ({num_fake/len(results)*100:.1f}%)")

## TRY 3 : 11/13 -> 0.43

일단, 예전에 했던 것처럼 확률값의 결과를 한번 알아보았다. 

In [None]:
import os
import argparse
import torch
from models.pred_func import load_genconvit, face_rec, preprocess_frame, is_video, extract_frames
from models.config import load_config
from typing import List, Dict, Union
import csv
import numpy as np
from PIL import Image
from tqdm import tqdm
import json

config = load_config()


def save_frame_probabilities_to_csv(
    results: List[Dict[str, Union[str, int, float]]],
    filepath: str
) -> None:
    """Save per-frame probabilities to CSV"""
    with open(filepath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["filename", "frame_idx", "real_proba", "fake_proba", "prediction"])
        for r in results:
            w.writerow([r['filename'], r['frame_idx'], r['real_proba'], r['fake_proba'], r['pred']])


def save_aggregated_to_json(
    results: Dict[str, Dict],
    filepath: str
) -> None:
    """Save aggregated statistics to JSON"""
    with open(filepath, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=2, ensure_ascii=False)


def save_submission_csv(
    results: List[Dict[str, Union[str, int]]],
    filepath: str
) -> None:
    """Save final submission CSV (filename, label)"""
    sorted_results = sorted(results, key=lambda x: x['name'])
    with open(filepath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["filename", "label"])
        for r in sorted_results:
            w.writerow([r['name'], r['pred']])


def get_files_paths(main_path: str, exs: List[str]) -> List[str]:
    """Collect all files with specified extensions"""
    exs = {("." + ext.lstrip(".")).lower() for ext in exs}
    results = []

    for root, _, files in os.walk(main_path):
        for fname in files:
            abs_path = os.path.abspath(os.path.join(root, fname))
            _, ext = os.path.splitext(fname)
            if ext.lower() in exs:
                results.append(abs_path)
            else:
                print(f"{fname} is not in [{exs}]. passing...")

    return results


def extract_frame_probabilities_video(
    file_path: str,
    model,
    fp16_mode: bool = False,
    num_frames: int = 15,
):
    """
    비디오에서 각 프레임별 확률값 추출
    
    Returns:
        frame_results: List of per-frame predictions
        aggregated: Dictionary with aggregated statistics
    """
    filename = os.path.basename(file_path)
    
    # Extract frames
    raw_frames = extract_frames(file_path, num_frames, consecutive=False)
    
    # Detect faces in each frame
    face_crops, count = face_rec(raw_frames)
    
    frame_results = []
    
    if count > 0:
        # Preprocess faces
        df = preprocess_frame(face_crops[:count])
        
        if fp16_mode and df.shape[0] > 0:
            df = df.half()
        
        if df.shape[0] > 0:
            with torch.no_grad():
                # Get logits for all frames
                outputs = model(df)
                
                # Calculate probabilities per frame
                probs = torch.softmax(outputs, dim=1)
                
                # Extract probabilities for each frame
                real_probas = probs[:, 0].cpu().numpy()
                fake_probas = probs[:, 1].cpu().numpy()
                
                # Store per-frame results
                for frame_idx, (real_p, fake_p) in enumerate(zip(real_probas, fake_probas)):
                    pred = 1 if fake_p >= real_p else 0
                    
                    frame_results.append({
                        'filename': filename,
                        'frame_idx': frame_idx,
                        'real_proba': float(real_p),
                        'fake_proba': float(fake_p),
                        'pred': pred
                    })
            
            # Calculate aggregated statistics
            avg_real = np.mean(real_probas)
            avg_fake = np.mean(fake_probas)
            final_pred = 1 if avg_fake >= avg_real else 0
            
            aggregated = {
                'num_frames': len(real_probas),
                'avg_real_proba': float(avg_real),
                'avg_fake_proba': float(avg_fake),
                'std_real_proba': float(np.std(real_probas)),
                'std_fake_proba': float(np.std(fake_probas)),
                'min_fake_proba': float(np.min(fake_probas)),
                'max_fake_proba': float(np.max(fake_probas)),
                'final_prediction': final_pred,
                'frame_predictions': [
                    {
                        'frame_idx': i,
                        'real_proba': float(real_probas[i]),
                        'fake_proba': float(fake_probas[i]),
                        'pred': 1 if fake_probas[i] >= real_probas[i] else 0
                    }
                    for i in range(len(real_probas))
                ]
            }
        else:
            # No valid preprocessed frames
            aggregated = {
                'num_frames': 0,
                'avg_real_proba': 0.5,
                'avg_fake_proba': 0.5,
                'final_prediction': 0,
                'frame_predictions': []
            }
    else:
        # No faces detected
        aggregated = {
            'num_frames': 0,
            'avg_real_proba': 0.5,
            'avg_fake_proba': 0.5,
            'final_prediction': 0,
            'frame_predictions': []
        }
    
    return frame_results, aggregated


def extract_frame_probabilities_image(
    file_path: str,
    model,
    fp16_mode: bool = False,
):
    """
    이미지에서 확률값 추출 (단일 프레임)
    
    Returns:
        frame_results: List with single frame prediction
        aggregated: Dictionary with statistics
    """
    filename = os.path.basename(file_path)
    
    try:
        im = Image.open(file_path).convert('RGB')
        arr = np.asarray(im)
    except Exception as e:
        print(f"Failed to open image {file_path}: {e}")
        return [], {
            'num_frames': 0,
            'avg_real_proba': 0.5,
            'avg_fake_proba': 0.5,
            'final_prediction': 0,
            'frame_predictions': []
        }
    
    # Detect face
    face, count = face_rec([arr])
    
    if count > 0:
        df = preprocess_frame(face[:count])
        
        if fp16_mode and df.shape[0] > 0:
            df = df.half()
        
        if df.shape[0] > 0:
            with torch.no_grad():
                outputs = model(df)
                probs = torch.softmax(outputs, dim=1)
                
                real_proba = probs[0, 0].cpu().item()
                fake_proba = probs[0, 1].cpu().item()
                pred = 1 if fake_proba >= real_proba else 0
                
                frame_results = [{
                    'filename': filename,
                    'frame_idx': 0,
                    'real_proba': float(real_proba),
                    'fake_proba': float(fake_proba),
                    'pred': pred
                }]
                
                aggregated = {
                    'num_frames': 1,
                    'avg_real_proba': float(real_proba),
                    'avg_fake_proba': float(fake_proba),
                    'final_prediction': pred,
                    'frame_predictions': [{
                        'frame_idx': 0,
                        'real_proba': float(real_proba),
                        'fake_proba': float(fake_proba),
                        'pred': pred
                    }]
                }
        else:
            frame_results = []
            aggregated = {
                'num_frames': 0,
                'avg_real_proba': 0.5,
                'avg_fake_proba': 0.5,
                'final_prediction': 0,
                'frame_predictions': []
            }
    else:
        frame_results = []
        aggregated = {
            'num_frames': 0,
            'avg_real_proba': 0.5,
            'avg_fake_proba': 0.5,
            'final_prediction': 0,
            'frame_predictions': []
        }
    
    return frame_results, aggregated


def extract_frame_probabilities_path(
    file_path: str,
    model,
    fp16_mode: bool = False,
    num_frames: int = 15,
):
    """
    파일 경로에서 프레임별 확률 추출 (비디오 또는 이미지)
    """
    if is_video(file_path):
        return extract_frame_probabilities_video(file_path, model, fp16_mode, num_frames)
    else:
        return extract_frame_probabilities_image(file_path, model, fp16_mode)


def gen_parser():
    parser = argparse.ArgumentParser("GenConViT frame probability extraction")
    parser.add_argument("--p", type=str, help="video or image path", default="data")
    parser.add_argument(
        "--f", type=int, help="number of frames to process for prediction", default=30
    )
    parser.add_argument(
        "--s", help="model size type: tiny, large.",
    )
    parser.add_argument("--fp16", action="store_true", help="half precision support")

    args = parser.parse_args()
    path = args.p
    num_frames = args.f
    fp16 = args.fp16

    net = 'genconvit'
    ed_weight = 'genconvit_ed_inference'
    vae_weight = 'genconvit_vae_inference'

    if args.s:
        if args.s in ['tiny', 'large']:
            config["model"]["backbone"] = f"convnext_{args.s}"
            config["model"]["embedder"] = f"swin_{args.s}_patch4_window7_224"
            config["model"]["type"] = args.s

    return path, num_frames, net, fp16, ed_weight, vae_weight


def main():
    # 하이퍼 파라미터 정의
    path, num_frames, net, fp16, ed_weight, vae_weight = gen_parser()

    # 모델 로드
    print("Loading model...")
    model = load_genconvit(config, net, ed_weight, vae_weight, fp16)
    print(f"✓ Loaded {net} network")

    # 입력 경로에서 이미지/비디오 수집
    root_dir = path
    all_files = get_files_paths(root_dir, exs=['png', 'jpg', 'jpeg', 'mp4'])
    print(f"✓ Found {len(all_files)} files")

    # 결과 저장용 리스트
    all_frame_results = []  # 모든 프레임별 결과
    all_aggregated = {}     # 파일별 집계 결과
    submission_results = [] # 최종 제출용 결과

    print(f"Start extracting per-frame probabilities...")
    for file_path in tqdm(all_files, total=len(all_files)):
        filename = os.path.basename(file_path)
        
        # 프레임별 확률 추출
        frame_results, aggregated = extract_frame_probabilities_path(
            file_path=file_path,
            model=model,
            fp16_mode=fp16,
            num_frames=num_frames,
        )
        
        # 결과 저장
        all_frame_results.extend(frame_results)
        all_aggregated[filename] = aggregated
        
        # 최종 제출용 결과
        submission_results.append({
            'name': filename,
            'pred': aggregated['final_prediction']
        })

    print("\n✓ Processing complete!")

    # ===== 결과 저장 =====
    
    # 1. 프레임별 확률값 CSV 저장
    csv_output = "frame_probabilities.csv"
    save_frame_probabilities_to_csv(all_frame_results, csv_output)
    print(f"✓ Saved per-frame probabilities to {csv_output}")

    # 2. 파일별 집계 통계 JSON 저장
    json_output = "aggregated_probabilities.json"
    save_aggregated_to_json(all_aggregated, json_output)
    print(f"✓ Saved aggregated statistics to {json_output}")

    # 3. 최종 제출용 CSV 저장
    submission_output = "submission.csv"
    save_submission_csv(submission_results, submission_output)
    print(f"✓ Saved submission to {submission_output}")

    # ===== 통계 출력 =====
    print(f"\n=== Statistics ===")
    print(f"Total files processed: {len(all_aggregated)}")
    print(f"Total frames analyzed: {len(all_frame_results)}")

    # 파일별 예측 통계
    num_fake_files = sum(1 for v in all_aggregated.values() if v['final_prediction'] == 1)
    num_real_files = len(all_aggregated) - num_fake_files

    print(f"\nFile-level predictions:")
    print(f"  Real: {num_real_files} ({num_real_files/len(all_aggregated)*100:.1f}%)")
    print(f"  Fake: {num_fake_files} ({num_fake_files/len(all_aggregated)*100:.1f}%)")

    # 프레임별 예측 통계
    if all_frame_results:
        num_fake_frames = sum(1 for r in all_frame_results if r['pred'] == 1)
        num_real_frames = len(all_frame_results) - num_fake_frames

        print(f"\nFrame-level predictions:")
        print(f"  Real: {num_real_frames} ({num_real_frames/len(all_frame_results)*100:.1f}%)")
        print(f"  Fake: {num_fake_frames} ({num_fake_frames/len(all_frame_results)*100:.1f}%)")

        # 평균 확률
        avg_fake_proba_all = np.mean([r['fake_proba'] for r in all_frame_results])
        print(f"\nAverage fake probability across all frames: {avg_fake_proba_all:.4f}")


if __name__ == "__main__":
    main()

filename,frame_idx,real_proba,fake_proba,prediction
sample_image_7.png,0,0.003400923451408744,0.9965991377830505,1
sample_image_6.png,0,0.002246193354949355,0.9977537989616394,1
sample_image_4.png,0,0.1387469321489334,0.861253023147583,1
sample_image_5.png,0,0.0009539324673824012,0.9990460276603699,1
sample_image_1.png,0,0.9998745918273926,0.00012533320114016533,0
sample_image_2.png,0,0.1822851002216339,0.8177149295806885,1
sample_image_3.png,0,0.2030087560415268,0.7969911694526672,1

sample_video_1.mp4,0,0.9790687561035156,0.020931314677000046,0
sample_video_1.mp4,1,0.9026157855987549,0.0973842442035675,0
sample_video_1.mp4,2,0.40101468563079834,0.5989852547645569,1
sample_video_1.mp4,3,0.9965033531188965,0.0034966946113854647,0
sample_video_1.mp4,4,0.99832683801651,0.001673210645094514,0
sample_video_1.mp4,5,0.9953550100326538,0.004644962027668953,0
sample_video_1.mp4,6,0.9287091493606567,0.07129091769456863,0
sample_video_1.mp4,7,0.29610496759414673,0.7038949728012085,1
sample_video_1.mp4,8,0.11226729303598404,0.8877326846122742,1
sample_video_1.mp4,9,0.5342972874641418,0.46570277214050293,0
sample_video_1.mp4,10,0.11400090903043747,0.8859990239143372,1
sample_video_1.mp4,11,0.9997707009315491,0.0002292477583978325,0
sample_video_1.mp4,12,0.9995256662368774,0.0004744052712339908,0
sample_video_1.mp4,13,0.9538599848747253,0.04614001885056496,0
sample_video_1.mp4,14,0.9768911004066467,0.02310883440077305,0
sample_video_1.mp4,15,0.0006715818308293819,0.9993284940719604,1
sample_video_1.mp4,16,8.202696335501969e-05,0.9999179840087891,1
sample_video_1.mp4,17,7.666852616239339e-05,0.9999233484268188,1
sample_video_1.mp4,18,0.168125718832016,0.8318743109703064,1
sample_video_1.mp4,19,0.9399539828300476,0.060045965015888214,0
sample_video_1.mp4,20,0.887516975402832,0.11248297989368439,0
sample_video_1.mp4,21,0.05511385574936867,0.9448862075805664,1
sample_video_1.mp4,22,0.2918016314506531,0.7081984281539917,1
sample_video_1.mp4,23,0.2077609747648239,0.7922390103340149,1
sample_video_1.mp4,24,0.02328997664153576,0.9767100214958191,1
sample_video_1.mp4,25,0.001251019537448883,0.9987490177154541,1
sample_video_1.mp4,26,0.7551747560501099,0.24482528865337372,0
sample_video_1.mp4,27,0.7623976469039917,0.2376023679971695,0
sample_video_1.mp4,28,1.4420808838622179e-05,0.9999855756759644,1
sample_video_1.mp4,29,0.0018300967058166862,0.9981698989868164,1


sample_video_2.mp4,0,0.552285373210907,0.4477146565914154,0
sample_video_2.mp4,1,0.6470654010772705,0.3529345989227295,0
sample_video_2.mp4,2,0.0025786624755710363,0.997421383857727,1
sample_video_2.mp4,3,0.9894221425056458,0.01057778112590313,0
sample_video_2.mp4,4,0.6920812726020813,0.3079186677932739,0
sample_video_2.mp4,5,0.9251995086669922,0.0748005285859108,0
sample_video_2.mp4,6,0.9979808926582336,0.002019119681790471,0
sample_video_2.mp4,7,0.9997732043266296,0.00022678310051560402,0
sample_video_2.mp4,8,0.9997422099113464,0.0002578355197329074,0
sample_video_2.mp4,9,0.9299098253250122,0.07009012997150421,0
sample_video_2.mp4,10,0.2362697720527649,0.7637302279472351,1
sample_video_2.mp4,11,0.3207833170890808,0.6792166829109192,1
sample_video_2.mp4,12,0.689947783946991,0.3100522458553314,0
sample_video_2.mp4,13,0.009349594824016094,0.9906504154205322,1
sample_video_2.mp4,14,0.9999815225601196,1.853395951911807e-05,0
sample_video_2.mp4,15,0.9996504783630371,0.0003495640412438661,0
sample_video_2.mp4,16,0.9984470009803772,0.0015530155505985022,0
sample_video_2.mp4,17,0.9999986886978149,1.3264098015497439e-06,0
sample_video_2.mp4,18,0.9976486563682556,0.0023513250052928925,0
sample_video_2.mp4,19,0.9992438554763794,0.0007561193197034299,0
sample_video_2.mp4,20,0.9577920436859131,0.04220796748995781,0
sample_video_2.mp4,21,0.9998151659965515,0.00018483144231140614,0
sample_video_2.mp4,22,0.9783141016960144,0.021685883402824402,0
sample_video_2.mp4,23,1.0,3.3712160840693173e-10,0
sample_video_2.mp4,24,0.9816708564758301,0.018329113721847534,0
sample_video_2.mp4,25,0.9990838766098022,0.0009161198977380991,0
sample_video_2.mp4,26,0.9988237023353577,0.0011763233924284577,0
sample_video_2.mp4,27,0.9939069151878357,0.006093055475503206,0


sample_video_3.mp4,0,2.231787766504567e-05,0.999977707862854,1
sample_video_3.mp4,1,2.8047297746525146e-05,0.9999719858169556,1
sample_video_3.mp4,2,1.0130503142136149e-05,0.9999898672103882,1
sample_video_3.mp4,3,3.5431723517831415e-05,0.9999645948410034,1
sample_video_3.mp4,4,0.0003543641942087561,0.9996455907821655,1
sample_video_3.mp4,5,0.10812240093946457,0.8918775916099548,1
sample_video_3.mp4,6,0.00042190373642370105,0.999578058719635,1
sample_video_3.mp4,7,0.024075962603092194,0.9759240746498108,1
sample_video_3.mp4,8,0.035062022507190704,0.9649379253387451,1
sample_video_3.mp4,9,0.01316155306994915,0.9868384003639221,1
sample_video_3.mp4,10,0.0010301542934030294,0.9989699125289917,1
sample_video_3.mp4,11,0.001404339331202209,0.998595654964447,1
sample_video_3.mp4,12,0.003312255721539259,0.9966877102851868,1
sample_video_3.mp4,13,0.02376420609652996,0.9762358069419861,1
sample_video_3.mp4,14,0.01845422014594078,0.9815457463264465,1
sample_video_3.mp4,15,0.0009211369324475527,0.9990788698196411,1
sample_video_3.mp4,16,4.310508302296512e-05,0.9999568462371826,1
sample_video_3.mp4,17,0.007463885936886072,0.992536187171936,1
sample_video_3.mp4,18,0.9999988079071045,1.14704209863703e-06,0
sample_video_3.mp4,19,0.9965841770172119,0.0034157640766352415,0
sample_video_3.mp4,20,6.028825794146542e-08,0.9999998807907104,1
sample_video_3.mp4,21,0.09927377104759216,0.9007261991500854,1
sample_video_3.mp4,22,0.29918229579925537,0.7008176445960999,1
sample_video_3.mp4,23,0.5502461791038513,0.4497537314891815,0
sample_video_3.mp4,24,0.01758238673210144,0.982417643070221,1
sample_video_3.mp4,25,0.007076663430780172,0.992923378944397,1
sample_video_3.mp4,26,0.04972818121314049,0.9502717852592468,1
sample_video_3.mp4,27,0.048936303704977036,0.9510636925697327,1


sample_video_4.mp4,0,0.00011951888154726475,0.9998804330825806,1
sample_video_4.mp4,1,3.800360354944132e-05,0.9999619722366333,1
sample_video_4.mp4,2,9.566800144966692e-05,0.9999042749404907,1
sample_video_4.mp4,3,0.004657244775444269,0.9953427314758301,1
sample_video_4.mp4,4,0.00039227845263667405,0.9996077418327332,1
sample_video_4.mp4,5,0.0008691847324371338,0.9991307854652405,1
sample_video_4.mp4,6,0.002691729459911585,0.9973082542419434,1
sample_video_4.mp4,7,0.0007344328914768994,0.9992656111717224,1
sample_video_4.mp4,8,0.24086201190948486,0.7591379284858704,1
sample_video_4.mp4,9,0.03669920936226845,0.9633007645606995,1
sample_video_4.mp4,10,0.8813573122024536,0.11864274740219116,0
sample_video_4.mp4,11,2.5278695829911157e-06,0.9999974966049194,1
sample_video_4.mp4,12,0.0002221961331088096,0.9997778534889221,1
sample_video_4.mp4,13,0.009309524670243263,0.9906904697418213,1
sample_video_4.mp4,14,0.015877963975071907,0.9841220378875732,1
sample_video_4.mp4,15,1.5118233775979206e-08,1.0,1
sample_video_5.mp4,0,0.00012067541683791205,0.9998793601989746,1
sample_video_5.mp4,1,0.9647229909896851,0.035277046263217926,0
sample_video_5.mp4,2,0.9860543608665466,0.013945622369647026,0
sample_video_5.mp4,3,0.9991369843482971,0.0008630394004285336,0
sample_video_5.mp4,4,0.9998331069946289,0.0001668316253926605,0
sample_video_5.mp4,5,0.9976425766944885,0.002357346937060356,0
sample_video_5.mp4,6,0.9635066986083984,0.03649332374334335,0
sample_video_5.mp4,7,0.9907829165458679,0.009217056445777416,0
sample_video_5.mp4,8,0.6926249861717224,0.3073749840259552,0
sample_video_5.mp4,9,0.9804189801216125,0.019580954685807228,0
sample_video_5.mp4,10,0.47934260964393616,0.5206574201583862,1
sample_video_5.mp4,11,0.977471649646759,0.022528300061821938,0
sample_video_5.mp4,12,0.40261024236679077,0.5973897576332092,1
sample_video_5.mp4,13,0.9825249910354614,0.01747497357428074,0
sample_video_5.mp4,14,0.37024402618408203,0.629755973815918,1
sample_video_5.mp4,15,6.316122380667366e-06,0.9999936819076538,1
sample_video_5.mp4,16,0.9905055165290833,0.009494493715465069,0
sample_video_5.mp4,17,0.9337236881256104,0.06627634912729263,0
sample_video_5.mp4,18,0.9441015124320984,0.0558985136449337,0
sample_video_5.mp4,19,0.9998006224632263,0.00019945681560784578,0
sample_video_5.mp4,20,0.7776525616645813,0.22234748303890228,0
sample_video_5.mp4,21,0.9942375421524048,0.005762387067079544,0
sample_video_5.mp4,22,0.9925923347473145,0.007407746277749538,0
sample_video_5.mp4,23,0.8768342733383179,0.12316570430994034,0
sample_video_5.mp4,24,0.5414329767227173,0.4585671126842499,0
sample_video_5.mp4,25,0.7917376160621643,0.2082623690366745,0
sample_video_5.mp4,26,0.9312676787376404,0.06873229891061783,0
sample_video_5.mp4,27,0.6853089928627014,0.3146910071372986,0
sample_video_5.mp4,28,0.9970850348472595,0.002914972370490432,0
sample_video_5.mp4,29,0.9671662449836731,0.0328337661921978,0

병렬 처리 + frame 확률 계산 ! 

In [None]:
import os
import argparse
import torch
from models.pred_func import load_genconvit, face_rec, preprocess_frame, is_video, extract_frames
from models.config import load_config
from typing import List, Dict, Union
import csv
import numpy as np
from PIL import Image
from tqdm import tqdm
import json
import cv2
import multiprocessing

config = load_config()


# =====================================================================
# CPU WORKER FUNCTION FOR PARALLEL PROCESSING
# =====================================================================

def process_single_file_cpu(args):
    """CPU 작업: 프레임 추출 및 얼굴 검출"""
    file_path, num_frames = args
    filename = os.path.basename(file_path)
    
    try:
        if is_video(file_path):
            # Extract frames
            raw_frames = extract_frames(file_path, num_frames, consecutive=False)
            
            # Detect faces
            face_crops, count = face_rec(raw_frames)
            
            if count > 0:
                # Calculate Laplacian variance for each frame
                laplacian_vars = []
                for face in face_crops[:count]:
                    if len(face.shape) == 3:
                        gray = cv2.cvtColor(face, cv2.COLOR_RGB2GRAY)
                    else:
                        gray = face
                    laplacian = cv2.Laplacian(gray, cv2.CV_64F)
                    laplacian_vars.append(laplacian.var())
                
                return (filename, face_crops[:count], raw_frames[:count], laplacian_vars, True, None)
            else:
                return (filename, None, None, [], True, None)
        else:
            # Load image
            im = Image.open(file_path).convert('RGB')
            arr = np.asarray(im)
            
            # Detect face
            face, count = face_rec([arr])
            
            if count > 0:
                face_crop = face[0]
                if len(face_crop.shape) == 3:
                    gray = cv2.cvtColor(face_crop, cv2.COLOR_RGB2GRAY)
                else:
                    gray = face_crop
                laplacian = cv2.Laplacian(gray, cv2.CV_64F)
                lap_var = laplacian.var()
                
                return (filename, face[:count], [arr], [lap_var], False, None)
            else:
                return (filename, None, None, [], False, None)
    
    except Exception as e:
        return (filename, None, None, [], None, str(e))


# =====================================================================
# TEMPORAL ANALYSIS
# =====================================================================

def calculate_temporal_variance(frames_np):
    """프레임 간 optical flow의 분산 계산"""
    if len(frames_np) < 2:
        return 0.0
    
    target_size = (224, 224)
    normalized_frames = []
    
    for frame in frames_np:
        if len(frame.shape) == 3:
            gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        else:
            gray = frame
        
        if gray.shape[:2] != target_size:
            gray = cv2.resize(gray, target_size, interpolation=cv2.INTER_AREA)
        
        normalized_frames.append(gray)
    
    prev_gray = None
    magnitudes = []
    
    for gray in normalized_frames:
        if prev_gray is not None:
            try:
                flow = cv2.calcOpticalFlowFarneback(
                    prev_gray, gray, None, 
                    pyr_scale=0.5, levels=3, winsize=15, 
                    iterations=3, poly_n=5, poly_sigma=1.2, flags=0
                )
                mag, _ = cv2.cartToPolar(flow[..., 0], flow[..., 1])
                magnitudes.append(np.mean(mag))
            except cv2.error:
                continue
        
        prev_gray = gray
    
    return np.var(magnitudes) if magnitudes else 0.0


def calculate_detail_stability(laplacian_vars):
    """프레임 간 디테일(선명도)의 일관성 측정"""
    if len(laplacian_vars) < 2:
        return 0.0
    return np.std(laplacian_vars)


def classify_video_type(temporal_var, detail_stability, avg_sharpness):
    """비디오 타입 분류"""
    TEMPORAL_HIGH = 0.03
    TEMPORAL_LOW = 0.01
    DETAIL_HIGH = 50.0
    DETAIL_LOW = 30.0
    SHARP_THRESHOLD = 150.0
    
    if avg_sharpness > SHARP_THRESHOLD and temporal_var > TEMPORAL_HIGH:
        return True, 0.85, "VEO3-type"
    
    if avg_sharpness < SHARP_THRESHOLD and temporal_var < TEMPORAL_LOW and detail_stability > DETAIL_HIGH:
        return True, 0.80, "SORA2-type"
    
    if detail_stability > DETAIL_HIGH and temporal_var > TEMPORAL_HIGH:
        return True, 0.75, "Hybrid-type"
    
    if detail_stability < DETAIL_LOW and TEMPORAL_LOW < temporal_var < TEMPORAL_HIGH:
        return False, 0.90, "Real-type"
    
    return None, 0.5, "Uncertain-type"


# =====================================================================
# ARTIFACT DETECTION
# =====================================================================

def check_frequency_artifacts(image):
    """주파수 도메인에서 deepfake artifacts 검사"""
    if isinstance(image, Image.Image):
        image = np.array(image.convert('L'))
    
    f_transform = np.fft.fft2(image)
    f_shift = np.fft.fftshift(f_transform)
    magnitude_spectrum = np.abs(f_shift)
    
    h, w = magnitude_spectrum.shape
    center_h, center_w = h // 2, w // 2
    
    high_freq_region = magnitude_spectrum.copy()
    mask_radius = min(h, w) // 4
    y, x = np.ogrid[:h, :w]
    mask = (x - center_w)**2 + (y - center_h)**2 <= mask_radius**2
    high_freq_region[mask] = 0
    
    total_energy = np.sum(magnitude_spectrum)
    high_freq_energy = np.sum(high_freq_region)
    high_freq_ratio = high_freq_energy / (total_energy + 1e-10)
    
    return high_freq_ratio


def check_face_boundary_artifacts(image):
    """얼굴 경계선 주변의 artifacts 검사"""
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    h, w = image.shape[:2]
    border_width = int(min(h, w) * 0.1)
    
    top_border = image[:border_width, :]
    bottom_border = image[-border_width:, :]
    left_border = image[:, :border_width]
    right_border = image[:, -border_width:]
    
    border_regions = [top_border, bottom_border, left_border, right_border]
    border_stds = [np.std(region) for region in border_regions]
    avg_border_std = np.mean(border_stds)
    
    center_region = image[border_width:-border_width, border_width:-border_width]
    center_std = np.std(center_region)
    
    boundary_anomaly = abs(avg_border_std - center_std) / (center_std + 1e-10)
    return boundary_anomaly


def check_color_consistency(image):
    """색상 일관성 검사"""
    if isinstance(image, Image.Image):
        image = np.array(image)
    
    r_mean = np.mean(image[:,:,0])
    g_mean = np.mean(image[:,:,1])
    b_mean = np.mean(image[:,:,2])
    channel_imbalance = np.std([r_mean, g_mean, b_mean]) / (np.mean([r_mean, g_mean, b_mean]) + 1e-10)
    
    return channel_imbalance


def perform_secondary_checks(image, laplacian_var):
    """선명한 이미지에 대한 추가 검증"""
    freq_artifact = check_frequency_artifacts(image)
    boundary_artifact = check_face_boundary_artifacts(image)
    color_inconsistency = check_color_consistency(image)
    
    freq_score = 0.0
    if freq_artifact < 0.01 or freq_artifact > 0.12:
        freq_score = min(abs(freq_artifact - 0.05) / 0.05, 1.0)
    
    boundary_score = min(boundary_artifact / 1.0, 1.0)
    color_score = min(color_inconsistency / 0.5, 1.0)
    
    weights = {'freq': 0.5, 'boundary': 0.2, 'color': 0.3}
    
    suspicion_score = (
        weights['freq'] * freq_score +
        weights['boundary'] * boundary_score +
        weights['color'] * color_score
    )
    
    return suspicion_score


# =====================================================================
# CSV SAVE FUNCTIONS
# =====================================================================

def save_results_to_csv(results: List[Dict], filepath: str, for_submit=False) -> None:
    """최종 제출용 CSV 저장"""
    sorted_results = sorted(results, key=lambda x: x['name'])
    with open(filepath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        if for_submit:
            w.writerow(("filename", "label"))
            for r in sorted_results:
                w.writerow((r['name'], r['pred']))
        else:
            w.writerow(("name", "pred", "pred_proba"))
            for r in sorted_results:
                w.writerow((r["name"], r["pred"], r["pred_proba"]))


def save_frame_probabilities_to_csv(results: List[Dict], filepath: str) -> None:
    """프레임별 확률 CSV 저장"""
    with open(filepath, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["filename", "frame_idx", "real_proba", "fake_proba", "prediction"])
        for r in results:
            w.writerow([r['filename'], r['frame_idx'], r['real_proba'], r['fake_proba'], r['pred']])


def get_files_paths(main_path: str, exs: List[str]) -> List[str]:
    """파일 경로 수집"""
    exs = {("." + ext.lstrip(".")).lower() for ext in exs}
    results = []

    for root, _, files in os.walk(main_path):
        for fname in files:
            abs_path = os.path.abspath(os.path.join(root, fname))
            _, ext = os.path.splitext(fname)
            if ext.lower() in exs:
                results.append(abs_path)

    return results


# =====================================================================
# PREDICTION FUNCTIONS WITH POSTPROCESSING
# =====================================================================

def predict_video(
    file_path: str,
    model,
    fp16_mode: bool = False,
    num_frames: int = 15,
):
    """비디오 예측 with 강력한 postprocessing"""
    filename = os.path.basename(file_path)
    
    # Extract frames
    raw_frames = extract_frames(file_path, num_frames, consecutive=False)
    
    # Detect faces
    face_crops, count = face_rec(raw_frames)
    
    frame_results = []
    
    if count > 0:
        # Preprocess
        df = preprocess_frame(face_crops[:count])
        
        if fp16_mode and df.shape[0] > 0:
            df = df.half()
        
        if df.shape[0] > 0:
            with torch.no_grad():
                outputs = model(df)
                probs = torch.softmax(outputs, dim=1)
                
                real_probas = probs[:, 0].cpu().numpy()
                fake_probas = probs[:, 1].cpu().numpy()
                
                # Store per-frame results
                for frame_idx, (real_p, fake_p) in enumerate(zip(real_probas, fake_probas)):
                    pred = 1 if fake_p >= real_p else 0
                    frame_results.append({
                        'filename': filename,
                        'frame_idx': frame_idx,
                        'real_proba': float(real_p),
                        'fake_proba': float(fake_p),
                        'pred': pred
                    })
            
            # Calculate average
            avg_real = np.mean(real_probas)
            avg_fake = np.mean(fake_probas)
            predicted_class = 1 if avg_fake >= avg_real else 0
            
            # ===== POSTPROCESSING: 더 공격적으로 FAKE 탐지 =====
            
            # 1. Sequence analysis - 임계값 낮춤 (더 민감하게)
            if df.shape[0] > 5:
                window_size = 5
                moving_avg = np.convolve(fake_probas, np.ones(window_size)/window_size, mode="valid")
                
                high_fake_ratio = np.mean(moving_avg > 0.45)  # 0.55 -> 0.45
                mid_fake_ratio = np.mean(moving_avg > 0.35)   # 0.4 -> 0.35
                max_moving_avg = np.max(moving_avg)
                avg_fake_prob = np.mean(fake_probas)
                
                # 더 낮은 임계값으로 FAKE 판정
                if (high_fake_ratio > 0.12 or mid_fake_ratio > 0.30 or 
                    max_moving_avg > 0.55 or avg_fake_prob > 0.30):
                    predicted_class = 1
                    print(f"[SEQ-OVERRIDE] {filename}: high={high_fake_ratio:.2f}, mid={mid_fake_ratio:.2f}, max={max_moving_avg:.2f}, avg={avg_fake_prob:.2f}")
            
            # 2. 프레임 간 변동성 체크 - FAKE 프레임이 하나라도 많으면 의심
            fake_frame_count = np.sum(fake_probas > 0.5)
            fake_frame_ratio = fake_frame_count / len(fake_probas)
            
            if fake_frame_ratio > 0.25:  # 25% 이상이 fake면 전체를 fake로
                predicted_class = 1
                print(f"[FRAME-RATIO-OVERRIDE] {filename}: {fake_frame_ratio:.2%} frames are fake")
            
            # 3. Temporal analysis
            if len(raw_frames) >= 5:
                # Calculate Laplacian variance
                laplacian_vars = []
                for face in face_crops[:count]:
                    if len(face.shape) == 3:
                        gray = cv2.cvtColor(face, cv2.COLOR_RGB2GRAY)
                    else:
                        gray = face
                    laplacian = cv2.Laplacian(gray, cv2.CV_64F)
                    laplacian_vars.append(laplacian.var())
                
                temporal_var = calculate_temporal_variance(raw_frames)
                detail_stability = calculate_detail_stability(laplacian_vars)
                avg_sharpness = np.mean(laplacian_vars)
                
                is_fake_temporal, temp_confidence, video_type = classify_video_type(
                    temporal_var, detail_stability, avg_sharpness
                )
                
                if is_fake_temporal is not None and is_fake_temporal:
                    predicted_class = 1
                    print(f"[TEMPORAL-OVERRIDE] {filename}: {video_type}")
            
            result = {
                "name": filename,
                "pred": predicted_class,
                "pred_proba": [float(avg_real), float(avg_fake)],
                "frame_results": frame_results
            }
        else:
            result = {
                "name": filename,
                "pred": 0,
                "pred_proba": [0.5, 0.5],
                "frame_results": []
            }
    else:
        result = {
            "name": filename,
            "pred": 0,
            "pred_proba": [0.5, 0.5],
            "frame_results": []
        }
    
    return result


def predict_image(
    file_path: str,
    model,
    fp16_mode: bool = False,
):
    """이미지 예측 with postprocessing"""
    filename = os.path.basename(file_path)
    
    try:
        im = Image.open(file_path).convert('RGB')
        arr = np.asarray(im)
    except Exception as e:
        print(f"Failed to open image {file_path}: {e}")
        return {
            "name": filename,
            "pred": 0,
            "pred_proba": [0.5, 0.5],
            "frame_results": []
        }
    
    # Detect face
    face, count = face_rec([arr])
    
    if count > 0:
        df = preprocess_frame(face[:count])
        
        if fp16_mode and df.shape[0] > 0:
            df = df.half()
        
        if df.shape[0] > 0:
            with torch.no_grad():
                outputs = model(df)
                probs = torch.softmax(outputs, dim=1)
                
                real_proba = probs[0, 0].cpu().item()
                fake_proba = probs[0, 1].cpu().item()
                predicted_class = 1 if fake_proba >= real_proba else 0
                
                # ===== POSTPROCESSING: Secondary checks for high sharpness =====
                face_crop = face[0]
                if len(face_crop.shape) == 3:
                    gray = cv2.cvtColor(face_crop, cv2.COLOR_RGB2GRAY)
                else:
                    gray = face_crop
                laplacian = cv2.Laplacian(gray, cv2.CV_64F)
                lap_var = laplacian.var()
                
                SHARPNESS_THRESHOLD = 180.0
                
                if lap_var > SHARPNESS_THRESHOLD and predicted_class == 0:
                    suspicion_score = perform_secondary_checks(face_crop, lap_var)
                    
                    SUSPICION_THRESHOLD = 0.60
                    
                    if suspicion_score > SUSPICION_THRESHOLD:
                        predicted_class = 1
                        print(f"[OVERRIDE] {filename}: Sharp but suspicious (susp={suspicion_score:.3f}, lap={lap_var:.1f})")
                
                frame_results = [{
                    'filename': filename,
                    'frame_idx': 0,
                    'real_proba': float(real_proba),
                    'fake_proba': float(fake_proba),
                    'pred': predicted_class
                }]
                
                result = {
                    "name": filename,
                    "pred": predicted_class,
                    "pred_proba": [float(real_proba), float(fake_proba)],
                    "frame_results": frame_results
                }
        else:
            result = {
                "name": filename,
                "pred": 0,
                "pred_proba": [0.5, 0.5],
                "frame_results": []
            }
    else:
        result = {
            "name": filename,
            "pred": 0,
            "pred_proba": [0.5, 0.5],
            "frame_results": []
        }
    
    return result


def predict_path(
    file_path: str,
    model,
    fp16_mode: bool = False,
    num_frames: int = 15,
):
    """파일 타입에 따라 예측"""
    if is_video(file_path):
        return predict_video(file_path, model, fp16_mode=fp16_mode, num_frames=num_frames)
    else:
        return predict_image(file_path, model, fp16_mode=fp16_mode)


def gen_parser():
    parser = argparse.ArgumentParser("GenConViT prediction with postprocessing")
    parser.add_argument("--p", type=str, help="video or image path", default="data")
    parser.add_argument("--f", type=int, help="number of frames to process", default=15)
    parser.add_argument("--s", help="model size type: tiny, large.")
    parser.add_argument("--fp16", type=str, help="half precision support")

    args = parser.parse_args()
    path = args.p
    num_frames = args.f
    fp16 = True if args.fp16 else False

    net = 'genconvit'
    ed_weight = 'genconvit_ed_inference'
    vae_weight = 'genconvit_vae_inference'

    if args.s:
        if args.s in ['tiny', 'large']:
            config["model"]["backbone"] = f"convnext_{args.s}"
            config["model"]["embedder"] = f"swin_{args.s}_patch4_window7_224"
            config["model"]["type"] = args.s

    return path, num_frames, net, fp16, ed_weight, vae_weight


def main():
    # 파라미터 로드
    path, num_frames, net, fp16, ed_weight, vae_weight = gen_parser()

    # 모델 로드
    print("Loading model...")
    model = load_genconvit(config, net, ed_weight, vae_weight, fp16)
    print(f"✓ Loaded {net} network\n")

    # 파일 수집
    root_dir = path
    all_files = get_files_paths(root_dir, exs=['png', 'jpg', 'jpeg', 'mp4'])
    print(f"✓ Found {len(all_files)} files")

    # Setup multiprocessing
    num_workers = min(max(1, multiprocessing.cpu_count() - 1), 8)
    print(f"✓ Using {num_workers} worker processes for preprocessing\n")
    
    # Prepare worker arguments
    worker_args = [(file_path, num_frames) for file_path in all_files]

    results = []
    all_frame_results = []
    
    # Statistics
    sequence_override_count = 0
    frame_ratio_override_count = 0
    temporal_override_count = 0
    secondary_check_count = 0
    
    print("Start Evaluating with Postprocessing...\n")
    
    # Parallel preprocessing (CPU)
    with multiprocessing.Pool(processes=num_workers) as pool:
        with tqdm(total=len(all_files), desc="Processing files") as pbar:
            for result in pool.imap_unordered(process_single_file_cpu, worker_args):
                filename, face_crops, raw_frames, laplacian_vars, is_video_file, error = result
                
                # Check for errors
                if error:
                    print(f"Error processing {filename}: {error}")
                    results.append({
                        "name": filename,
                        "pred": 0,
                        "pred_proba": [0.5, 0.5]
                    })
                    pbar.update(1)
                    continue
                
                # No face detected
                if face_crops is None:
                    results.append({
                        "name": filename,
                        "pred": 0,
                        "pred_proba": [0.5, 0.5]
                    })
                    pbar.update(1)
                    continue
                
                # ===== GPU INFERENCE =====
                df = preprocess_frame(face_crops)
                if fp16 and df.shape[0] > 0:
                    df = df.half()
                
                if df.shape[0] > 0:
                    with torch.no_grad():
                        outputs = model(df)
                        probs = torch.softmax(outputs, dim=1)
                        
                        real_probas = probs[:, 0].cpu().numpy()
                        fake_probas = probs[:, 1].cpu().numpy()
                        
                        # Store per-frame results
                        for frame_idx, (real_p, fake_p) in enumerate(zip(real_probas, fake_probas)):
                            pred = 1 if fake_p >= real_p else 0
                            all_frame_results.append({
                                'filename': filename,
                                'frame_idx': frame_idx,
                                'real_proba': float(real_p),
                                'fake_proba': float(fake_p),
                                'pred': pred
                            })
                    
                    # Calculate average
                    avg_real = np.mean(real_probas)
                    avg_fake = np.mean(fake_probas)
                    predicted_class = 1 if avg_fake >= avg_real else 0
                    
                    # ===== POSTPROCESSING =====
                    
                    if is_video_file:
                        # 1. Sequence analysis
                        if df.shape[0] > 5:
                            window_size = 5
                            moving_avg = np.convolve(fake_probas, np.ones(window_size)/window_size, mode="valid")
                            
                            high_fake_ratio = np.mean(moving_avg > 0.45)
                            mid_fake_ratio = np.mean(moving_avg > 0.35)
                            max_moving_avg = np.max(moving_avg)
                            avg_fake_prob = np.mean(fake_probas)
                            
                            if (high_fake_ratio > 0.12 or mid_fake_ratio > 0.30 or 
                                max_moving_avg > 0.55 or avg_fake_prob > 0.30):
                                if predicted_class == 0:
                                    sequence_override_count += 1
                                predicted_class = 1
                                # print(f"[SEQ-OVERRIDE] {filename}")
                        
                        # 2. Frame ratio check
                        fake_frame_count = np.sum(fake_probas > 0.5)
                        fake_frame_ratio = fake_frame_count / len(fake_probas)
                        
                        if fake_frame_ratio > 0.25:
                            if predicted_class == 0:
                                frame_ratio_override_count += 1
                            predicted_class = 1
                            # print(f"[FRAME-RATIO-OVERRIDE] {filename}: {fake_frame_ratio:.2%}")
                        
                        # 3. Temporal analysis
                        if len(raw_frames) >= 5 and laplacian_vars:
                            temporal_var = calculate_temporal_variance(raw_frames)
                            detail_stability = calculate_detail_stability(laplacian_vars)
                            avg_sharpness = np.mean(laplacian_vars)
                            
                            is_fake_temporal, temp_confidence, video_type = classify_video_type(
                                temporal_var, detail_stability, avg_sharpness
                            )
                            
                            if is_fake_temporal is not None and is_fake_temporal:
                                if predicted_class == 0:
                                    temporal_override_count += 1
                                predicted_class = 1
                                # print(f"[TEMPORAL-OVERRIDE] {filename}: {video_type}")
                    
                    else:
                        # Image postprocessing
                        if laplacian_vars:
                            lap_var = laplacian_vars[0]
                            SHARPNESS_THRESHOLD = 180.0
                            
                            if lap_var > SHARPNESS_THRESHOLD and predicted_class == 0:
                                suspicion_score = perform_secondary_checks(face_crops[0], lap_var)
                                SUSPICION_THRESHOLD = 0.60
                                
                                if suspicion_score > SUSPICION_THRESHOLD:
                                    predicted_class = 1
                                    secondary_check_count += 1
                                    # print(f"[OVERRIDE] {filename}: Sharp but suspicious")
                    
                    results.append({
                        "name": filename,
                        "pred": predicted_class,
                        "pred_proba": [float(avg_real), float(avg_fake)]
                    })
                else:
                    results.append({
                        "name": filename,
                        "pred": 0,
                        "pred_proba": [0.5, 0.5]
                    })
                
                pbar.update(1)

    # 결과 저장
    print("\n✓ Processing complete!\n")
    
    # 1. 최종 제출용 CSV
    save_results_to_csv(results, "submission.csv", for_submit=True)
    print(f"✓ Saved submission to submission.csv")
    
    # # 2. 프레임별 확률 CSV
    # if all_frame_results:
    #     save_frame_probabilities_to_csv(all_frame_results, "frame_probabilities.csv")
    #     print(f"✓ Saved frame probabilities to frame_probabilities.csv")
    
    # 통계 출력
    print(f"\n=== Detection Statistics ===")
    print(f"Sequence analysis overrides: {sequence_override_count}")
    print(f"Frame ratio overrides: {frame_ratio_override_count}")
    print(f"Temporal analysis overrides: {temporal_override_count}")
    print(f"Secondary check overrides: {secondary_check_count}")
    
    num_fake = sum(1 for r in results if r['pred'] == 1)
    num_real = len(results) - num_fake
    
    print(f"\n=== Prediction Summary ===")
    print(f"Total files: {len(results)}")
    print(f"  Real: {num_real} ({num_real/len(results)*100:.1f}%)")
    print(f"  Fake: {num_fake} ({num_fake/len(results)*100:.1f}%)")
    
    if all_frame_results:
        num_fake_frames = sum(1 for r in all_frame_results if r['pred'] == 1)
        print(f"\nTotal frames: {len(all_frame_results)}")
        print(f"  Fake frames: {num_fake_frames} ({num_fake_frames/len(all_frame_results)*100:.1f}%)")


if __name__ == "__main__":
    main()