In [None]:
import torch
import torch.functional as F

In [None]:
def train(vae, optimizer, x, true_mu_tensor, true_theta_tensor, true_pi_tensor):
    """
    Train the VAE model.
    
    Parameters:
        vae (VAE): The VAE model to train.
        optimizer (torch.optim.Optimizer): The optimizer for training.
        x (torch.Tensor): The input data.
        true_mu_tensor (torch.Tensor): The true mean tensor.
        true_theta_tensor (torch.Tensor): The true theta tensor.
        true_pi_tensor (torch.Tensor): The true pi tensor.
    """
    
    vae.train()
    for epoch in range(1000):
        optimizer.zero_grad()
        mean, disp, pi, mu, logvar, z = vae(x)
        loss = vae.loss_function(x, mean, disp, pi, mu, logvar)
        loss.backward()
        optimizer.step()
        #print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
        
        with torch.no_grad():
            mean_mse = F.mse_loss(mean, true_mu_tensor).item()
            theta_mse = F.mse_loss(disp, true_theta_tensor).item()
            pi_mse = F.mse_loss(pi, true_pi_tensor).item()

            print(f"Epoch {epoch}, Loss: {loss.item():.4f}, Mean MSE: {mean_mse:.4f}, Theta MSE: {theta_mse:.4f}, Pi MSE: {pi_mse:.4f}")

    return mean, disp, pi, mu, logvar, z

In [None]:
def train(vae, optimizer, x, true_mu_tensor, true_theta_tensor, true_pi_tensor):
    """
    Train the VAE model.
    
    Parameters:
        vae (VAE): The VAE model to train.
        optimizer (torch.optim.Optimizer): The optimizer for training.
        x (torch.Tensor): The input data.
        true_mu_tensor (torch.Tensor): The true mean tensor.
        true_theta_tensor (torch.Tensor): The true theta tensor.
        true_pi_tensor (torch.Tensor): The true pi tensor.
    """
    
    vae.train()
    for epoch in range(1000):
        optimizer.zero_grad()
        mean, disp, pi, mu, logvar, z = vae(x)
        loss = vae.loss_function(x, mean, disp, pi, mu, logvar)
        loss.backward()
        optimizer.step()
        #print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
        
        with torch.no_grad():
            mean_mse = F.mse_loss(mean, true_mu_tensor).item()
            theta_mse = F.mse_loss(disp, true_theta_tensor).item()
            pi_mse = F.mse_loss(pi, true_pi_tensor).item()

            print(f"Epoch {epoch}, Loss: {loss.item():.4f}, Mean MSE: {mean_mse:.4f}, Theta MSE: {theta_mse:.4f}, Pi MSE: {pi_mse:.4f}")

    return mean, disp, pi, mu, logvar, z