Skip to content

Commit

Permalink
Adds reranker example (#58)
Browse files Browse the repository at this point in the history
* Adds reranker example

* Uses real example
  • Loading branch information
rodrigonogueira4 committed Jul 11, 2020
1 parent 3e07b5c commit 4b8d67b
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,41 @@ Currently, this repo contains implementations of the rerankers for [CovidQA](htt
0. Install [Anserini](https://github.com/castorini/anserini).


# A simple reranking example
The code below exemplifies how to score two documents for a given query using a T5 reranker from [Document Ranking with a Pretrained
Sequence-to-Sequence Model](https://arxiv.org/pdf/2003.06713.pdf).
```python
import torch
from transformers import AutoTokenizer, T5ForConditionalGeneration
from pygaggle.model import T5BatchTokenizer
from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import T5Reranker

model_name = 'castorini/monot5-base-msmarco'
tokenizer_name = 't5-base'
batch_size = 8

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = T5ForConditionalGeneration.from_pretrained(model_name)
model = model.to(device).eval()

tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer = T5BatchTokenizer(tokenizer, batch_size)
reranker = T5Reranker(model, tokenizer)

query = Query('what causes low liver enzymes')

correct_doc = Text('Reduced production of liver enzymes may indicate dysfunction of the liver. This article explains the causes and symptoms of low liver enzymes. Scroll down to know how the production of the enzymes can be accelerated.')

wrong_doc = Text('Elevated liver enzymes often indicate inflammation or damage to cells in the liver. Inflamed or injured liver cells leak higher than normal amounts of certain chemicals, including liver enzymes, into the bloodstream, elevating liver enzymes on blood tests.')

documents = [correct_doc, wrong_doc]

scores = [result.score for result in reranker.rerank(query, documents)]
# scores = [-0.1782158613204956, -0.36637523770332336]
```

# Evaluations

## Additional Instructions
Expand Down

0 comments on commit 4b8d67b

Please sign in to comment.