In [1]:
import numpy as np
import torch
from env import Env, RandomEnv
from reward import random_reward
from _types import Reward
from utils import timed


In [9]:
@timed
def minimal_canon(reward: Reward, env: Env, norm_ord: int|float, steps=100000) -> Reward:
  r = torch.tensor(reward)
  potential = torch.tensor(reward.mean(axis=(1, 2)), requires_grad=True)
  # potential = torch.zeros(env.n_s, requires_grad=True)
  frozen_potential = torch.clone(potential)  

  optimizer = torch.optim.Adam([potential], lr=1e-5)
  for i in range(steps):
    optimizer.zero_grad()
    r_prime = r + env.discount * potential[None, None, :] - potential[:, None, None]
    loss = torch.norm(r_prime, norm_ord)
    loss.backward()
    optimizer.step()
    # convergence = small gradient or potential hasn't changed in a while
    if torch.norm(potential.grad, 2) < 1e-4:
      print(f"Converged by low grad at {i=}")
      break
    if i%10000 == 0 and i != 0:
      if torch.isclose(potential, frozen_potential, rtol=1e-3, atol=1e-3).all():
        print(f"Converged by no movement {i=}")
        break
      else: frozen_potential = torch.clone(potential)

    if i == 0: print(f"Initial {norm_ord=} loss={loss.item()} grad={torch.norm(potential.grad, 2)}")
    if i == steps-1: print(f"Didn't converge {norm_ord=} loss={loss.item()} grad={torch.norm(potential.grad, 2)}")
    if i%10000 == 0: print(f'{i=} grad norm={torch.norm(potential.grad, 2)}')
    # if i%10000 == 0: print(f'{potential=}\n\n{potential.grad=}\n\ngrad norm={torch.norm(potential.grad, 2)}\n\n\n\n')
  return r_prime.detach().numpy()


In [12]:
e = RandomEnv(n_s=32, n_a=8)
r = random_reward(e)
_ = minimal_canon(r, e, 1, steps=1000000)

Initial norm_ord=1 loss=6642.052112415129 grad=120.64990675504062
i=0 grad norm=120.64990675504062
i=10000 grad norm=7.211102550927978
Converged by no movement i=20000
minimal_canon took 3.3108s


In [22]:
for _ in range(10):
  e = RandomEnv(n_s=32, n_a=8)
  r = random_reward(e)
  for n_ord in [1, 2, float('inf')]:
    print(f'{n_ord=}')
    canon1 = minimal_canon(r, e, n_ord)
    canon2 = minimal_canon(r, e, n_ord)
    assert np.isclose(canon1, canon2).all()

n_ord=1
Initial norm_ord=1 loss=74605.61017625945 grad=144.8155059814453


KeyboardInterrupt: 