#0.import and random seeds


In [1]:
import torch
import torch.nn as nn
import torch.distributions as td
import torch.utils.data
from tqdm import tqdm
from copy import deepcopy
import os
import math
import matplotlib.pyplot as plt
#dataset
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [2]:
import random
import numpy as np

def set_seed(seed=42):
    random.seed(seed)                           # Python built-in random module
    np.random.seed(seed)                        # NumPy random generator
    torch.manual_seed(seed)                     # PyTorch CPU random seed
    torch.cuda.manual_seed(seed)                # PyTorch current GPU random seed
    torch.cuda.manual_seed_all(seed)            # PyTorch all GPUs random seed
    torch.backends.cudnn.deterministic = True   # Ensure deterministic behavior in cuDNN
    torch.backends.cudnn.benchmark = False      # Disable auto-optimization to prevent non-deterministic behavior
    os.environ["PYTHONHASHSEED"] = str(seed)    # Control hash-based randomness in Python

set_seed(42)


#1.Dataset  
just copy the code provided by the professor



In [3]:
def subsample(data, targets, num_data, num_classes):
    idx = targets < num_classes  # Select samples with class labels less than num_classes (e.g., only classes 0, 1, 2)
    new_data = data[idx][:num_data].unsqueeze(1).to(torch.float32) / 255  # Select the first num_data images and normalize to [0,1]
    new_targets = targets[idx][:num_data]  # Select corresponding labels for the subsampled images
    return torch.utils.data.TensorDataset(new_data, new_targets)  # Create a TensorDataset with the filtered images and labels

num_train_data = 2048
num_classes = 3

train_tensors = datasets.MNIST(
    "data/", train=True, download=True,
    transform=transforms.Compose([transforms.ToTensor()])  # Convert images to tensors
)

test_tensors = datasets.MNIST(
    "data/", train=False, download=True,
    transform=transforms.Compose([transforms.ToTensor()])  # Convert images to tensors
)
train_data = subsample(
    train_tensors.data, train_tensors.targets,
    num_train_data, num_classes
)
test_data = subsample(
    test_tensors.data, test_tensors.targets,
    num_train_data, num_classes
)

batch_size=32
mnist_train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, shuffle=True
)
mnist_test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=batch_size, shuffle=False
)
latent_dim=2
M=latent_dim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


100%|██████████| 9.91M/9.91M [00:00<00:00, 41.6MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.16MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 10.7MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.74MB/s]


#2.GaussianPrior and Encoder/Decoder
Just copy the code provided by the professor


In [4]:
class GaussianPrior(nn.Module):
    def __init__(self, M):
        """
        Define a Gaussian prior distribution with zero mean and unit variance.

                Parameters:
        M: [int]
           Dimension of the latent space.
        """
        super(GaussianPrior, self).__init__()
        self.M = M
        self.mean = nn.Parameter(torch.zeros(self.M), requires_grad=False)
        self.std = nn.Parameter(torch.ones(self.M), requires_grad=False)

    def forward(self):
        """
        Return the prior distribution.

        Returns:
        prior: [torch.distributions.Distribution]
        """
        return td.Independent(td.Normal(loc=self.mean, scale=self.std), 1)

class GaussianEncoder(nn.Module):
    def __init__(self, encoder_net):
        """
        Define a Gaussian encoder distribution based on a given encoder network.

        Parameters:
        encoder_net: [torch.nn.Module]
            The encoder network that takes a tensor of dimension
            `(batch_size, feature_dim1, feature_dim2)` as input
            and outputs a tensor of dimension `(batch_size, 2M)`,
            where M is the dimension of the latent space.
        """
        super(GaussianEncoder, self).__init__()
        self.encoder_net = encoder_net

    def forward(self, x):
        """
        Given a batch of input data, return a Gaussian distribution over the latent space.

        Parameters:
        x: [torch.Tensor]
            A tensor of dimension `(batch_size, feature_dim1, feature_dim2)`.

        Returns:
        A Gaussian distribution with computed mean and standard deviation.
        """
        mean, std = torch.chunk(self.encoder_net(x), 2, dim=-1)
        return td.Independent(td.Normal(loc=mean, scale=torch.exp(std)), 1)

        # Example:
        # z = torch.randn(4, 10)  # Assume z is a tensor of shape [batch_size=4, 10]
        # a, b = torch.chunk(z, 2, dim=-1)
        # a and b will have shape [4, 5], as the tensor is split into two parts along the last dimension.
def new_encoder():
        encoder_net = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.Softmax(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.Softmax(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, stride=2, padding=1),
            nn.Flatten(),
            nn.Linear(512, 2 * M),
        )
        return encoder_net
class GaussianDecoder(nn.Module):
    def __init__(self, decoder_net):
        """
        Define a Gaussian decoder distribution based on a given decoder network.

        Parameters:
        decoder_net: [torch.nn.Module]
            The decoder network that takes a tensor of dimension `(batch_size, M)`
            as input, where M is the dimension of the latent space, and outputs a
            tensor of dimension `(batch_size, feature_dim1, feature_dim2)`.
        """
        super(GaussianDecoder, self).__init__()
        self.decoder_net = decoder_net
        # self.std = nn.Parameter(torch.ones(28, 28) * 0.5, requires_grad=True)
        # In case you want to learn the standard deviation of the Gaussian.

    def forward(self, z):
        """
        Given a batch of latent variables, return a Gaussian distribution over the data space.

        Parameters:
        z: [torch.Tensor]
            A tensor of dimension `(batch_size, M)`, where M is the dimension of the latent space.

        Returns:
        A Gaussian distribution with computed mean and a fixed standard deviation.
        """
        means = self.decoder_net(z)
        return td.Independent(td.Normal(loc=means, scale=1e-1), 3) #note the variance of decoder is fixed
        # This defines a 784-dimensional independent normal distribution, where each dimension is independent.
def new_decoder():
        decoder_net = nn.Sequential(
            nn.Linear(M, 512),
            nn.Unflatten(-1, (32, 4, 4)),
            nn.Softmax(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 32, 3, stride=2, padding=1, output_padding=0),
            nn.Softmax(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.Softmax(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
        )
        return decoder_net

#3.VAE
having changed the provided code ,so that for each mini-batch of data, we randomly sample a decoder and take a gradient step to optimize the ELBO


In [5]:
class VAE(nn.Module):
    def __init__(self, prior, decoders, encoder):
        """
        Variational Autoencoder (VAE) with multiple decoders.

        Parameters:
        prior: [torch.nn.Module]
            The prior distribution over the latent space.
        decoders: [list of torch.nn.Module]
            A list containing multiple decoders.
        encoder: [torch.nn.Module]
            The encoder network that maps input data to a latent distribution.
        """
        super(VAE, self).__init__()
        self.prior = prior
        self.decoders = nn.ModuleList(decoders)  # Use ModuleList to allow PyTorch to properly track parameters
        self.encoder = encoder

    def elbo(self, x, decoder_idx):
        """
        Compute the Evidence Lower Bound (ELBO) for a given input and selected decoder.

        Parameters:
        x: [torch.Tensor]
            The input data tensor.
        decoder_idx: [int]
            The index of the decoder to be used.

        Returns:
        The computed ELBO value.
        """
        q = self.encoder(x)  # Encode input into a latent distribution
        z = q.rsample()  # Sample from the latent distribution using the reparameterization trick
        decoder = self.decoders[decoder_idx]  # Select the corresponding decoder

        elbo = torch.mean(
            decoder(z).log_prob(x) - q.log_prob(z) + self.prior().log_prob(z)
        )  # Compute ELBO using the likelihood, posterior, and prior

        return elbo

    def sample(self, decoder_idx, n_samples=1):
        """
        Generate samples from the specified decoder.

        Parameters:
        decoder_idx: [int]
            The index of the decoder to be used.
        n_samples: [int, default=1]
            The number of samples to generate.

        Returns:
        A batch of generated samples.
        """
        z = self.prior().sample(torch.Size([n_samples]))  # Sample from the prior distribution
        decoder = self.decoders[decoder_idx]  # Select the corresponding decoder
        return decoder(z).sample()  # Generate samples from the decoder

    def forward(self, x, decoder_idx):
        """
        Compute the negative ELBO for optimization.

        Parameters:
        x: [torch.Tensor]
            The input data tensor.
        decoder_idx: [int]
            The index of the decoder to be used.

        Returns:
        The negative ELBO value.
        """
        return -self.elbo(x, decoder_idx)


#4.Training vae


In [6]:
def train(model, optimizers, data_loader, epochs, device):
    num_decoders = len(model.decoders)
    #one error
    # have changed the code to fit with different number of decoders
    total_epochs = epochs_per_decoder * num_decoders
    num_steps = len(data_loader) * total_epochs
    epoch = 0

    losses = []

    def noise(x, std=0.05):
        eps = std * torch.randn_like(x)
        return torch.clamp(x + eps, min=0.0, max=1.0)

    with tqdm(range(num_steps)) as pbar:
        for step in pbar:
            try:
                x = next(iter(data_loader))[0]
                x = noise(x.to(device))
                model=model
                idx = torch.randint(0, num_decoders, (1,)).item() #for each mini-batch of data, we randomly sample a decoder and take a gradient step to optimize the ELBO
                optimizer = optimizers[idx]
                optimizer.zero_grad()
                loss = model(x, decoder_idx=idx) #correspond to the changed part in VAE
                loss.backward()
                optimizer.step()

                loss_val = loss.detach().cpu().item()
                losses.append(loss_val)

                if step % 5 == 0:
                    pbar.set_description(
                        f"epoch={epoch}, step={step}, decoder={idx}, loss={loss_val:.1f}"
                    )
                if (step + 1) % len(data_loader) == 0:
                    epoch += 1
            except KeyboardInterrupt:
                print(f"Stopped at epoch {epoch}, step {step}, loss {loss_val:.1f}")
                break

    return losses


In [7]:
#because in part B we need 10 independent VAEs with  decoders 1,2,3, so I define a new train function
#S is the number of the decoders
def train_single_vae(seed, save_path, S, epochs_per_decoder):
    set_seed(seed)
    decoders = [GaussianDecoder(new_decoder()) for _ in range(S)]
    #instantiating  S randomly initialized decoders
    # note: set_seed(1001) only ensure the next time we run set_seed(1001), it's still the SAME randomly three decoders
    # it will not destroy of the randomness of three different decoders
    encoder = GaussianEncoder(new_encoder())
    prior = GaussianPrior(M)

    model = VAE(prior, decoders, encoder).to(device)
   # I just use the learning rate provided
    optimizers = [
        torch.optim.Adam(
            list(model.encoder.parameters()) + list(decoder.parameters()), lr=1e-3
        )
        for decoder in model.decoders
    ]

    losses = train(model, optimizers, mnist_train_loader, epochs_per_decoder, device)

    torch.save(model.state_dict(), save_path)
    plt.figure()
    plt.plot(range(5000, len(losses)), losses[5000:])
    plt.xlabel("Iteration")
    plt.ylabel("ELBO Loss")
    plt.title(f"Training Loss (Seed {seed}) [After 5000 Steps]")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path.replace(".pt", "_loss.png"))
    plt.close()


#5 train ONE VAE single and ensemble

In [8]:
# General training parameters
epochs_per_decoder = 400 # in fact, If I use 300,400 total epochs for 3-decoder vae, the peformance is not well
seed_base = 1000

# Single decoder model
num_decoders_single = 1
experiments_folder_single = "experiments/vae_single"
os.makedirs(experiments_folder_single, exist_ok=True)  # Create directory if it doesn't exist

# Multiple decoder model
num_decoders_ensemble = 3
experiments_folder_ensemble = "experiments/vae_ensemble"
os.makedirs(experiments_folder_ensemble, exist_ok=True)  # Create directory if it doesn't exist

In [9]:
# Train single decoder model
for i in range(1):  # Can be expanded to train multiple VAE models,see another file.
    seed = seed_base + i
    save_path = f"{experiments_folder_single}/sinmodel_seed{seed}.pt"
    print(f"Training SINGLE model with seed {seed}...")
    train_single_vae(seed, save_path, num_decoders_single, epochs_per_decoder)

Training SINGLE model with seed 1000...


  return self._call_impl(*args, **kwargs)
epoch=399, step=25595, decoder=0, loss=37.4: 100%|██████████| 25600/25600 [03:06<00:00, 137.34it/s]


In [10]:
# Train ensemble decoder model one vae
for i in range(1):  # Can be expanded to train multiple models
    seed = seed_base + i
    save_path = f"{experiments_folder_ensemble}/enmodel_seed{seed}.pt"
    print(f"Training ENSEMBLE model with seed {seed}...")
    train_single_vae(seed, save_path, num_decoders_ensemble, epochs_per_decoder)

Training ENSEMBLE model with seed 1000...


epoch=1199, step=76795, decoder=2, loss=-127.0: 100%|██████████| 76800/76800 [09:17<00:00, 137.83it/s]


In [11]:
# After training download the model_seed.pt, next time directly load the model
# Parameters
M = latent_dim
num_classes = 3
seed = 1000
# Load single decoder model
decoders_single = [GaussianDecoder(new_decoder())]
encoder_single = GaussianEncoder(new_encoder())
prior_single = GaussianPrior(M)
vae_single = VAE(prior_single, decoders_single, encoder_single).to(device)
vae_single.load_state_dict(torch.load(f"{experiments_folder_single}/sinmodel_seed{seed}.pt"))
# Load multiple decoder model
decoders_ensemble = [GaussianDecoder(new_decoder()) for _ in range(num_decoders_ensemble)]
encoder_ensemble = GaussianEncoder(new_encoder())
prior_ensemble = GaussianPrior(M)
vae_ensemble = VAE(prior_ensemble, decoders_ensemble, encoder_ensemble).to(device)
vae_ensemble.load_state_dict(torch.load(f"{experiments_folder_ensemble}/enmodel_seed{seed}.pt"))

<All keys matched successfully>

##5.1 latent space:
not relevant, just for checking whether the model encode the right latent space

In [12]:
def visualize_latent_space(model, dataloader, num_classes, save_path):
    model.eval()
    zs, labels = [], []
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            z = model.encoder(x).base_dist.loc
            zs.append(z.cpu())
            labels.append(y)
    zs = torch.cat(zs, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()

    plt.figure(figsize=(6, 6))
    for i in range(num_classes):
        idx = labels == i
        plt.scatter(zs[idx, 0], zs[idx, 1], s=5, alpha=0.6, label=f"Class {i}")
    plt.legend()
    plt.title("Latent Space")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


In [13]:

# Visualize the latent space for the single decoder model
visualize_latent_space(
    vae_single,
    mnist_test_loader,
    num_classes=num_classes,
    save_path="latent_space_single.png"
)

# Visualize the latent space for the multiple decoder model
visualize_latent_space(
    vae_ensemble,
    mnist_test_loader,
    num_classes=num_classes,
    save_path="latent_space_ensemble.png"
)


##5.2 sample
not relevant to this task, just for checking the quality of VAE


In [14]:
experiments_single = "experiments/vae_single"
outputs_single = "experiments/vae_single_outputs"
os.makedirs(outputs_single, exist_ok=True)  # Create output directory if it doesn't exist
vae_single.eval()  # Set model to evaluation mode
out_dir_single = f"{outputs_single}/vae_seed{seed}"
os.makedirs(out_dir_single, exist_ok=True)  # Create output directory for the specific seed
with torch.no_grad():
    # Sampling (only one decoder)
    samples = vae_single.sample(decoder_idx=0, n_samples=64).cpu()
    save_image(samples.view(64, 1, 28, 28), f"{out_dir_single}/samples_decoder0.png")

    # Reconstruction
    data = next(iter(mnist_test_loader))[0].to(device)  # Get a batch of test data
    z = vae_single.encoder(data).mean  # Encode data into latent space
    recon = vae_single.decoders[0](z).mean  # Decode from latent representation
    save_image(
        torch.cat([data.cpu(), recon.cpu()], dim=0),  # Concatenate original and reconstructed images
        f"{out_dir_single}/reconstruction_decoder0.png"
    )

In [15]:
num_decoders=3
outputs_ensemble = "experiments/vae_ensemble_outputs"
os.makedirs(outputs_ensemble, exist_ok=True)  # Create output directory if it doesn't exist
vae_ensemble.eval()  # Set model to evaluation mode
# Load multiple decoder model
out_dir_ensemble = f"{outputs_ensemble}/vae_seed{seed}"
os.makedirs(out_dir_ensemble, exist_ok=True)  # Create output directory for the specific seed

with torch.no_grad():
    # Sampling from each decoder
    for i in range(num_decoders):
        samples = vae_ensemble.sample(decoder_idx=i, n_samples=64).cpu()
        save_image(samples.view(64, 1, 28, 28), f"{out_dir_ensemble}/samples_decoder{i}.png")

    # Reconstruction using each decoder
    data = next(iter(mnist_test_loader))[0].to(device)  # Get a batch of test data
    z = vae_ensemble.encoder(data).mean  # Encode data into latent space
    for i in range(num_decoders):
        recon = vae_ensemble.decoders[i](z).mean  # Decode from latent representation
        save_image(
            torch.cat([data.cpu(), recon.cpu()], dim=0),  # Concatenate original and reconstructed images
            f"{out_dir_ensemble}/reconstruction_decoder{i}.png"
        )


##5.3 elbo not relevant to this task


In [16]:
'''
num_vaes =1
experiments_folder = "experiments/vae_retrain_seeds"

for vae_idx in range(num_vaes):
    seed = 1000 + vae_idx
    model_path = f"{experiments_folder}/model_seed{seed}.pt"


    decoders = [GaussianDecoder(new_decoder()) for _ in range(num_decoders)]
    encoder = GaussianEncoder(new_encoder())
    prior = GaussianPrior(M)

    model = VAE(prior, decoders, encoder).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    elbos_per_decoder = [[] for _ in range(num_decoders)]

    with torch.no_grad():
        for x, _ in mnist_test_loader:
            x = x.to(device)
            for i in range(num_decoders):
                elbo = model.elbo(x, decoder_idx=i)
                elbos_per_decoder[i].append(elbo)


    print(f"\nVAE model with seed {seed}:")
    for i in range(num_decoders):
        mean_elbo = torch.tensor(elbos_per_decoder[i]).mean()
        print(f"  Decoder {i} mean test ELBO: {mean_elbo.item():.4f}")
'''


'\nnum_vaes =1\nexperiments_folder = "experiments/vae_retrain_seeds"\n\nfor vae_idx in range(num_vaes):\n    seed = 1000 + vae_idx\n    model_path = f"{experiments_folder}/model_seed{seed}.pt"\n\n\n    decoders = [GaussianDecoder(new_decoder()) for _ in range(num_decoders)]\n    encoder = GaussianEncoder(new_encoder())\n    prior = GaussianPrior(M)\n\n    model = VAE(prior, decoders, encoder).to(device)\n    model.load_state_dict(torch.load(model_path, map_location=device))\n    model.eval()\n\n    elbos_per_decoder = [[] for _ in range(num_decoders)]\n\n    with torch.no_grad():\n        for x, _ in mnist_test_loader:\n            x = x.to(device)\n            for i in range(num_decoders):\n                elbo = model.elbo(x, decoder_idx=i)\n                elbos_per_decoder[i].append(elbo)\n\n\n    print(f"\nVAE model with seed {seed}:")\n    for i in range(num_decoders):\n        mean_elbo = torch.tensor(elbos_per_decoder[i]).mean()\n        print(f"  Decoder {i} mean test ELBO: {m

#6 cruve and energy


##6.1 cubiccurve


In [111]:
#correspond to page 69 in our textbook
class CubicCurve(nn.Module):
    def __init__(self, c0, c1):
        """
        Cubic polynomial curve module with fixed endpoints and parameterized middle section.

        Parameters:
        - c0: [d] Tensor representing the start point.
        - c1: [d] Tensor representing the end point.
        """
        super().__init__()
        self.register_buffer("c0", c0)
        self.register_buffer("c1", c1)

        d = c0.shape[0]
        # Learnable parameters: w1, w2 ∈ R^d
        self.w1 = nn.Parameter(torch.zeros(d, requires_grad=True))
        self.w2 = nn.Parameter(torch.zeros(d, requires_grad=True))

    def forward(self, t):
        """
        Forward pass to compute the cubic curve.

        Parameters:
        - t: [B] or [B, 1], curve parameter t ∈ [0, 1].

        Returns:
        - c(t): [B, d], computed points on the cubic curve.
        """
        if t.dim() == 1:
            t = t.unsqueeze(1)  # Reshape to [B, 1],so that t can * c0

        # [B, d] broadcasting
        t1 = t
        t2 = t ** 2
        t3 = t ** 3

        w1 = self.w1  # [d]
        w2 = self.w2  # [d]
        w3 = -w1 - w2  # Ensure smooth transition

        # Linear interpolation between c0 and c1
        linear = (1 - t) * self.c0 + t * self.c1  # [B, d]

        # Residual polynomial component
        residual = w1 * t1 + w2 * t2 + w3 * t3  # [B, d]

        return linear + residual  # [B, d], final cubic curve output


## 6.2 compute_energy

$$
\mathcal{E}[\gamma] \approx \sum_{t=0}^{T-1} \mathbb{E}_{\theta, \theta' \sim q(\theta) q(\theta)}
\left[ \left\| f_{\theta} (\gamma(t + \frac{1}{T})) - f_{\theta'} (\gamma(t / T)) \right\|^2 \right]
$$

$f_{\theta}$ $f_{\theta'}$ denotes deoder ensemble members drawn uniformly

In [112]:
import torch

def compute_curve_energy(curve, decoders, T=16, num_samples=1, fixed_indices=None, device='cuda'):
    """
    Compute the energy of a curve using fixed decoder indices to ensure the objective function remains consistent.

    Parameters:
    - curve: An instance of CubicCurve
    - decoders: List of decoder modules
    - T: Number of time steps, default is 16
    - num_samples: Number of Monte Carlo samples, default is 1
    - fixed_indices: Pre-fixed decoder indices [(idx1_t0, idx2_t0), (idx1_t1, idx2_t1), ...]
    - device: Computing device, default is 'cuda'

    Returns:
    - Scalar energy value
    """
    total_energy = 0.0  # Accumulate energy over all time steps

    for i in range(T):
        t0 = torch.tensor([i / T], device=device, dtype=torch.float32)
        t1 = torch.tensor([(i + 1) / T], device=device, dtype=torch.float32)

        x0 = curve(t0)  # γ(t0), shape [1, d]
        x1 = curve(t1)  # γ(t1)

        energy = 0.0  # Energy for the current time step

        for _ in range(num_samples):
            idx1, idx2 = fixed_indices[i]  # Retrieve fixed indices

            # **Compute only the required decoder outputs**
            sampled_mean_x0 = decoders[idx1](x0).mean  # Directly compute the mean for idx1
            sampled_mean_x1 = decoders[idx2](x1).mean  # Directly compute the mean for idx2

            # Compute L2 norm
            energy += torch.norm(sampled_mean_x1 - sampled_mean_x0, p=2)

        # Take Monte Carlo average, negligible when num_samples = 1
        total_energy += energy / num_samples

    return total_energy  # Return total energy


##6.3 optimize_geodesics

In [114]:
def optimize_geodesic(c0, c1, decoders, T=16, steps=500, lr=1e-2, device='cuda',
                      early_stopping_n=100, early_stopping_delta=1e-4):
    """
    Optimize a geodesic curve while ensuring the objective function remains unchanged during optimization.

    Parameters:
    - c0: Starting point
    - c1: Endpoint
    - decoders: List of decoder modules
    - T: Number of time steps
    - steps: Number of optimization iterations
    - lr: Learning rate
    - device: Computing device
    - early_stopping_n: Number of steps to check for early stopping
    - early_stopping_delta: Minimum required improvement to continue

    Returns:
    - Optimized curve
    - Logged energy values
    """
    curve = CubicCurve(c0, c1).to(device)  # Initialize the curve
    optimizer = torch.optim.Adam(curve.parameters(), lr=lr)  # Adam optimizer
    energy_log = []  # Store energy values

    # **Pre-generate fixed decoder indices**
    fixed_indices = [(torch.randint(0, len(decoders), (1,), device=device).item(),
                      torch.randint(0, len(decoders), (1,), device=device).item())
                     for _ in range(T)]

    best_energy = float('inf')
    no_improve_count = 0  # Counter for early stopping

    with tqdm(range(steps)) as pbar:
        for step in pbar:
            optimizer.zero_grad()  # Clear gradients
            energy = compute_curve_energy(curve, decoders, T=T, fixed_indices=fixed_indices, device=device)  # Compute energy
            energy.backward()  # Backpropagation
            optimizer.step()  # Update parameters

            energy_value = energy.item()
            energy_log.append(energy_value)  # Store energy value

            # Early Stopping Logic
            if energy_value < best_energy - early_stopping_delta:
                best_energy = energy_value
                no_improve_count = 0  # Reset counter
            else:
                no_improve_count += 1

            if no_improve_count >= early_stopping_n:
                print(f"Early stopping at step {step}, energy: {energy_value:.6f}")
                break  # Stop training if no improvement

            # Update progress bar
            pbar.set_description(f"Energy: {energy_value:.6f}")

    return curve, energy_log


#7.plot  25 random latent_varibale pairs


##7.1 data preparation

In [115]:
#参数设置
# ====== 参数设置 ======
num_pairs = 25  # 需要处理的点对数量
T = 256  # 测地线离散化段数
steps = 4200  # 迭代步数
lr = 1e-2  # 学习率
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random_seed = 42
random.seed(random_seed)

# 存储路径
save_dir = "energy/"
results_dir = "results"
energy_dir = "energy_logs"
os.makedirs(save_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
os.makedirs(energy_dir, exist_ok=True)

In [116]:
# ====== Obtain Full Test Set Latent Representations (for Scatter Plot) ======
def encode_all_test_latents(model):
    """
    Encode the entire test dataset into the latent space.

    Parameters:
    - model: VAE model (single or ensemble) used for encoding.

    Returns:
    - zs: Tensor of latent representations.
    - ys: Corresponding labels.
    """
    zs, ys = [], []
    with torch.no_grad():
        for x, y in mnist_test_loader:
            x = x.to(device)
            q = model.encoder(x)
            z = q.base_dist.loc  # Extract mean of the latent distribution
            zs.append(z.cpu())
            ys.append(y)
    return torch.cat(zs, dim=0), torch.cat(ys, dim=0)

# Encode the test set using both models
latent_z_single, labels_single = encode_all_test_latents(vae_single)
latent_z_ensemble, labels_ensemble = encode_all_test_latents(vae_ensemble)

In [117]:
# 每次计算完一个点对，就画 energy 和 geodesics 图
def plot_energy_log_single(log, pair_idx, decoder_count):
    folder = os.path.join(energy_dir, f"vae_d{decoder_count}")
    os.makedirs(folder, exist_ok=True)

    plt.figure(figsize=(8, 6))
    plt.plot(log)
    plt.xlabel("Optimization Step")
    plt.ylabel("Energy")
    plt.title(f"Geodesic Energy (Decoder={decoder_count}, Pair={pair_idx})")
    plt.grid(True)
    plt.tight_layout()

    save_path = os.path.join(folder, f"pair_{pair_idx}.png")
    plt.savefig(save_path, dpi=300)
    plt.close()


In [118]:
def plot_geodesics_progress(latent_z, labels, z_pairs, curves, decoder_count, pair_idx):
    """
    累积绘制测地线路径，并叠加潜在变量空间（背景）。

    Parameters:
    - latent_z: 测试集潜在表示（背景散点）
    - labels: 对应的类别标签（背景散点颜色）
    - z_pairs: 累计的起始点对
    - curves: 累计的测地线曲线
    - decoder_count: 解码器数量
    - pair_idx: 当前点对索引
    """
    folder = os.path.join(results_dir, f"vae_d{decoder_count}")
    os.makedirs(folder, exist_ok=True)

    plt.figure(figsize=(8, 6))

    # 背景散点图，使用原来的 tab10 配色
    scatter = plt.scatter(latent_z[:, 0], latent_z[:, 1], c=labels, cmap="tab10", s=8, alpha=0.4)

    t_vals = torch.linspace(0, 1, sample_steps).unsqueeze(1).to(device)

    for i in range(len(curves)):
        gamma = curves[i](t_vals).detach().cpu()
        c0, c1 = z_pairs[i, 0].cpu(), z_pairs[i, 1].cpu()

        # 不指定颜色，自动配色 + 线加粗
        plt.plot(gamma[:, 0], gamma[:, 1], linewidth=2.2)  # 更粗线条
        plt.plot([c0[0], c1[0]], [c0[1], c1[1]], 'k--', linewidth=1.0)  # 虚线连接端点

    plt.title(f"Geodesics (Decoder={decoder_count}, Pairs=1~{pair_idx+1})")
    plt.xlabel("z1")
    plt.ylabel("z2")
    plt.grid(True)
    plt.legend(*scatter.legend_elements(), title="Class")
    plt.tight_layout()

    save_path = os.path.join(folder, f"geodesics_pairs_{pair_idx+1}.png")
    plt.savefig(save_path, dpi=300)
    plt.close()


# 7.2 load data and encode into latent  space z

In [119]:
# ====== 读取测试数据 ======
test_images = []
test_labels = []
for x, y in mnist_test_loader:
    test_images.append(x)
    test_labels.append(y)

test_images = torch.cat(test_images, dim=0)  # [N, 1, 28, 28]
test_labels = torch.cat(test_labels, dim=0)

# ====== 读取或生成固定的测试点对索引 ======
index_file = os.path.join(save_dir, "test_indices.pt")
if os.path.exists(index_file):
    indices = torch.load(index_file)
    print(f"加载已有测试点对索引，共 {len(indices)//2} 对")
else:
    N = test_images.shape[0]
    indices = random.sample(range(N), 2 * num_pairs)
    torch.save(indices, index_file)
    print(f"生成并保存新测试点对索引，共 {num_pairs} 对")

# ====== 生成固定的测试点对 ======
x_pairs = torch.stack([
    torch.stack([test_images[indices[i]], test_images[indices[i + 1]]], dim=0)
    for i in range(0, 2 * num_pairs, 2)
])  # Shape: [num_pairs, 2, 1, 28, 28]

# ====== 进入 VAE 潜在空间 ======
vae_single.eval()
vae_ensemble.eval()

z_pairs_single = []
z_pairs_ensemble = []

with torch.no_grad():
    for i in range(num_pairs):
        x0 = x_pairs[i, 0].to(device)  # 第一张图片
        x1 = x_pairs[i, 1].to(device)  # 第二张图片

        # 单解码器 VAE
        z0_single = vae_single.encoder(x0.unsqueeze(0)).base_dist.loc.squeeze(0)
        z1_single = vae_single.encoder(x1.unsqueeze(0)).base_dist.loc.squeeze(0)
        z_pairs_single.append(torch.stack([z0_single, z1_single], dim=0))

        # 多解码器 VAE
        z0_ens = vae_ensemble.encoder(x0.unsqueeze(0)).base_dist.loc.squeeze(0)
        z1_ens = vae_ensemble.encoder(x1.unsqueeze(0)).base_dist.loc.squeeze(0)
        z_pairs_ensemble.append(torch.stack([z0_ens, z1_ens], dim=0))

z_pairs_single = torch.stack(z_pairs_single)     # Shape: [num_pairs, 2, latent_dim]
z_pairs_ensemble = torch.stack(z_pairs_ensemble) # Shape: [num_pairs, 2, latent_dim]





生成并保存新测试点对索引，共 25 对


## 7.3 compute geodesics


In [None]:
# ====== 计算测地线并保存 ======
curves_single = []
energy_logs_single = []
curves_ensemble = []
energy_logs_ensemble = []

for i in range(num_pairs):
    pair_file = os.path.join(save_dir, f"pair_{i}.pt")

    if os.path.exists(pair_file):
        # **已存在，直接加载**
        data = torch.load(pair_file, weights_only=False)
        curves_single.append(data["curve_single"])
        energy_logs_single.append(data["energy_log_single"])
        curves_ensemble.append(data["curve_ensemble"])
        energy_logs_ensemble.append(data["energy_log_ensemble"])
        print(f"点对 {i} 已加载")
    else:
        # **计算 single**
        c0, c1 = z_pairs_single[i, 0].to(device), z_pairs_single[i, 1].to(device)
        curve_single, energy_log_single = optimize_geodesic(
            c0, c1, decoders=[vae_single.decoders[0]],
            T=T, steps=steps, lr=lr, device=device,
            early_stopping_n=100, early_stopping_delta=1e-4
        )

        # **计算 ensemble**
        c0_ens, c1_ens = z_pairs_ensemble[i, 0].to(device), z_pairs_ensemble[i, 1].to(device)
        curve_ensemble, energy_log_ensemble = optimize_geodesic(
            c0_ens, c1_ens, decoders=vae_ensemble.decoders,
            T=T, steps=steps, lr=lr, device=device
        )

        # **存储结果**
        torch.save({
            "curve_single": curve_single,
            "curve_ensemble": curve_ensemble,
            "energy_log_single": energy_log_single,
            "energy_log_ensemble": energy_log_ensemble
        }, pair_file)

        print(f"点对 {i} 计算完成，已保存到 {pair_file}")

        # **加载计算结果，保证后续代码可用**
        curves_single.append(curve_single)
        energy_logs_single.append(energy_log_single)
        curves_ensemble.append(curve_ensemble)
        energy_logs_ensemble.append(energy_log_ensemble)

    # **画能量曲线**
    plot_energy_log_single(energy_logs_single[-1], i, 1)
    plot_energy_log_single(energy_logs_ensemble[-1], i, len(vae_ensemble.decoders))

    # **画测地线**
    plot_geodesics_progress(latent_z_single, labels_single, z_pairs_single[:i+1], curves_single, 1, i)
    plot_geodesics_progress(latent_z_ensemble, labels_ensemble, z_pairs_ensemble[:i+1], curves_ensemble, len(vae_ensemble.decoders), i)


Energy: 30.850645:   4%|▍         | 173/4200 [03:24<1:18:44,  1.17s/it]

In fact,different points has different decreasing covergence speed, so I guess we can use early stopping, setting max=5000, if within patience=50 steps , the energy is basically the same, then early stop.

just refer to the orginial paper (also use early stopping)
https://github.com/mustass/ensertainty/blob/main/configs/inference/ensemble_geodesics.yaml

##7.3 plot the required comparision plot of geodesics

In [50]:
# ====== Plotting Function ======
sample_steps=32 #after optimizing , we get final curve. using how many points to draw this curve
# 生成 25 种不同的颜色（从 colormap 获取）
colors = plt.cm.viridis(np.linspace(0, 1, 25))  # 可以换成其他 colormap，如 plt.cm.jet

def plot_geodesics(latent_z, labels, z_pairs, curves, title, out_path):
    """
    Plot geodesic paths in the latent space.

    Parameters:
    - latent_z: Latent representations of the test dataset.
    - labels: Corresponding class labels.
    - z_pairs: Pairs of latent points between which geodesics are computed.
    - curves: Optimized geodesic curves.
    - title: Plot title.
    - out_path: File path to save the plot.
    """
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(latent_z[:, 0], latent_z[:, 1], c=labels, cmap="tab10", s=8, alpha=0.4)
    t_vals = torch.linspace(0, 1, sample_steps).unsqueeze(1).to(device)  # Interpolation points along the geodesic

    for i in range(len(curves)):
        gamma = curves[i](t_vals).detach().cpu()  # Compute geodesic path
        c0, c1 = z_pairs[i, 0], z_pairs[i, 1]

        plt.plot(gamma[:, 0], gamma[:, 1], linewidth=1.5, color=colors[i])  # 固定颜色
        plt.plot([c0[0], c1[0]], [c0[1], c1[1]], 'k--', linewidth=0.8)  # Dashed line connecting endpoints

    plt.title(title)
    plt.xlabel("z1")
    plt.ylabel("z2")
    plt.grid(True)
    plt.legend(*scatter.legend_elements(), title="Class")
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()





In [80]:
# ====== Generate Plots for Single and Ensemble Models ======
plot_geodesics(latent_z_single, labels_single, z_pairs_single.cpu(), curves_single,
               "Geodesics in Latent Space (Single Decoder)", "vae_single_geodesics.png")

In [81]:
plot_geodesics(latent_z_ensemble, labels_ensemble, z_pairs_ensemble.cpu(), curves_ensemble,
               "Geodesics in Latent Space (Ensemble Decoder)", "vae_ensemble_geodesics.png")
