Skip to content

Commit

Permalink
Fix GAIL expert dataloader
Browse files Browse the repository at this point in the history
Set drop_last=True only if the dataset size is more than GAIL batch size (otherwise it just returns an empty dataloader).
  • Loading branch information
ranamihir committed May 3, 2019
1 parent b4133ec commit ddbfd9f
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions main.py
Expand Up @@ -79,13 +79,15 @@ def main():
file_name = os.path.join(
args.gail_experts_dir, "trajs_{}.pt".format(
args.env_name.split('-')[0].lower()))


expert_dataset = gail.ExpertDataset(
file_name, num_trajectories=4, subsample_frequency=20)
drop_last = len(expert_dataset) > args.gail_batch_size
gail_train_loader = torch.utils.data.DataLoader(
gail.ExpertDataset(
file_name, num_trajectories=4, subsample_frequency=20),
dataset=expert_dataset,
batch_size=args.gail_batch_size,
shuffle=True,
drop_last=True)
drop_last=drop_last)

rollouts = RolloutStorage(args.num_steps, args.num_processes,
envs.observation_space.shape, envs.action_space,
Expand Down

0 comments on commit ddbfd9f

Please sign in to comment.