#0.import and random seeds


In [2]:
import torch
import torch.nn as nn
import torch.distributions as td
import torch.utils.data
from tqdm import tqdm
from copy import deepcopy
import os
import math
import matplotlib.pyplot as plt
#dataset
from torchvision import datasets, transforms
from torchvision.utils import save_image

In [3]:
import random
import numpy as np

def set_seed(seed=42):
    random.seed(seed)                           # Python built-in random module
    np.random.seed(seed)                        # NumPy random generator
    torch.manual_seed(seed)                     # PyTorch CPU random seed
    torch.cuda.manual_seed(seed)                # PyTorch current GPU random seed
    torch.cuda.manual_seed_all(seed)            # PyTorch all GPUs random seed
    torch.backends.cudnn.deterministic = True   # Ensure deterministic behavior in cuDNN
    torch.backends.cudnn.benchmark = False      # Disable auto-optimization to prevent non-deterministic behavior
    os.environ["PYTHONHASHSEED"] = str(seed)    # Control hash-based randomness in Python

set_seed(42)


#1.Dataset  
just copy the code provided by the professor



In [5]:
def subsample(data, targets, num_data, num_classes):
    idx = targets < num_classes  # Select samples with class labels less than num_classes (e.g., only classes 0, 1, 2)
    new_data = data[idx][:num_data].unsqueeze(1).to(torch.float32) / 255  # Select the first num_data images and normalize to [0,1]
    new_targets = targets[idx][:num_data]  # Select corresponding labels for the subsampled images
    return torch.utils.data.TensorDataset(new_data, new_targets)  # Create a TensorDataset with the filtered images and labels

num_train_data = 2048
num_classes = 3

train_tensors = datasets.MNIST(
    "data/", train=True, download=True,
    transform=transforms.Compose([transforms.ToTensor()])  # Convert images to tensors
)

test_tensors = datasets.MNIST(
    "data/", train=False, download=True,
    transform=transforms.Compose([transforms.ToTensor()])  # Convert images to tensors
)
train_data = subsample(
    train_tensors.data, train_tensors.targets,
    num_train_data, num_classes
)
test_data = subsample(
    test_tensors.data, test_tensors.targets,
    num_train_data, num_classes
)

batch_size=32
mnist_train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=batch_size, shuffle=True
)
mnist_test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=batch_size, shuffle=False
)
latent_dim=2
M=latent_dim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#2.GaussianPrior and Encoder/Decoder
Just copy the code provided by the professor


In [7]:
class GaussianPrior(nn.Module):
    def __init__(self, M):
        """
        Define a Gaussian prior distribution with zero mean and unit variance.

                Parameters:
        M: [int]
           Dimension of the latent space.
        """
        super(GaussianPrior, self).__init__()
        self.M = M
        self.mean = nn.Parameter(torch.zeros(self.M), requires_grad=False)
        self.std = nn.Parameter(torch.ones(self.M), requires_grad=False)

    def forward(self):
        """
        Return the prior distribution.

        Returns:
        prior: [torch.distributions.Distribution]
        """
        return td.Independent(td.Normal(loc=self.mean, scale=self.std), 1)

class GaussianEncoder(nn.Module):
    def __init__(self, encoder_net):
        """
        Define a Gaussian encoder distribution based on a given encoder network.

        Parameters:
        encoder_net: [torch.nn.Module]
            The encoder network that takes a tensor of dimension
            `(batch_size, feature_dim1, feature_dim2)` as input
            and outputs a tensor of dimension `(batch_size, 2M)`,
            where M is the dimension of the latent space.
        """
        super(GaussianEncoder, self).__init__()
        self.encoder_net = encoder_net

    def forward(self, x):
        """
        Given a batch of input data, return a Gaussian distribution over the latent space.

        Parameters:
        x: [torch.Tensor]
            A tensor of dimension `(batch_size, feature_dim1, feature_dim2)`.

        Returns:
        A Gaussian distribution with computed mean and standard deviation.
        """
        mean, std = torch.chunk(self.encoder_net(x), 2, dim=-1)
        return td.Independent(td.Normal(loc=mean, scale=torch.exp(std)), 1)

        # Example:
        # z = torch.randn(4, 10)  # Assume z is a tensor of shape [batch_size=4, 10]
        # a, b = torch.chunk(z, 2, dim=-1)
        # a and b will have shape [4, 5], as the tensor is split into two parts along the last dimension.
def new_encoder():
        encoder_net = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),
            nn.Softmax(),
            nn.BatchNorm2d(16),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),
            nn.Softmax(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 32, 3, stride=2, padding=1),
            nn.Flatten(),
            nn.Linear(512, 2 * M),
        )
        return encoder_net
class GaussianDecoder(nn.Module):
    def __init__(self, decoder_net):
        """
        Define a Gaussian decoder distribution based on a given decoder network.

        Parameters:
        decoder_net: [torch.nn.Module]
            The decoder network that takes a tensor of dimension `(batch_size, M)`
            as input, where M is the dimension of the latent space, and outputs a
            tensor of dimension `(batch_size, feature_dim1, feature_dim2)`.
        """
        super(GaussianDecoder, self).__init__()
        self.decoder_net = decoder_net
        # self.std = nn.Parameter(torch.ones(28, 28) * 0.5, requires_grad=True)
        # In case you want to learn the standard deviation of the Gaussian.

    def forward(self, z):
        """
        Given a batch of latent variables, return a Gaussian distribution over the data space.

        Parameters:
        z: [torch.Tensor]
            A tensor of dimension `(batch_size, M)`, where M is the dimension of the latent space.

        Returns:
        A Gaussian distribution with computed mean and a fixed standard deviation.
        """
        means = self.decoder_net(z)
        return td.Independent(td.Normal(loc=means, scale=1e-1), 3) #note the variance of decoder is fixed
        # This defines a 784-dimensional independent normal distribution, where each dimension is independent.
def new_decoder():
        decoder_net = nn.Sequential(
            nn.Linear(M, 512),
            nn.Unflatten(-1, (32, 4, 4)),
            nn.Softmax(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 32, 3, stride=2, padding=1, output_padding=0),
            nn.Softmax(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),
            nn.Softmax(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),
        )
        return decoder_net

#3.VAE
having changed the provided code ,so that for each mini-batch of data, we randomly sample a decoder and take a gradient step to optimize the ELBO


In [9]:
class VAE(nn.Module):
    def __init__(self, prior, decoders, encoder):
        """
        Variational Autoencoder (VAE) with multiple decoders.

        Parameters:
        prior: [torch.nn.Module]
            The prior distribution over the latent space.
        decoders: [list of torch.nn.Module]
            A list containing multiple decoders.
        encoder: [torch.nn.Module]
            The encoder network that maps input data to a latent distribution.
        """
        super(VAE, self).__init__()
        self.prior = prior
        self.decoders = nn.ModuleList(decoders)  # Use ModuleList to allow PyTorch to properly track parameters
        self.encoder = encoder

    def elbo(self, x, decoder_idx):
        """
        Compute the Evidence Lower Bound (ELBO) for a given input and selected decoder.

        Parameters:
        x: [torch.Tensor]
            The input data tensor.
        decoder_idx: [int]
            The index of the decoder to be used.

        Returns:
        The computed ELBO value.
        """
        q = self.encoder(x)  # Encode input into a latent distribution
        z = q.rsample()  # Sample from the latent distribution using the reparameterization trick
        decoder = self.decoders[decoder_idx]  # Select the corresponding decoder

        elbo = torch.mean(
            decoder(z).log_prob(x) - q.log_prob(z) + self.prior().log_prob(z)
        )  # Compute ELBO using the likelihood, posterior, and prior

        return elbo

    def sample(self, decoder_idx, n_samples=1):
        """
        Generate samples from the specified decoder.

        Parameters:
        decoder_idx: [int]
            The index of the decoder to be used.
        n_samples: [int, default=1]
            The number of samples to generate.

        Returns:
        A batch of generated samples.
        """
        z = self.prior().sample(torch.Size([n_samples]))  # Sample from the prior distribution
        decoder = self.decoders[decoder_idx]  # Select the corresponding decoder
        return decoder(z).sample()  # Generate samples from the decoder

    def forward(self, x, decoder_idx):
        """
        Compute the negative ELBO for optimization.

        Parameters:
        x: [torch.Tensor]
            The input data tensor.
        decoder_idx: [int]
            The index of the decoder to be used.

        Returns:
        The negative ELBO value.
        """
        return -self.elbo(x, decoder_idx)


#4.Training vae


In [11]:
def train(model, optimizers, data_loader, epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_decoders = len(model.decoders)
    #total epoch should depende on the number of decoders
    total_epochs = epochs * num_decoders
    num_steps = len(data_loader) * total_epochs
    epoch = 0

    losses = []

    def noise(x, std=0.05):
        eps = std * torch.randn_like(x)
        return torch.clamp(x + eps, min=0.0, max=1.0)

    with tqdm(range(num_steps)) as pbar:
        for step in pbar:
            try:
                x = next(iter(data_loader))[0]
                x = noise(x.to(device))
                model=model
                idx = torch.randint(0, num_decoders, (1,)).item()
                #for each mini-batch of data, we randomly sample a decoder
                #and take a gradient step to optimize the ELBO of  that decoder
                # if we [1]
                optimizer = optimizers[idx]
                optimizer.zero_grad()
                loss = model(x, decoder_idx=idx) #correspond to the changed part in VAE
                loss.backward()
                optimizer.step()

                loss_val = loss.detach().cpu().item()
                losses.append(loss_val)

                if step % 5 == 0:
                    pbar.set_description(
                        f"epoch={epoch}, step={step}, decoder={idx}, loss={loss_val:.1f}"
                    )
                if (step + 1) % len(data_loader) == 0:
                    epoch += 1
            except KeyboardInterrupt:
                print(f"Stopped at epoch {epoch}, step {step}, loss {loss_val:.1f}")
                break

    return losses


In [12]:
#because in part B we need 10 independent VAEs with  decoders 1,2,3, so I define a new train function
#S is the number of the decoders
def train_single_vae(seed, save_path, S, epochs_per_decoder):
    set_seed(seed)
    decoders = [GaussianDecoder(new_decoder()) for _ in range(S)]
    #instantiating  S randomly initialized decoders
    # note: set_seed(1001) only ensure the next time we run set_seed(1001), it's still the SAME randomly three decoders
    # it will not destroy of the randomness of three different decoders
    # [dd1,dd2,dd0]  [dd1,dd0,dd2]
    encoder = GaussianEncoder(new_encoder())
    prior = GaussianPrior(M)

    model = VAE(prior, decoders, encoder).to(device)
   # I just use the learning rate provided
    optimizers = [
        torch.optim.Adam(
            list(model.encoder.parameters()) + list(decoder.parameters()), lr=1e-3
        )
        for decoder in model.decoders
    ]

    losses = train(model, optimizers, mnist_train_loader, epochs_per_decoder)

    torch.save(model.state_dict(), save_path)
    plt.figure()
    plt.plot(range(5000, len(losses)), losses[5000:])
    plt.xlabel("Iteration")
    plt.ylabel("ELBO Loss")
    plt.title(f"Training Loss (Seed {seed}) [After 5000 Steps]")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_path.replace(".pt", "_loss.png"))
    plt.close()


#5 train 10 VAES

train 10 VAES with different decoder_counts,1,2,3

I think this depends on how you implemented the retrainings.

If you did something like :
## Training:
for each training run:
    train model with max_number_of_decoders
   10 independent vaes and 10 encoders
## Geodesics
for number_of_decoders in range(1,max_number_of_decoders):
    compute_geodesics using number_of_decoders
Resulting in 10 training runs, then the blue curve should be constant as the variation of the Euclidean distances should not vary.


The difference is in whether you have 10 encoders

We advise you to follow the first approach as it is computationally cheaper. But doing the second approach is by no means wrong.

In [16]:
def train_super_vae_models(Q=10, epochs_per_decoder=400, base_seed=1000, max_decoder_num=3):
    folder = f"experiments/vae_d{max_decoder_num}"
    os.makedirs(folder, exist_ok=True)

    for i in range(Q):
        seed = base_seed + i
        name = f"vae_d{max_decoder_num}_seed{seed}"
        save_path = os.path.join(folder, f"{name}.pt")

        if os.path.exists(save_path):
            print(f"Skipping training for {name}, checkpoint exists.")
            continue

        print(f"Training VAE: decoder={max_decoder_num}, seed={seed}")
        train_single_vae(seed, save_path, S=max_decoder_num, epochs_per_decoder=epochs_per_decoder)


In [17]:
#Q is the number of VAES,here Q=2 epoch=2 just for testing the code.
train_super_vae_models(Q=10, epochs_per_decoder=400, base_seed=1000,max_decoder_num=3)

Skipping training for vae_d3_seed1000, checkpoint exists.
Skipping training for vae_d3_seed1001, checkpoint exists.
Training VAE: decoder=3, seed=1002


  return self._call_impl(*args, **kwargs)
epoch=2, step=190, decoder=1, loss=1441.6: 100%|██████████| 192/192 [00:02<00:00, 79.67it/s]


Training VAE: decoder=3, seed=1003


epoch=2, step=155, decoder=2, loss=3340.2:  81%|████████▏ | 156/192 [00:01<00:00, 83.56it/s]


Stopped at epoch 2, step 156, loss 3340.2
Training VAE: decoder=3, seed=1004


epoch=2, step=190, decoder=1, loss=1727.1: 100%|██████████| 192/192 [00:02<00:00, 85.72it/s]


Training VAE: decoder=3, seed=1005


epoch=2, step=190, decoder=1, loss=1375.7: 100%|██████████| 192/192 [00:02<00:00, 85.05it/s]


Training VAE: decoder=3, seed=1006


epoch=2, step=190, decoder=2, loss=1592.6: 100%|██████████| 192/192 [00:02<00:00, 82.89it/s]


Training VAE: decoder=3, seed=1007


epoch=2, step=190, decoder=0, loss=1787.2: 100%|██████████| 192/192 [00:02<00:00, 84.76it/s]


Training VAE: decoder=3, seed=1008


epoch=2, step=190, decoder=1, loss=1435.0: 100%|██████████| 192/192 [00:02<00:00, 85.30it/s]


Training VAE: decoder=3, seed=1009


epoch=2, step=190, decoder=0, loss=1450.8: 100%|██████████| 192/192 [00:02<00:00, 87.26it/s]


In [27]:
def load_all_vaes(base_folder="experiments", max_decoder_num=3, Q=1, base_seed=1000, device = torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    all_models = {}

    folder = os.path.join(base_folder, f"vae_d{max_decoder_num}")

    for i in range(Q):
        seed = base_seed + i
        name = f"vae_d{max_decoder_num}_seed{seed}"
        path = os.path.join(folder, f"{name}.pt")

        # construct model structure
        decoders = [GaussianDecoder(new_decoder()) for _ in range(max_decoder_num)]
        encoder = GaussianEncoder(new_encoder())
        prior = GaussianPrior(M=2)  # latent_dim = 2

        model = VAE(prior, decoders, encoder).to(device)
        model.load_state_dict(torch.load(path, map_location=device))
        model.eval()

        all_models[name] = model
        print(f"Loaded {name}")

    return all_models
models = load_all_vaes()

Loaded vae_d3_seed1000


In [None]:
#At convergence, we
#have access to one encoder and Sdecoders.

#for test

# one model
m = models["vae_d3_seed1001"]
encoder = m.encoder
decoder0 = m.decoders[0]
print(encoder)
print(decoder0)

Loaded vae_d3_seed1000
Loaded vae_d3_seed1001
GaussianEncoder(
  (encoder_net): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): Softmax(dim=None)
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): Softmax(dim=None)
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (7): Flatten(start_dim=1, end_dim=-1)
    (8): Linear(in_features=512, out_features=4, bias=True)
  )
)
GaussianDecoder(
  (decoder_net): Sequential(
    (0): Linear(in_features=2, out_features=512, bias=True)
    (1): Unflatten(dim=-1, unflattened_size=(32, 4, 4))
    (2): Softmax(dim=None)
    (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(2, 2)

## 5.2 check vae quality

In [29]:
#5.2
def visualize_all_vaes_all_decoders(models, test_loader, output_folder="vae_vis_outputs", device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    os.makedirs(output_folder, exist_ok=True)

    for name, model in models.items():
        model.eval()
        out_dir = os.path.join(output_folder, name)
        os.makedirs(out_dir, exist_ok=True)

        num_decoders = len(model.decoders)

        with torch.no_grad():
            data = next(iter(test_loader))[0].to(device)
            z = model.encoder(data).mean

            for decoder_idx in range(num_decoders):
                # 1. Sampling
                samples = model.sample(decoder_idx=decoder_idx, n_samples=64).cpu()
                sample_path = os.path.join(out_dir, f"samples_decoder{decoder_idx}.png")
                save_image(samples.view(64, 1, 28, 28), sample_path)

                # 2. Reconstruction
                recon = model.decoders[decoder_idx](z).mean
                recon_path = os.path.join(out_dir, f"reconstruction_decoder{decoder_idx}.png")
                save_image(torch.cat([data.cpu(), recon.cpu()], dim=0), recon_path)

            print(f"✅ Saved all decoders for {name}")


In [31]:

visualize_all_vaes_all_decoders(models, mnist_test_loader, output_folder="vae_vis_outputs")


✅ Saved all decoders for vae_d3_seed1000


  return self._call_impl(*args, **kwargs)


#6 cruve and energy


##6.1 cubiccurve 参数T=64


In [34]:
# page 68 cubiccurve(t) core code
class CubicCurve(nn.Module):
    def __init__(self, c0, c1):
        """
        Parameters:
        c0: torch.Tensor, start point (D,)
        c1: torch.Tensor, end point (D,)
        """
        super().__init__()
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'  # Use GPU if available
        self.T = 256  # Number of segments, typically in [64, 256] in Q2 we use 64
        self.D = c0.shape[0]  # Dimensionality of the curve

        self.register_buffer("c0", c0.to(self.device))  # Register start point as buffer (non-trainable)
        self.register_buffer("c1", c1.to(self.device))  # Register end point as buffer (non-trainable)

        # Initialize intermediate control points along a straight line from c0 to c1
        # This ensures the initial curve is linear before optimization
        intermediate_points = torch.stack([
            c0 + (c1 - c0) * (i + 1) / self.T  # Equally spaced points between c0 and c1
            for i in range(self.T - 1)
        ], dim=0).to(self.device)  # Shape: (T-1, D), stacked along the 0-th dimension

        self.c_t = nn.Parameter(intermediate_points)  # Make intermediate points trainable

        # Create a uniform grid of time values from 0 to 1 with (T + 1) points
        self.register_buffer("t_grid", torch.linspace(0, 1, self.T + 1, device=self.device))
        # t0, t1, ..., tT

    def forward(self, t):
        """
        Compute the value of the piecewise linear curve at position t

        Parameters:
        t: torch.Tensor, shape (...,), range [0,1]

        Returns:
        torch.Tensor, shape (..., D)

        Each segment is linearly interpolated between two control points:
        Given control points p0 and p1, and a scalar alpha ∈ [0, 1],
        interpolation formula is: p(t) = (1 - alpha) * p0 + alpha * p1
        where alpha = (t - t0) / (t1 - t0)
        """
        t = t.to(self.device)  # Move t to the appropriate device

        # Concatenate all control points: [c0, ..., c_t, ..., c1], shape (T+1, D)
        control_points = torch.cat([self.c0.unsqueeze(0), self.c_t, self.c1.unsqueeze(0)], dim=0)
        # unsqueeze(0) adds a new dimension at index 0: (D,) → (1, D)

        # For each t, find the corresponding segment index in t_grid
        # Example: t_grid = [0.0, 0.25, 0.5, 0.75, 1.0], t = 0.6 → returns 3
        # Even t = 0.5 returns 3 (right=True), then idx = 3 - 1 = 2

        idx = torch.searchsorted(self.t_grid, t, right=True) - 1  # Get the left segment index
        idx = idx.clamp(0, self.T - 1)  # Clamp idx to range [0, T - 1] to avoid overflow
        # If idx == T, then idx + 1 is out of bounds, so clamp to T - 1

        # Get start and end times of the segment
        t0, t1 = self.t_grid[idx], self.t_grid[idx + 1]
        # For example, t = 0.6 → segment is [t2, t3] = [0.5, 0.75]

        # Compute normalized position within the segment
        alpha = (t - t0) / (t1 - t0)

        # Get the corresponding control points for interpolation
        ct0, ct1 = control_points[idx], control_points[idx + 1]

        # Perform linear interpolation:
        # p(t) = (1 - alpha) * c0 + alpha * c1
        # Equivalent to:
        # If t ∈ [t_i, t_{i+1}]:
        #     p(t) = c_i + (c_{i+1} - c_i) * (t - t_i) / (t_{i+1} - t_i)
        #          = (1 - alpha) * c_i + alpha * c_{i+1}
        return (1 - alpha.unsqueeze(-1)) * ct0 + alpha.unsqueeze(-1) * ct1  # (..., D), broadcast over last dim


##6.2 compute_energy

$$
\mathcal{E}[\gamma] \approx \sum_{t=0}^{T-1} \mathbb{E}_{\theta, \theta' \sim q(\theta) q(\theta)}
\left[ \left\| f_{\theta} (\gamma(t + \frac{1}{T})) - f_{\theta'} (\gamma(t / T)) \right\|^2 \right]
$$

$f_{\theta}$ $f_{\theta'}$ denotes deoder ensemble members drawn uniformly

In [38]:
def compute_curve_energy(curve, decoders, T, fixed_indices=None, device='cuda'):
    """
    Compute the energy of a curve using fixed decoder indices (no Monte Carlo).

    Parameters:
    - curve: An instance of CubicCurve
    - decoders: List of decoder modules
    - T: Number of time steps
    - fixed_indices: Pre-fixed decoder indices [(idx1_t0, idx2_t0), (idx1_t1, idx2_t1), ...]
    - device: Computing device

    Returns:
    - Scalar energy value
    """
    total_energy = 0.0  # Accumulate energy over all time steps

    for i in range(T):
        t0 = torch.tensor([i / T], device=device, dtype=torch.float32)
        t1 = torch.tensor([(i + 1) / T], device=device, dtype=torch.float32)

        x0 = curve(t0)  # γ(t0), shape [1, D]
        x1 = curve(t1)  # γ(t1), shape [1, D]

        # Get fixed decoder indices for this segment
        idx1, idx2 = fixed_indices[i]
        #why i don't write this
        # idx1,idx2=random(./)
        # Compute decoder outputs at t0 and t1
        y0 = decoders[idx1](x0).mean
        y1 = decoders[idx2](x1).mean
        #You must implement an algorithm to compute geodesics under the pull-back metric
#associated with the mean of the Gaussian decoder.
        # Energy = L2 norm between decoded outputs
        energy = torch.norm(y1 - y0, p=2) **2

        total_energy += energy

    return total_energy

##6.3 optimize_geodesics

In [75]:
def optimize_geodesic(c0, c1, decoders, T, steps, lr, device,
                      early_stopping_n, early_stopping_delta):
    """
    Optimize a geodesic curve while ensuring the objective function remains unchanged during optimization.

    Parameters:
    - c0: Starting point
    - c1: Endpoint
    - decoders: List of decoder modules
    - T: Number of time steps
    - steps: Number of optimization iterations
    - lr: Learning rate
    - device: Computing device
    - early_stopping_n: Number of steps to check for early stopping
    - early_stopping_delta: Minimum required improvement to continue

    Returns:
    - Optimized curve
    - Logged energy values
    """
    curve = CubicCurve(c0, c1).to(device)
    optimizer = torch.optim.Adam(curve.parameters(), lr=lr)

    # 内部直接定义调度器（自动学习率衰减）
    # min represent we are minimizing the data
    #factor - after the I also added a `torch.optim.lr_scheduler.ReduceLROnPlateau`,
    #which reduces the learning rate by a factor of 0.5 if the change is less than 1e-3 within 100 steps,
    # if the change is less than 1e-3 within 300 steps, early stopping is triggered. Of course,
    # this is just based on experience.I set 2000steps and it’s very slow.
    #I  don’t think my method is good.
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.7,
        patience=50,
        threshold=1e-3,
        verbose=True,
        min_lr=1e-4
    )

    energy_log = []

    #generate inside the function but not inside the optimization loop
    fixed_indices = [(torch.randint(0, len(decoders), (1,), device=device).item(),
                      torch.randint(0, len(decoders), (1,), device=device).item())
                     for _ in range(T)]

    best_energy = float('inf')
    no_improve_count = 0

    with tqdm(range(steps)) as pbar:
        for step in pbar:
            optimizer.zero_grad()
            energy = compute_curve_energy(curve, decoders, T=T, fixed_indices=fixed_indices, device=device)
            energy.backward()
            optimizer.step()
            current_lr = optimizer.param_groups[0]['lr']
            # 衰减学习率（每步传当前 energy 给 scheduler）
            energy_value = energy.item()
            scheduler.step(energy_value)
            energy_log.append(energy_value)

            # Early Stopping（从 step >= 300 起判断）
            if step >= 500:
                if energy_value < best_energy - early_stopping_delta:
                    best_energy = energy_value
                    no_improve_count = 0
                else:
                    no_improve_count += 1

                if no_improve_count >= early_stopping_n:
                    print(f"Early stopping at step {step}, energy: {energy_value:.6f}, LR: {current_lr:.2e}")
                    break

            pbar.set_description(f"Energy: {energy_value:.6f}, LR: {current_lr:.2e}")

    return curve, energy_log


# 7 10random latent_varibale pairs


##7.1 data preparation

In [77]:
# ====== Extract all images from the test set ======
all_test_x = []
all_test_y = []
for x, y in mnist_test_loader:
    all_test_x.append(x)
    all_test_y.append(y)

all_test_x = torch.cat(all_test_x, dim=0)  # shape: [N, 1, 28, 28]
all_test_y = torch.cat(all_test_y, dim=0)  # shape: [N]

# ====== random seed3 :to ensure everyone  get the same test points 1 global 2 vae  ======
random.seed(42) #very important!

# ====== Sample 10 image pairs ==here is a way to split the work, one 2-3 pairs====
num_pairs = 25 #要求用10个测试点对
N = all_test_x.shape[0] #total number of test points
indices = random.sample(range(N), 2 * num_pairs) #from this take 20 pair images

# ====== Construct image pair list: [(x_i, x_j), ...] ======
test_image_pairs = [
    (all_test_x[indices[i]], all_test_x[indices[i + 1]])
    for i in range(0, 2 * num_pairs, 2)
]  # Each element shape: ([1, 28, 28], [1, 28, 28]) [0,1][2,3][4,5])
# let (yi,yj ) denote a fixed pair of test points (these should be the same across different models)

##7.2 geodesic_distances

x1 = x1.unsqueeze(0)  # Convert x1 to shape [1, C, H, W] as a batch input
# Models usually expect batched input, even for a single image

posterior = model.encoder(x1)  # Pass the image through the VAE encoder, returns a distribution object

z = posterior.base_dist.loc  # Get the mean vector from the underlying distribution (usually Normal(mean, std))

z = z.squeeze(0)  # Remove the batch dimension: [1, latent_dim] → [latent_dim]


下面的代码两个随机种子助教在群里明确回复了。第一个不能和pair_idx和有。第二个要和pair_idx有关（因为是在用3个decoder算那个vae下,不同测试点的能量的蒙特卡罗近似，用随机算的更准确）

```python
# This logic is quite tricky and must be understood carefully. GPT might directly give a wrong answer.
# Suppose the number of decoders is 2.
# For a given data pair, we need to evaluate it using different VAEs. The way we compute is to treat it as a 2-decoder VAE
# (even though it may have been trained with 3 decoders originally, we treat it as having only 2 decoders).
# We can, for instance, take the same index set (e.g., 1,2) for all VAEs. Since each VAE is different, fixing indices 1,2 or sampling them randomly makes no difference.
# The decoder indices 1,2 from VAE1 and 2,3 from VAE2 are equally random—just random selections.
# Therefore, the hash function and the M-VAE index can be either related or unrelated—it doesn’t matter. To be extra safe, they can be related.
# But can they be related to pair_idx?
# That is, can we, for example, assign the same VAE to all data pairs, and for the first pair use decoder indices 1,2, for the second pair 2,3, etc.?
# I believe this is not allowed. Because if we do this over 10 data pairs, it’s as if the VAE has 3 decoders again!
# What we actually want to evaluate is a VAE with only 2 decoders (fixed for the whole dataset).
# So, it absolutely must not depend on pair_idx.
# Regarding the outer loop: whether the number of decoders is 1 or 2, using the same random seed or not has no effect.
# Because decoder1 selects 1 index, and decoder2 selects 2 indices—different amounts of data.
# Moreover, even if there’s some relationship—e.g., decoder1 selects index 1, decoder2 selects indices 1,2, or decoder1 selects 2, and decoder2 selects 2,3—
# This doesn’t work either. If it happens to be the same VAE, then having two decoders is always more powerful than one.
# Since we are comparing the performance difference between 1 and 2 decoders, these selections must be independent.
```

Next, about random seed sampling. First of all, can it be placed inside the optimization loop? Absolutely not!  
If so, for the same VAE and different data points, the sampling during each optimization would be the same in the for i_T loop—making the objective function exactly the same.  
This weakens the effect of having 10 different data points. Moreover, the calculation is originally based on an expectation—those 10 points effectively perform 10 rounds of Monte Carlo sampling.  
So the hash function must depend on the data point. Each data point’s sampling must be independent.

Next, regarding the VAE: since the VAEs are independent, whether to fix the decoder choice per optimization or not doesn’t matter. To be conservative, we do fix it.

Finally, regarding the decoder count: same logic as before—it’s best to make them independent.  
If decoder count 1 and 2 are correlated, then decoder1 and decoder2 might differ not just in the number of decoders, but because decoder2 includes more powerful decoders.  
We want to isolate the effect of decoder count alone, so they must be sampled independently.


In [79]:
def compute_all_geodesic_distances(models_dict, test_image_pairs, max_decoder_num, device, T, steps, lr, num_vaes, early_stopping_n, early_stopping_delta):
    distances = []  # shape: [decoder_num][pair_idx][vae_idx]
    curves = []     # shape: 同上，保存每条曲线（tensor）

    for number_of_decoders in range(1, max_decoder_num + 1):
        pair_results_energy = []  # 当前 decoder_num 下所有 pair 的能量结果
        pair_results_curve = []   # 当前 decoder_num 下所有 pair 的曲线结果

        for pair_idx, (x1, x2) in enumerate(test_image_pairs):
            vae_energies = []  # 当前 pair，所有 vae 的能量
            vae_curves = []    # 当前 pair，所有 vae 的曲线

            for m in range(num_vaes):
                model_name = f"vae_d{max_decoder_num}_seed{1000 + m}"
                model = models_dict[model_name]
                model.eval()

                with torch.no_grad():
                    z1 = model.encoder(x1.unsqueeze(0).to(device)).base_dist.loc.squeeze(0)
                    z2 = model.encoder(x2.unsqueeze(0).to(device)).base_dist.loc.squeeze(0)

                result_dir = f"results_geodesic/decoders_{number_of_decoders}/pair_{pair_idx}/vae_{m}"
                os.makedirs(result_dir, exist_ok=True)

                curve_path = os.path.join(result_dir, "curve.pt")
                energy_log_path = os.path.join(result_dir, "energy_log.pt")

                if os.path.exists(curve_path) and os.path.exists(energy_log_path):
                    curve = torch.load(curve_path, map_location=device, weights_only=False)
                    energy_log = torch.load(energy_log_path, map_location=device, weights_only=False)
                    print(f"✅ Loaded: decoders={number_of_decoders}, pair={pair_idx}, vae={m}")
                else:
                    seed_dec = hash(("decoder_select", m, number_of_decoders)) % (2**32)
                    random.seed(seed_dec)
                    selected_decoders = random.sample(list(model.decoders), number_of_decoders)

                    seed_fixed = hash(("fixed_indices", m, pair_idx, number_of_decoders)) % (2**32)
                    torch.manual_seed(seed_fixed)

                    curve, energy_log = optimize_geodesic(
                        z1, z2, decoders=selected_decoders, T=T, steps=steps, lr=lr, device=device,
                        early_stopping_n=early_stopping_n, early_stopping_delta=early_stopping_delta
                    )

                    torch.save(curve, curve_path)
                    torch.save(energy_log, energy_log_path)
                    print(f"✅ Computed & saved: decoders={number_of_decoders}, pair={pair_idx}, vae={m}")

                energy = energy_log[-1]
                vae_energies.append(energy)
                vae_curves.append(curve)

            pair_results_energy.append(vae_energies)
            pair_results_curve.append(vae_curves)

        distances.append(pair_results_energy)
        curves.append(pair_results_curve)

    return distances, curves


In [81]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
T = 256 # this T is not the same  as the total segmentation
steps = 2000 #optimization step ,
lr = 5*1e-2 #learning rate
num_vaes = 10 #required 10
max_decoder_num = 3
early_stopping_n=100
early_stopping_delta=1e-4

geodesic_distances, geodesic_curves = compute_all_geodesic_distances(
    models_dict=models,
    test_image_pairs=test_image_pairs,
    max_decoder_num=max_decoder_num,
    device=device,
    T=T,
    steps=steps,
    lr=lr,
    num_vaes=num_vaes,
    early_stopping_n=early_stopping_n,
    early_stopping_delta=early_stopping_delta
)



✅ Loaded: decoders=1, pair=0, vae=0
✅ Loaded: decoders=1, pair=1, vae=0
✅ Loaded: decoders=1, pair=2, vae=0
✅ Loaded: decoders=1, pair=3, vae=0
✅ Loaded: decoders=1, pair=4, vae=0
✅ Loaded: decoders=1, pair=5, vae=0
✅ Loaded: decoders=1, pair=6, vae=0
✅ Loaded: decoders=1, pair=7, vae=0
✅ Loaded: decoders=1, pair=8, vae=0
✅ Loaded: decoders=1, pair=9, vae=0


Energy: 2.347769, LR: 1.66e-04:  77%|███████▋  | 1535/2000 [18:33<05:37,  1.38it/s]


Early stopping at step 1535, energy: 2.347769, LR: 1.66e-04
✅ Computed & saved: decoders=1, pair=10, vae=0


Energy: 5.373148, LR: 1.66e-04:  78%|███████▊  | 1565/2000 [18:33<05:09,  1.41it/s]


Early stopping at step 1565, energy: 5.373149, LR: 1.66e-04
✅ Computed & saved: decoders=1, pair=11, vae=0


Energy: 3.628220, LR: 1.16e-04:  84%|████████▍ | 1689/2000 [20:33<03:47,  1.37it/s]


Early stopping at step 1689, energy: 3.628218, LR: 1.16e-04
✅ Computed & saved: decoders=1, pair=12, vae=0


Energy: 0.973430, LR: 1.41e-03: 100%|██████████| 2000/2000 [23:03<00:00,  1.45it/s]


✅ Computed & saved: decoders=1, pair=13, vae=0


Energy: 0.045527, LR: 1.66e-04:  42%|████▏     | 839/2000 [06:49<09:27,  2.05it/s]


Early stopping at step 839, energy: 0.045526, LR: 1.66e-04
✅ Computed & saved: decoders=1, pair=14, vae=0


Energy: 1.981453, LR: 1.41e-03: 100%|██████████| 2000/2000 [17:21<00:00,  1.92it/s]


✅ Computed & saved: decoders=1, pair=15, vae=0


Energy: 1.111566, LR: 4.84e-04:  64%|██████▍   | 1276/2000 [10:49<06:08,  1.97it/s]


Early stopping at step 1276, energy: 1.111565, LR: 3.39e-04
✅ Computed & saved: decoders=1, pair=16, vae=0


Energy: 4.130901, LR: 1.16e-04:  82%|████████▏ | 1630/2000 [13:52<03:09,  1.96it/s]


Early stopping at step 1630, energy: 4.130900, LR: 1.16e-04
✅ Computed & saved: decoders=1, pair=17, vae=0


Energy: 0.490134, LR: 1.16e-04:  44%|████▍     | 880/2000 [07:15<09:14,  2.02it/s]


Early stopping at step 880, energy: 0.490133, LR: 1.16e-04
✅ Computed & saved: decoders=1, pair=18, vae=0


Energy: 1.938667, LR: 1.41e-03:  38%|███▊      | 754/2000 [06:15<10:20,  2.01it/s]


Early stopping at step 754, energy: 1.938665, LR: 1.41e-03
✅ Computed & saved: decoders=1, pair=19, vae=0


Energy: 10.438465, LR: 1.00e-04: 100%|██████████| 2000/2000 [16:47<00:00,  1.98it/s]


✅ Computed & saved: decoders=1, pair=20, vae=0


Energy: 0.227190, LR: 1.00e-04:  50%|████▉     | 994/2000 [08:29<08:35,  1.95it/s]


Early stopping at step 994, energy: 0.227189, LR: 1.00e-04
✅ Computed & saved: decoders=1, pair=21, vae=0


Energy: 6.337172, LR: 1.16e-04:  71%|███████▏  | 1427/2000 [12:27<05:00,  1.91it/s]


Early stopping at step 1427, energy: 6.337169, LR: 1.16e-04
✅ Computed & saved: decoders=1, pair=22, vae=0


Energy: 0.391794, LR: 3.39e-04:  56%|█████▌    | 1118/2000 [09:35<07:34,  1.94it/s]


Early stopping at step 1118, energy: 0.391794, LR: 3.39e-04
✅ Computed & saved: decoders=1, pair=23, vae=0


Energy: 0.102680, LR: 1.66e-04:  43%|████▎     | 863/2000 [07:20<09:40,  1.96it/s]

Early stopping at step 863, energy: 0.102679, LR: 1.66e-04
✅ Computed & saved: decoders=1, pair=24, vae=0
✅ Loaded: decoders=2, pair=0, vae=0
✅ Loaded: decoders=2, pair=1, vae=0
✅ Loaded: decoders=2, pair=2, vae=0
✅ Loaded: decoders=2, pair=3, vae=0
✅ Loaded: decoders=2, pair=4, vae=0
✅ Loaded: decoders=2, pair=5, vae=0
✅ Loaded: decoders=2, pair=6, vae=0
✅ Loaded: decoders=2, pair=7, vae=0
✅ Loaded: decoders=2, pair=8, vae=0
✅ Loaded: decoders=2, pair=9, vae=0
✅ Loaded: decoders=2, pair=10, vae=0
✅ Loaded: decoders=2, pair=11, vae=0
✅ Loaded: decoders=2, pair=12, vae=0
✅ Loaded: decoders=2, pair=13, vae=0
✅ Loaded: decoders=2, pair=14, vae=0
✅ Loaded: decoders=2, pair=15, vae=0
✅ Loaded: decoders=2, pair=16, vae=0
✅ Loaded: decoders=2, pair=17, vae=0
✅ Loaded: decoders=2, pair=18, vae=0
✅ Loaded: decoders=2, pair=19, vae=0
✅ Loaded: decoders=2, pair=20, vae=0
✅ Loaded: decoders=2, pair=21, vae=0
✅ Loaded: decoders=2, pair=22, vae=0
✅ Loaded: decoders=2, pair=23, vae=0
✅ Loaded: decode




In [None]:
import numpy as np
#加了这一大串，名字还是之前的名字，有报错的话随时和我说
def sqrt_nested_distances(distances):
    return [
        [  # decoder_idx 层
            [np.sqrt(d) for d in vae_dists]  # 对每个 vae 的能量开根号
            for vae_dists in pair_list       # 遍历所有 pair
        ]
        for pair_list in distances           # 遍历 decoder_num
    ]
geodesic_distances = sqrt_nested_distances(geodesic_distances)

In [90]:
#接下来是画图任务！用这两个vae画partA需要的十个点的图
#首先是把，把测试集的全部数据都编码到潜在空间
def encode_all_models_latents(models_dict, test_loader, device):
    """
    对所有模型将测试数据编码为潜变量空间坐标。

    参数:
    - models_dict: 字典，key 为模型名，value 为模型实例
    - test_loader: DataLoader，测试数据集
    - device: 设备 (cuda 或 cpu)

    返回:
    - latents_dict: {模型名: (zs, ys)}，zs 是所有测试样本的潜变量表示，ys 是对应标签
    """
    latents_dict = {}

    for model_name, model in models_dict.items():
        model.eval()
        zs, ys = [], []

        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                q = model.encoder(x)
                z = q.base_dist.loc
                zs.append(z.cpu())
                ys.append(y)

        zs_all = torch.cat(zs, dim=0)
        ys_all = torch.cat(ys, dim=0)
        latents_dict[model_name] = (zs_all, ys_all)
        print(f"✅ Encoded latent space for: {model_name}")

    return latents_dict

# 调用
latents_all_models = encode_all_models_latents(models, mnist_test_loader, device)


✅ Encoded latent space for: vae_d3_seed1000


In [92]:
def encode_test_image_pairs(models_dict, test_image_pairs, device):
    """
    将 test_image_pairs 编码为潜变量空间中的 (z1, z2) 对。

    参数：
    - models_dict: 所有 VAE 模型的字典
    - test_image_pairs: 图像对列表，形如 [(x1, x2), ...]，每个 x 形状是 [1, 28, 28]
    - device: cuda 或 cpu

    返回：
    - z_pairs_dict: {模型名: Tensor[num_pairs, 2, latent_dim]}
    """
    z_pairs_dict = {}

    for model_name, model in models_dict.items():
        model.eval()
        z_list = []

        with torch.no_grad():
            for x1, x2 in test_image_pairs:
                z1 = model.encoder(x1.unsqueeze(0).to(device)).base_dist.loc.squeeze(0)
                z2 = model.encoder(x2.unsqueeze(0).to(device)).base_dist.loc.squeeze(0)
                z_pair = torch.stack([z1, z2], dim=0)  # [2, latent_dim]
                z_list.append(z_pair.cpu())

        z_pairs = torch.stack(z_list, dim=0)  # [num_pairs, 2, latent_dim]
        z_pairs_dict[model_name] = z_pairs
        print(f"✅ Encoded z_pairs for {model_name}")

    return z_pairs_dict

# 调用：
z_pairs_all_models = encode_test_image_pairs(models, test_image_pairs, device)

✅ Encoded z_pairs for vae_d3_seed1000


In [94]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import ListedColormap
import matplotlib
import os

def plot_geodesics_per_vae(model_name, z_all, y_all, z_pairs, curves_list, title, out_path):
    num_pairs = z_pairs.shape[0]
    fig, ax = plt.subplots(figsize=(8, 6))

    base_cmap = matplotlib.colormaps["tab10"]
    tab10_colors = base_cmap(np.arange(10))
    background_colors = tab10_colors[7:10]
    light_colors = [((r + 1)/2, (g + 1)/2, (b + 1)/2, 0.3) for r, g, b, _ in background_colors]
    light_cmap = ListedColormap(light_colors)
    mapped_labels = (y_all % 3).numpy()

    ax.scatter(z_all[:, 0], z_all[:, 1], c=mapped_labels, cmap=light_cmap, s=10, alpha=0.6, zorder=1)

    t_vals = torch.linspace(0, 1, 256).to(z_all.device)
    x_all, y_all_ = [], []

    for i in range(num_pairs):
        z1, z2 = z_pairs[i, 0], z_pairs[i, 1]
        curve_obj = curves_list[i]

        if isinstance(curve_obj, torch.Tensor):
            gamma = curve_obj.cpu()
        else:
            gamma = curve_obj(t_vals).detach().cpu()

        ax.plot(gamma[:, 0], gamma[:, 1], linewidth=1.5, zorder=2)

        # 找到最近的潜变量点，获取标签类别
        idx1 = torch.argmin(torch.norm(z_all - z1, dim=1))
        idx2 = torch.argmin(torch.norm(z_all - z2, dim=1))
        label1_mod3 = int(y_all[idx1]) % 3
        label2_mod3 = int(y_all[idx2]) % 3
        color1 = light_cmap(label1_mod3)
        color2 = light_cmap(label2_mod3)

        ax.scatter(z1[0], z1[1], color=color1, s=30, zorder=3)
        ax.scatter(z2[0], z2[1], color=color2, s=30, zorder=3)

        ax.text(z1[0], z1[1], str(i + 1), fontsize=9, color='black', ha='center', va='center', zorder=4)
        ax.text(z2[0], z2[1], str(i + 1), fontsize=9, color='black', ha='center', va='center', zorder=4)

        x_all.extend(gamma[:, 0].tolist())
        y_all_.extend(gamma[:, 1].tolist())

    if x_all and y_all_:
        ax.set_xlim(min(x_all) - 0.5, max(x_all) + 0.5)
        ax.set_ylim(min(y_all_) - 0.5, max(y_all_) + 0.5)

    ax.set_title(title)
    ax.set_xlabel("z1")
    ax.set_ylabel("z2")
    plt.tight_layout()
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    plt.savefig(out_path, dpi=300)
    plt.close()


In [96]:
# decoder 数目索引
max_decoder_num=3
decoder_indices = [0, max_decoder_num - 1]  # decoder=1 和 decoder=max
decoder_nums = [1, max_decoder_num]

for decoder_idx, decoder_num in zip(decoder_indices, decoder_nums):
    for vae_idx in range(num_vaes):
        model_name = f"vae_d{max_decoder_num}_seed{1000 + vae_idx}"
        
        # 获取潜变量和标签
        z_all, y_all = latents_all_models[model_name]
        
        # 获取该模型的测试点对编码 (z1, z2)
        z_pairs = z_pairs_all_models[model_name]  # [num_pairs, 2, 2]
        
        # 获取测地线曲线
        curves_list = geodesic_curves[decoder_idx][
            :  # 所有测试点对
        ][vae_idx]  # 注意：内部结构是 [pair_idx][vae_idx]

        # 转置结构：geodesic_curves[decoder_idx][pair_idx][vae_idx] → [pair_idx]
        curves_per_pair = [geodesic_curves[decoder_idx][pair_idx][vae_idx] for pair_idx in range(len(z_pairs))]

        # 标题和输出路径
        title = f"Geodesics (VAE {vae_idx}, Decoder={decoder_num})"
        out_path = f"geodesic_plots/compare_d{decoder_num}_vae{vae_idx}.png"

        # 画图
        plot_geodesics_per_vae(
            model_name=model_name,
            z_all=z_all,
            y_all=y_all,
            z_pairs=z_pairs,
            curves_list=curves_per_pair,
            title=title,
            out_path=out_path
        )


In [None]:
import os
import shutil

if os.path.exists("results_geodesic"):
    shutil.rmtree("results_geodesic")


## 7.3 energy log plot

In [98]:
def plot_all_energy_logs(max_decoder_num, num_pairs, num_vaes, root='results_geodesic', out_root='energy_plots'):
    for n in range(1, max_decoder_num + 1):
        for pair_idx in range(num_pairs):
            for m in range(num_vaes):
                log_path = f"{root}/decoders_{n}/pair_{pair_idx}/vae_{m}/energy_log.pt"
                if not os.path.exists(log_path):
                    print(f"❌ Missing: {log_path}")
                    continue

                # 加载 energy_log
                energy_log = torch.load(log_path)

                # 准备输出路径
                save_dir = f"{out_root}/decoders_{n}/pair_{pair_idx}/vae_{m}"
                os.makedirs(save_dir, exist_ok=True)
                save_path = os.path.join(save_dir, "energy.png")

                # 绘图
                plt.figure()
                plt.plot(energy_log)
                plt.xlabel("Optimization Step")
                plt.ylabel("Energy")
                plt.title(f"Decoders={n}, Pair={pair_idx}, VAE={m}")
                plt.grid(True)
                plt.tight_layout()
                plt.savefig(save_path)
                plt.close()

                print(f"✅ Saved: {save_path}")


In [104]:
plot_all_energy_logs(max_decoder_num=3, num_pairs=25, num_vaes=1)

✅ Saved: energy_plots/decoders_1/pair_0/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_1/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_2/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_3/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_4/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_5/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_6/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_7/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_8/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_9/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_10/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_11/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_12/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_13/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_14/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_15/vae_0/energy.png
✅ Saved: energy_plots/decoders_1/pair_16/vae_0/energy.png
✅ Saved: energy_plots/de

In [None]:
import os
import shutil

if os.path.exists("results_geodesic"):
    shutil.rmtree("results_geodesic")



##7.4 geodesics plot not relevant to this task,alough helpful to task1


In [None]:
#7.3 energy log plot

# 8.1 Euclidean Distances

In [None]:
def compute_all_euclidean_distances(models_dict, test_image_pairs, max_decoder_num=3, num_vaes=10):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    distances = []  # [decoder_num - 1][pair_idx][vae_idx]

    for number_of_decoders in range(1, max_decoder_num + 1):
        pair_results = []

        for pair_idx, (x1, x2) in enumerate(test_image_pairs):
            vae_results = []

            for m in range(num_vaes):
                model_name = f"vae_d{max_decoder_num}_seed{1000 + m}"
                model = models_dict[model_name]
                model.eval()

                result_dir = f"results_euclidean/decoders_{number_of_decoders}/pair_{pair_idx}/vae_{m}"
                os.makedirs(result_dir, exist_ok=True)
                dist_path = os.path.join(result_dir, "euclidean.pt")

                if os.path.exists(dist_path):
                    euclidean = torch.load(dist_path,map_location=device,weights_only=False)
                    print(f"✅ Loaded: decoders={number_of_decoders}, pair={pair_idx}, vae={m}")
                else:
                    with torch.no_grad():
                        z1 = model.encoder(x1.unsqueeze(0).to(device)).base_dist.loc.squeeze(0)
                        z2 = model.encoder(x2.unsqueeze(0).to(device)).base_dist.loc.squeeze(0)

                    euclidean = torch.norm(z1 - z2, p=2).item()#这里是取了平方根的
                    torch.save(euclidean, dist_path)
                    print(f"✅ Computed & saved: decoders={number_of_decoders}, pair={pair_idx}, vae={m}")

                vae_results.append(euclidean)

            pair_results.append(vae_results)

        distances.append(pair_results)

    return distances


In [None]:
euclidean_distances = compute_all_euclidean_distances(
    models_dict=models,
    test_image_pairs=test_image_pairs,
    max_decoder_num=3,
    num_vaes=10
)
#test_image_pairs = [
    #(all_test_x[indices[i]], all_test_x[indices[i + 1]])
    #for i in range(0, 2 * num_pairs, 2)
#]  还是那10个测试点对，

✅ Computed & saved: decoders=1, pair=0, vae=0
✅ Computed & saved: decoders=1, pair=0, vae=1
✅ Computed & saved: decoders=1, pair=0, vae=2
✅ Computed & saved: decoders=1, pair=1, vae=0
✅ Computed & saved: decoders=1, pair=1, vae=1
✅ Computed & saved: decoders=1, pair=1, vae=2
✅ Computed & saved: decoders=1, pair=2, vae=0
✅ Computed & saved: decoders=1, pair=2, vae=1
✅ Computed & saved: decoders=1, pair=2, vae=2
✅ Computed & saved: decoders=2, pair=0, vae=0
✅ Computed & saved: decoders=2, pair=0, vae=1
✅ Computed & saved: decoders=2, pair=0, vae=2
✅ Computed & saved: decoders=2, pair=1, vae=0
✅ Computed & saved: decoders=2, pair=1, vae=1
✅ Computed & saved: decoders=2, pair=1, vae=2
✅ Computed & saved: decoders=2, pair=2, vae=0
✅ Computed & saved: decoders=2, pair=2, vae=1
✅ Computed & saved: decoders=2, pair=2, vae=2
✅ Computed & saved: decoders=3, pair=0, vae=0
✅ Computed & saved: decoders=3, pair=0, vae=1
✅ Computed & saved: decoders=3, pair=0, vae=2
✅ Computed & saved: decoders=3, pa

In [None]:
euclidean_distances

[[[4.804272651672363, 4.759725093841553, 4.026854038238525],
  [0.97257000207901, 0.9008387923240662, 0.7701483368873596],
  [5.820830345153809, 4.815484523773193, 5.016952991485596]],
 [[4.804272651672363, 4.759725093841553, 4.026854038238525],
  [0.97257000207901, 0.9008387923240662, 0.7701483368873596],
  [5.820830345153809, 4.815484523773193, 5.016952991485596]],
 [[4.804272651672363, 4.759725093841553, 4.026854038238525],
  [0.97257000207901, 0.9008387923240662, 0.7701483368873596],
  [5.820830345153809, 4.815484523773193, 5.016952991485596]]]

In [None]:
import os
import shutil

if os.path.exists("results_euclidean"):
    shutil.rmtree("results_euclidean")



# 8.2  compute Average Cov~num of decoders (both the geodesics and Euclidean) # i forget to take sqrt

In [None]:
def compute_avg_covs_across_pairs(distances):
    avg_covs = []
    decoder_indices = []
    # distances == [decoder_num - 1][pair_idx][vae_idx]

    for decoder_idx, all_pairs in enumerate(distances):  # Outer loop over decoder count
        covs = []  # covs[1]: cov_ij for the first test point, covs[2]: for the second, etc.

        for vae_dists in all_pairs:  # [pair_idx][vae_idx], all_pairs[1] = first test point
            d = np.array(vae_dists)
            mean = np.mean(d)  # Mean across 10 VAEs
            std = np.std(d)    # Std across 10 VAEs

            if mean > 0:
                cov = std / mean  # Compute cov_ij for the current test point
                covs.append(cov)  # Append current test point's cov_ij to covs list

        avg_cov = np.mean(covs)  # Average cov_ij over 10 test points
        avg_covs.append(avg_cov)  # Append to avg_covs list for the current decoder count
        decoder_indices.append(decoder_idx + 1)  # Decoder count starts from 1

    return avg_covs, decoder_indices


In [None]:
geodesic_avg_covs, decoder_counts = compute_avg_covs_across_pairs(geodesic_distances)
euclidean_avg_covs, _ = compute_avg_covs_across_pairs(euclidean_distances)

In [None]:
def plot_and_save_avg_cov(geodesic_avg_covs, euclidean_avg_covs, decoder_counts, save_path="final_plots/avg_cov_vs_decoders.png"):
    plt.figure(figsize=(6, 4))
    plt.plot(decoder_counts, geodesic_avg_covs, marker='o', label='Geodesic distance')
    plt.plot(decoder_counts, euclidean_avg_covs, marker='o', label='Euclidean distance')
    plt.xlabel('Number of ensemble decoders')
    plt.ylabel('Average Coefficient of Variation (CoV)')
    plt.title('Average CoV vs. Number of Ensemble Decoders')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    plt.savefig(save_path)
    plt.show()


In [None]:
plot_and_save_avg_cov(geodesic_avg_covs, euclidean_avg_covs, decoder_counts)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>