In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, models, datasets
from PIL import Image
import matplotlib.pyplot as plt

def target_transform_fixed(label):
    return label

data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

def make_dataset(root, transform=None, target_transform=None):
    images = []
    labels = []

    for filename in os.listdir(root):
        path = os.path.join(root, filename)
        images.append(path)
        
        if 'easy' in filename:
            labels.append(0)  # Falso
        elif 'real' in filename:
            labels.append(1)  # Real
        else:
            labels.append(-1)  # Etiqueta desconocida o sin etiqueta

    return list(zip(images, labels))

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_list, transform=None, target_transform=None):
        self.dataset_list = dataset_list
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.dataset_list)

    def __getitem__(self, idx):
        img_path, label = self.dataset_list[idx]
        img = Image.open(img_path)

        if self.transform:
            img = self.transform(img)

        img = transforms.ToTensor()(img)

        if self.target_transform:
            label = self.target_transform(label)

        return img, label


# Redefine el modelo para utilizar un modelo preentrenado
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        resnet18 = models.resnet18(pretrained=True)
        self.features = nn.Sequential(*list(resnet18.children())[:-1])  # Elimina la última capa completamente conectada
        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x
        return x

# Definir las rutas de los conjuntos de prueba
train_dataset_fake = CustomDataset(
    make_dataset(r'./Datos/real_and_fake_face/training_fake'),
    target_transform=target_transform_fixed,
)

train_dataset_real = CustomDataset(
    make_dataset(r'./Datos/real_and_fake_face/training_real'),
    target_transform=target_transform_fixed,
)

train_loader_fake = DataLoader(train_dataset_fake, batch_size=8, shuffle=True, num_workers=0)
train_loader_real = DataLoader(train_dataset_real, batch_size=8, shuffle=True, num_workers=0)
# Definir el dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Crear el modelo y enviarlo al dispositivo
model = CNN().to(device)

# Definir el criterio y el optimizador
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Número de épocas
num_epochs = 5

# Bucle de entrenamiento
for epoch in range(num_epochs):
    model.train()
    for inputs_fake, _ in train_loader_fake:
        inputs_fake, labels_fake = inputs_fake.to(device), torch.zeros(inputs_fake.size(0), 1).to(device)

        optimizer.zero_grad()
        outputs_fake = model(inputs_fake)
        loss_fake = criterion(outputs_fake, labels_fake)
        loss_fake.backward()
        optimizer.step()

    for inputs_real, _ in train_loader_real:
        inputs_real, labels_real = inputs_real.to(device), torch.ones(inputs_real.size(0), 1).to(device)

        optimizer.zero_grad()
        outputs_real = model(inputs_real)
        loss_real = criterion(outputs_real, labels_real)
        loss_real.backward()
        optimizer.step()

    # Evaluar el modelo en el conjunto de prueba para calcular el accuracy
    model.eval()
    correct_fake, total_fake = 0, 0
    correct_real, total_real = 0, 0

    with torch.no_grad():
        for inputs_fake, _ in test_loader_fake:
            inputs_fake, labels_fake = inputs_fake.to(device), torch.zeros(inputs_fake.size(0), 1).to(device)
            outputs_fake = model(inputs_fake)
            predictions_fake = (outputs_fake > 0.5).float()
            correct_fake += (predictions_fake == labels_fake).sum().item()
            total_fake += labels_fake.size(0)

        for inputs_real, _ in test_loader_real:
            inputs_real, labels_real = inputs_real.to(device), torch.ones(inputs_real.size(0), 1).to(device)
            outputs_real = model(inputs_real)
            predictions_real = (outputs_real > 0.5).float()
            correct_real += (predictions_real == labels_real).sum().item()
            total_real += labels_real.size(0)

    accuracy_fake = correct_fake / total_fake if total_fake > 0 else 0
    accuracy_real = correct_real / total_real if total_real > 0 else 0

    print(f'Epoch {epoch + 1}/{num_epochs}, Loss Fake: {loss_fake.item()}, Loss Real: {loss_real.item()}')
    print(f'Accuracy Fake: {accuracy_fake}, Accuracy Real: {accuracy_real}')

    # Vuelve a poner el modelo en modo de entrenamiento
    model.train()
