##  1.环境设置和数据下载

In [None]:
# =============================================================================
# 第1段：环境设置和导入
# =============================================================================

import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

# PyTorch相关
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.models as models

# 机器学习指标
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report,
    roc_curve, auc
)

# 设置随机种子
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# 检查GPU可用性
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
if torch.cuda.is_available():
    print(f"GPU型号: {torch.cuda.get_device_name(0)}")
    print(f"GPU内存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# 创建必要的目录
os.makedirs('./data', exist_ok=True)
os.makedirs('./models', exist_ok=True)
os.makedirs('./logs', exist_ok=True)
os.makedirs('./results', exist_ok=True)

print("✅ 环境设置完成")

##  2.数据下载与预处理


In [None]:
# =============================================================================
# 第2段：轻量级优化的数据下载和预处理（Kaggle友好版本）
# =============================================================================

import os
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')
import gc
import json
from pathlib import Path
from sklearn.model_selection import train_test_split

# 检查是否在Kaggle环境中
IS_KAGGLE = os.path.exists('/kaggle')

if IS_KAGGLE:
    BASE_DATA_DIR = '/kaggle/input/ff-c23/FaceForensics++_C23'
    print("检测到Kaggle环境")
    print(f"数据基础路径: {BASE_DATA_DIR}")
else:
    BASE_DATA_DIR = './FaceForensics++_C23'
    print("本地环境")

# 内存友好的帧提取函数
def extract_frames_memory_efficient(video_path, max_frames=24, target_size=(160, 160), 
                                   quality_threshold=30, skip_frames=2):
    """
    内存友好的帧提取函数
    - 降低分辨率减少内存使用
    - 减少帧数
    - 添加跳帧机制
    - 简化质量检测
    """
    cap = cv2.VideoCapture(video_path)
    frames = []
    
    if not cap.isOpened():
        print(f"无法打开视频: {video_path}")
        return frames
    
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    if total_frames == 0:
        cap.release()
        return frames
    
    # 简化采样策略：均匀采样，避免复杂计算
    if total_frames <= max_frames:
        frame_indices = list(range(0, total_frames, skip_frames))
    else:
        step = max(1, total_frames // max_frames)
        frame_indices = list(range(0, total_frames, step))[:max_frames]
    
    frame_count = 0
    for frame_idx in frame_indices:
        if frame_count >= max_frames:
            break
            
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        
        if ret:
            # 转换颜色空间
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            # 简化质量检测（使用更快的方法）
            gray = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
            quality = cv2.Laplacian(gray, cv2.CV_64F).var()
            
            if quality > quality_threshold:
                # 调整大小
                frame = cv2.resize(frame, target_size)
                frames.append(frame)
                frame_count += 1
    
    cap.release()
    
    # 如果帧数不足，简单重复最后一帧
    while len(frames) < max_frames and len(frames) > 0:
        frames.append(frames[-1].copy())
    
    return frames[:max_frames]

# 简化的视频处理函数
def process_videos_simple(base_data_dir, max_videos_per_class=80, max_frames=24):
    """
    简化的视频处理函数，避免并发和复杂操作
    """
    data_list = []
    
    # 定义类别映射
    fake_methods = ['Deepfakes', 'Face2Face', 'FaceShifter', 'FaceSwap', 'NeuralTextures']
    
    print("开始处理真实视频...")
    # 处理真实视频
    original_dir = os.path.join(base_data_dir, 'original')
    if os.path.exists(original_dir):
        video_files = [f for f in os.listdir(original_dir) 
                      if f.endswith(('.mp4', '.avi', '.mov'))]
        
        # 限制视频数量
        if len(video_files) > max_videos_per_class:
            video_files = random.sample(video_files, max_videos_per_class)
        
        print(f"找到 {len(video_files)} 个真实视频")
        
        for i, video_file in enumerate(tqdm(video_files, desc="处理真实视频")):
            try:
                video_path = os.path.join(original_dir, video_file)
                
                # 提取帧
                frames = extract_frames_memory_efficient(video_path, max_frames)
                
                if len(frames) > 0:
                    # 保存帧数据路径
                    frame_save_dir = os.path.join('./data', 'frames', 'real')
                    os.makedirs(frame_save_dir, exist_ok=True)
                    
                    video_name = os.path.splitext(video_file)[0]
                    frame_save_path = os.path.join(frame_save_dir, f"{video_name}.npy")
                    
                    # 保存帧数据
                    np.save(frame_save_path, np.array(frames, dtype=np.uint8))
                    
                    data_list.append({
                        'video_path': video_path,
                        'frame_path': frame_save_path,
                        'label': 0,
                        'category': 'real',
                        'method': 'original',
                        'num_frames': len(frames),
                        'video_name': video_name
                    })
                
                # 每处理10个视频清理一次内存
                if (i + 1) % 10 == 0:
                    gc.collect()
                    
            except Exception as e:
                print(f"处理视频 {video_file} 时出错: {e}")
                continue
    
    print("开始处理伪造视频...")
    # 处理伪造视频
    for method in fake_methods:
        method_dir = os.path.join(base_data_dir, method)
        if os.path.exists(method_dir):
            video_files = [f for f in os.listdir(method_dir) 
                          if f.endswith(('.mp4', '.avi', '.mov'))]
            
            # 限制每种方法的视频数量
            method_limit = max_videos_per_class // len(fake_methods)
            if len(video_files) > method_limit:
                video_files = random.sample(video_files, method_limit)
            
            print(f"处理 {method}: {len(video_files)} 个视频")
            
            for i, video_file in enumerate(tqdm(video_files, desc=f"处理{method}")):
                try:
                    video_path = os.path.join(method_dir, video_file)
                    
                    # 提取帧
                    frames = extract_frames_memory_efficient(video_path, max_frames)
                    
                    if len(frames) > 0:
                        # 保存帧数据路径
                        frame_save_dir = os.path.join('./data', 'frames', 'fake')
                        os.makedirs(frame_save_dir, exist_ok=True)
                        
                        video_name = os.path.splitext(video_file)[0]
                        frame_save_path = os.path.join(frame_save_dir, f"{method}_{video_name}.npy")
                        
                        # 保存帧数据
                        np.save(frame_save_path, np.array(frames, dtype=np.uint8))
                        
                        data_list.append({
                            'video_path': video_path,
                            'frame_path': frame_save_path,
                            'label': 1,
                            'category': 'fake',
                            'method': method,
                            'num_frames': len(frames),
                            'video_name': video_name
                        })
                    
                    # 每处理5个视频清理一次内存
                    if (i + 1) % 5 == 0:
                        gc.collect()
                        
                except Exception as e:
                    print(f"处理视频 {video_file} 时出错: {e}")
                    continue
    
    print(f"总共成功处理了 {len(data_list)} 个视频")
    print(f"真实视频: {len([d for d in data_list if d['label'] == 0])} 个")
    print(f"伪造视频: {len([d for d in data_list if d['label'] == 1])} 个")
    
    return data_list

# 简化的数据集划分
def create_simple_dataset_split(data_list, test_size=0.2, val_size=0.1):
    """
    简化的数据集划分，避免复杂的分层采样
    """
    df = pd.DataFrame(data_list)
    
    # 简单的分层划分
    train_df, test_df = train_test_split(
        df, test_size=test_size, random_state=42, 
        stratify=df['label'] if len(df) > 10 else None
    )
    
    if val_size > 0 and len(train_df) > 10:
        train_df, val_df = train_test_split(
            train_df, test_size=val_size/(1-test_size), random_state=42,
            stratify=train_df['label'] if len(train_df) > 10 else None
        )
        return train_df, val_df, test_df
    
    return train_df, test_df

# 数据质量检查
def check_data_quality(data_list):
    """
    简单的数据质量检查
    """
    if not data_list:
        print("❌ 没有有效的数据")
        return False
    
    df = pd.DataFrame(data_list)
    
    print("\n=== 数据统计 ===")
    print(f"总样本数: {len(df)}")
    print(f"真实视频: {len(df[df['label']==0])} 个")
    print(f"伪造视频: {len(df[df['label']==1])} 个")
    
    print("\n各方法分布:")
    method_counts = df['method'].value_counts()
    for method, count in method_counts.items():
        print(f"  {method}: {count} 个")
    
    # 检查数据平衡性
    real_count = len(df[df['label']==0])
    fake_count = len(df[df['label']==1])
    
    if real_count == 0 or fake_count == 0:
        print("⚠️ 数据严重不平衡，缺少某一类别")
        return False
    
    ratio = min(real_count, fake_count) / max(real_count, fake_count)
    if ratio < 0.3:
        print(f"⚠️ 数据不平衡，比例: {ratio:.2f}")
    else:
        print(f"✅ 数据平衡性良好，比例: {ratio:.2f}")
    
    return True

# 主处理流程
print("=== 开始数据预处理 ===")

# 检查数据目录
if IS_KAGGLE and os.path.exists(BASE_DATA_DIR):
    print("检查数据目录结构...")
    subdirs = [d for d in os.listdir(BASE_DATA_DIR) 
              if os.path.isdir(os.path.join(BASE_DATA_DIR, d))]
    print(f"找到子目录: {subdirs}")
    
    for subdir in subdirs[:6]:  # 只显示前6个，避免输出过多
        subdir_path = os.path.join(BASE_DATA_DIR, subdir)
        try:
            video_files = [f for f in os.listdir(subdir_path) 
                          if f.endswith(('.mp4', '.avi', '.mov'))]
            print(f"  {subdir}: {len(video_files)} 个视频文件")
        except:
            print(f"  {subdir}: 无法访问")

# 检查是否已有处理好的数据
if (os.path.exists('./data/train.csv') and 
    os.path.exists('./data/val.csv') and 
    os.path.exists('./data/test.csv')):
    
    print("✅ 发现已有预处理数据")
    train_df = pd.read_csv('./data/train.csv')
    val_df = pd.read_csv('./data/val.csv')
    test_df = pd.read_csv('./data/test.csv')
    
    print(f"训练集: {len(train_df)} 个样本")
    print(f"验证集: {len(val_df)} 个样本")
    print(f"测试集: {len(test_df)} 个样本")
else:
    print("开始处理视频数据...")
    
    try:
        # 处理视频
        data_list = process_videos_simple(
            BASE_DATA_DIR, 
            max_videos_per_class=200,  # 视频数量
            max_frames=30  # 帧数
        )
        
        # 检查数据质量
        if not check_data_quality(data_list):
            print("❌ 数据质量检查失败")
        else:
            # 创建数据集划分
            print("\n创建数据集划分...")
            train_df, val_df, test_df = create_simple_dataset_split(
                data_list, test_size=0.15, val_size=0.15
            )
            
            # 保存数据集
            print("保存数据集文件...")
            train_df.to_csv('./data/train.csv', index=False)
            val_df.to_csv('./data/val.csv', index=False)
            test_df.to_csv('./data/test.csv', index=False)
            
            # 保存处理信息
            process_info = {
                'total_samples': len(data_list),
                'train_samples': len(train_df),
                'val_samples': len(val_df),
                'test_samples': len(test_df),
                'processed_time': pd.Timestamp.now().isoformat(),
                'max_frames': 20,
                'target_size': [160, 160]
            }
            
            with open('./data/process_info.json', 'w') as f:
                json.dump(process_info, f, indent=2)
            
            print(f"\n✅ 数据处理完成")
            print(f"训练集: {len(train_df)} 个样本")
            print(f"验证集: {len(val_df)} 个样本")
            print(f"测试集: {len(test_df)} 个样本")
            
            # 最终内存清理
            del data_list
            gc.collect()
            
    except Exception as e:
        print(f"❌ 处理过程中出现错误: {e}")
        print("建议检查数据路径和可用内存")

## 3.数据集类定义 

In [None]:
# =============================================================================
# 第3段：优化的数据集类定义
# =============================================================================

import torch.nn.functional as F
from torch.utils.data import WeightedRandomSampler
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
import albumentations as A
from albumentations.pytorch import ToTensorV2
import random
from collections import Counter

class AdvancedDeepfakeDataset(Dataset):
    """
    高级深度伪造检测数据集（无imgaug依赖版本）
    """
    def __init__(self, csv_file, transform=None, max_frames=32, 
                 augment_prob=0.5, temporal_augment=True, 
                 mixup_alpha=0.2, cutmix_alpha=1.0, mode='train'):
        
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.max_frames = max_frames
        self.augment_prob = augment_prob
        self.temporal_augment = temporal_augment
        self.mixup_alpha = mixup_alpha
        self.cutmix_alpha = cutmix_alpha
        self.mode = mode
        
        # 创建类别权重（用于处理不平衡数据）
        self.class_weights = self._calculate_class_weights()
        
        # 初始化数据增强
        self._init_augmentations()
        
        print(f"数据集初始化完成: {len(self.data)} 个样本 ({mode} 模式)")
        print(f"真实视频: {len(self.data[self.data['label']==0])} 个")
        print(f"伪造视频: {len(self.data[self.data['label']==1])} 个")
    
    def _calculate_class_weights(self):
        """计算类别权重"""
        class_counts = self.data['label'].value_counts().sort_index()
        total_samples = len(self.data)
        weights = total_samples / (len(class_counts) * class_counts.values)
        return torch.FloatTensor(weights)
    
    def _init_augmentations(self):
        """初始化数据增强（使用torchvision替代imgaug）"""
        # 空间增强（使用torchvision）
        self.spatial_transforms = [
            T.GaussianBlur(kernel_size=3, sigma=(0.1, 0.5)),
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            T.RandomAdjustSharpness(sharpness_factor=2, p=0.3),
            T.RandomAutocontrast(p=0.2),
            T.RandomEqualize(p=0.2),
        ]
        
        # 时序增强
        self.temporal_augs = {
            'frame_drop': 0.1,      # 随机丢弃帧
            'frame_repeat': 0.1,    # 随机重复帧
            'temporal_shift': 0.2,  # 时序偏移
            'reverse': 0.05,        # 时序反转
        }
    
    def __len__(self):
        return len(self.data)
    
    def _load_and_preprocess_frames(self, frame_path):
        """加载和预处理帧数据"""
        try:
            frames = np.load(frame_path)
            
            # 确保帧数量
            if len(frames) < self.max_frames:
                # 智能填充：使用插值而不是简单重复
                indices = np.linspace(0, len(frames)-1, self.max_frames)
                new_frames = []
                for i in indices:
                    if i == int(i):
                        new_frames.append(frames[int(i)])
                    else:
                        # 线性插值
                        i1, i2 = int(i), min(int(i)+1, len(frames)-1)
                        alpha = i - i1
                        frame = (1-alpha) * frames[i1] + alpha * frames[i2]
                        new_frames.append(frame.astype(np.uint8))
                frames = np.array(new_frames)
            elif len(frames) > self.max_frames:
                # 智能采样：保留关键帧
                indices = self._select_key_frames(frames, self.max_frames)
                frames = frames[indices]
            
            return frames
        except Exception as e:
            print(f"加载帧数据失败 {frame_path}: {e}")
            # 返回随机帧作为fallback
            return np.random.randint(0, 255, (self.max_frames, 224, 224, 3), dtype=np.uint8)
    
    def _select_key_frames(self, frames, target_count):
        """选择关键帧"""
        if len(frames) <= target_count:
            return np.arange(len(frames))
        
        # 计算帧间差异
        frame_diffs = []
        for i in range(1, len(frames)):
            diff = np.mean(np.abs(frames[i].astype(float) - frames[i-1].astype(float)))
            frame_diffs.append(diff)
        
        # 选择变化最大的帧
        key_indices = [0]  # 总是包含第一帧
        
        # 基于差异选择帧
        remaining_count = target_count - 2  # 减去首尾帧
        if remaining_count > 0:
            diff_indices = np.argsort(frame_diffs)[-remaining_count:]
            key_indices.extend(sorted(diff_indices + 1))  # +1因为diff_indices是相对于frames[1:]的
        
        key_indices.append(len(frames) - 1)  # 总是包含最后一帧
        
        return sorted(list(set(key_indices)))[:target_count]
    
    def _apply_temporal_augmentation(self, frames):
        """应用时序增强"""
        if not self.temporal_augment or self.mode != 'train':
            return frames
        
        frames = frames.copy()
        
        # 随机丢弃帧
        if random.random() < self.temporal_augs['frame_drop']:
            drop_count = random.randint(1, min(3, len(frames)//4))
            drop_indices = random.sample(range(len(frames)), drop_count)
            for idx in sorted(drop_indices, reverse=True):
                if len(frames) > self.max_frames // 2:  # 确保不会丢弃太多帧
                    frames = np.delete(frames, idx, axis=0)
        
        # 随机重复帧
        if random.random() < self.temporal_augs['frame_repeat']:
            repeat_idx = random.randint(0, len(frames)-1)
            frames = np.insert(frames, repeat_idx, frames[repeat_idx], axis=0)
        
        # 时序偏移
        if random.random() < self.temporal_augs['temporal_shift']:
            shift = random.randint(-2, 2)
            if shift != 0:
                frames = np.roll(frames, shift, axis=0)
        
        # 时序反转
        if random.random() < self.temporal_augs['reverse']:
            frames = frames[::-1]
        
        # 确保帧数量
        if len(frames) != self.max_frames:
            if len(frames) < self.max_frames:
                # 重复最后几帧
                repeat_count = self.max_frames - len(frames)
                last_frames = frames[-repeat_count:]
                frames = np.concatenate([frames, last_frames], axis=0)
            else:
                # 截断
                frames = frames[:self.max_frames]
        
        return frames
    
    def _apply_spatial_augmentation(self, frames):
        """应用空间增强（使用torchvision替代imgaug）"""
        if self.mode != 'train' or random.random() > self.augment_prob:
            return frames
        
        # 对每一帧应用增强
        augmented_frames = []
        for frame in frames:
            if random.random() < 0.7:  # 70%的概率对单帧进行增强
                # 转换为PIL图像
                frame_pil = T.ToPILImage()(torch.tensor(frame).permute(2, 0, 1))
                
                # 随机选择一个增强
                if self.spatial_transforms:
                    aug = random.choice(self.spatial_transforms)
                    frame_pil = aug(frame_pil)
                
                # 转换回numpy
                frame = np.array(frame_pil)
            
            augmented_frames.append(frame)
        
        return np.array(augmented_frames)
    
    def _apply_mixup(self, frames1, label1, frames2, label2):
        """应用MixUp增强"""
        if self.mixup_alpha <= 0:
            return frames1, label1
        
        lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
        mixed_frames = lam * frames1 + (1 - lam) * frames2
        mixed_label = lam * label1 + (1 - lam) * label2
        
        return mixed_frames, mixed_label
    
    def _apply_cutmix(self, frames, label):
        """应用CutMix增强"""
        if self.cutmix_alpha <= 0 or self.mode != 'train':
            return frames, label
        
        # 随机选择另一个样本
        mix_idx = random.randint(0, len(self.data) - 1)
        mix_row = self.data.iloc[mix_idx]
        mix_frames = self._load_and_preprocess_frames(mix_row['frame_path'])
        mix_label = mix_row['label']
        
        # 应用CutMix
        lam = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
        
        H, W = frames.shape[1], frames.shape[2]
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(W * cut_rat)
        cut_h = int(H * cut_rat)
        
        cx = np.random.randint(W)
        cy = np.random.randint(H)
        
        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)
        
        frames[:, bby1:bby2, bbx1:bbx2, :] = mix_frames[:, bby1:bby2, bbx1:bbx2, :]
        
        # 调整标签
        lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (W * H))
        mixed_label = lam * label + (1 - lam) * mix_label
        
        return frames, mixed_label
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # 加载帧数据
        frames = self._load_and_preprocess_frames(row['frame_path'])
        label = float(row['label'])
        
        # 应用时序增强
        frames = self._apply_temporal_augmentation(frames)
        
        # 应用空间增强
        frames = self._apply_spatial_augmentation(frames)
        
        # 应用CutMix（训练时）
        if self.mode == 'train' and random.random() < 0.1:
            frames, label = self._apply_cutmix(frames, label)
        
        # 转换为tensor
        if self.transform:
            transformed_frames = []
            for frame in frames:
                # 确保frame是uint8类型
                if frame.dtype != np.uint8:
                    frame = np.clip(frame, 0, 255).astype(np.uint8)
                transformed_frame = self.transform(frame)
                transformed_frames.append(transformed_frame)
            frames = torch.stack(transformed_frames)
        else:
            frames = torch.tensor(frames, dtype=torch.float32).permute(0, 3, 1, 2) / 255.0
        
        # 添加额外的元数据
        metadata = {
            'video_name': row.get('video_name', ''),
            'method': row.get('method', ''),
            'avg_quality': row.get('avg_quality', 0.0)
        }
        
        return frames, torch.tensor(label, dtype=torch.float32), metadata

# 创建加权采样器
def create_weighted_sampler(dataset):
    """创建加权随机采样器以处理类别不平衡"""
    labels = [dataset.data.iloc[i]['label'] for i in range(len(dataset))]
    class_counts = Counter(labels)
    
    # 计算每个样本的权重
    weights = []
    for label in labels:
        weight = 1.0 / class_counts[label]
        weights.append(weight)
    
    return WeightedRandomSampler(weights, len(weights), replacement=True)

# 优化的数据变换
def get_optimized_transforms(mode='train', image_size=224):
    """获取优化的数据变换"""
    if mode == 'train':
        transform = A.Compose([
            A.Resize(image_size, image_size),
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
            A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=20, val_shift_limit=10, p=0.3),
            A.OneOf([
                A.MotionBlur(blur_limit=3, p=0.2),
                A.GaussianBlur(blur_limit=3, p=0.2),
                A.MedianBlur(blur_limit=3, p=0.2)
            ], p=0.2),
            A.OneOf([
                A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
                A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.2),
            ], p=0.2),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, p=0.3),
            A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.2),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        transform = A.Compose([
            A.Resize(image_size, image_size),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    
    return lambda x: transform(image=x)['image']

# 创建数据加载器
def create_optimized_dataloaders(train_csv, val_csv, test_csv=None, 
                               batch_size=16, num_workers=4, 
                               max_frames=32, image_size=224):
    """创建优化的数据加载器"""
    
    # 获取变换
    train_transform = get_optimized_transforms('train', image_size)
    val_transform = get_optimized_transforms('val', image_size)
    
    # 创建数据集
    train_dataset = AdvancedDeepfakeDataset(
        train_csv, transform=train_transform, max_frames=max_frames,
        augment_prob=0.6, temporal_augment=True, mode='train'
    )
    
    val_dataset = AdvancedDeepfakeDataset(
        val_csv, transform=val_transform, max_frames=max_frames,
        augment_prob=0.0, temporal_augment=False, mode='val'
    )
    
    # 创建采样器
    train_sampler = create_weighted_sampler(train_dataset)
    
    # 创建数据加载器
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=num_workers, pin_memory=True, drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )
    
    loaders = {'train': train_loader, 'val': val_loader}
    
    # 测试集（如果提供）
    if test_csv and os.path.exists(test_csv):
        test_dataset = AdvancedDeepfakeDataset(
            test_csv, transform=val_transform, max_frames=max_frames,
            augment_prob=0.0, temporal_augment=False, mode='test'
        )
        
        test_loader = DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False,
            num_workers=num_workers, pin_memory=True
        )
        
        loaders['test'] = test_loader
    
    return loaders

print("✅ 数据集定义完成")

## 4.模型定义

In [None]:
# =============================================================================
# 第4段：模型定义
# =============================================================================

# 改进的CNN特征提取器
class ImprovedCNNFeatureExtractor(nn.Module):  
    def __init__(self, pretrained=True, backbone='resnet50'):
        super(ImprovedCNNFeatureExtractor, self).__init__()
        
        if backbone == 'resnet50':
            self.backbone = models.resnet50(pretrained=pretrained)
            self.feature_dim = 2048
        elif backbone == 'efficientnet':
            self.backbone = models.efficientnet_b0(pretrained=pretrained)
            self.feature_dim = 1280
        else:
            self.backbone = models.resnet18(pretrained=pretrained)
            self.feature_dim = 512
            
        # 移除最后的分类层
        if hasattr(self.backbone, 'classifier'):
            self.backbone.classifier = nn.Identity()
        else:
            self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
        
        # 添加特征降维层
        self.feature_reducer = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.output_dim = 512
    
    def forward(self, x):
        features = self.backbone(x)
        features = features.view(features.size(0), -1)
        features = self.feature_reducer(features)
        return features

# 改进的注意力层
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads=8):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.head_dim = hidden_dim // num_heads
        
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.output = nn.Linear(hidden_dim, hidden_dim)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()
        
        # 计算Q, K, V
        Q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # 计算注意力分数
        scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_weights = torch.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        
        # 应用注意力
        attended = torch.matmul(attention_weights, V)
        attended = attended.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_dim)
        
        # 输出投影
        output = self.output(attended)
        
        # 全局平均池化
        global_attended = torch.mean(output, dim=1)
        
        return global_attended, attention_weights.mean(dim=1)

# 优化的深度伪造检测模型
class OptimizedDeepfakeDetector(nn.Module):
    def __init__(self, num_classes=1, hidden_dim=512, num_layers=3, dropout=0.3, backbone='resnet50'):
        super(OptimizedDeepfakeDetector, self).__init__()
        
        # 改进的CNN特征提取器
        self.cnn = ImprovedCNNFeatureExtractor(pretrained=True, backbone=backbone)
        
        # 双向LSTM层
        self.lstm = nn.LSTM(
            input_size=self.cnn.output_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        
        # 多头注意力机制
        self.attention = MultiHeadAttention(hidden_dim * 2, num_heads=8)
        
        # 改进的分类器
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout // 2),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(dropout // 4),
            nn.Linear(hidden_dim // 4, num_classes)
        )
        
        # 初始化权重
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LSTM):
                for name, param in m.named_parameters():
                    if 'weight' in name:
                        nn.init.xavier_uniform_(param)
                    elif 'bias' in name:
                        nn.init.constant_(param, 0)
    
    def forward(self, x):
        batch_size, seq_len, channels, height, width = x.size()
        
        # CNN特征提取
        x = x.view(batch_size * seq_len, channels, height, width)
        features = self.cnn(x)
        features = features.view(batch_size, seq_len, -1)
        
        # LSTM处理
        lstm_out, _ = self.lstm(features)
        
        # 多头注意力机制
        attended, attention_weights = self.attention(lstm_out)
        
        # 分类
        output = self.classifier(attended)
        
        return torch.sigmoid(output.squeeze()), attention_weights

# 集成模型
class EnsembleDeepfakeDetector(nn.Module):
    def __init__(self, num_classes=1, hidden_dim=512):
        super(EnsembleDeepfakeDetector, self).__init__()
        
        # 多个不同的模型
        self.model1 = OptimizedDeepfakeDetector(num_classes, hidden_dim, backbone='resnet50')
        self.model2 = OptimizedDeepfakeDetector(num_classes, hidden_dim//2, backbone='efficientnet')
        
        # 融合层
        self.fusion = nn.Sequential(
            nn.Linear(2, 16),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(16, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        out1, _ = self.model1(x)
        out2, _ = self.model2(x)
        
        # 融合预测结果
        combined = torch.stack([out1, out2], dim=1)
        final_output = self.fusion(combined)
        
        return final_output.squeeze(), None

# 模型创建函数
def create_optimized_model(model_type='optimized', hidden_dim=512, backbone='resnet50'):
    """
    创建优化后的模型
    
    Args:
        model_type: 'optimized' 或 'ensemble'
        hidden_dim: 隐藏层维度
        backbone: CNN骨干网络
    """
    if model_type == 'ensemble':
        model = EnsembleDeepfakeDetector(num_classes=1, hidden_dim=hidden_dim)
    else:
        model = OptimizedDeepfakeDetector(
            num_classes=1, 
            hidden_dim=hidden_dim, 
            num_layers=3, 
            dropout=0.3,
            backbone=backbone
        )
    
    return model

print("✅ 模型定义完成")

## 5.训练和验证函数

In [None]:
# =============================================================================
# 第5段：训练和验证函数
# =============================================================================

# 改进的数据增强
def get_enhanced_transforms(is_training=True):
    if is_training:
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((144, 144)),  # 稍大一些然后裁剪
            transforms.RandomCrop((128, 128)),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomRotation(degrees=5),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.1, scale=(0.02, 0.1))
        ])
    else:
        return transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

# 焦点损失函数
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        bce_loss = nn.BCELoss(reduction='none')(inputs, targets)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# 改进的训练函数
def train_epoch(model, train_loader, criterion, optimizer, device, epoch, total_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # 使用tqdm显示进度
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{total_epochs} [Train]')
    
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # 前向传播
        output, _ = model(data)
        loss = criterion(output, target)
        
        # 反向传播
        loss.backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # 统计
        running_loss += loss.item()
        predicted = (output > 0.5).float()
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        # 更新进度条
        pbar.set_postfix({
            'Loss': f'{running_loss/(batch_idx+1):.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    
    return epoch_loss, epoch_acc

# 改进的验证函数
def validate_epoch(model, val_loader, criterion, device, epoch, total_epochs):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_targets = []
    
    pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{total_epochs} [Val]')
    
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(pbar):
            data, target = data.to(device), target.to(device)
            
            output, _ = model(data)
            loss = criterion(output, target)
            
            running_loss += loss.item()
            predicted = (output > 0.5).float()
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # 收集预测和目标用于计算指标
            all_predictions.extend(output.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            
            pbar.set_postfix({
                'Loss': f'{running_loss/(batch_idx+1):.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100. * correct / total
    
    # 计算AUC
    try:
        auc_score = roc_auc_score(all_targets, all_predictions)
    except:
        auc_score = 0.0
    
    return epoch_loss, epoch_acc, auc_score

# 学习率调度器
class CosineAnnealingWarmRestarts(optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, T_0, T_mult=1, eta_min=0, last_epoch=-1):
        self.T_0 = T_0
        self.T_i = T_0
        self.T_mult = T_mult
        self.eta_min = eta_min
        self.T_cur = last_epoch
        super(CosineAnnealingWarmRestarts, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        return [self.eta_min + (base_lr - self.eta_min) * 
                (1 + np.cos(np.pi * self.T_cur / self.T_i)) / 2
                for base_lr in self.base_lrs]
    
    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.T_cur = epoch
        if epoch >= self.T_i:
            self.T_cur = 0
            self.T_i *= self.T_mult
        super(CosineAnnealingWarmRestarts, self).step(epoch)

# 早停机制
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
        
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

# 优化的训练主函数
def train_optimized_model(model, train_loader, val_loader, num_epochs=50, 
                         learning_rate=0.001, device='cuda', save_path='./models'):
    """
    优化的模型训练函数
    """
    # 使用焦点损失
    criterion = FocalLoss(alpha=1, gamma=2)
    
    # 使用AdamW优化器
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.01)
    
    # 学习率调度器
    scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
    
    # 早停机制
    early_stopping = EarlyStopping(patience=10, min_delta=0.001)
    
    # 记录训练历史
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [], 'val_auc': []
    }
    
    best_val_auc = 0.0
    
    print(f"开始训练，共 {num_epochs} 个epoch")
    print(f"设备: {device}")
    print(f"模型参数数量: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
    
    for epoch in range(num_epochs):
        # 训练
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device, epoch, num_epochs
        )
        
        # 验证
        val_loss, val_acc, val_auc = validate_epoch(
            model, val_loader, criterion, device, epoch, num_epochs
        )
        
        # 更新学习率
        scheduler.step()
        
        # 记录历史
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['val_auc'].append(val_auc)
        
        # 打印epoch结果
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, Val AUC: {val_auc:.4f}')
        print(f'  Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
        
        # 保存最佳模型
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_auc': val_auc,
                'history': history
            }, f'{save_path}/best_model.pth')
            print(f'  ✅ 保存最佳模型 (AUC: {val_auc:.4f})')
        
        # 早停检查
        if early_stopping(val_loss, model):
            print(f'早停触发，在第 {epoch+1} 个epoch停止训练')
            break
        
        print('-' * 60)
    
    return history

print("✅ 优化后的训练和验证函数定义完成")

##  6.模型训练

In [None]:
# =============================================================================
# 第6段：模型训练
# =============================================================================

import time
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau
from torch.cuda.amp import GradScaler, autocast

# 焦点损失函数 - 解决类别不平衡问题
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
    
    def forward(self, inputs, targets):
        ce_loss = nn.BCELoss(reduction='none')(inputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# 早停机制
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
        
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

# 改进的数据增强
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((144, 144)),  # 稍大一些，为裁剪留空间
    transforms.RandomCrop((128, 128)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.1, scale=(0.02, 0.1))  # 随机擦除
])

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 优化的训练函数
def train_epoch(model, train_loader, criterion, optimizer, device, scaler=None):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    # 创建进度条
    pbar = tqdm(train_loader, desc='Training', leave=False)
    
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # 使用混合精度训练
        if scaler is not None:
            with autocast():
                output, _ = model(data)
                loss = criterion(output, target)
            
            scaler.scale(loss).backward()
            # 梯度裁剪
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            output, _ = model(data)
            loss = criterion(output, target)
            loss.backward()
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        total_loss += loss.item()
        
        # 计算准确率
        predicted = (output > 0.5).float()
        total += target.size(0)
        correct += (predicted == target).sum().item()
        
        # 收集预测和目标用于AUC计算
        all_preds.extend(output.detach().cpu().numpy())
        all_targets.extend(target.detach().cpu().numpy())
        
        # 更新进度条
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / total
    
    # 计算AUC
    try:
        auc_score = roc_auc_score(all_targets, all_preds)
    except:
        auc_score = 0.0
    
    return avg_loss, accuracy, auc_score

# 优化的验证函数
def validate_epoch(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation', leave=False)
        
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            
            output, _ = model(data)
            loss = criterion(output, target)
            
            total_loss += loss.item()
            
            # 计算准确率
            predicted = (output > 0.5).float()
            total += target.size(0)
            correct += (predicted == target).sum().item()
            
            # 收集预测和目标
            all_preds.extend(output.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
            
            # 更新进度条
            pbar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{100.*correct/total:.2f}%'
            })
    
    avg_loss = total_loss / len(val_loader)
    accuracy = 100. * correct / total
    
    # 计算AUC
    try:
        auc_score = roc_auc_score(all_targets, all_preds)
    except:
        auc_score = 0.0
    
    return avg_loss, accuracy, auc_score

# 训练配置
TRAIN_CONFIG = {
    'batch_size': 8,  # 增加批次大小
    'learning_rate': 1e-4,  # 降低学习率
    'num_epochs': 50,  # 增加训练轮数
    'weight_decay': 1e-4,
    'patience': 10,  # 早停耐心值
    'use_focal_loss': True,
    'use_mixed_precision': True,
    'gradient_clip': 1.0
}

print("✅ 优化的模型训练函数定义完成")
print(f"训练配置: {TRAIN_CONFIG}")

## 7.执行训练循环

In [None]:
# =============================================================================
# 第7段：执行训练循环
# =============================================================================

# 添加必要的导入
import os
import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import cv2
from sklearn.metrics import roc_auc_score, accuracy_score
from torchvision import transforms
import random
from collections import Counter

# 直接定义DeepfakeVideoDataset类（适用于Kaggle环境）
class DeepfakeVideoDataset(Dataset):
    """深度伪造检测数据集（Kaggle兼容版本）"""
    def __init__(self, csv_file, transform=None, max_frames=30, mode='train'):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.max_frames = max_frames
        self.mode = mode
        
        print(f"数据集初始化完成: {len(self.data)} 个样本 ({mode} 模式)")
        if 'label' in self.data.columns:
            print(f"真实视频: {len(self.data[self.data['label']==0])} 个")
            print(f"伪造视频: {len(self.data[self.data['label']==1])} 个")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # 处理帧路径
        if 'frames' in row:
            try:
                frame_paths = eval(row['frames'])  # 将字符串转换为列表
                if isinstance(frame_paths, str):
                    frame_paths = [frame_paths]
            except:
                frame_paths = [row['frames']]
        elif 'frame_path' in row:
            frame_paths = [row['frame_path']]
        else:
            # 如果没有找到帧路径，创建随机帧
            frame_paths = []
        
        label = float(row['label']) if 'label' in row else 0.0
        
        # 加载帧
        frames = []
        for i, frame_path in enumerate(frame_paths[:self.max_frames]):
            try:
                if os.path.exists(frame_path):
                    frame = cv2.imread(frame_path)
                    if frame is not None:
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        frame = cv2.resize(frame, (224, 224))
                    else:
                        frame = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
                else:
                    frame = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
            except:
                frame = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
            
            if self.transform:
                frame = self.transform(frame)
            else:
                frame = torch.tensor(frame, dtype=torch.float32).permute(2, 0, 1) / 255.0
            
            frames.append(frame)
        
        # 如果帧数不足，用零填充或重复最后一帧
        while len(frames) < self.max_frames:
            if frames:
                frames.append(frames[-1].clone())
            else:
                # 如果没有帧，创建零张量
                if self.transform:
                    dummy_frame = np.zeros((224, 224, 3), dtype=np.uint8)
                    frames.append(self.transform(dummy_frame))
                else:
                    frames.append(torch.zeros(3, 224, 224))
        
        # 将帧堆叠成张量
        frames_tensor = torch.stack(frames[:self.max_frames])
        
        return frames_tensor, torch.tensor(label, dtype=torch.float32)

# 定义简化的模型类（如果需要）
class OptimizedDeepfakeDetector(nn.Module):
    """优化的深度伪造检测模型"""
    def __init__(self, backbone='resnet50', hidden_dim=512, num_layers=2, dropout=0.3, num_heads=8):
        super().__init__()
        
        # 使用预训练的ResNet作为特征提取器
        if backbone == 'resnet50':
            from torchvision.models import resnet50
            self.backbone = resnet50(pretrained=True)
            self.backbone.fc = nn.Identity()  # 移除最后的分类层
            feature_dim = 2048
        else:
            # 简化的CNN特征提取器
            self.backbone = nn.Sequential(
                nn.Conv2d(3, 64, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(64, 128, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(128, 256, 3, padding=1),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten()
            )
            feature_dim = 256
        
        # 时序处理
        self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        
        # 注意力机制
        self.attention = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout, batch_first=True)
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        batch_size, seq_len, c, h, w = x.shape
        
        # 提取每帧特征
        x = x.view(-1, c, h, w)
        features = self.backbone(x)
        features = features.view(batch_size, seq_len, -1)
        
        # LSTM处理时序信息
        lstm_out, _ = self.lstm(features)
        
        # 注意力机制
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        
        # 全局平均池化
        pooled = torch.mean(attn_out, dim=1)
        
        # 分类
        output = self.classifier(pooled)
        
        return output.squeeze(-1)

# 定义损失函数
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        bce_loss = nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * bce_loss
        return focal_loss.mean()

# 早停机制
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
    
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
        
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

# 训练函数
def train_epoch(model, dataloader, criterion, optimizer, device, scaler=None):
    model.train()
    total_loss = 0
    all_preds = []
    all_targets = []
    
    for batch_idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        if scaler:
            with torch.cuda.amp.autocast():
                output = model(data)
                loss = criterion(output, target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        
        total_loss += loss.item()
        all_preds.extend(output.detach().cpu().numpy())
        all_targets.extend(target.detach().cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(np.array(all_targets) > 0.5, np.array(all_preds) > 0.5) * 100
    auc = roc_auc_score(all_targets, all_preds)
    
    return avg_loss, accuracy, auc

# 验证函数
def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            total_loss += loss.item()
            all_preds.extend(output.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(np.array(all_targets) > 0.5, np.array(all_preds) > 0.5) * 100
    auc = roc_auc_score(all_targets, all_preds)
    
    return avg_loss, accuracy, auc

# 定义训练配置（保持您的原始参数）
TRAIN_CONFIG = {
    'batch_size': 16,
    'learning_rate': 1e-4,
    'weight_decay': 1e-4,
    'num_epochs': 50,
    'patience': 10,
    'use_focal_loss': True,
    'use_mixed_precision': True
}

# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

# 数据变换（保持您的原始配置）
train_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 创建必要的目录
os.makedirs('./models', exist_ok=True)
os.makedirs('./logs', exist_ok=True)
os.makedirs('./results', exist_ok=True)

# 创建优化的数据集和数据加载器
print("创建数据集和数据加载器...")

# 使用改进的数据增强
train_dataset = DeepfakeVideoDataset('./data/train.csv', transform=train_transform, max_frames=30)
val_dataset = DeepfakeVideoDataset('./data/val.csv', transform=val_transform, max_frames=30)

# ... existing code ...

## 8.模型评估

In [None]:
# =============================================================================
# 第8段：模型评估
# =============================================================================

import time
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, confusion_matrix, classification_report,
    roc_curve, auc, precision_recall_curve, average_precision_score
)
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle

# 优化的评估函数
def evaluate_model_optimized(model, test_loader, criterion, device, save_attention=True):
    """
    优化的模型评估函数，包含更全面的指标和分析
    """
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_targets = []
    all_scores = []
    all_attention_weights = []
    inference_times = []
    
    with torch.no_grad():
        progress_bar = tqdm(test_loader, desc='模型评估中')
        
        for inputs, labels in progress_bar:
            inputs = inputs.to(device)
            labels = labels.float().to(device)
            
            # 记录推理时间
            start_time = time.time()
            
            # 前向传播
            if hasattr(model, 'forward') and len(inspect.signature(model.forward).parameters) > 1:
                outputs, attention_weights = model(inputs)
                if save_attention:
                    all_attention_weights.append(attention_weights.cpu().numpy())
            else:
                outputs = model(inputs)
                attention_weights = None
            
            inference_time = time.time() - start_time
            inference_times.append(inference_time)
            
            outputs = outputs.squeeze()
            
            # 计算损失
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            
            # 收集预测和目标
            preds = (outputs > 0.5).float().cpu().numpy()
            all_predictions.extend(preds)
            all_targets.extend(labels.cpu().numpy())
            all_scores.extend(outputs.cpu().numpy())
            
            # 更新进度条
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_time': f'{np.mean(inference_times):.3f}s'
            })
    
    # 计算平均损失和推理时间
    test_loss = running_loss / len(test_loader.dataset)
    avg_inference_time = np.mean(inference_times)
    
    return {
        'loss': test_loss,
        'predictions': all_predictions,
        'targets': all_targets,
        'scores': all_scores,
        'attention_weights': all_attention_weights,
        'avg_inference_time': avg_inference_time,
        'total_inference_time': sum(inference_times)
    }

# 计算全面的评估指标
def calculate_comprehensive_metrics(predictions, targets, scores):
    """
    计算更全面的评估指标
    """
    # 基础指标
    accuracy = accuracy_score(targets, predictions)
    precision = precision_score(targets, predictions, zero_division=0)
    recall = recall_score(targets, predictions, zero_division=0)
    f1 = f1_score(targets, predictions, zero_division=0)
    
    # ROC和PR曲线指标
    if len(set(targets)) > 1:
        auc_roc = roc_auc_score(targets, scores)
        precision_vals, recall_vals, _ = precision_recall_curve(targets, scores)
        auc_pr = average_precision_score(targets, scores)
    else:
        auc_roc = 0.0
        auc_pr = 0.0
    
    # 混淆矩阵
    cm = confusion_matrix(targets, predictions)
    tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
    
    # 额外指标
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0  # 负预测值
    balanced_accuracy = (recall + specificity) / 2
    
    # 打印详细结果
    print("=" * 60)
    print("📊 模型评估结果")
    print("=" * 60)
    print(f"🎯 准确率 (Accuracy): {accuracy:.4f}")
    print(f"🎯 平衡准确率 (Balanced Accuracy): {balanced_accuracy:.4f}")
    print(f"🔍 精确率 (Precision): {precision:.4f}")
    print(f"🔍 召回率 (Recall/Sensitivity): {recall:.4f}")
    print(f"🔍 特异性 (Specificity): {specificity:.4f}")
    print(f"🔍 负预测值 (NPV): {npv:.4f}")
    print(f"⚖️ F1分数: {f1:.4f}")
    print(f"📈 AUC-ROC: {auc_roc:.4f}")
    print(f"📈 AUC-PR: {auc_pr:.4f}")
    print("=" * 60)
    
    # 混淆矩阵详情
    print("\n📋 混淆矩阵详情:")
    print(f"真阴性 (TN): {tn}")
    print(f"假阳性 (FP): {fp}")
    print(f"假阴性 (FN): {fn}")
    print(f"真阳性 (TP): {tp}")
    
    # 分类报告
    print("\n📊 详细分类报告:")
    print(classification_report(targets, predictions, 
                             target_names=['真实视频', '伪造视频'],
                             digits=4))
    
    return {
        'accuracy': accuracy,
        'balanced_accuracy': balanced_accuracy,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'npv': npv,
        'f1': f1,
        'auc_roc': auc_roc,
        'auc_pr': auc_pr,
        'confusion_matrix': cm,
        'tn': tn, 'fp': fp, 'fn': fn, 'tp': tp
    }

# 增强的可视化函数
def plot_enhanced_confusion_matrix(cm, save_path=None):
    """
    绘制增强的混淆矩阵
    """
    plt.figure(figsize=(10, 8))
    
    # 计算百分比
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    # 创建标注
    annotations = []
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            annotations.append(f'{cm[i,j]}\n({cm_percent[i,j]:.1f}%)')
    
    annotations = np.array(annotations).reshape(cm.shape)
    
    # 绘制热力图
    sns.heatmap(cm, annot=annotations, fmt='', cmap='Blues', 
                cbar_kws={'label': '样本数量'},
                xticklabels=['真实视频', '伪造视频'], 
                yticklabels=['真实视频', '伪造视频'])
    
    plt.xlabel('预测标签', fontsize=12, fontweight='bold')
    plt.ylabel('真实标签', fontsize=12, fontweight='bold')
    plt.title('混淆矩阵 (数量和百分比)', fontsize=14, fontweight='bold')
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✅ 混淆矩阵已保存到: {save_path}")
    
    plt.show()

# 绘制ROC和PR曲线
def plot_roc_pr_curves(targets, scores, save_path=None):
    """
    同时绘制ROC曲线和PR曲线
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # ROC曲线
    fpr, tpr, _ = roc_curve(targets, scores)
    roc_auc = auc(fpr, tpr)
    
    ax1.plot(fpr, tpr, color='darkorange', lw=2, 
             label=f'ROC曲线 (AUC = {roc_auc:.4f})')
    ax1.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.8)
    ax1.set_xlim([0.0, 1.0])
    ax1.set_ylim([0.0, 1.05])
    ax1.set_xlabel('假阳性率 (FPR)', fontweight='bold')
    ax1.set_ylabel('真阳性率 (TPR)', fontweight='bold')
    ax1.set_title('ROC曲线', fontsize=14, fontweight='bold')
    ax1.legend(loc='lower right')
    ax1.grid(True, linestyle='--', alpha=0.7)
    
    # PR曲线
    precision_vals, recall_vals, _ = precision_recall_curve(targets, scores)
    pr_auc = average_precision_score(targets, scores)
    
    ax2.plot(recall_vals, precision_vals, color='darkgreen', lw=2,
             label=f'PR曲线 (AUC = {pr_auc:.4f})')
    ax2.axhline(y=np.mean(targets), color='navy', linestyle='--', alpha=0.8,
                label=f'基线 ({np.mean(targets):.3f})')
    ax2.set_xlim([0.0, 1.0])
    ax2.set_ylim([0.0, 1.05])
    ax2.set_xlabel('召回率 (Recall)', fontweight='bold')
    ax2.set_ylabel('精确率 (Precision)', fontweight='bold')
    ax2.set_title('精确率-召回率曲线', fontsize=14, fontweight='bold')
    ax2.legend(loc='lower left')
    ax2.grid(True, linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"✅ ROC和PR曲线已保存到: {save_path}")
    
    plt.show()

# 注意力权重可视化
def visualize_attention_weights(attention_weights, save_path=None, num_samples=5):
    """
    可视化注意力权重
    """
    if not attention_weights:
        print("⚠️ 没有注意力权重数据")
        return
    
    # 选择前几个样本进行可视化
    num_samples = min(num_samples, len(attention_weights))
    
    fig, axes = plt.subplots(num_samples, 1, figsize=(12, 2*num_samples))
    if num_samples == 1:
        axes = [axes]
    
    for i in range(num_samples):
        weights = attention_weights[i].squeeze()
        if len(weights.shape) > 1:
            weights = weights.mean(axis=0)  # 如果是多头注意力，取平均
        
        axes[i].bar(range(len(weights)), weights, alpha=0.7)
        axes[i].set_title(f'样本 {i+1} 的注意力权重分布')
        axes[i].set_xlabel('帧序号')
        axes[i].set_ylabel('注意力权重')
        axes[i].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')  
        print(f"✅ 注意力权重可视化已保存到: {save_path}")
    
    plt.show()

print("✅ 优化的模型评估模块已定义")

## 9.执行模型评估

In [None]:
# =============================================================================
# Kaggle 兼容的深度伪造检测模型评估脚本
# =============================================================================

import os
import json
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, 
    roc_auc_score, confusion_matrix, classification_report, 
    roc_curve, auc, average_precision_score, precision_recall_curve
)
import cv2
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

# 定义数据转换
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 数据集类定义
class DeepfakeVideoDataset(Dataset):
    """深度伪造检测数据集（Kaggle兼容版本）"""
    def __init__(self, csv_file, transform=None, max_frames=30, mode='train'):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.max_frames = max_frames
        self.mode = mode
        
        print(f"数据集初始化完成: {len(self.data)} 个样本 ({mode} 模式)")
        if 'label' in self.data.columns:
            print(f"真实视频: {len(self.data[self.data['label']==0])} 个")
            print(f"伪造视频: {len(self.data[self.data['label']==1])} 个")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        
        # 处理帧路径
        if 'frames' in row:
            try:
                frame_paths = eval(row['frames'])  # 将字符串转换为列表
                if isinstance(frame_paths, str):
                    frame_paths = [frame_paths]
            except:
                frame_paths = [row['frames']]
        elif 'frame_path' in row:
            frame_paths = [row['frame_path']]
        else:
            frame_paths = []
        
        label = float(row['label']) if 'label' in row else 0.0
        
        # 加载帧
        frames = []
        for i, frame_path in enumerate(frame_paths[:self.max_frames]):
            try:
                if os.path.exists(frame_path):
                    frame = cv2.imread(frame_path)
                    if frame is not None:
                        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        frame = cv2.resize(frame, (224, 224))
                    else:
                        frame = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
                else:
                    frame = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
            except:
                frame = np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8)
            
            if self.transform:
                frame = self.transform(frame)
            else:
                frame = torch.tensor(frame, dtype=torch.float32).permute(2, 0, 1) / 255.0
            
            frames.append(frame)
        
        # 如果帧数不足，用零填充或重复最后一帧
        while len(frames) < self.max_frames:
            if frames:
                frames.append(frames[-1].clone())
            else:
                if self.transform:
                    dummy_frame = np.zeros((224, 224, 3), dtype=np.uint8)
                    frames.append(self.transform(dummy_frame))
                else:
                    frames.append(torch.zeros(3, 224, 224))
        
        # 将帧堆叠成张量
        frames_tensor = torch.stack(frames[:self.max_frames])
        
        return frames_tensor, torch.tensor(label, dtype=torch.float32)

# 模型定义
class OptimizedDeepfakeDetector(nn.Module):
    """优化的深度伪造检测模型"""
    def __init__(self, backbone='resnet50', hidden_dim=512, num_layers=2, dropout=0.3, use_attention=True):
        super().__init__()
        
        # 使用预训练的ResNet作为特征提取器
        if backbone == 'resnet50':
            self.backbone = models.resnet50(pretrained=True)
            self.backbone.fc = nn.Identity()  # 移除最后的分类层
            feature_dim = 2048
        else:
            # 简化的CNN特征提取器
            self.backbone = nn.Sequential(
                nn.Conv2d(3, 64, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(64, 128, 3, padding=1),
                nn.ReLU(),
                nn.MaxPool2d(2),
                nn.Conv2d(128, 256, 3, padding=1),
                nn.ReLU(),
                nn.AdaptiveAvgPool2d((1, 1)),
                nn.Flatten()
            )
            feature_dim = 256
        
        # 时序处理
        self.lstm = nn.LSTM(feature_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout)
        
        # 注意力机制
        self.use_attention = use_attention
        if use_attention:
            self.attention = nn.MultiheadAttention(hidden_dim, 8, dropout=dropout, batch_first=True)
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        batch_size, seq_len, c, h, w = x.shape
        
        # 提取每帧特征
        x = x.view(-1, c, h, w)
        features = self.backbone(x)
        features = features.view(batch_size, seq_len, -1)
        
        # LSTM处理时序信息
        lstm_out, _ = self.lstm(features)
        
        # 注意力机制
        if self.use_attention:
            attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
            pooled = torch.mean(attn_out, dim=1)
        else:
            pooled = torch.mean(lstm_out, dim=1)
        
        # 分类
        output = self.classifier(pooled)
        
        return output.squeeze(-1)

# 优化的模型评估函数
def evaluate_model_optimized(model, test_loader, criterion, device):
    """优化的模型评估函数"""
    model.eval()
    running_loss = 0.0
    predictions = []
    targets = []
    scores = []
    
    total_inference_time = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(test_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # 记录推理时间
            import time
            start_time = time.time()
            
            # 前向传播
            outputs = model(inputs)
            
            batch_time = time.time() - start_time
            total_inference_time += batch_time
            
            if len(outputs.shape) > 1:
                outputs = outputs.squeeze()
            
            # 计算损失
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            
            # 收集预测结果
            batch_preds = (outputs > 0.5).float().cpu().numpy()
            batch_targets = labels.cpu().numpy()
            batch_scores = outputs.cpu().numpy()
            
            predictions.extend(batch_preds)
            targets.extend(batch_targets)
            scores.extend(batch_scores)
    
    avg_loss = running_loss / len(test_loader.dataset)
    avg_inference_time = total_inference_time / len(test_loader.dataset)
    
    return {
        'loss': avg_loss,
        'predictions': np.array(predictions),
        'targets': np.array(targets),
        'scores': np.array(scores),
        'total_inference_time': total_inference_time,
        'avg_inference_time': avg_inference_time
    }

# 计算全面的评估指标
def calculate_comprehensive_metrics(predictions, targets, scores):
    """计算全面的评估指标"""
    # 基础指标
    accuracy = accuracy_score(targets, predictions)
    precision = precision_score(targets, predictions, zero_division=0)
    recall = recall_score(targets, predictions, zero_division=0)
    f1 = f1_score(targets, predictions, zero_division=0)
    
    # 混淆矩阵
    cm = confusion_matrix(targets, predictions)
    tn, fp, fn, tp = cm.ravel()
    
    # 额外指标
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    npv = tn / (tn + fn) if (tn + fn) > 0 else 0  # 负预测值
    balanced_accuracy = (recall + specificity) / 2
    
    # AUC指标
    try:
        auc_roc = roc_auc_score(targets, scores)
        auc_pr = average_precision_score(targets, scores)
    except:
        auc_roc = 0.0
        auc_pr = 0.0
    
    return {
        'accuracy': accuracy,
        'balanced_accuracy': balanced_accuracy,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'f1': f1,
        'npv': npv,
        'auc_roc': auc_roc,
        'auc_pr': auc_pr,
        'confusion_matrix': cm,
        'tn': tn, 'fp': fp, 'fn': fn, 'tp': tp
    }

# 增强的混淆矩阵可视化
def plot_enhanced_confusion_matrix(cm, save_path):
    """绘制增强的混淆矩阵"""
    plt.figure(figsize=(10, 8))
    
    # 计算百分比
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    # 创建标签
    labels = np.array([[
        f'{cm[i,j]}\n({cm_percent[i,j]:.1f}%)' 
        for j in range(cm.shape[1])
    ] for i in range(cm.shape[0])])
    
    # 绘制热力图
    sns.heatmap(cm, annot=labels, fmt='', cmap='Blues', 
                xticklabels=['真实', '伪造'], 
                yticklabels=['真实', '伪造'],
                cbar_kws={'label': '样本数量'})
    
    plt.xlabel('预测标签', fontsize=12)
    plt.ylabel('真实标签', fontsize=12)
    plt.title('增强混淆矩阵', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

# ROC和PR曲线
def plot_roc_pr_curves(targets, scores, save_path):
    """绘制ROC和PR曲线"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # ROC曲线
    fpr, tpr, _ = roc_curve(targets, scores)
    roc_auc = auc(fpr, tpr)
    
    ax1.plot(fpr, tpr, color='darkorange', lw=2, 
             label=f'ROC曲线 (AUC = {roc_auc:.4f})')
    ax1.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    ax1.set_xlim([0.0, 1.0])
    ax1.set_ylim([0.0, 1.05])
    ax1.set_xlabel('假阳性率')
    ax1.set_ylabel('真阳性率')
    ax1.set_title('ROC曲线')
    ax1.legend(loc='lower right')
    ax1.grid(True, alpha=0.3)
    
    # PR曲线
    precision_curve, recall_curve, _ = precision_recall_curve(targets, scores)
    pr_auc = auc(recall_curve, precision_curve)
    
    ax2.plot(recall_curve, precision_curve, color='darkgreen', lw=2,
             label=f'PR曲线 (AUC = {pr_auc:.4f})')
    ax2.set_xlim([0.0, 1.0])
    ax2.set_ylim([0.0, 1.05])
    ax2.set_xlabel('召回率')
    ax2.set_ylabel('精确率')
    ax2.set_title('精确率-召回率曲线')
    ax2.legend(loc='lower left')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.close()

# 主评估流程
print("🔄 创建优化的测试数据加载器...")

# 使用更大的批次大小以提高效率
optimized_batch_size = 16 if torch.cuda.is_available() else 8

test_dataset = DeepfakeVideoDataset(
    csv_file='./data/val.csv',
    transform=transform,
    max_frames=30
)

test_loader = DataLoader(
    test_dataset,
    batch_size=optimized_batch_size,
    shuffle=False,
    num_workers=2,  # 减少worker数量以适应Kaggle环境
    pin_memory=True
)

print(f"📊 测试集大小: {len(test_dataset)} 个样本")
print(f"🔧 批次大小: {optimized_batch_size}")
print(f"🔧 批次数量: {len(test_loader)}")

# 加载最佳模型
print("\n🤖 加载训练好的模型...")

try:
    checkpoint = torch.load('./models/best_model.pth', 
                          map_location=device, 
                          weights_only=False)
except TypeError:
    checkpoint = torch.load('./models/best_model.pth', map_location=device)

# 创建模型并加载权重
model = OptimizedDeepfakeDetector(
    backbone='resnet50',
    hidden_dim=512,
    num_layers=2,
    dropout=0.3,
    use_attention=True
).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
print(f"✅ 模型加载成功")
print(f"📈 最佳验证准确率: {checkpoint.get('best_val_acc', 'N/A')}")
print(f"🔄 训练轮数: {checkpoint.get('epoch', 'N/A')}")

# 记录评估开始时间
eval_start_time = datetime.now()
print(f"\n⏰ 评估开始时间: {eval_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

# 执行优化的模型评估
print("\n🚀 开始执行模型评估...")
print("=" * 60)

criterion = nn.BCELoss()

# 执行评估
eval_results = evaluate_model_optimized(
    model=model,
    test_loader=test_loader,
    criterion=criterion,
    device=device
)

# 计算全面的评估指标
metrics = calculate_comprehensive_metrics(
    predictions=eval_results['predictions'],
    targets=eval_results['targets'],
    scores=eval_results['scores']
)

# 性能分析
eval_end_time = datetime.now()
total_eval_time = (eval_end_time - eval_start_time).total_seconds()

print(f"\n⏱️ 评估性能分析:")
print(f"总评估时间: {total_eval_time:.2f} 秒")
print(f"平均每样本推理时间: {eval_results['avg_inference_time']*1000:.2f} ms")
print(f"推理吞吐量: {len(test_dataset)/eval_results['total_inference_time']:.1f} 样本/秒")

# 打印详细评估结果
print(f"\n📊 详细评估结果:")
print(f"测试损失: {eval_results['loss']:.4f}")
print(f"准确率: {metrics['accuracy']:.4f}")
print(f"平衡准确率: {metrics['balanced_accuracy']:.4f}")
print(f"精确率: {metrics['precision']:.4f}")
print(f"召回率: {metrics['recall']:.4f}")
print(f"特异性: {metrics['specificity']:.4f}")
print(f"F1分数: {metrics['f1']:.4f}")
print(f"AUC-ROC: {metrics['auc_roc']:.4f}")
print(f"AUC-PR: {metrics['auc_pr']:.4f}")
print(f"负预测值: {metrics['npv']:.4f}")

# 创建结果目录
os.makedirs('./results/evaluation', exist_ok=True)

# 绘制增强的可视化图表
print("\n📊 生成评估图表...")

# 1. 增强的混淆矩阵
plot_enhanced_confusion_matrix(
    cm=metrics['confusion_matrix'],
    save_path='./results/evaluation/enhanced_confusion_matrix.png'
)

# 2. ROC和PR曲线
plot_roc_pr_curves(
    targets=eval_results['targets'],
    scores=eval_results['scores'],
    save_path='./results/evaluation/roc_pr_curves.png'
)

# 保存详细的评估报告
detailed_report = {
    'evaluation_info': {
        'timestamp': eval_start_time.isoformat(),
        'model_path': './models/best_model.pth',
        'test_dataset_size': len(test_dataset),
        'batch_size': optimized_batch_size
    },
    'performance_metrics': {
        'test_loss': float(eval_results['loss']),
        'accuracy': float(metrics['accuracy']),
        'balanced_accuracy': float(metrics['balanced_accuracy']),
        'precision': float(metrics['precision']),
        'recall': float(metrics['recall']),
        'specificity': float(metrics['specificity']),
        'f1_score': float(metrics['f1']),
        'auc_roc': float(metrics['auc_roc']),
        'auc_pr': float(metrics['auc_pr']),
        'npv': float(metrics['npv'])
    },
    'confusion_matrix': {
        'true_negative': int(metrics['tn']),
        'false_positive': int(metrics['fp']),
        'false_negative': int(metrics['fn']),
        'true_positive': int(metrics['tp'])
    },
    'performance_analysis': {
        'total_evaluation_time_seconds': total_eval_time,
        'average_inference_time_ms': eval_results['avg_inference_time'] * 1000,
        'throughput_samples_per_second': len(test_dataset) / eval_results['total_inference_time']
    }
}

# 保存JSON格式的详细报告
with open('./results/evaluation/detailed_evaluation_report.json', 'w', encoding='utf-8') as f:
    json.dump(detailed_report, f, indent=2, ensure_ascii=False)

# 保存CSV格式的简要报告
summary_df = pd.DataFrame([{
    '评估时间': eval_start_time.strftime('%Y-%m-%d %H:%M:%S'),
    '测试损失': f"{eval_results['loss']:.4f}",
    '准确率': f"{metrics['accuracy']:.4f}",
    '平衡准确率': f"{metrics['balanced_accuracy']:.4f}",
    '精确率': f"{metrics['precision']:.4f}",
    '召回率': f"{metrics['recall']:.4f}",
    'F1分数': f"{metrics['f1']:.4f}",
    'AUC-ROC': f"{metrics['auc_roc']:.4f}",
    'AUC-PR': f"{metrics['auc_pr']:.4f}",
    '推理时间(ms)': f"{eval_results['avg_inference_time']*1000:.2f}",
    '吞吐量(样本/秒)': f"{len(test_dataset)/eval_results['total_inference_time']:.1f}"
}])

summary_df.to_csv('./results/evaluation/evaluation_summary.csv', index=False, encoding='utf-8')

print("\n📁 评估结果已保存到:")
print("  📊 ./results/evaluation/enhanced_confusion_matrix.png")
print("  📈 ./results/evaluation/roc_pr_curves.png")
print("  📋 ./results/evaluation/detailed_evaluation_report.json")
print("  📊 ./results/evaluation/evaluation_summary.csv")

print("\n🎉 模型评估完成！")
print("=" * 60)