diff --git a/blink/main_dense.py b/blink/main_dense.py index c9c01469..557b1bdc 100644 --- a/blink/main_dense.py +++ b/blink/main_dense.py @@ -241,6 +241,7 @@ def _run_biencoder(biencoder, dataloader, candidate_encoding, top_k=100, indexer all_scores = [] for batch in tqdm(dataloader): context_input, _, label_ids = batch + context_input = context_input.to(device=biencoder.device) with torch.no_grad(): if indexer is not None: context_encoding = biencoder.encode_context(context_input).numpy()