Skip to content

Commit

Permalink
Fix default args in MonoBERT/T5 (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuxuan-ji committed Sep 20, 2020
1 parent d88f8ce commit 73539ea
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pygaggle/rerank/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

class MonoT5(Reranker):
def __init__(self,
model_name_or_instance: Union[str, T5ForConditionalGeneration] = 'castorini/monoT5-base-msmarco',
model_name_or_instance: Union[str, T5ForConditionalGeneration] = 'castorini/monot5-base-msmarco',
tokenizer_name_or_instance: Union[str, QueryDocumentBatchTokenizer] = 't5-base'):
if isinstance(model_name_or_instance, str):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand Down Expand Up @@ -109,7 +109,7 @@ def rerank(self, query: Query, texts: List[Text]) -> List[Text]:

class MonoBERT(Reranker):
def __init__(self,
model_name_or_instance: Union[str, PreTrainedModel] = 'castorini/monoBERT-large-msmarco',
model_name_or_instance: Union[str, PreTrainedModel] = 'castorini/monobert-large-msmarco',
tokenizer_name_or_instance: Union[str, PreTrainedTokenizer] = 'bert-large-uncased'):
if isinstance(model_name_or_instance, str):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Expand Down

0 comments on commit 73539ea

Please sign in to comment.