Skip to content

Commit

Permalink
Merge 673126a into b8d0ee3
Browse files Browse the repository at this point in the history
  • Loading branch information
Sloane Simmons committed Sep 7, 2019
2 parents b8d0ee3 + 673126a commit 21971c5
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch_struct/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def SubTokenizedField(tokenizer):
return FIELD


def TokenBucket(train, batch_size):
def TokenBucket(train, batch_size, device='cuda:0'):
def batch_size_fn(x, _, size):
return size + max(len(x.word[0]), 5)

Expand All @@ -73,5 +73,5 @@ def batch_size_fn(x, _, size):
sort_key=lambda x: len(x.word[0]),
repeat=True,
batch_size_fn=batch_size_fn,
device="cuda:0",
device=device,
)

0 comments on commit 21971c5

Please sign in to comment.