# Entropy optimization using PyTorch

First, let's import some things from PyTorch that we'll need to perform this optimization. We'll use PyTorch's optimization framework to create a distribution with our desired entropy.

- We'll rely on some functions from the root `torch` module.
- We'll use the `nn` module to create an `nn.Param` to store the distribution itself.
- Finally, for compatibility with a Python language server, I find it's nice to import `Tensor` so that I can use it in type hints.

In [None]:
import torch
from torch import nn
from torch import Tensor

Let's define a custom loss function that we can use to bring our distribution's entropy closer to our target entropy. A simple way to do this is (with gradients enabled):

- Normalize the distribution to sum to 1.
- Calculate the entropy of the distribution.
- Use a distance metric such as Mean Squared Error to compare the distribution's entropy to the target entropy.

Despite its simplicity, this works shockingly well!

In [None]:
class MSEAgainstEntropyLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def forward(self, dist: Tensor, true_entropy: Tensor) -> Tensor:
        
        # normalize the distribution
        dist = torch.softmax(dist, dim=0)
        
        # calculate the entropy
        approx_entropy = -(dist * dist.log()).sum()
        
        # calculate the mean squared error
        mse = (approx_entropy - true_entropy) ** 2
        
        # return the loss tensor
        return mse

Now, we'll define a function that returns a distribution with the desired entropy (or as close as we can get). It will take the following arguments:

- `criterion (nn.Module)`: The loss function to use to optimize the entropy (in this case, it will just be the class we defined above).
- `support_size (int)`: The number of outcomes we want our random variable to have.
- `desired_entropy (float)`: The entropy we want our distribution to have.
- `lr (float)`: The learning rate for the optimization algorithm.
- `tol (float)`: The tolerance for the optimization algorithm. We will stop when the loss is less than this value.
- `max_iter (int)`: The maximum number of iterations to run the optimization algorithm.
- `do_logging (bool)`: Whether to log the loss during optimization.
- `log_freq (int)`: How often to log the loss during optimization.

In [None]:
def get_dist(
    criterion: nn.Module,
    support_size: int,
    desired_entropy: float,
    lr: float = 0.001,
    tol: float = 1e-6,
    max_iter: int = 100_000,
    do_logging: bool = True,
    log_freq: int = 200
) -> Tensor:
    
    # define a parameter (gradient updates possible) with the right support size
    dist = nn.Parameter(torch.randn((support_size,), dtype=torch.float64))
    dist.requires_grad = True
    
    # make a torch.Tensor with the desired entropy to compute loss
    DE = torch.tensor(desired_entropy, dtype=torch.float64)
    DE.requires_grad = False
    
    # define an optimizer over the parameter
    optimizer = torch.optim.AdamW([dist], lr=lr)
    
    i = 0
    if do_logging:
        print('-----------------------------------------------------')
    while True:
        
        # optimize the parameter
        optimizer.zero_grad()
        loss = criterion(dist, DE)
        loss.backward()
        optimizer.step()
        
        # log the loss
        if (i % log_freq == 0):
            loss_val = loss.item()
            if do_logging:
                print(f'loss: {loss_val:.4}')
            if loss_val < tol: # we are done if the loss is small enough
                break
        i += 1 # count iterations
        
        # give up if max_iter is reached
        if i > max_iter:
            msg = 'Optimization did not converge!'
            Warning(msg)
            break
    
    # renormalize
    final_dist = torch.softmax(dist, dim=0)
    
    # summary of results
    if do_logging:
        print('-----------------------------------------------------')
        print(f'sum of probabilities (should be 1): {final_dist.sum()}')
        approx_entropy = -(final_dist * final_dist.log()).sum()
        print('-----------------------------------------------------')
        print(f'desired entropy:    {desired_entropy}')
        print(f'true entropy:       {approx_entropy.item()}')
        print('-----------------------------------------------------')
    
    return final_dist