In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torchvision.utils as vutils
import os


# Flags y rutas
isGcloud = False  # Cambiar a True para usar Google Drive

# Definición de rutas
if isGcloud:
    from google.colab import drive
    # Montar Google Drive
    drive.mount('/content/drive')
    BASE_PATH = '/content/drive/MyDrive/ML2'  # Ruta base en Drive
else:
    BASE_PATH = '~/eric/BEGAN'  # Ruta base local
    DATASET_PATH = '~/datasets'

# Crear directorios necesarios
PATHS = {
    'data': os.path.join(DATASET_PATH, 'lfw'),
    'models': os.path.join(BASE_PATH, 'models'),
    'results': os.path.join(BASE_PATH, 'results')
}

for path in PATHS.values():
    if not os.path.exists(path):
        os.makedirs(path)

# Hyperparameters optimizados
latent_size = 128
hidden_size = 128  # Aumentado de 64 a 128
image_size = 64
lambda_k = 0.001
gamma = 0.7  # Aumentado de 0.5 a 0.7
batch_size = 32
num_epochs = 100

# Aumentar el tamaño de las capas ocultas
hidden_size = 256  # En lugar de 128

# Ajustar los hiperparámetros de equilibrio
gamma = 0.8  # Aumentado de 0.7
lambda_k = 0.0015  # Aumentado de 0.001

# Aumentar el tamaño del batch para mayor estabilidad
batch_size = 64  # En lugar de 32

class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(-1, *self.shape)

class Generator(nn.Module):
    def __init__(self, latent_size, hidden_size):
        super(Generator, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(latent_size, 8 * 8 * hidden_size),
            nn.BatchNorm1d(8 * 8 * hidden_size),
            nn.ReLU(True),
            Reshape((hidden_size, 8, 8)),
            nn.ConvTranspose2d(hidden_size, hidden_size, 3, 2, 1, 1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size, hidden_size, 3, 2, 1, 1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size, hidden_size, 3, 2, 1, 1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            nn.Conv2d(hidden_size, 3, 3, 1, 1),
            nn.Tanh()
        )

# class Generator(nn.Module):
#     def __init__(self, latent_size, hidden_size):
#         super(Generator, self).__init__()
#         self.decoder = nn.Sequential(
#             nn.Linear(latent_size, 8 * 8 * hidden_size),
#             nn.BatchNorm1d(8 * 8 * hidden_size),
#             nn.LeakyReLU(0.2, True),  # Cambiar ReLU por LeakyReLU
#             Reshape((hidden_size, 8, 8)),
            
#             # Añadir más capas convolucionales
#             nn.ConvTranspose2d(hidden_size, hidden_size, 3, 2, 1, 1),
#             nn.BatchNorm2d(hidden_size),
#             nn.LeakyReLU(0.2, True),
#             nn.Conv2d(hidden_size, hidden_size, 3, 1, 1),  # Capa adicional
#             nn.BatchNorm2d(hidden_size),
#             nn.LeakyReLU(0.2, True),
            
#             nn.ConvTranspose2d(hidden_size, hidden_size//2, 3, 2, 1, 1),
#             nn.BatchNorm2d(hidden_size//2),
#             nn.LeakyReLU(0.2, True),
#             nn.Conv2d(hidden_size//2, hidden_size//2, 3, 1, 1),  # Capa adicional
#             nn.BatchNorm2d(hidden_size//2),
#             nn.LeakyReLU(0.2, True),
            
#             nn.ConvTranspose2d(hidden_size//2, hidden_size//4, 3, 2, 1, 1),
#             nn.BatchNorm2d(hidden_size//4),
#             nn.LeakyReLU(0.2, True),
            
#             nn.Conv2d(hidden_size//4, 3, 3, 1, 1),
#             nn.Tanh()
#         )

    def forward(self, z):
        return self.decoder(z)

class Discriminator(nn.Module):
    def __init__(self, hidden_size):
        super(Discriminator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, hidden_size, 3, 2, 1),
            nn.ReLU(True),
            nn.Conv2d(hidden_size, hidden_size * 2, 3, 2, 1),
            nn.BatchNorm2d(hidden_size * 2),
            nn.ReLU(True),
            nn.Conv2d(hidden_size * 2, hidden_size * 4, 3, 2, 1),
            nn.BatchNorm2d(hidden_size * 4),
            nn.ReLU(True),
            nn.Conv2d(hidden_size * 4, hidden_size * 8, 3, 2, 1),
            nn.BatchNorm2d(hidden_size * 8),
            nn.ReLU(True)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(hidden_size * 8, hidden_size * 4, 3, 2, 1, 1),
            nn.BatchNorm2d(hidden_size * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size * 4, hidden_size * 2, 3, 2, 1, 1),
            nn.BatchNorm2d(hidden_size * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size * 2, hidden_size, 3, 2, 1, 1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(True),
            nn.ConvTranspose2d(hidden_size, 3, 3, 2, 1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

def save_checkpoint(state, filename):
    torch.save(state, filename)
    if isGcloud:
        print(f"Checkpoint guardado en Google Drive: {filename}")
    else:
        print(f"Checkpoint guardado localmente: {filename}")

def load_checkpoint(filename, generator, discriminator, g_optimizer, d_optimizer):
    if os.path.exists(filename):
        checkpoint = torch.load(filename)
        generator.load_state_dict(checkpoint['generator_state_dict'])
        discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        g_optimizer.load_state_dict(checkpoint['g_optimizer_state_dict'])
        d_optimizer.load_state_dict(checkpoint['d_optimizer_state_dict'])
        return checkpoint['epoch'], checkpoint['k'], checkpoint['M_global']
    return 0, 0.5, None  # Inicializar k en 0.5 en lugar de 0.0

def train_began(latent_size, hidden_size, image_size, lambda_k, gamma, batch_size, num_epochs):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    dataset = torchvision.datasets.LFWPeople(
        root=PATHS['data'],
        split='train',
        transform=transform,
        download=True
    )
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator = Generator(latent_size, hidden_size).to(device)
    discriminator = Discriminator(hidden_size).to(device)

    g_optimizer = optim.Adam(generator.parameters(), lr=0.0001, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0001, betas=(0.5, 0.999))

    checkpoint_path = os.path.join(PATHS['models'], 'began_checkpoint.pth')
    start_epoch, k, M_global = load_checkpoint(
        checkpoint_path, generator, discriminator, g_optimizer, d_optimizer
    )

    if k == 0.5:  # Si no se cargó checkpoint
        M_global = None

    for epoch in range(start_epoch, num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.to(device)

            # Train discriminator
            d_optimizer.zero_grad()

            d_real = discriminator(real_images)
            d_real_loss = torch.mean(torch.abs(real_images - d_real))

            z = torch.randn(batch_size, latent_size).to(device)
            fake_images = generator(z)
            d_fake = discriminator(fake_images.detach())
            g_loss = torch.mean(torch.abs(fake_images.detach() - d_fake))

            d_loss = d_real_loss - k * g_loss  # Modificado: usando g_loss en lugar de d_fake_loss
            d_loss.backward()
            d_optimizer.step()

            # Train generator
            g_optimizer.zero_grad()
            d_fake = discriminator(fake_images)
            g_loss = torch.mean(torch.abs(fake_images - d_fake))
            g_loss.backward()
            g_optimizer.step()

            # Update k - modificado
            balance = (gamma * d_real_loss - g_loss).item()
            k = k + lambda_k * balance
            k = min(max(k, 0), 1)

            # Update convergence measure
            M_global = d_real_loss.item() + abs(balance)

            if i % 100 == 0:  # Aumentada frecuencia de visualización
                print(f'Epoch [{epoch}/{num_epochs}] Step [{i}/{len(dataloader)}] '
                      f'd_loss: {d_loss.item():.4f} g_loss: {g_loss.item():.4f} '
                      f'M: {M_global:.4f} k: {k:.4f}')

                results_path = os.path.join(PATHS['results'], f'fake_images_epoch_{epoch}_step_{i}.png')
                vutils.save_image(fake_images.data[:16], results_path, normalize=True, nrow=4)

                fig, ax = plt.subplots(3, 4, figsize=(12, 8))
                for j in range(4):
                    real_img = (real_images[j].cpu().detach().permute(1,2,0) + 1) / 2
                    recon_img = (d_real[j].cpu().detach().permute(1,2,0) + 1) / 2
                    fake_img = (fake_images[j].cpu().detach().permute(1,2,0) + 1) / 2

                    real_img = torch.clamp(real_img, 0, 1)
                    recon_img = torch.clamp(recon_img, 0, 1)
                    fake_img = torch.clamp(fake_img, 0, 1)

                    ax[0,j].imshow(real_img)
                    ax[0,j].axis('off')
                    if j == 0: ax[0,j].set_title('Real')

                    ax[1,j].imshow(recon_img)
                    ax[1,j].axis('off')
                    if j == 0: ax[1,j].set_title('Reconstructed')

                    ax[2,j].imshow(fake_img)
                    ax[2,j].axis('off')
                    if j == 0: ax[2,j].set_title('Generated')

                plt.tight_layout()
                comparison_path = os.path.join(PATHS['results'], f'comparison_epoch_{epoch}_step_{i}.png')
                plt.savefig(comparison_path)
                plt.close()

        checkpoint = {
            'epoch': epoch + 1,
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'g_optimizer_state_dict': g_optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
            'k': k,
            'M_global': M_global
        }
        save_checkpoint(checkpoint, checkpoint_path)
