In [2]:
import os
from PIL import Image
import torch
import torchvision
from torchvision import datasets, transforms, models
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import KFold


## Procesamiento de datos

In [3]:
# Directorio de los datos
directory = "./Datos/real_and_fake_face"

# Preprocesar imágenes
for filename in os.listdir(directory):
    if filename.endswith((".jpg", ".png")):
        img_path = os.path.join(directory, filename)
        with Image.open(img_path) as img:
            img = img.convert('RGB')
            img = img.resize((256, 256))
            img.save(img_path)
            print("Converted: ", filename)

## Carga y transformaciones

In [4]:
# Aumento de datos y transformaciones
train_transform = transforms.Compose([
    transforms.RandomRotation(15),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Cargar datos
train_dataset = datasets.ImageFolder(directory, transform=train_transform)
val_dataset = datasets.ImageFolder(directory, transform=val_transform)

# Cargador de datos
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=True)

## Definición del modelo y entrenamiento

In [7]:
# Definir el modelo
resnet18_model = models.resnet18(pretrained=True)
resnet18_model.fc = nn.Sequential(
    nn.Linear(512, 256),
    nn.ELU(),
    nn.Dropout(0.5),
    nn.Linear(256, 1)
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet18_model = resnet18_model.to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(resnet18_model.parameters(), lr=0.0001)

# Número de divisiones para la validación cruzada
n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)

# Número de épocas
epochs = 2

for epoch in range(epochs):
    print(f"Epoch [{epoch + 1}/{epochs}]")

    # Validación cruzada
    for fold, (train_idx, val_idx) in enumerate(kf.split(train_dataset)):
        train_fold = torch.utils.data.Subset(train_dataset, train_idx)
        val_fold = torch.utils.data.Subset(train_dataset, val_idx)

        train_loader = torch.utils.data.DataLoader(train_fold, batch_size=128, shuffle=True, num_workers=4)
        val_loader = torch.utils.data.DataLoader(val_fold, batch_size=128, shuffle=True, num_workers=4)

        resnet18_model.train()
        running_loss = 0.0
        correct_predictions = 0
        total_samples = 0

        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = resnet18_model(inputs)
            loss = criterion(outputs.squeeze(), labels.float())
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            predictions = (outputs > 0.5).float()
            correct_predictions += (predictions == labels.unsqueeze(1)).sum().item()
            total_samples += labels.size(0)

        average_loss = running_loss / len(train_loader)
        accuracy = correct_predictions / total_samples
        print(f'Fold {fold + 1}/{n_splits}, Train Loss: {average_loss:.4f}, Train Accuracy: {accuracy:.2%}')

        resnet18_model.eval()
        correct_predictions_val = 0
        total_samples_val = 0

        with torch.no_grad():
            for inputs_val, labels_val in val_loader:
                inputs_val, labels_val = inputs_val.to(device), labels_val.to(device)

                outputs_val = resnet18_model(inputs_val)
                predictions_val = (outputs_val > 0.5).float()

                correct_predictions_val += (predictions_val == labels_val.unsqueeze(1)).sum().item()
                total_samples_val += labels_val.size(0)

        accuracy_val = correct_predictions_val / total_samples_val
        print(f'Fold {fold + 1}/{n_splits}, Validation Accuracy: {accuracy_val:.2%}')

Epoch [1/2]
Fold 1/5, Train Loss: 0.6872, Train Accuracy: 50.61%
Fold 1/5, Validation Accuracy: 51.34%
Fold 2/5, Train Loss: 0.6626, Train Accuracy: 56.64%
Fold 2/5, Validation Accuracy: 61.52%
Fold 3/5, Train Loss: 0.6255, Train Accuracy: 61.48%
Fold 3/5, Validation Accuracy: 58.58%
Fold 4/5, Train Loss: 0.6179, Train Accuracy: 62.65%
Fold 4/5, Validation Accuracy: 66.42%
Fold 5/5, Train Loss: 0.5890, Train Accuracy: 65.71%
Fold 5/5, Validation Accuracy: 61.27%
Epoch [2/2]
Fold 1/5, Train Loss: 0.5671, Train Accuracy: 68.32%
Fold 1/5, Validation Accuracy: 69.19%
Fold 2/5, Train Loss: 0.5543, Train Accuracy: 69.44%
Fold 2/5, Validation Accuracy: 63.97%
Fold 3/5, Train Loss: 0.5410, Train Accuracy: 70.36%
Fold 3/5, Validation Accuracy: 72.30%
Fold 4/5, Train Loss: 0.5063, Train Accuracy: 72.63%
Fold 4/5, Validation Accuracy: 75.49%
Fold 5/5, Train Loss: 0.4807, Train Accuracy: 76.42%
Fold 5/5, Validation Accuracy: 77.45%


## Evaluación del rendimiento

In [8]:
# Función para convertir las predicciones y etiquetas a "real" o "fake"
def get_label(value, threshold=0.5):
    return 'real' if value > threshold else 'fake'

# Validación del modelo después de todas las épocas
resnet18_model.eval()
correct_predictions_val = 0
total_samples_val = 0
predicted_labels_val = []

with torch.no_grad():
    for inputs_val, labels_val in val_loader:
        inputs_val, labels_val = inputs_val.to(device), labels_val.to(device)

        outputs_val = resnet18_model(inputs_val)
        predictions_val = (outputs_val > 0.5).float()

        correct_predictions_val += (predictions_val == labels_val.unsqueeze(1)).sum().item()
        total_samples_val += labels_val.size(0)

        predicted_labels_val.extend(predictions_val.cpu().numpy().tolist())  # Guardar las predicciones

# Calcular métricas al final de todas las épocas en el conjunto de validación
accuracy_val = correct_predictions_val / total_samples_val
print(f'Validation - After all epochs, Accuracy: {accuracy_val:.2%}')

# Obtener un lote de imágenes y etiquetas del conjunto de validación
validation_images, validation_labels = next(iter(val_loader))

# Mover el modelo a la GPU si está disponible
resnet18_model = resnet18_model.to(device)

# Obtener predicciones del modelo
with torch.no_grad():
    resnet18_model.eval()
    outputs = resnet18_model(validation_images.to(device))
    predictions = torch.sigmoid(outputs)

# Umbral para clasificación binaria
threshold = 0.5

# Contadores para aciertos e incorrectos
correct_count = 0
incorrect_count = 0

# Validación sin mostrar imágenes y nombres
sample_limit = 100
for i in range(min(sample_limit, len(validation_images))):
    predicted_label = get_label(predictions[i].item(), threshold)
    actual_label = get_label(validation_labels[i].item(), threshold)

    # Imprimir si la predicción fue correcta e incorrecta
    if predicted_label == actual_label:
        correct_count += 1
    else:
        incorrect_count += 1

    # Imprimir información sobre la muestra actual
    print(f"Sample {i + 1}: Predicted: {predicted_label}, Actual: {actual_label}")

# Imprimir el número total de muestras y resultados
print(f"\nTotal de muestras evaluadas: {min(sample_limit, len(validation_images))}")
print(f"Total de predicciones correctas: {correct_count}")
print(f"Total de predicciones incorrectas: {incorrect_count}")

Validation - After all epochs, Accuracy: 76.96%
Sample 1: Predicted: fake, Actual: real
Sample 2: Predicted: fake, Actual: fake
Sample 3: Predicted: real, Actual: fake
Sample 4: Predicted: fake, Actual: real
Sample 5: Predicted: real, Actual: fake
Sample 6: Predicted: fake, Actual: real
Sample 7: Predicted: fake, Actual: fake
Sample 8: Predicted: fake, Actual: fake
Sample 9: Predicted: fake, Actual: fake
Sample 10: Predicted: real, Actual: fake
Sample 11: Predicted: real, Actual: fake
Sample 12: Predicted: fake, Actual: real
Sample 13: Predicted: fake, Actual: fake
Sample 14: Predicted: fake, Actual: fake
Sample 15: Predicted: real, Actual: real
Sample 16: Predicted: real, Actual: real
Sample 17: Predicted: fake, Actual: fake
Sample 18: Predicted: real, Actual: real
Sample 19: Predicted: real, Actual: real
Sample 20: Predicted: fake, Actual: real
Sample 21: Predicted: real, Actual: real
Sample 22: Predicted: fake, Actual: real
Sample 23: Predicted: real, Actual: real
Sample 24: Predict