In [1]:
import plotly.graph_objects as go
import torch
from tqdm.notebook import trange

In [2]:
loc = torch.tensor(1.0, dtype=torch.float, requires_grad=True)
scale_param = torch.tensor(0.5, dtype=torch.float, requires_grad=True)

def make_distribution():
    return torch.distributions.Normal(loc=loc, scale=torch.clip(scale_param, min=0.001))

def make_sample():
    return make_distribution().sample(sample_shape=())

make_sample()

tensor(2.0084)

In [3]:
def make_reward(sample):
    return -(sample - 0.5) ** 2

make_reward(torch.tensor(1.0, dtype=torch.float))

tensor(-0.2500)

In [4]:
def make_loss(sample):
    return -make_reward(sample) * make_distribution().log_prob(sample)

make_loss(torch.tensor(1.0, dtype=torch.float))

tensor(-0.0564, grad_fn=<MulBackward0>)

In [5]:
optimizer = torch.optim.Adam([loc, scale_param], lr=1e-2)
for step_id in trange(1024):
    sample = make_sample()
    optimizer.zero_grad()
    make_loss(sample).backward()
    optimizer.step()

  0%|          | 0/1024 [00:00<?, ?it/s]

In [6]:
loc, scale_param

(tensor(0.4997, requires_grad=True), tensor(0.0009, requires_grad=True))

In [7]:
go.Figure(
    data=[
        go.Histogram(
            x=make_distribution().sample([4096])
        )
    ]
)