In [1]:
'''
파이토치 쿠다설정
https://blog.naver.com/me_a_me/223570004477

python = 3.11.11
torch = 2.6.0
cuda = 12.4
cudnn = 9.1.0.70

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
'''

'\n파이토치 쿠다설정\nhttps://blog.naver.com/me_a_me/223570004477\n\npython = 3.11.11\ntorch = 2.6.0\ncuda = 12.4\ncudnn = 9.1.0.70\n\npip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124\n'

In [2]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("birdy654/cifake-real-and-ai-generated-synthetic-images")

print("Path to dataset files:", path)

Path to dataset files: C:\Users\qwer\.cache\kagglehub\datasets\birdy654\cifake-real-and-ai-generated-synthetic-images\versions\3


In [3]:
import os
import io
import timm
import cv2

from PIL import Image
from tqdm.auto import tqdm  # Progress bars
from PIL import ImageFile

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from collections import Counter


import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.datasets import ImageFolder

import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import Subset, random_split

import albumentations as A
from albumentations.pytorch import ToTensorV2

In [4]:
import os

dataset_base_path = path
dataset_path_train_real = os.path.join(dataset_base_path, 'train/REAL')
dataset_path_train_fake = os.path.join(dataset_base_path, 'train/FAKE')
dataset_path_test_real = os.path.join(dataset_base_path, 'test/REAL')
dataset_path_test_fake = os.path.join(dataset_base_path, 'test/FAKE')

# 경로에서 이미지 파일 목록 가져오기
image_files_train_real = [f for f in os.listdir(dataset_path_train_real) if f.endswith(('.jpg', '.png'))]
image_files_test_real = [f for f in os.listdir(dataset_path_test_real) if f.endswith(('.jpg', '.png'))]
image_files_train_fake = [f for f in os.listdir(dataset_path_train_fake) if f.endswith(('.jpg', '.png'))]
image_files_test_fake = [f for f in os.listdir(dataset_path_test_fake) if f.endswith(('.jpg', '.png'))]

print(f"train_REAL 이미지 파일 수: {len(image_files_train_real)}")
print(f"test_REAL 이미지 파일 수: {len(image_files_test_real)}")
print(f"train_FAKE 이미지 파일 수: {len(image_files_train_fake)}")
print(f"test_FAKE 이미지 파일 수: {len(image_files_test_fake)}")

# 목록 합치기
image_files = image_files_train_real + image_files_train_fake + image_files_test_real + image_files_test_fake

print(f"이미지 파일 수: {len(image_files)}")

# 이후 코드는 image_files를 사용하여 진행
# ...

train_REAL 이미지 파일 수: 50000
test_REAL 이미지 파일 수: 10000
train_FAKE 이미지 파일 수: 50000
test_FAKE 이미지 파일 수: 10000
이미지 파일 수: 120000


In [5]:
#Dataset transformations are specified here

IMG_SIZE = 224  # Swin Transformer input size

train_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE), # 크기변경
    A.HorizontalFlip(p=0.5), # 좌우 뒤집기
    A.RandomBrightnessContrast(p=0.2), # 밝기와 대비 변경
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # 정규화
    ToTensorV2(),
])

test_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2(),
])

In [6]:
#Custom dataloader
class CustomDataset(ImageFolder):
    def __init__(self, root, transform=None):
        super().__init__(root, transform=None)
        self.transform = transform

    def __getitem__(self, index):
        path, label = self.samples[index]
        image = Image.open(path).convert("RGB")
        image = np.array(image)
        if self.transform:
            image = self.transform(image=image)["image"]
        return image, label


In [7]:
train_dataset = CustomDataset(root=os.path.join(dataset_base_path, 'train'), transform=train_transform)
test_dataset = CustomDataset(root=os.path.join(dataset_base_path, 'test'), transform=test_transform)

# test와 validation 데이터 분리
fake_samples = [sample for sample in test_dataset.samples if sample[1] == test_dataset.class_to_idx['FAKE']]
real_samples = [sample for sample in test_dataset.samples if sample[1] == test_dataset.class_to_idx['REAL']]
# test dataset을 validation과 test로 분할 (각각 10000개) real 5000개 fake5000개
real_validation_size = 5000
fake_validation_size = 5000
real_test_size = 5000
fake_test_size  = 5000

fake_validation_dataset, fake_test_dataset = random_split(fake_samples, [fake_validation_size, fake_test_size])
real_validation_dataset, real_test_dataset = random_split(real_samples, [real_validation_size, real_test_size])

# Subset을 사용하여 데이터셋 생성
validation_samples = [fake_samples[i] for i in fake_validation_dataset.indices] + [real_samples[i] for i in real_validation_dataset.indices]
test_samples = [fake_samples[i] for i in fake_test_dataset.indices] + [real_samples[i] for i in real_test_dataset.indices]

validation_dataset = Subset(test_dataset, [test_dataset.samples.index(sample) for sample in validation_samples])
test_dataset = Subset(test_dataset, [test_dataset.samples.index(sample) for sample in test_samples])

# 빠른 테스트를 위해 1000개만 불러오기 (real과 fake데이터 비율 안맞아서 고쳐야됨)
# indices = list(range(2000))
# train_dataset = Subset(train_dataset, indices)
# validation_dataset = Subset(validation_dataset, indices)
# test_dataset = Subset(test_dataset, indices)

print(len(train_dataset))
print(len(test_dataset))
print(len(validation_dataset))

100000
10000
10000


In [8]:
# 클래스별 데이터 개수 확인
test_labels = [sample[1] for sample in test_dataset]
test_class_counts = Counter(test_labels)
validation_labels = [sample[1] for sample in validation_dataset]
validation_class_counts = Counter(validation_labels)

# 클래스 이름으로 출력
class_names = ["REAL", "FAKE"]#train_dataset.classes
test_class_counts_by_name = {class_names[label]: count for label, count in test_class_counts.items()}
print(f"test 데이터셋 클래스별 개수: {test_class_counts_by_name}")
validation_class_counts_by_name = {class_names[label]: count for label, count in validation_class_counts.items()}
print(f"Validation 데이터셋 클래스별 개수: {validation_class_counts_by_name}")

test 데이터셋 클래스별 개수: {'REAL': 5000, 'FAKE': 5000}
Validation 데이터셋 클래스별 개수: {'REAL': 5000, 'FAKE': 5000}


In [9]:
#이미지 메모리 제한 해제
Image.MAX_IMAGE_PIXELS = None

In [None]:
class SwinClassifier(nn.Module):
    def __init__(self, num_classes=2):
        super(SwinClassifier, self).__init__()
        self.model = timm.create_model('swin_tiny_patch4_window7_224', pretrained=True, num_classes=num_classes)
        # self.model.head = nn.Linear(self.model.head.in_features, num_classes)  # Adjust classifier

    def forward(self, x):
        return self.model(x)

# cuda check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model = SwinClassifier(num_classes=2).to(device)

cuda


In [11]:
#Defining loss function and optimiser
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True  # Allow loading truncated images

# Assume these variables are already defined:
# - train_dataset: your training dataset
# - test_dataset: your validation/test dataset
# - SwinClassifier: your model definition

# Set DataLoader parameters; using num_workers=0 for TPU stability
BATCH_SIZE = 32
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Training hyperparameters
EPOCHS = 10 #에포크 10
PATIENCE = 3

# Define your model, loss, optimizer, and scheduler
model = SwinClassifier(num_classes=2).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-6)

# Checkpoint file path
checkpoint_file = "checkpoint.pth"
# checkpoint_file = "/content/gdrive/MyDrive/best_swin_model.pth"
# Initialize or resume training variables
start_epoch = 0
if os.path.exists(checkpoint_file):
    print("Checkpoint found. Resuming training from checkpoint...")
    # Load checkpoint to CPU first, then move state to TPU
    checkpoint = torch.load(checkpoint_file, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_val_loss = checkpoint['best_val_loss']
    epochs_no_improve = checkpoint['epochs_no_improve']
    train_losses = checkpoint.get('train_losses', [])
    val_losses = checkpoint.get('val_losses', [])
    # Move the model to the TPU device
    model.to(device)
else:
    best_val_loss = float("inf")
    epochs_no_improve = 0
    train_losses, val_losses = [], []

for epoch in range(start_epoch, EPOCHS):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    iteration = 0

    # Training loop with a progress bar
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} Training", leave=False)
    for images, labels in train_pbar:
        print(labels)
        iteration += 1
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        # if outputs.dim() == 4:
        #     outputs = outputs.flatten(1)
        loss = criterion(outputs, labels.long())

        loss.backward()
        optimizer.step()
        # mark.step()  # Ensure TPU executes pending operations

        batch_size = images.size(0)
        print(f"Batch size: {batch_size}")
        train_loss += loss.item() * batch_size
        _, predicted = torch.max(outputs, 1)
        print(predicted)
        correct += (predicted == labels).sum().item()
        total += batch_size

        train_pbar.set_postfix({
            "Batch Loss": f"{loss.item():.4f}",
            "Avg Loss": f"{train_loss/total:.4f}",
            "Acc": f"{correct/total:.4f}"
        })

    epoch_train_loss = train_loss / total if total > 0 else 0
    train_losses.append(epoch_train_loss)
    train_acc = correct / total if total > 0 else 0

    model.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} Validation", leave=False)
    with torch.no_grad():
        for images, labels in val_pbar:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            # if outputs.dim() == 4:
            #     outputs = outputs.flatten(1)
            loss = criterion(outputs, labels.long())

            batch_size = images.size(0)
            val_loss += loss.item() * batch_size
            _, predicted = torch.max(outputs, 1)
            correct_val += (predicted == labels).sum().item()
            total_val += batch_size

            val_pbar.set_postfix({
                "Batch Loss": f"{loss.item():.4f}",
                "Avg Loss": f"{val_loss/total_val:.4f}",
                "Acc": f"{correct_val/total_val:.4f}"
            })

    epoch_val_loss = val_loss / total_val if total_val > 0 else 0
    val_losses.append(epoch_val_loss)
    val_acc = correct_val / total_val if total_val > 0 else 0

    print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {epoch_train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Save checkpoint after every epoch
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_val_loss': best_val_loss,
        'epochs_no_improve': epochs_no_improve,
        'train_losses': train_losses,
        'val_losses': val_losses,
    }
    torch.save(checkpoint, checkpoint_file)
    # Also save a separate model file for each epoch if desired
    torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")

    # Early stopping check
    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        epochs_no_improve = 0
        torch.save(model.state_dict(), "best_swin_model.pth")  # Save best model separately
    else:
        epochs_no_improve += 1
        if epochs_no_improve >= PATIENCE:
            print("Early stopping triggered. Training stopped.")
            break

    scheduler.step()
    # print(met.metrics_report())

print("Training complete!")

Checkpoint found. Resuming training from checkpoint...


                                                                                                                         

Epoch 11/20 - Train Loss: 0.6496, Train Acc: 0.6247 | Val Loss: 0.6320, Val Acc: 0.6384


                                                                                                                         

Epoch 12/20 - Train Loss: 0.6447, Train Acc: 0.6306 | Val Loss: 0.6316, Val Acc: 0.6383


                                                                                                                         

Epoch 13/20 - Train Loss: 0.6530, Train Acc: 0.6233 | Val Loss: 0.6364, Val Acc: 0.6378


                                                                                                                         

Epoch 14/20 - Train Loss: 0.6549, Train Acc: 0.6190 | Val Loss: 0.6364, Val Acc: 0.6366


                                                                                                                         

Epoch 15/20 - Train Loss: 0.6613, Train Acc: 0.6137 | Val Loss: 0.6345, Val Acc: 0.6401
Early stopping triggered. Training stopped.
Training complete!


In [13]:
torch.save(model.state_dict(), "swin_transformer_real_fake.pth")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import transforms
from PIL import ImageFile
from sklearn.metrics import f1_score, confusion_matrix
import numpy as np

# Assume that test_dataset and your model class SwinClassifier are already defined
BATCH_SIZE = 32
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# checkpoint_path = "swin_transformer_real_fake.pth"
checkpoint_path = "best_swin_model.pth"

model = SwinClassifier(num_classes=2)  # Create an instance of your model
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)
model.to(device)
model.eval()

# Define the loss function (if you want to compute test loss)
criterion = nn.CrossEntropyLoss()

# Run testing/inference
correct = 0
total = 0
test_loss = 0.0
all_preds = []
all_labels = []

test_pbar = tqdm(test_loader, desc="Testing", leave=False)
with torch.no_grad():
    for images, labels in test_pbar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        # Flatten outputs if needed (if output shape is [N, C, H, W])
        # if outputs.dim() == 4:
        #     outputs = outputs.flatten(1)
        loss = criterion(outputs, labels.long())

        batch_size = images.size(0)
        test_loss += loss.item() * batch_size
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += batch_size

        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        test_pbar.set_postfix({"Batch Loss": f"{loss.item():.4f}"})

avg_loss = test_loss / total if total > 0 else 0
accuracy = correct / total if total > 0 else 0

# Calculate F1 score and confusion matrix
f1 = f1_score(all_labels, all_preds, average='binary')  # or 'binary' for binary classification
conf_matrix = confusion_matrix(all_labels, all_preds)

print(f"Test Loss: {avg_loss:.4f}, Test Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}")
print("Confusion Matrix:")
print(conf_matrix)

                                                                             

Test Loss: 0.6305, Test Accuracy: 0.6406, F1 Score: 0.6241
Confusion Matrix:
[[3422 1578]
 [2016 2984]]


