In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import make_moons
from tqdm import tqdm, trange
import torch.distributions as dist

# Model

In [2]:
class CVAE(nn.Module):
    """Conditional Variational Autoencoder (CVAE) class.

    Args:
        x_dim (int): Dimensionality of the condition x.
        y_dim (int): Dimensionality of the input/output data y.
        hidden_dim (int): Dimensionality of the hidden layers.
        latent_dim (int): Dimensionality of the latent space.
    """
    def __init__(self, x_dim, y_dim, hidden_dim, latent_dim):
        super(CVAE, self).__init__()

        self.x_dim = x_dim
        self.y_dim = y_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim

        # 1.1 Modelo de Inferência (Encoder Variacional) q(z | x, y)
        self.inference_network = nn.Sequential(
            nn.Linear(x_dim + y_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim // 2, 2 * latent_dim)  # mu_q e log_sigma_q^2
        )

        # 1.2 Modelo da Priori Condicionada p(z | x)
        self.prior_network = nn.Sequential(
            nn.Linear(x_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim // 2, 2 * latent_dim)  # mu_p e log_sigma_p^2
        )

        # 1.3 Modelo Gerador (Decoder) p(y | x, z)
        self.generator_network = nn.Sequential(
            nn.Linear(x_dim + latent_dim, hidden_dim // 2),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim // 2, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, y_dim)  # mean de y
        )

    def forward(self, x, y):
        # Encoder q(z|x,y)
        xy = torch.cat([x, y], dim=1)
        params_q = self.inference_network(xy)
        mu_q = params_q[:, :self.latent_dim]
        log_sigma_sq_q = params_q[:, self.latent_dim:]
        q_dist = dist.Normal(mu_q, torch.exp(0.5 * log_sigma_sq_q))

        # Reparameterization
        z = q_dist.rsample()

        # Decoder p(y|x,z)
        xz = torch.cat([x, z], dim=1)
        y_recon_mean = self.generator_network(xz)

        return y_recon_mean, q_dist, z

    def compute_loss(self, x, y, kl_weight=1.0):
        y_recon_mean, q_dist, z = self.forward(x, y)

        # Reconstruction Loss (MSE, equivalente a Gaussian NLL com variância fixa)
        reconstruction_loss = nn.functional.mse_loss(y_recon_mean, y, reduction='sum')

        # KL Divergence
        # Prior p(z|x)
        params_p = self.prior_network(x)
        mu_p = params_p[:, :self.latent_dim]
        log_sigma_sq_p = params_p[:, self.latent_dim:]
        p_dist = dist.Normal(mu_p, torch.exp(0.5 * log_sigma_sq_p))

        kl_loss = dist.kl_divergence(q_dist, p_dist).sum()

        # Loss final
        loss = reconstruction_loss + kl_weight * kl_loss

        return loss, reconstruction_loss, kl_loss

    def generate(self, x):
        # Amostra z da priori p(z|x) e decodifica para gerar y
        params_p = self.prior_network(x)
        mu_p = params_p[:, :self.latent_dim]
        log_sigma_sq_p = params_p[:, self.latent_dim:]
        p_dist = dist.Normal(mu_p, torch.exp(0.5 * log_sigma_sq_p))
        z = p_dist.sample()
        xz = torch.cat([x, z], dim=1)
        y_gen = self.generator_network(xz)
        return y_gen

# Data

In [3]:
class Moons(Dataset):
    def __init__(self, num_samples: int, noise: float):
        super().__init__()
        self.num_samples = num_samples
        self._X, self._y = make_moons(n_samples=num_samples, noise=noise, random_state=123)
        self.X = torch.from_numpy(self._X).float()
        self.y_class = torch.from_numpy(self._y.reshape(-1, 1)).float()
    
    def __getitem__(self, index):
        # x é a primeira coordenada, y é a segunda
        return self.X[index, 0].unsqueeze(0), self.X[index, 1].unsqueeze(0), self.y_class[index]
    
    def __len__(self):
        return self.num_samples

# Training

In [4]:
torch.manual_seed(123)

batch_size = 150
learning_rate = 1e-3
num_epochs = 300
x_dim = 1
y_dim = 1
latent_dim = 2
hidden_dim = 256
kl_weight = 0.1 # Beta para o beta-VAE

train_data = Moons(3000, noise=0.15)
test_data = Moons(600, noise=0.15)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [5]:
def train(model, dataloader, optimizer, kl_weight):
    model.train()
    total_loss, total_recon_loss, total_kl_loss = 0.0, 0.0, 0.0
    for x, y, _ in dataloader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss, recon_loss, kl_loss = model.compute_loss(x, y, kl_weight)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
    return total_loss / len(dataloader.dataset), total_recon_loss / len(dataloader.dataset), total_kl_loss / len(dataloader.dataset)

def test(model, dataloader, kl_weight):
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for x, y, _ in dataloader:
            x, y = x.to(device), y.to(device)
            loss, _, _ = model.compute_loss(x, y, kl_weight)
            test_loss += loss.item()
    return test_loss / len(dataloader.dataset)

In [6]:
print("Training Conditional Variational Autoencoder...")
model_CVAE = CVAE(x_dim=x_dim, y_dim=y_dim, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
print(model_CVAE)

optimizer = torch.optim.Adam(model_CVAE.parameters(), lr=learning_rate)

train_loss_history_CVAE = []
test_loss_history_CVAE = []
for epoch in trange(1, num_epochs + 1, desc='Training', unit='epoch'):
    train_loss, recon_loss, kl_loss = train(model_CVAE, train_loader, optimizer, kl_weight)
    test_loss = test(model_CVAE, test_loader, kl_weight)
    train_loss_history_CVAE.append(train_loss)
    test_loss_history_CVAE.append(test_loss)
    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Train Loss: {train_loss:.4f}, Recon Loss: {recon_loss:.4f}, KL Loss: {kl_loss:.4f}, Test Loss: {test_loss:.4f}')

# Plots

In [7]:
plt.plot(range(1, len(train_loss_history_CVAE)+1), train_loss_history_CVAE, label="Train loss (epoch average)")
plt.plot(range(1, len(train_loss_history_CVAE)+1), test_loss_history_CVAE, label="Test loss")
plt.legend()
plt.xlabel("Epochs")
plt.title("Conditional Variational Autoencoder")
plt.show()

In [8]:
def plot_latent_space_cvae(model, dataloader):
    model.eval()
    z_all = []
    y_class_all = []
    with torch.no_grad():
        for x, y, y_class in tqdm(dataloader, desc='Encoding', unit='batch'):
            x, y = x.to(device), y.to(device)
            xy = torch.cat([x, y], dim=1)
            params_q = model.inference_network(xy)
            mu_q = params_q[:, :model.latent_dim]
            z_all.append(mu_q.cpu().numpy())
            y_class_all.append(y_class.numpy())
    z_all = np.concatenate(z_all, axis=0)
    y_class_all = np.concatenate(y_class_all, axis=0)
    plt.figure(figsize=(10, 10))
    plt.scatter(z_all[:, 0], z_all[:, 1], c=y_class_all, cmap='tab10', alpha=0.5)
    plt.title('Latent Space Projection (Mean)')
    plt.xlabel('z1')
    plt.ylabel('z2')
    plt.show()

plot_latent_space_cvae(model_CVAE, train_loader)

In [9]:
def plot_generated_samples_cvae(model, dataloader, num_samples=1000):
    model.eval()
    
    # Pega uma amostra de x do dataloader para usar como condição
    x_cond, _, _ = next(iter(dataloader))
    x_cond = x_cond.to(device)
    # Repete as condições para gerar várias amostras para cada x
    x_cond = x_cond.repeat(num_samples // x_cond.size(0) + 1, 1)[:num_samples]
    
    with torch.no_grad():
        y_gen = model.generate(x_cond)
    
    x_cond_np = x_cond.cpu().numpy()
    y_gen_np = y_gen.cpu().numpy()
    
    plt.figure(figsize=(10, 10))
    plt.scatter(x_cond_np, y_gen_np, alpha=0.5, label='Generated')
    
    # Plota os dados originais para comparação
    original_data = dataloader.dataset.X.numpy()
    plt.scatter(original_data[:, 0], original_data[:, 1], alpha=0.1, label='Original')
    
    plt.title('Generated Samples from CVAE')
    plt.xlabel('x (condition)')
    plt.ylabel('y (generated)')
    plt.legend()
    plt.show()

plot_generated_samples_cvae(model_CVAE, test_loader, num_samples=2000)