From 831dbc2926c9d1b5a32089e138256fe87ca80a45 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 4 Jan 2023 20:26:56 +0000 Subject: [PATCH] Generate: Fix CI related to #20727 (#21003) --- tests/generation/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index dfe8be0efd7e8b..aeb2bf480b25bf 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -3035,7 +3035,7 @@ def test_eos_token_id_int_and_list_greedy_search(self): tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") text = """Hello, my dog is cute and""" - tokens = tokenizer(text, return_tensors="pt") + tokens = tokenizer(text, return_tensors="pt").to(torch_device) model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) @@ -3060,7 +3060,7 @@ def test_eos_token_id_int_and_list_contrastive_search(self): tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") text = """Hello, my dog is cute and""" - tokens = tokenizer(text, return_tensors="pt") + tokens = tokenizer(text, return_tensors="pt").to(torch_device) model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) @@ -3086,7 +3086,7 @@ def test_eos_token_id_int_and_list_top_k_top_sampling(self): tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") text = """Hello, my dog is cute and""" - tokens = tokenizer(text, return_tensors="pt") + tokens = tokenizer(text, return_tensors="pt").to(torch_device) model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) @@ -3109,7 +3109,7 @@ def test_eos_token_id_int_and_list_beam_search(self): tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") text = """Hello, my dog is cute and""" - tokens = tokenizer(text, return_tensors="pt") + tokens = tokenizer(text, return_tensors="pt").to(torch_device) model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)