In [215]:
import numpy as np
import torch
from env import Env, RandomEnv
from reward import random_reward
from _types import Reward
from utils import timed

In [228]:
@timed
def minimal_canonx(reward: Reward, env: Env, norm_ord: int) -> Reward:
    r = torch.tensor(reward)
    potential = torch.zeros(env.n_s, requires_grad=True)
    optimizer = optimizer = torch.optim.AdamW([potential], lr=1e-3)

    def smooth_inf_norm(x, alpha):
        return (1 / alpha) * torch.logsumexp(alpha * x, dim=tuple(range(x.ndim)))

    def closure():
        optimizer.zero_grad()
        r_prime = r + env.discount * potential[None, None, :] - potential[:, None, None]
        alpha = 1.0
        loss = smooth_inf_norm(r_prime, alpha)
        loss.backward()
        return loss, r_prime

    for i in range(200000):
        optimizer.step()

        loss, r_prime = closure()
        loss = loss.item()
        grad_norm = torch.norm(potential.grad, 2)

        if grad_norm < 1e-4: break
        if i %1000== 0: print(f"Initial {norm_ord=} loss={loss} grad={grad_norm}")
        if i == 200000-1: print(f"Didn't converge {norm_ord=} loss={loss} grad={grad_norm}")

    return r_prime.detach().numpy()


In [220]:
@timed
def minimal_canon(reward: Reward, env: Env, norm_ord: int) -> Reward:
  r = torch.tensor(reward)
  # potential = torch.tensor(reward.mean(axis=(1, 2)), requires_grad=True)
  potential = torch.zeros(env.n_s, requires_grad=True)
  optimizer = torch.optim.Adam([potential], lr = 1e-2)
  for i in range(200000):
    optimizer.zero_grad()
    r_prime = r + env.discount * potential[None, None, :] - potential[:, None, None]
    loss = torch.norm(r_prime, norm_ord)
    loss.backward()
    optimizer.step()
    if torch.norm(potential.grad, 2) < 1e-4: break
    if i == 0: print(f"Initial {norm_ord=} loss={loss.item()} grad={torch.norm(potential.grad, 2)}")
    if i %1000== 0: print(f"Running {norm_ord=} loss={loss} grad={torch.norm(potential.grad, 2)}")
    if i == 200000-1: print(f"Didn't converge {norm_ord=} loss={loss.item()} grad={torch.norm(potential.grad, 2)}")
  return r_prime.detach().numpy()

In [231]:
for _ in range(1):
  e = RandomEnv(n_s=32, n_a=8)
  r = random_reward(e)
  for n_ord in [np.inf]:
    print(f'{n_ord=}')
    canon1, canon2 = minimal_canonx(r, e, n_ord)

n_ord=inf
Initial norm_ord=inf loss=41.5733948707029 grad=0.9245467782020569
Initial norm_ord=inf loss=39.68673829583661 grad=0.9219754338264465
Initial norm_ord=inf loss=37.82967453155023 grad=0.9138655066490173
Initial norm_ord=inf loss=36.04899127575182 grad=0.8757125735282898
Initial norm_ord=inf loss=34.4831486998651 grad=0.75460284948349
Initial norm_ord=inf loss=33.4046149900243 grad=0.5120911002159119
Initial norm_ord=inf loss=32.90994761315389 grad=0.2743326723575592
Initial norm_ord=inf loss=32.724756265647265 grad=0.14148031175136566
Initial norm_ord=inf loss=32.644359168453946 grad=0.08058717846870422
Initial norm_ord=inf loss=32.59584921618627 grad=0.05616192892193794
Initial norm_ord=inf loss=32.551106444600514 grad=0.04879264906048775
Initial norm_ord=inf loss=32.49586263023192 grad=0.047331809997558594
Initial norm_ord=inf loss=32.42518276395282 grad=0.047433044761419296
Initial norm_ord=inf loss=32.3425120840166 grad=0.047866810113191605
Initial norm_ord=inf loss=32.25

KeyboardInterrupt: 