In [1]:
import plotly.graph_objects as go
import torch
from tqdm.notebook import trange

In [2]:
class KindaLogNormal(torch.nn.Module):
    def __init__(self, uncertainty: float):
        super().__init__()
        self.uncertainty = uncertainty
        self.loc = torch.nn.Parameter(torch.tensor(0, dtype=torch.float))
        self.scale = torch.nn.Parameter(torch.tensor(1, dtype=torch.float))
    
    def torch_distribution(self):
        return torch.distributions.Normal(self.loc, self.uncertainty + self.scale.abs())
    
    def sample(self, shape=()):
        return self.torch_distribution().sample(shape).exp()

distribution = KindaLogNormal(uncertainty=0.001)
distribution.sample()

tensor(2.1205)

In [3]:
def make_reward(sample):
    return -(sample - 0.5) ** 2

make_reward(torch.tensor(1.0, dtype=torch.float))

tensor(-0.2500)

In [4]:
def make_loss(sample: torch.Tensor):
    return -make_reward(sample) * distribution.torch_distribution().log_prob(sample.log())

make_loss(torch.tensor(1.0, dtype=torch.float))

tensor(-0.2300, grad_fn=<MulBackward0>)

In [5]:
optimizer = torch.optim.Adam(distribution.parameters(), lr=5e-2)
for step_id in trange(1024):
    sample = distribution.sample()
    optimizer.zero_grad()
    make_loss(sample).backward()
    optimizer.step()

  0%|          | 0/1024 [00:00<?, ?it/s]

In [6]:
distribution

KindaLogNormal()

In [7]:
go.Figure(
    data=[
        go.Histogram(
            x=distribution.sample([4096])
        )
    ]
)