In [None]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import MultivariateNormal, kl_divergence
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.datasets import make_moons
from tqdm import tqdm, trange
import torch.distributions as dist

# Model

In [None]:
@dataclass
class CVAEOutput:
    """
    Dataclass for CVAE output.
    
    Attributes:
        recog_pass (bool): Whether is a recognition pass or a generation pass
        x (torch.Tensor): input x
        y (torch.Tensor): input y
        z_prior_dist (torch.distributions.Distribution): Prior distribution of the latent variable z.
        z_prior_sample (torch.Tensor): Sampled value of the prior of latent variable z.
        z_recog_dist (torch.distributions.Distribution): Recognition distribution of the latent variable z.
        z_recog_sample (torch.Tensor): Sampled value of the recognition distribution of latent variable z.
        y_recon (torch.Tensor): The reconstructed y.
        loss (torch.Tensor): The overall loss of the CVAE.
        loss_recon (torch.Tensor): The reconstruction loss component of the CVAE loss.
        loss_kl (torch.Tensor): The KL divergence component of the CVAE loss.
    """
    recog_pass: bool
    x: torch.Tensor
    y: torch.Tensor
    z_prior_dist: torch.distributions.Distribution
    z_prior_sample: torch.Tensor
    z_recog_dist: torch.distributions.Distribution
    z_recog_sample: torch.Tensor
    y_recon: torch.Tensor
    
    loss: torch.Tensor
    loss_recon: torch.Tensor
    loss_kl: torch.Tensor


class CVAE(nn.Module):
    """Conditional Variational Autoencoder (CVAE) class.

    Args:
        x_dim (int): Dimensionality of the condition x.
        y_dim (int): Dimensionality of the input/output data y.
        hidden_dim (int): Dimensionality of the hidden layers.
        latent_dim (int): Dimensionality of the latent space.
    """
    def __init__(
        self,
        x_dim,
        y_dim,
        latent_dim,
        prior_hidden_dim,
        gen_hidden_dim,
        recog_hidden_dim,
    ):
        super(CVAE, self).__init__()

        self.x_dim = x_dim
        self.y_dim = y_dim
        self.latent_dim = latent_dim
        self.prior_hidden_dim = prior_hidden_dim
        self.gen_hidden_dim = gen_hidden_dim
        self.recog_hidden_dim = recog_hidden_dim

        # Prior network: p(z|x)
        prior_layers_list = []
        if type(prior_hidden_dim) != list:
            prior_layers_list.append(nn.Linear(x_dim, prior_hidden_dim))
            prior_layers_list.append(nn.LeakyReLU())
            prior_layers_list.append(nn.Linear(prior_hidden_dim, 2 * latent_dim))
        else:
            self.hidden_dim = prior_hidden_dim.copy()
            prior_layers_list.append(nn.Linear(x_dim, prior_hidden_dim[0]))
            prior_layers_list.append(nn.LeakyReLU())
            for i in range(len(prior_hidden_dim[1:])):
                prior_layers_list.append(nn.Linear(prior_hidden_dim[i], prior_hidden_dim[i+1]))
                prior_layers_list.append(nn.LeakyReLU())
            prior_layers_list.append(nn.Linear(prior_hidden_dim[-1], 2 * latent_dim))
        
        self.prior_network = nn.Sequential(*prior_layers_list)

        # Generation network: p(y|x,z)
        gen_layers_list = []
        if type(gen_hidden_dim) != list:
            gen_layers_list.append(nn.Linear(x_dim + latent_dim, gen_hidden_dim))
            gen_layers_list.append(nn.LeakyReLU())
            gen_layers_list.append(nn.Linear(gen_hidden_dim, y_dim))
            gen_layers_list.append(nn.Sigmoid())
        else:
            self.gen_hidden_dim = gen_hidden_dim.copy()
            gen_layers_list.append(nn.Linear(x_dim + latent_dim, gen_hidden_dim[0]))
            gen_layers_list.append(nn.LeakyReLU())
            for i in range(len(gen_hidden_dim[1:])):
                gen_layers_list.append(nn.Linear(gen_hidden_dim[i], gen_hidden_dim[i+1]))
                gen_layers_list.append(nn.LeakyReLU())
            gen_layers_list.append(nn.Linear(gen_hidden_dim[-1], y_dim))
            gen_layers_list.append(nn.Sigmoid())
        
        self.gen_network = nn.Sequential(*gen_layers_list)

        # Recognition/Inference network: p(z|x,y)
        recog_layers_list = []
        if type(recog_hidden_dim) != list:
            recog_layers_list.append(nn.Linear(x_dim + y_dim, recog_hidden_dim))
            recog_layers_list.append(nn.LeakyReLU())
            recog_layers_list.append(nn.Linear(recog_hidden_dim, 2 * latent_dim))
        else:
            self.recog_hidden_dim = recog_hidden_dim.copy()
            recog_layers_list.append(nn.Linear(x_dim + y_dim, recog_hidden_dim[0]))
            recog_layers_list.append(nn.LeakyReLU())
            for i in range(len(recog_hidden_dim[1:])):
                recog_layers_list.append(nn.Linear(recog_hidden_dim[i], recog_hidden_dim[i+1]))
                recog_layers_list.append(nn.LeakyReLU())
            recog_layers_list.append(nn.Linear(recog_hidden_dim[-1], 2 * latent_dim))
        
        self.recog_network = nn.Sequential(*recog_layers_list)

    def recognize(self, x, y, eps: float = 1e-8):
        xy = torch.concatenate((x, y), dim=1)
        params = self.recog_network(xy)
        means = params[:, :self.latent_dim]
        covs = torch.diag_embed(torch.exp(params[:, self.latent_dim:]) + eps)
        return torch.distributions.MultivariateNormal(means, covariance_matrix=covs)

    def prior(self, x, eps: float = 1e-8):
        params = self.prior_network(x)
        means = params[:, :self.latent_dim]
        covs = torch.diag_embed(torch.exp(params[:, self.latent_dim:]) + eps)
        return torch.distributions.MultivariateNormal(means, covariance_matrix=covs)

    def forward(self, x, y, kl_weight: float = 1.0, compute_loss: bool = True):
        z_recog_dist = self.recognize(x, y)
        z = z_recog_dist.rsample()
        # Reconstruction of y
        xz = torch.concatenate((x, z), dim=1)
        recon_y = self.gen_network(xz)
        if not compute_loss:
            return CVAEOutput(
                recog_pass=True,
                x=x,
                y=y,
                z_recog_dist=z_recog_dist,
                z_recog_sample=z,
                y_recon=recon_y,
            )
        # Compute loss
        if y.shape[1] != 1:
            raise NotImplementedError('Just for binary classification')
        loss_recon = F.binary_cross_entropy(recon_y, y, reduction='none').sum(-1).mean()
        z_prior_dist = self.prior(x)
        loss_kl = kl_divergence(z_recog_dist, z_prior_dist).mean()
        loss = loss_recon + kl_weight * loss_kl
        return CVAEOutput(
            recog_pass=True,
            x=x,
            y=y,
            z_prior_dist=z_prior_dist,
            z_recog_dist=z_recog_dist,
            z_recog_sample=z,
            y_recon=recon_y,
            loss=loss,
            loss_recon=loss_recon,
            loss_kl=loss_kl,
        )

    def generate(self, x, sample: bool = False):
        z_prior_dist = self.prior(x)
        if sample:
            z = z_prior_dist.sample()
        else:
            z = z_prior_dist.mean
        xz = torch.concatenate((x, z), dim=1)
        y_gen = self.gen_network(xz)
        return y_gen


# Data

In [None]:
class Moons(Dataset):

    def __init__(self, num_samples: int, noise: float):
        super().__init__()
        self.num_samples = num_samples
        self._X, self._y = make_moons(n_samples=num_samples, noise=noise, random_state=42)
        self._y = self._y.reshape(-1, 1)
        self.X = torch.from_numpy(self._X).float()
        self.y = torch.from_numpy(self._y).float()
    
    def __getitem__(self, index):
        return self.X[index], self.y[index]
    
    def __len__(self):
        return self.num_samples

# Training

In [None]:
torch.manual_seed(123)

batch_size = 150
learning_rate = 1e-3
num_epochs = 300
x_dim = 2
y_dim = 1
latent_dim = 2
prior_hidden_dim = [128, 64]
gen_hidden_dim = [256, 128, 64]
recog_hidden_dim = [256, 128, 64]
kl_weight = 1.2

train_data = Moons(3000, noise=0.15)
test_data = Moons(600, noise=0.15)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
def train(model, dataloader, optimizer):
    """
    Trains the model on the given data.
    
    Args:
        model (nn.Module): The model to train.
        dataloader (torch.utils.data.DataLoader): The data loader.
        optimizer: The optimizer.
    """
    model.train()  # Set the model to training mode
    total_loss = 0.0
    for x, y in dataloader:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()  # Zero the gradients
        output = model(x, y, compute_loss=True)  # Forward pass
        loss = output.loss
        loss.backward()
        optimizer.step()  # Update the model parameters
        total_loss += loss.item()
    return total_loss / len(dataloader)


def test(model, dataloader):
    """
    Tests the model on the given data.
    
    Args:
        model (nn.Module): The model to test.
        dataloader (torch.utils.data.DataLoader): The data loader.
    """
    model.eval()  # Set the model to evaluation mode
    recon_loss = 0.0
    gen_loss = 0.0
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            output = model(x, y, compute_loss=True)  # Forward pass
            recon_loss += output.loss_recon.item()
            y_gen = model.generate(x)
            gen_loss += F.binary_cross_entropy(y_gen, y, reduction='none').sum(-1).mean().item()
    recon_loss /= len(dataloader)
    gen_loss /= len(dataloader)
    return recon_loss, gen_loss

In [None]:
# print("Training Conditional Variational Autoencoder...")
# model_CVAE = CVAE(x_dim=x_dim, y_dim=y_dim, hidden_dim=hidden_dim, latent_dim=latent_dim).to(device)
# print(model_CVAE)

# optimizer = torch.optim.Adam(model_CVAE.parameters(), lr=learning_rate)

# train_loss_history_CVAE = []
# test_loss_history_CVAE = []
# for epoch in trange(1, num_epochs + 1, desc='Training', unit='epoch'):
#     train_loss, recon_loss, kl_loss = train(model_CVAE, train_loader, optimizer, kl_weight)
#     recon_loss = test(model_CVAE, test_loader, kl_weight)
#     train_loss_history_CVAE.append(train_loss)
#     test_loss_history_CVAE.append(recon_loss)
#     if epoch % 20 == 0:
#         print(f'Epoch {epoch}, Train Loss: {train_loss:.4f}, Recon Loss: {recon_loss:.4f}, KL Loss: {kl_loss:.4f}, Test Loss: {recon_loss:.4f}')

# Plots

In [None]:
# plt.plot(range(1, len(train_loss_history_CVAE)+1), train_loss_history_CVAE, label="Train loss (epoch average)")
# plt.plot(range(1, len(train_loss_history_CVAE)+1), test_loss_history_CVAE, label="Test loss")
# plt.legend()
# plt.xlabel("Epochs")
# plt.title("Conditional Variational Autoencoder")
# plt.show()

In [None]:
# def plot_latent_space_cvae(model, dataloader):
#     model.eval()
#     z_all = []
#     y_class_all = []
#     with torch.no_grad():
#         for x, y, y_class in tqdm(dataloader, desc='Encoding', unit='batch'):
#             x, y = x.to(device), y.to(device)
#             xy = torch.cat([x, y], dim=1)
#             params_q = model.inference_network(xy)
#             mu_q = params_q[:, :model.latent_dim]
#             z_all.append(mu_q.cpu().numpy())
#             y_class_all.append(y_class.numpy())
#     z_all = np.concatenate(z_all, axis=0)
#     y_class_all = np.concatenate(y_class_all, axis=0)
#     plt.figure(figsize=(10, 10))
#     plt.scatter(z_all[:, 0], z_all[:, 1], c=y_class_all, cmap='tab10', alpha=0.5)
#     plt.title('Latent Space Projection (Mean)')
#     plt.xlabel('z1')
#     plt.ylabel('z2')
#     plt.show()

# plot_latent_space_cvae(model_CVAE, train_loader)

In [None]:
# def plot_generated_samples_cvae(model, dataloader, num_samples=1000):
#     model.eval()
    
#     # Pega uma amostra de x do dataloader para usar como condição
#     x_cond, _, _ = next(iter(dataloader))
#     x_cond = x_cond.to(device)
#     # Repete as condições para gerar várias amostras para cada x
#     x_cond = x_cond.repeat(num_samples // x_cond.size(0) + 1, 1)[:num_samples]
    
#     with torch.no_grad():
#         y_gen = model.generate(x_cond)
    
#     x_cond_np = x_cond.cpu().numpy()
#     y_gen_np = y_gen.cpu().numpy()
    
#     plt.figure(figsize=(10, 10))
#     plt.scatter(x_cond_np, y_gen_np, alpha=0.5, label='Generated')
    
#     # Plota os dados originais para comparação
#     original_data = dataloader.dataset.X.numpy()
#     plt.scatter(original_data[:, 0], original_data[:, 1], alpha=0.1, label='Original')
    
#     plt.title('Generated Samples from CVAE')
#     plt.xlabel('x (condition)')
#     plt.ylabel('y (generated)')
#     plt.legend()
#     plt.show()

# plot_generated_samples_cvae(model_CVAE, test_loader, num_samples=2000)