In [21]:
import numpy as np
import torch
from env import Env, RandomEnv
from reward import random_reward
from _types import Reward
from utils import timed


In [53]:
@timed
def minimal_canon(reward: Reward, env: Env, norm_ord: int|float) -> Reward:
  r = torch.tensor(reward)
  # potential = torch.tensor(reward.mean(axis=(1, 2)), requires_grad=True)
  potential = torch.zeros(env.n_s, requires_grad=True)
  optimizer = torch.optim.Adam([potential])
  for i in range(20000):
    optimizer.zero_grad()
    r_prime = r + env.discount * potential[None, None, :] - potential[:, None, None]
    loss = torch.norm(r_prime, norm_ord)
    loss.backward()
    optimizer.step()
    if torch.norm(potential.grad, 2) < 1e-4: break
    if i == 0: print(f"Initial {norm_ord=} loss={loss.item()} grad={torch.norm(potential.grad, 2)}")
    if i == 10000-1: print(f"Didn't converge {norm_ord=} loss={loss.item()} grad={torch.norm(potential.grad, 2)}")
  return r_prime.detach().numpy()


In [54]:
for _ in range(10):
  e = RandomEnv(n_s=32, n_a=8)
  r = random_reward(e)
  for n_ord in [1, 2, float('inf')]:
    print(f'{n_ord=}')
    canon1 = minimal_canon(r, e, n_ord)
    canon2 = minimal_canon(r, e, n_ord)
    assert np.isclose(canon1, canon2).all()

n_ord=1
Initial norm_ord=1 loss=63939.03575974314 grad=103.81232452392578
Didn't converge norm_ord=1 loss=63769.264895336026 grad=8.109253883361816
minimal_canon took 2.8044s
Initial norm_ord=1 loss=63939.03575974314 grad=103.81232452392578
Didn't converge norm_ord=1 loss=63769.264895336026 grad=8.109253883361816
minimal_canon took 2.6185s
n_ord=2
Initial norm_ord=2 loss=885.7309486838012 grad=1.44558584690094
Didn't converge norm_ord=2 loss=883.5622704776193 grad=0.00011598818673519418
minimal_canon took 1.4641s
Initial norm_ord=2 loss=885.7309486838012 grad=1.44558584690094
Didn't converge norm_ord=2 loss=883.5622704776193 grad=0.00011598818673519418
minimal_canon took 1.3243s
n_ord=inf
Initial norm_ord=inf loss=38.7035637769601 grad=1.3453624248504639
Didn't converge norm_ord=inf loss=31.83190739804587 grad=1.3453624248504639
minimal_canon took 3.5038s
Initial norm_ord=inf loss=38.7035637769601 grad=1.3453624248504639
Didn't converge norm_ord=inf loss=31.83190739804587 grad=1.345362

KeyboardInterrupt: 