In [None]:
import torch
import torch.nn as nn
from torch.distributions import kl_divergence, Normal
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import make_moons, make_circles
from sklearn.preprocessing import StandardScaler

from scipy.stats import multivariate_normal, gaussian_kde


In [None]:
torch.manual_seed(1)
np.random.seed(1)

In [None]:
def make_connected_dataset():
    N = 1000
    
    # Latente gaussiano
    z1 = np.random.randn(N)
    z2 = np.random.randn(N)
    
    # Transformaci칩n no lineal (banana)
    x1 = z1**2
    x2 = z2 + 0.3 * (z1 ** 4)
    
    X = np.column_stack([x1, x2])
    return X

## Dataset 

Let's generate data from $p(x)$ and also from the latent space $p(z)$. Here we use same dimensionality although we are not required to.

In [None]:
# Fijar semilla para reproducibilidad
np.random.seed(42)

# Crear figura con dos subplots (1 fila, 2 columnas)
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# =========================
# == Prior p(z) = N(0,I) ==
# =========================

# Generar muestras 2D de una normal est치ndar
mean = [0, 0]
cov = [[1, 0], [0, 1]]
samples = np.random.multivariate_normal(mean, cov, 500)

# Scatter de las muestras
axes[0].scatter(samples[:, 0], samples[:, 1], color = 'C1')

# Crear grid para curvas de nivel
x = np.linspace(-4, 4, 200)
y = np.linspace(-4, 4, 200)
X, Y = np.meshgrid(x, y)
pos = np.dstack((X, Y))

rv = multivariate_normal(mean, cov)
Z = rv.pdf(pos)

# Dibujar curvas de nivel
axes[0].contour(X, Y, Z, levels=10, cmap = 'Oranges')

axes[0].set_title(r"$p({\bf z})$")
axes[0].set_xlabel(r"x_1")
axes[0].set_ylabel(r"x_2")

# =========================
# ======== p(x) ===========
# =========================
X_moons, y_moons = make_moons(n_samples=1000, noise=0.1, random_state=42)
X_moons, y_moons = make_circles(n_samples=1000, noise=0.05, factor=0.5, random_state=42)
X_moons = make_connected_dataset()

axes[1].scatter(X_moons[:, 0], X_moons[:, 1], color = 'C0')
axes[1].set_title(r"$p({\bf x})$")
axes[1].set_xlabel(r"x_1")
axes[1].set_ylabel(r"x_2")

plt.tight_layout()
plt.show()


#### Create Torch dataset and dataloader

In [None]:
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_moons)

X = torch.tensor(X_scaled, dtype=torch.float32)
dataset = TensorDataset(X)
train_loader = DataLoader(dataset, batch_size=1000, shuffle=True)
sampling_loader = DataLoader(dataset, batch_size=20, shuffle=True)

## Train a Variational Autoencoder

In [None]:
class VAE(nn.Module):
    def __init__(self, encoder_layers, decoder_layers, latent_dim, decoder_type, N):
        """
        encoder_layers: list og tuples (in_dim, out_dim, activation)
        decoder_layers: list of tuples (in_dim, out_dim, activation)
        activation is nn.Module class or None to specify a linear activation
        decoder_type: to specify the observation model p(x|z). For the moment only Gaussian
        is considered
        """
        super().__init__()
 
        self.latent_dim = latent_dim
        self.decoder_type = decoder_type
        # self.im_shape = im_shape
        self.N = N

        # standard normal prior
        self.p_z = Normal(torch.zeros(latent_dim),torch.ones(latent_dim))

        # for sampling
        self.zero_tensor = torch.tensor(0.0)

        # Definir encoder y decoder
        self.encoder = self.build_encoder(encoder_layers)
        self.decoder = self.build_decoder(decoder_layers)

    def build_encoder(self, layers_config):
        layers = []
        for in_dim, out_dim, activation in layers_config[:-1]:
            layers.append(nn.Linear(in_dim, out_dim))
            if activation is not None:
                layers.append(activation())
                
        self.enc = nn.Sequential(*layers)
        in_dim, out_dim, activation = layers_config[-1]
        self.enc_mean = nn.Sequential(*[nn.Linear(in_dim, out_dim)])
        self.enc_logvar = nn.Sequential(*[nn.Linear(in_dim, out_dim),nn.Tanh()])
    
    def build_decoder(self, layers_config):
    
        layers = []
        for in_dim, out_dim, activation in layers_config[:-1]:
            layers.append(nn.Linear(in_dim, out_dim))
            if activation is not None:
                layers.append(activation())
    
        self.dec = nn.Sequential(*layers)
        in_dim, out_dim, activation = layers_config[-1]
        self.dec_mean = nn.Linear(in_dim, out_dim)
        self.dec_logvar = nn.Sequential(*[nn.Linear(in_dim, out_dim),nn.Tanh()])

    def encoder_forward(self, x):
        z = self.enc(x)
        return self.enc_mean(z), 5*self.enc_logvar(z)

    def decoder_forward(self, z):
        x = self.dec(z)
        return self.dec_mean(x), 5*self.dec_logvar(x)

    def sample_gaussian(self, mean, logvar, return_mean=False):
        if return_mean:
            return mean
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def ELBO(self, x, mc_samples=1, kld_scale=1.0):
        M = x.size(0)

        # Encoder
        q_mean, q_logvar = self.encoder_forward(x)

        # Encoder distribution
        q_zx = Normal(q_mean, torch.exp(0.5*q_logvar))

        # DKL
        KLD = kld_scale * kl_divergence(q_zx,self.p_z)

        # sum over each training point
        KLD = KLD.sum()
        
        # Log Likelihood using Monte Carlo. We could vectorize mc_sampling but for academic
        # purposes this is better. Also vectorizing requires taking care of memory.
        LLH = torch.tensor(0.0)
        for _ in range(mc_samples):
            z = self.sample_gaussian(q_mean, q_logvar)
            dec_mean, dec_logvar = self.decoder_forward(z)

            # Sum over training points and over dimensions
            LLH += Normal(dec_mean, torch.exp(0.5*dec_logvar)).log_prob(x).sum()

        ## Monte Carlo Estimation
        LLH /= mc_samples

        ## ELBO
        ELBO = LLH - KLD

        ## Renormalize minibatching for appropidate scale
        ELBO *= self.N / M

        return ELBO, LLH, KLD

    def sample_from_prior(self, n_samples, return_mean=False):
        """Sample from prior via ancestral sampling"""
        z = self.sample_gaussian(self.zero_tensor.expand(n_samples, self.latent_dim),
                               self.zero_tensor.expand(n_samples, self.latent_dim),
                               return_mean=False)
        
        mean, logvar = self.decoder_forward(z)
        return z, self.sample_gaussian(mean, logvar, return_mean=return_mean)

    def sample_from_posterior(self, x, return_mean=False):
        """Sample from posterior distribution q(z|x)."""
        mean, logvar = self.encoder_forward(x)
        z = self.sample_gaussian(mean, logvar, return_mean=False)
        dec_mean, dec_logvar = self.decoder_forward(z)
        return self.sample_gaussian(dec_mean, dec_logvar, return_mean=return_mean)

    def run_mcmc(self, x, num_steps, n_chains=1, return_mean=False):
        x_chain = []
    
        batch_size, x_dim = x.shape
        # expand x to [batch_size, n_chains, x_dim] then flatten to [batch_size*n_chains, x_dim]
        x_t = x.unsqueeze(1).repeat(1, n_chains, 1).view(batch_size * n_chains, x_dim)
        x_chain.append(x_t)
        z_chain = []
    
        for _ in range(num_steps):
            # z_t ~ q(z|x_t)
            q_mean, q_logvar = self.encoder_forward(x_t)
            z_t = self.sample_gaussian(q_mean, q_logvar)
    
            # x_{t+1} ~ p(x|z_t)
            dec_mean, dec_logvar = self.decoder_forward(z_t)
            x_t = self.sample_gaussian(dec_mean, dec_logvar, return_mean=return_mean)
    
            x_chain.append(x_t)
            z_chain.append(z_t)
    
        # stack and reshape to [num_steps, batch_size, n_chains, z_dim] / [num_steps+1, batch_size, n_chains, x_dim]
        z_chain = torch.stack(z_chain, dim=0).view(num_steps, batch_size, n_chains, -1)
        x_chain = torch.stack(x_chain, dim=0).view(num_steps+1, batch_size, n_chains, -1)

        # reshape to [batch_size, n_chains, num_steps, dim] without permute
        z_chain = z_chain.transpose(0,1).contiguous().view(batch_size, n_chains, num_steps, -1)
        x_chain = x_chain.transpose(0,1).contiguous().view(batch_size, n_chains, num_steps+1, -1)

        return z_chain, x_chain
    

In [None]:
data_dim = 2
latent_dim = 2
hidden_dim = 64

encoder_layers = [
    (data_dim, hidden_dim, nn.Tanh),
    (hidden_dim, hidden_dim, nn.Tanh),
    (hidden_dim, latent_dim, None)
]
decoder_layers = [
    (latent_dim, hidden_dim, nn.Tanh),
    (hidden_dim, hidden_dim, nn.Tanh),
    (hidden_dim, data_dim, None)
]



In [None]:
vae = VAE(
            encoder_layers,
            decoder_layers,
            latent_dim = latent_dim,
            decoder_type = "Gaussian",
            N = X_moons.shape[0],
)

optimizer = optim.Adam(vae.parameters(), lr=1e-3)

num_epochs = 2000

for epoch in range(num_epochs):
    vae.train()
    total_elbo = 0
    total_kld = 0.0
    total_ell = 0.0
    for x_tr, in train_loader:

        optimizer.zero_grad()
        elbo, ell, kld = vae.ELBO(x_tr, mc_samples = 1, kld_scale = 1)
        loss = -elbo 
        loss.backward()
        optimizer.step()

        total_elbo += elbo.item() 
        total_ell += ell.item()
        total_kld += kld.item()

    if epoch == 200: # annealing
        for param_group in optimizer.param_groups:
            param_group['lr'] = 1e-3

    if epoch == 1200: # annealing
        for param_group in optimizer.param_groups:
            param_group['lr'] = 1e-4
        
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1}, ELBO: {total_elbo:.4f}, ELL:{total_ell:.4f}, KLD: {total_kld:.4}")

### Generating Samples from the prior p(z)

In [None]:
%matplotlib tk
try:
    plt.close("all")
except:
    pass
    
# Fijar semilla para reproducibilidad
np.random.seed(42)

# Crear figura con dos filas y dos columnas
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# =========================
# == Prior p(z) = N(0,I) ==
# =========================

# Generar muestras 2D de una normal est치ndar
mean = [0, 0]
cov = [[1, 0], [0, 1]]
samples = np.random.multivariate_normal(mean, cov, 500)

# Scatter de las muestras
axes[0, 0].scatter(samples[:, 0], samples[:, 1], color='C1')

# Crear grid para curvas de nivel
x = np.linspace(-4, 4, 200)
y = np.linspace(-4, 4, 200)
X, Y = np.meshgrid(x, y)
pos = np.dstack((X, Y))

rv = multivariate_normal(mean, cov)
Z = rv.pdf(pos)

# Dibujar curvas de nivel
axes[0, 0].contour(X, Y, Z, levels=10, cmap='Oranges')

axes[0, 0].set_title(r"$p({\bf z})$")
axes[0, 0].set_xlabel(r"$x_1$")
axes[0, 0].set_ylabel(r"$x_2$")

# =========================
# ======== p(x) ===========
# =========================
axes[0, 1].scatter(X_moons[:, 0], X_moons[:, 1], color='C0')
axes[0, 1].set_title(r"$p({\bf x})$")
axes[0, 1].set_xlabel(r"$x_1$")
axes[0, 1].set_ylabel(r"$x_2$")

# ================================================
# == Sample from p(x,z) through ancestral sampling
# z ~ p(z)
# x ~ p(x|z)
# ================================================
torch.manual_seed(10)
n_samples = 500
vae.eval()
with torch.no_grad():
    z, x_z = vae.sample_from_prior(n_samples)

    # come back to original scale.
    x_z = scaler.inverse_transform(x_z.numpy())


axes[1, 0].plot(z[:, 0], z[:, 1],'x', color='C1', alpha = 0.1)        
axes[1, 0].contour(X, Y, Z, levels=10, cmap='Oranges')
axes[1, 0].set_title(r"Sample from $p({\bf z})$")
axes[1, 0].set_xlabel(r"$z_1$")
axes[1, 0].set_ylabel(r"$z_2$")


axes[1, 1].plot(x_z[:, 0], x_z[:, 1],'x', color='C0', alpha = 0.1)
axes[1, 1].set_title(r"Sample from $p({\bf x}|{\bf z})$")
axes[1, 1].set_xlabel(r"$x_1$")
axes[1, 1].set_ylabel(r"$x_2$")

for s in range(10):
    plot_z, = axes[1, 0].plot(z[s, 0], z[s, 1], 'o', color='C1', markersize = 10)    
    
    plt.pause(0.5)
    
    plot_x_z, = axes[1, 1].plot(x_z[s, 0], x_z[s, 1],'o', color='C0', markersize = 10)

    plt.pause(0.5)

    plot_z.remove()
    plot_x_z.remove()

%matplotlib inline


In [None]:
try:
    plt.close("all")
except:
    pass

### Generating Samples from the model using approximate MCMC

In [None]:
%matplotlib tk
try:
    plt.close("all")
except:
    pass
    
# Fijar semilla para reproducibilidad
np.random.seed(42)

# Crear figura con dos filas y dos columnas
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# =========================
# == Prior p(z) = N(0,I) ==
# =========================

# Generar muestras 2D de una normal est치ndar
mean = [0, 0]
cov = [[1, 0], [0, 1]]
samples = np.random.multivariate_normal(mean, cov, 500)

# Scatter de las muestras
axes[0, 0].plot(samples[:, 0], samples[:, 1],'x', color='C1', alpha = 0.1)

# Crear grid para curvas de nivel
x = np.linspace(-4, 4, 200)
y = np.linspace(-4, 4, 200)
X, Y = np.meshgrid(x, y)
pos = np.dstack((X, Y))

rv = multivariate_normal(mean, cov)
Z = rv.pdf(pos)

# Dibujar curvas de nivel
axes[0, 0].contour(X, Y, Z, levels=10, cmap='Oranges')

axes[0, 0].set_title(r"$p({\bf z})$")
axes[0, 0].set_xlabel(r"$z_1$")
axes[0, 0].set_ylabel(r"$z_2$")

# =========================
# ======== p(x) ===========
# =========================
axes[0, 1].plot(X_moons[:, 0], X_moons[:, 1],'x', color='C0', alpha = 0.1)
axes[0, 1].set_title(r"$p({\bf x})$")
axes[0, 1].set_xlabel(r"$x_1$")
axes[0, 1].set_ylabel(r"$x_2$")

# ================================================
# == Sample from p(x,z) through ancestral sampling
# z ~ p(z)
# x ~ p(x|z)
# ================================================
torch.manual_seed(1)
num_steps = 20
n_chains = 3
vae.eval()
with torch.no_grad():
    for x, in sampling_loader:
        # add the point(10,30), normalized, which looks nice
        x_n = torch.tensor(scaler.transform(np.array([[10.,30.]])), dtype = torch.float)
        x = torch.cat([x_n, x], dim=0)
        
        z_x, x_z = vae.run_mcmc(
            x = x, 
            num_steps = num_steps,
            n_chains = n_chains,
            return_mean = True
        )
    # print(z_x.shape) #  [batch_size, n_chains, num_steps, dim] 
    # print(x_z.shape)

## Estimate a Kernel density estimator on q(z|x) to grab idea on what is the posterior
#  looking. Get final half of the chain.
kde = gaussian_kde(z_x[:,:,int(num_steps/2):,:].contiguous().view(-1,data_dim).numpy().T)

# come back to original scale.
x_z_flat = x_z.view(-1, x_z.shape[-1])
x_z_orig = scaler.inverse_transform(x_z_flat)
x_z = torch.tensor(x_z_orig).view(x_z.shape)

## Plot kde level curves
Z = kde(pos.reshape(-1, 2).T).reshape(X.shape)
axes[1,0].contour(X, Y, Z, levels=20,  cmap='Oranges')

## plot all. Pick half of the chain
z_x_plot = z_x[:,:,int(num_steps/2):,:].contiguous().view(-1,data_dim).numpy()
x_z_plot = x_z[:,:,int(num_steps/2):,:].contiguous().view(-1,data_dim).numpy()

axes[1, 0].plot(z_x_plot[:, 0], z_x_plot[:, 1],'x', color='C1', alpha = 0.2)  
axes[1, 0].set_xlim([-4,4])
axes[1, 0].set_ylim([-4,4])
axes[1, 0].set_title(r"Sample from $q({\bf z} \mid {\bf x})$")
axes[1, 0].set_xlabel(r"$z_1$")
axes[1, 0].set_ylabel(r"$z_2$")

axes[1, 1].plot(x_z_plot[:, 0], x_z_plot[:, 1],'x', color='C0', alpha = 0.2)
axes[1, 1].set_title(r"Sample from $p({\bf x}|{\bf z})$")
axes[1, 1].set_xlabel(r"$x_1$")
axes[1, 1].set_ylabel(r"$x_2$")


# =========================
# Animate MCMC trajectories 
# =========================
batch_size, n_chains, num_steps, data_dim = x_z.shape

# Initialize trajectory lists and plot handles
plot_handles = [[None for _ in range(n_chains)] for _ in range(batch_size)]

# Loop over MCMC steps
for b in range(batch_size):
    for c in range(n_chains):
        traj_x_x = []
        traj_x_y = []
        traj_z_x = []
        traj_z_y = []
        plot_handle_x = None

        # initial samples where the chain is started
        init_x_x = x_z[b, c, 0, 0].item() # traj_x_x.append()
        init_x_y = x_z[b, c, 0, 1].item() #traj_x_y.append()

        # highlight the initial samples of the chain
        init_chain, = axes[0,1].plot(init_x_x, init_x_y, 'x', color = 'black', markersize=8)

        traj_x_x.append(init_x_x)
        traj_x_y.append(init_x_y)
        
        for t in range(num_steps-1):
            # append current point to trajectory
            traj_x_x.append(x_z[b, c, t+1, 0].item())
            traj_x_y.append(x_z[b, c, t+1, 1].item())

            traj_z_x.append(z_x[b, c, t, 0].item())
            traj_z_y.append(z_x[b, c, t, 1].item())

            if plot_handle_x is not None:
                plot_handle_z.remove()
            
            # plot trajectory for this chain of this batch element
            plot_handle_z, = axes[1,0].plot(traj_z_x, traj_z_y,
                                          '--x', markersize=4, color='black', alpha=0.5, zorder = 10)
            plt.pause(0.5)

            if plot_handle_x is not None:
                plot_handle_x.remove()
            
            plot_handle_x, = axes[1,1].plot(traj_x_x, traj_x_y,
                                          '--x', markersize=4, color='black', alpha=0.5, zorder = 10)

            plt.pause(0.5)  # pause to animate each step

        # remove previous line
        plot_handle_x.remove()
        plot_handle_z.remove()
        init_chain.remove()
        
        # highlight final element in the chain
        axes[1, 1].plot(traj_x_x[-1], traj_x_y[-1],'o', markersize = 8, color='C0')
        axes[1, 0].plot(traj_z_x[-1], traj_z_y[-1],'o', markersize = 8, color='C1')

%matplotlib inline

In [None]:
try:
    plt.close("all")
except:
    pass