# Entropy optimization using PyTorch

First, let's import some things from PyTorch that we'll need to perform this optimization. We'll use PyTorch's optimization framework to create a distribution with our desired entropy.

- We'll rely on some functions from the root `torch` module.
- We'll use the `nn` module to create an `nn.Param` to store the distribution itself.
- Finally, for compatibility with a Python language server, I find it's nice to import `Tensor` so that I can use it in type hints.

In [None]:
import torch
from torch import nn
from torch import Tensor

Let's define a custom loss function that we can use to bring our distribution's entropy closer to our target entropy. A simple way to do this is (with gradients enabled):

- Normalize the distribution to sum to 1.
- Calculate the entropy of the distribution.
- Use a distance metric such as Mean Squared Error to compare the distribution's entropy to the target entropy.

Despite its simplicity, this works shockingly well!

In [None]:
class MSEAgainstEntropyLoss(nn.Module):
    
    def __init__(self):
        super().__init__()
        
    def forward(self, dist: Tensor, true_entropy: Tensor) -> Tensor:
        
        dist = torch.softmax(dist, dim=0)
        
        approx_entropy = -(dist * dist.log()).sum()
        
        mse = (approx_entropy - true_entropy) ** 2
        
        return mse

Now, we'll define a function that returns a distribution with the desired entropy.

In [None]:
def get_dist(
    criterion: nn.Module,
    vocab_size: int,
    desired_entropy: float,
    do_logging: bool = True,
    tol: float = 1e-6,
    lr: float = 0.001,
    log_freq: int = 200,
    max_iter: int = 100_000
):
    
    dist = nn.Parameter(torch.randn((vocab_size,), dtype=torch.float64))
    dist.requires_grad = True
    DE = torch.tensor(desired_entropy, dtype=torch.float64)
    DE.requires_grad = False
    
    optimizer = torch.optim.AdamW([dist], lr=lr)
    
    i = 0
    if do_logging:
        print('-----------------------------------------------------')
    while True:
        
        optimizer.zero_grad()
        loss = criterion(dist, DE)
        loss.backward()
        optimizer.step()
        
        if (i % log_freq == 0):
            with torch.no_grad():
                loss_val = loss.item()
                if do_logging:
                    print(f'loss: {loss_val:.4}')
                if loss_val < tol:
                    break

        i += 1
        
        if i > max_iter:
            msg = 'Optimization did not converge!'
            Warning(msg)
            break
    
    final_dist = torch.softmax(dist, dim=0)
    
    if do_logging:
        print('-----------------------------------------------------')
        print(f'sum of probabilities (should be 1): {final_dist.sum()}')
        approx_entropy = -(final_dist * final_dist.log()).sum()
        print('-----------------------------------------------------')
        print(f'desired entropy:    {desired_entropy}')
        print(f'true entropy:       {approx_entropy.item()}')
        print('-----------------------------------------------------')
    return final_dist