<a href="https://colab.research.google.com/github/chiyeon01/CNN_Model_Mechanism/blob/main/pytorch/deepfake_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torchvision
!pip install torchmetrics
!pip install torchinfo

Collecting torchmetrics
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m59.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.15.2 torchmetrics-1.8.2
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0


In [2]:
!wget https://raw.githubusercontent.com/chiyeon01/CNN_Model_Mechanism/refs/heads/main/pytorch/utils.py

--2025-12-22 10:23:36--  https://raw.githubusercontent.com/chiyeon01/CNN_Model_Mechanism/refs/heads/main/pytorch/utils.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 10814 (11K) [text/plain]
Saving to: ‘utils.py’


2025-12-22 10:23:37 (152 MB/s) - ‘utils.py’ saved [10814/10814]



In [3]:
# Module import
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchinfo
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torch.optim import SGD, Adagrad, RMSprop, Adam, AdamW
import torchmetrics

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from PIL import Image
from utils import Trainer, Predictor, Custom_Dataset, create_pretrained_model

In [4]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("manjilkarki/deepfake-and-real-images")

print("Path to dataset files:", path)

Using Colab cache for faster access to the 'deepfake-and-real-images' dataset.
Path to dataset files: /kaggle/input/deepfake-and-real-images


In [5]:
# Kaggle DeepFake Dataset
image_paths = []
labels = []
gubuns = []

for dirname, _, filenames in os.walk("/kaggle/input/deepfake-and-real-images"):
    for filename in filenames:
        image_path = os.path.join(dirname, filename)
        image_paths.append(image_path)
        if 'Train' in image_path:
            gubuns.append('Train')
        elif 'Validation' in image_path:
            gubuns.append('Validation')
        elif 'Test' in image_path:
            gubuns.append('Test')

        if 'Real' in image_path:
            labels.append('Real')
        elif 'Fake' in image_path:
            labels.append('Fake')

print(f"학습 데이터 수: {len(image_paths)}")
print(f"타겟 수: {len(labels)}")

학습 데이터 수: 190335
타겟 수: 190335


In [6]:
# Train, Validation, Test 모두를 담는 DataFrame
deepfake_df = pd.DataFrame({
    'image_path': image_paths,
    'label': labels,
    'gubun': gubuns
})
deepfake_df['label'] = deepfake_df['label'].map({
    'Real': 0,
    'Fake': 1
})

# Train DataFrame
train_df = deepfake_df[deepfake_df['gubun'] == 'Train']

# Validation DataFrame
val_df = deepfake_df[deepfake_df['gubun'] == 'Validation']

# Test DataFrame
test_df = deepfake_df[deepfake_df['gubun'] == 'Test']

print(f"Train DataFrame Shape: {train_df.shape}")
print(f"Validation DataFrame Shape: {val_df.shape}")
print(f"Test DataFrame Shape: {test_df.shape}")
print(f"All DataFrame Shape: {deepfake_df.shape}")

Train DataFrame Shape: (140002, 3)
Validation DataFrame Shape: (39428, 3)
Test DataFrame Shape: (10905, 3)
All DataFrame Shape: (190335, 3)


In [7]:
# Config & Settings
class Config:
    batch_size = 32
    FineTune = False
    shuffle = False
    first_learning_rate = 1e-3
    second_learning_rate = 1e-4
    image_size = [224, 224] # 이미지는 3채널로 간주(image is 3 channels)
    model_name = 'efficientnet_b0'
    make_summary = False
    callbacks = []
    metric = torchmetrics.classification.Accuracy(task="multiclass", num_classes=10)
    classifier_layer = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(in_features=1280, out_features=2)
        )

In [8]:
# 모델 생성
model = create_pretrained_model(
    model_name = Config.model_name,
    classifier_layer = Config.classifier_layer,
    image_size=Config.image_size,
    make_summary = Config.make_summary
)

# albumentation 기반 transform 생성.
transform = A.Compose([
    A.Resize(*Config.image_size),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# dataset 생성
train_dataset = Custom_Dataset(train_df['image_path'].values, targets=train_df['label'].values, transform=transform)
val_dataset = Custom_Dataset(val_df['image_path'].values, targets=val_df['label'].values, transform=transform)
test_dataset = Custom_Dataset(test_df['image_path'].values, targets=test_df['label'].values, transform=transform)

# dataloader 생성
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, pin_memory=True)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=Config.first_learning_rate)

trainer = Trainer(model=model, loss_fn=loss_fn, metric=Config.metric, optimizer=optimizer)

history = trainer.fit(epochs=1, train_dataloader=train_dataloader, val_dataloader=val_dataloader)

Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:00<00:00, 231MB/s]
[Training...] : 100%|██████████| 8751/8751 [16:56<00:00,  8.61it/s, Train_Loss=0.0863, Train_Accuracy=0.5]
[Validating..] : 100%|██████████| 2465/2465 [03:38<00:00, 11.29it/s, Validate_Loss=0.0805, Validate_Accuracy=1]


In [9]:
# evaluate
trainer.evaluate(test_dataloader)

[Evaluating..] : 100%|██████████| 682/682 [01:02<00:00, 10.93it/s, Evaluate_Loss=0.21, Evaluate_Accuracy=0.889]


(0.21004895801393758, 0.8888888955116272)

In [18]:
# return trained model
model = trainer.get_trained_model()

In [12]:
# model save
torch.save(model.state_dict(), "model_weights.pt")

In [20]:
# model load
model = create_pretrained_model(
    model_name = Config.model_name,
    classifier_layer = Config.classifier_layer,
    image_size=Config.image_size,
    make_summary = Config.make_summary
)

model.load_state_dict(torch.load("model_weights.pt"))

<All keys matched successfully>

In [21]:
model

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNormActivat