In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
from torchvision import models, transforms
import pandas as pd
import cv2
from PIL import Image
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader


# GPU 디바이스 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Early Stopping 클래스
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, checkpoint_path='./best_model.pth'):
        self.patience = patience
        self.verbose = verbose
        self.checkpoint_path = checkpoint_path
        self.best_loss = float('inf')
        self.counter = 0

    def __call__(self, loss, model):
        if loss < self.best_loss:
            self.best_loss = loss
            self.counter = 0
            torch.save(model.state_dict(), self.checkpoint_path)  # 체크포인트 저장
            if self.verbose:
                print(f"Validation loss decreased. Saving model to {self.checkpoint_path}")
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                print("Early stopping triggered")
                return True
        return False

# 데이터셋 클래스 정의
class DeepFakeDataset(Dataset):
    def __init__(self, csv_path, root_dir, transform=None, frame_num=1):
        self.data = pd.read_csv(csv_path)
        self.root_dir = root_dir
        self.transform = transform
        self.frame_num = frame_num
        self.data['label'] = self.data['label'].apply(lambda x: 0 if x == 'FAKE' else 1)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        video_path = os.path.join(self.root_dir, self.data.iloc[idx, 0])
        label = self.data.iloc[idx, 1]
        cap = cv2.VideoCapture(video_path)
        frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        target_frame = self.frame_num if self.frame_num <= frame_count else frame_count // 2
        cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
        ret, frame = cap.read()
        cap.release()

        if not ret:
            raise RuntimeError(f"Failed to read frame from video: {video_path}")
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(frame)

        if self.transform:
            image1 = self.transform(image)
            image2 = self.transform(image)
        return image1, image2, label

# 데이터 변환
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),
    transforms.RandomGrayscale(p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 데이터 로더 설정
csv_path = './metadata.csv'
root_dir = './train_data'
train_dataset = DeepFakeDataset(csv_path, root_dir, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# SimCLR 모델 정의
class SimCLRModel(nn.Module):
    def __init__(self, base_model):
        super(SimCLRModel, self).__init__()
        self.encoder = base_model
        self.projector = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, x):
        features = self.encoder(x)
        features = torch.flatten(features, start_dim=1)
        projected = self.projector(features)
        return projected

# ResNet18 기반 모델
base_model = models.resnet18(pretrained=True)
base_model = nn.Sequential(*list(base_model.children())[:-1])
model = SimCLRModel(base_model).to(device)  # 모델을 GPU로 이동

# NT-Xent 손실 함수
def nt_xent_loss(features, temperature=0.5):
    batch_size = features.shape[0]
    labels = torch.cat([torch.arange(batch_size // 2) for _ in range(2)]).to(device)
    similarity_matrix = F.cosine_similarity(features.unsqueeze(1), features.unsqueeze(0), dim=2)
    mask = torch.eye(batch_size, dtype=torch.bool).to(device)
    positives = similarity_matrix[mask].view(batch_size, -1)
    negatives = similarity_matrix[~mask].view(batch_size, -1)
    logits = torch.cat([positives, negatives], dim=1) / temperature
    labels = torch.zeros(batch_size, dtype=torch.long).to(device)
    return F.cross_entropy(logits, labels)

# 학습 설정
optimizer = optim.Adam(model.parameters(), lr=3e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.5, verbose=True)
early_stopping = EarlyStopping(patience=5, verbose=True, checkpoint_path='./best_model.pth')

# 학습 루프
for epoch in tqdm(range(1000)):
    model.train()
    total_loss = 0
    for img1, img2, _ in train_loader:
        img1, img2 = img1.to(device), img2.to(device)  # 데이터를 GPU로 이동
        features1 = model(img1)
        features2 = model(img2)
        features = torch.cat([features1, features2], dim=0)
        loss = nt_xent_loss(features)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/1000], Loss: {avg_loss:.4f}")
    scheduler.step(avg_loss)

    if early_stopping(avg_loss, model):
        print("Training stopped early")
        break


In [None]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import cv2

# GPU 디바이스 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# SimCLR 모델 정의
class SimCLRModel(nn.Module):
    def __init__(self, base_model):
        super(SimCLRModel, self).__init__()
        self.encoder = base_model
        self.projector = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, x):
        features = self.encoder(x)
        features = torch.flatten(features, start_dim=1)
        projected = self.projector(features)
        return projected

# ResNet18 기반 모델 생성
base_model = models.resnet18(pretrained=False)
base_model = nn.Sequential(*list(base_model.children())[:-1])
model = SimCLRModel(base_model).to(device)

# 체크포인트 로드
checkpoint_path = './best_model.pth'
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

# 데이터 전처리 정의
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 프레임 추출 및 변환
def process_video(video_path, frame_num=1):
    cap = cv2.VideoCapture(video_path)
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    target_frame = frame_num if frame_num <= frame_count else frame_count // 2
    cap.set(cv2.CAP_PROP_POS_FRAMES, target_frame)
    ret, frame = cap.read()
    cap.release()

    if not ret:
        raise RuntimeError(f"Failed to read frame from video: {video_path}")
    
    # BGR -> RGB 변환
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    image = Image.fromarray(frame)

    # 이미지 변환
    image = transform(image)
    return image.unsqueeze(0).to(device)  # 배치 차원 추가 후 GPU로 이동

# 영상 예측 함수
def predict(video_path):
    try:
        # 프레임 처리
        image = process_video(video_path)

        # 모델 추론
        with torch.no_grad():
            features = model.encoder(image)
            output = features.view(features.size(0), -1)  # Flatten
            prediction = torch.sigmoid(output).mean().item()

        # FAKE 또는 REAL로 변환
        label = "FAKE" if prediction < 0.5 else "REAL"
        return label, prediction
    except Exception as e:
        return f"Error: {e}", None

# 폴더 내 모든 영상에 대해 예측
def predict_folder(folder_path):
    results = []
    for file_name in os.listdir(folder_path):
        file_path = os.path.join(folder_path, file_name)
        if os.path.isfile(file_path) and file_name.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
            label, confidence = predict(file_path)
            results.append((file_name, label, confidence))
            print(f"{file_name}: {label} (Confidence: {confidence:.4f})")
    return results

# 결과 저장 함수 (선택 사항)
def save_results(results, output_file="results.csv"):
    import csv
    with open(output_file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(["File Name", "Prediction", "Confidence"])
        writer.writerows(results)
    print(f"Results saved to {output_file}")

# 테스트용 코드
folder_path = './train_data'  # 추론할 영상들이 있는 폴더 경로
results = predict_folder(folder_path)

# 결과를 CSV 파일로 저장
save_results(results, output_file="prediction_results.csv")
