In [6]:
import torch
import torch.distributions as dists

d = dists.Normal(torch.zeros((3, 10)), torch.ones(3, 10))

def f(x, d):
    return (-.5 * ((x - d.mean) / d.stddev) ** 2).sum(-1)

f(torch.randn((3, 10)), d)

tensor([ -2.6603, -10.1530,  -3.0829])

In [1]:
import sys
from pathlib import Path
p = Path('.').absolute().parent
if sys.path[-1] != str(p):
    sys.path.append(str(p))

from graph_irl.buffer_v2 import GraphBuffer
from graph_irl.policy import GaussPolicy, TwoStageGaussPolicy, GCN, Qfunc
from graph_irl.graph_rl_utils import GraphEnv
from graph_irl.sac import SACAgentGraph, TEST_OUTPUTS_PATH
from graph_irl.reward import GraphReward
from graph_irl.examples.circle_graph import create_circle_graph

import numpy as np
import torch
torch.manual_seed(0)
np.random.seed(0)


# circular graph with 7 nodes;
n_nodes, node_dim = 11, 5
nodes, edge_index = create_circle_graph(n_nodes, node_dim, torch.ones)
encoder_hiddens = [7, 7, 2]

encoder = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True)
encoderq1 = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True)
encoderq2 = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True)
encoderq1t = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True)
encoderq2t = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True)
encoder_reward = GCN(node_dim, encoder_hiddens, with_layer_norm=True, final_tanh=True)
reward_fn = GraphReward(encoder_reward, encoder_hiddens[-1], [5, 5], with_layer_norm=True)

gauss_policy_kwargs = dict(
    obs_dim=encoder_hiddens[-1],
    action_dim=encoder_hiddens[-1],
    hiddens=[5, 5],
    with_layer_norm=True,
    encoder=encoder,
    two_action_vectors=True,
)

tsg_policy_kwargs = dict(
    obs_dim=encoder_hiddens[-1],
    action_dim=encoder_hiddens[-1],
    hiddens1=[5, 5],
    hiddens2=[7, 7],
    encoder=encoder,
    with_layer_norm=True,
)

qfunc_kwargs = dict(
    obs_action_dim=encoder_hiddens[-1] * 3, 
    hiddens=[5, 5], 
    with_layer_norm=True, 
    encoder=None
)

Q1_kwargs = qfunc_kwargs.copy()
Q1_kwargs['encoder'] = encoderq1
Q2_kwargs = qfunc_kwargs.copy()
Q2_kwargs['encoder'] = encoderq2
Q1t_kwargs = qfunc_kwargs.copy()
Q1t_kwargs['encoder'] = encoderq1t
Q2t_kwargs = qfunc_kwargs.copy()
Q2t_kwargs['encoder'] = encoderq2t

agent_kwargs=dict(
    name='SACAgentGraph',
    policy_constructor=GaussPolicy,
    qfunc_constructor=Qfunc,
    env_constructor=GraphEnv,
    buffer_constructor=GraphBuffer,
    optimiser_constructors=dict(
        policy_optim=torch.optim.SGD,
        temperature_optim=torch.optim.SGD,
        Q1_optim=torch.optim.SGD,
        Q2_optim=torch.optim.SGD,
    ),
    entropy_lb=encoder_hiddens[-1],
    policy_lr=1e-3,
    temperature_lr=1e-3,
    qfunc_lr=1e-3,
    tau=0.005,
    discount=1.,
    save_to=TEST_OUTPUTS_PATH,
    cache_best_policy=False,
    clip_grads=True,
    UT_trick=False,
    with_entropy=False,
)

config = dict(
    training_kwargs=dict(
        seed=0,
        num_iters=50,
        num_steps_to_sample=100,
        num_grad_steps=1,
        batch_size=100,
        num_eval_steps_to_sample=n_nodes,
        min_steps_to_presample=100,
    ),
    Q1_kwargs=Q1_kwargs,
    Q2_kwargs=Q2_kwargs,
    Q1t_kwargs=Q1t_kwargs,
    Q2t_kwargs=Q2t_kwargs,
    policy_kwargs=gauss_policy_kwargs,
    buffer_kwargs=dict(
        max_size=10_000,
        nodes=nodes,
        seed=0,
        drop_repeats_or_self_loops=True,
        get_batch_reward=True,
        graphs_per_batch=100,
        action_is_index=True,
        per_decision_imp_sample=False,
    ),
    env_kwargs=dict(
        x=nodes,
        reward_fn=reward_fn,
        max_episode_steps=n_nodes,
        num_expert_steps=n_nodes,
        max_repeats=n_nodes // 3,
        max_self_loops=n_nodes // 3,
        drop_repeats_or_self_loops=True,
        id=None,
        reward_fn_termination=False,
        calculate_reward=False,
        min_steps_to_do=3,
    )
)

agent = SACAgentGraph(
    **agent_kwargs,
    **config
)

# print(agent.buffer.idx, agent.buffer.reward_idx)
# agent.buffer.collect_path(agent.env, agent, agent.num_steps_to_sample)
# print(agent.buffer.idx, agent.buffer.reward_idx, 
#       np.mean(agent.buffer.path_lens), 
#       np.max(agent.buffer.path_lens), 
#       agent.buffer.path_lens,
#       agent.buffer.reward_t[:agent.buffer.reward_idx],
#       sep='\n')

pygame 2.5.0 (SDL 2.28.0, Python 3.8.17)
Hello from the pygame community. https://www.pygame.org/contribute.html
/home/mario/coding/urban-nets-style-transfer/tests


In [2]:
agent.buffer.get_single_ep_rewards_and_weights(agent.env, agent, True)

KeyboardInterrupt: 

In [4]:
agent.buffer.idx, agent.buffer.reward_idx

(0, 0)

In [5]:
from graph_irl.irl_trainer import IRLGraphTrainer

a = [0] + torch.repeat_interleave(torch.arange(1, n_nodes), 2).tolist() + [0]

expert_edge_index = torch.tensor([
    a,
    ((torch.tensor([1, -1] * (len(a) // 2)) + torch.tensor(a)) % n_nodes).tolist()
], dtype=torch.long)

irl_trainer = IRLGraphTrainer(
    reward_fn=reward_fn,
    reward_optim=torch.optim.Adam(reward_fn.parameters(), lr=1e-3),
    agent=agent,
    nodes=nodes,
    expert_edge_index=expert_edge_index,
    num_expert_traj=5, 
    # num_generated_traj=5,
    graphs_per_batch=config['buffer_kwargs']['graphs_per_batch'], 
    reward_optim_lr_scheduler=None,
    reward_grad_clip=False,
    reward_scale=1.,
    per_decision_imp_sample=config['buffer_kwargs']['per_decision_imp_sample'],
    add_expert_to_generated=False,
    lcr_regularisation_coef=None,
    mono_regularisation_on_demo_coef=None,
    verbose=True
)

# irl_trainer = IRLGraphTrainer(
    # reward_fn,
    # reward_optim,
    # agent,
    # nodes,
    # expert_edge_index,
    # num_expert_traj,
    # graphs_per_batch,
    # reward_optim_lr_scheduler=None,
    # reward_grad_clip=False,
    # reward_scale=1.,
    # per_decision_imp_sample=False,
    # add_expert_to_generated=False,
    # lcr_regularisation_coef=None,
    # mono_regularisation_on_demo_coef=None,
    # verbose=False,
# )

In [6]:
irl_trainer.train_policy_k_epochs(k=1)

100%|██████████| 50/50 [00:14<00:00,  3.49it/s]
100%|██████████| 1/1 [00:14<00:00, 14.55s/it]


In [35]:
irl_trainer._get_per_dec_imp_samp_returns()

9 tensor(0.7223, grad_fn=<AddBackward0>)
7 tensor(1.1016, grad_fn=<AddBackward0>)
4 tensor(0.4583, grad_fn=<AddBackward0>)
10 tensor(5.0503, grad_fn=<AddBackward0>)
6 tensor(0.5155, grad_fn=<AddBackward0>)
10 tensor(1.9955, grad_fn=<AddBackward0>)
6 tensor(1.0445, grad_fn=<AddBackward0>)


(tensor(nan, grad_fn=<SumBackward0>), tensor(1.5554, grad_fn=<AddBackward0>))

In [7]:

irl_trainer._get_vanilla_imp_sampled_returns()

weights vanilla
[tensor(4.5529e+11), tensor(7.4860e+12), tensor(5.6435e+19), tensor(1.0507e+17), tensor(4.5316e+16), tensor(5.8864e+16)]


(tensor(-1.8085, grad_fn=<SumBackward0>), 0.0)

In [8]:
irl_trainer.do_reward_grad_step()

weights vanilla
[tensor(8.4403e+19), tensor(4.7524e+14), tensor(3.9331e+08), tensor(2.5771e+15), tensor(707796.2500), tensor(1.3794e+18), tensor(6.6347e+16)]
hi
expert avg rewards: -2.7328226566314697
imp sampled rewards: -4.173259258270264
mono loss: 0.0
lcr_expert_loss: 0.0
lcr_sampled_loss: 0.0
overall reward loss: -1.440436601638794
len module param: 7 l2 norm of grad of params: 454.623046875
len module param: 7 l2 norm of grad of params: 855.0362548828125
len module param: 7 l2 norm of grad of params: 308.23040771484375
len module param: 7 l2 norm of grad of params: 818.640625
len module param: 2 l2 norm of grad of params: 12.563790321350098
len module param: 2 l2 norm of grad of params: 113.96983337402344
len module param: 5 l2 norm of grad of params: 259.1490478515625
len module param: 5 l2 norm of grad of params: 195.12933349609375
len module param: 5 l2 norm of grad of params: 0.9488996267318726
len module param: 5 l2 norm of grad of params: 3.7890665531158447
len module param

In [9]:
irl_trainer.train_irl(2, 1)

  0%|          | 0/2 [00:00<?, ?it/s]

weights vanilla
[tensor(5.6648e+13), tensor(3.4684e+18), tensor(2.0338e+20), tensor(2.0412e+17), tensor(7.7866e+18), tensor(1.2702e+19)]
hi
expert avg rewards: -2.25759220123291
imp sampled rewards: -1.1205915212631226
mono loss: 0.0
lcr_expert_loss: 0.0
lcr_sampled_loss: 0.0
overall reward loss: 1.1370006799697876
len module param: 7 l2 norm of grad of params: 7.995495319366455
len module param: 7 l2 norm of grad of params: 18.172767639160156
len module param: 7 l2 norm of grad of params: 29.354124069213867
len module param: 7 l2 norm of grad of params: 15.44437313079834
len module param: 2 l2 norm of grad of params: 1.6760761737823486
len module param: 2 l2 norm of grad of params: 2.1715078353881836
len module param: 5 l2 norm of grad of params: 13.020103454589844
len module param: 5 l2 norm of grad of params: 10.493112564086914
len module param: 5 l2 norm of grad of params: 1.2404022216796875
len module param: 5 l2 norm of grad of params: 2.262253999710083
len module param: 5 l2 nor


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 50/50 [00:14<00:00,  3.40it/s]
100%|██████████| 1/1 [00:14<00:00, 14.72s/it]
  ax.set_ylim(ylow, yhigh)
 50%|█████     | 1/2 [00:15<00:15, 15.76s/it]

weights vanilla
[tensor(7.1979e+17), tensor(1.4116e+08), tensor(1.1959e+19), tensor(1.9929e+15), tensor(24201576.), tensor(7.7659e+18), tensor(2.4562e+18)]
hi
expert avg rewards: -1.7531073093414307
imp sampled rewards: -1.7056950330734253
mono loss: 0.0
lcr_expert_loss: 0.0
lcr_sampled_loss: 0.0
overall reward loss: 0.04741227626800537
len module param: 7 l2 norm of grad of params: 14.223734855651855
len module param: 7 l2 norm of grad of params: 22.31707000732422
len module param: 7 l2 norm of grad of params: 83.16067504882812
len module param: 7 l2 norm of grad of params: 15.936647415161133
len module param: 2 l2 norm of grad of params: 1.0145635604858398
len module param: 2 l2 norm of grad of params: 6.396261215209961
len module param: 5 l2 norm of grad of params: 13.428030014038086
len module param: 5 l2 norm of grad of params: 3.036083936691284
len module param: 5 l2 norm of grad of params: 2.091089963912964
len module param: 5 l2 norm of grad of params: 2.1539509296417236
len mo


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
100%|██████████| 50/50 [00:14<00:00,  3.55it/s]
100%|██████████| 1/1 [00:14<00:00, 14.10s/it]
100%|██████████| 2/2 [00:30<00:00, 15.42s/it]
