In [1]:
import torch

def get_mu_std_grad(mu, std, data):
    """ Compute the measure-valued gradients wrt mu and std."""
    with torch.no_grad():
        z1 = mu + torch.randn_like(std).to(device) * std
        z2 = mu + torch.randn_like(std).to(device) * std
        mu_grad = torch.zeros_like(mu).to(device)
        std_grad = torch.zeros_like(std).to(device)

        for i in range(args.z_dim):
            mu_i = mu[:, i]
            std_i = std[:, i]

            # compute mu_grad
            w1 = sample_weibull(std_i.shape).to(device)
            w2 = sample_weibull(std_i.shape).to(device)

            z1_copy, z2_copy = z1.clone(), z2.clone()
            z1_copy[:, i] = mu_i + std_i * w1
            z2_copy[:, i] = mu_i - std_i * w2

            f1 = f_integrand(z1_copy, data)
            f2 = f_integrand(z2_copy, data)

            mu_grad[:, i] = (f1 - f2) / math.sqrt(2.0 * math.pi) / std_i

            # compute std_grad
            dsm_sample = sample_doublesided_maxwell(std_i.shape).to(device)
            unif_sample = torch.rand_like(std_i).to(device)
            normal_sample = dsm_sample * unif_sample

            z1_copy, z2_copy = z1.clone(), z2.clone()
            z1_copy[:, i] = mu_i + std_i * dsm_sample
            z2_copy[:, i] = mu_i + std_i * normal_sample

            f1 = f_integrand(z1_copy, data)
            f2 = f_integrand(z2_copy, data)

            std_grad[:, i] = (f1 - f2) / std_i

    return mu_grad, std_grad
