Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Help Wanted] Incompatible log_probs shape when combining trajectories #70

Closed
JosephRRB opened this issue Jun 30, 2023 · 4 comments
Closed
Labels
help wanted Extra attention is needed

Comments

@JosephRRB
Copy link
Contributor

Hello again everyone,

We are trying to augment our training with backward trajectories starting from states sampled from a reward-prioritized replay buffer, as described here: https://arxiv.org/abs/2305.07170. I used Trajectories.revert_backward_trajectories() to transform the backward trajectories into forward ones. But attempting to combine them with the forward sampled trajectories causes an error. Specifically, the code below reproduces the error:

torch.manual_seed(0)

env = HyperGrid(ndim=2, height=3, R0=0.01)
logit_PF = LogitPFEstimator(env=env, module_name="NeuralNet")
logit_PB = LogitPBEstimator(
    env=env, module_name="NeuralNet", torso=logit_PF.module.torso,
)
forward_sampler = TrajectoriesSampler(
    env=env, actions_sampler=DiscreteActionsSampler(estimator=logit_PF)
)
backward_sampler = TrajectoriesSampler(
    env=env, actions_sampler=BackwardDiscreteActionsSampler(estimator=logit_PB)
)

trajectories = forward_sampler.sample(n_trajectories=4)
states = env.reset(batch_shape=4, random=True) # Would come from replay buffer
backward_trajectories = backward_sampler.sample_trajectories(states)
offline_trajectories = Trajectories.revert_backward_trajectories(backward_trajectories)

trajectories.extend(offline_trajectories)

The error is:

    def extend(self, other: Trajectories) -> None:
        """Extend the trajectories with another set of trajectories."""
        self.extend_actions(required_first_dim=max(self.max_length, other.max_length))
        other.extend_actions(required_first_dim=max(self.max_length, other.max_length))
    
        self.states.extend(other.states)
        self.actions = torch.cat((self.actions, other.actions), dim=1)
        self.when_is_done = torch.cat((self.when_is_done, other.when_is_done), dim=0)
>       self.log_probs = torch.cat((self.log_probs, other.log_probs), dim=1)
E       RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 5 but got size 4 for tensor number 1 in the list.

Inserting the code

offline_trajectories.log_probs = torch.cat([
    offline_trajectories.log_probs,
    torch.full(
        size=(
            1,
            offline_trajectories.n_trajectories,
        ),
        fill_value=0,
        dtype=torch.float,
    )
], dim=0)

before trajectories.extend(offline_trajectories) seems to work but I don't know if there will be unexpected behavior downstream. It seems that log_probs needs to be padded after Trajectories.revert_backward_trajectories(). I would appreciate your insight.

Possibly a bit off topic, would it be better to sample forward trajectories stored in ReplayBuffer instead? I would just need to sort the trajectories according to the rewards of the terminating states?

Thank you very much for your time!

@saleml
Copy link
Collaborator

saleml commented Jun 30, 2023

before trajectories.extend(offline_trajectories) seems to work but I don't know if there will be unexpected behavior downstream. It seems that log_probs needs to be padded after Trajectories.revert_backward_trajectories(). I would appreciate your insight.

If you actually don't care about the log_probs attribute of the obtained trajectories, then yeah your suggestion should work fine.

would it be better to sample forward trajectories stored in ReplayBuffer instead?

I'm not sure what you mean by that. Did you mean "to store" instead of "to sample" ? If so, note that a replay buffer works essentially with extends. The difference with what you wrote is that a ReplayBuffer object has a limited capacity. What do you think such a replay buffer would be helpful for ?

@JosephRRB
Copy link
Contributor Author

If you actually don't care about the log_probs attribute of the obtained trajectories, then yeah your suggestion should work fine.

Alright thank you! I can set on_policy=False on the loss function so that the log_probs attribute won't be used during training, right? I was just wondering if I was converting backward trajectories into forward ones correctly. Since I'm also using them in training

I'm not sure what you mean by that. Did you mean "to store" instead of "to sample" ? If so, note that a replay buffer works essentially with extends. The difference with what you wrote is that a ReplayBuffer object has a limited capacity. What do you think such a replay buffer would be helpful for ?

I'm implementing a replay buffer where we can sample from the top highest (and lowest) reward terminal states seen so far. Its PRT from https://arxiv.org/abs/2305.07170. We need it because our rewards are very skewed.

Currently, every time I add terminal states to the buffer, I remove any duplicated terminal states and I sort them according to their corresponding rewards. So when I sample from the buffer, I only sample from the first n states and last n states. If I reach the maximum capacity, I remove states from the middle of the buffer.

My question was, instead of storing just terminal states and sampling the highest (and lowest) reward terminal states from them, can I store the trajectories themselves and find a way to sample those trajectories with the desired rewards? This is so that I wouldn't need to generate the backward trajectories from the terminal states and convert them to forward ones.

What do you think is the best way to implement such a buffer?

@saleml saleml added the help wanted Extra attention is needed label Jul 7, 2023
@saleml
Copy link
Collaborator

saleml commented Jul 10, 2023

I have never implemented a prioritized replay buffer, but your idea looks like it can be extended to storing trajectories rather than terminal states.

For example, you can use the provided ReplayBuffer class, and keep a sorted list of indices (the indices of the trajectories). Every time you add a trajectory, you sort again the list of indices (using the corresponding trajectory reward for sorting). Then you could use the sorted list of indices to subsample your replay buffer, rather than using the sample() method.

@JosephRRB
Copy link
Contributor Author

Alright, let me try that out! Thank you :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants