From 4b8d67b4712ae548555f257b7c204ec5dac504bc Mon Sep 17 00:00:00 2001 From: Rodrigo Frassetto Nogueira Date: Sat, 11 Jul 2020 13:25:05 -0300 Subject: [PATCH] Adds reranker example (#58) * Adds reranker example * Uses real example --- README.md | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/README.md b/README.md index 70f24196..d5fb3f9c 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,41 @@ Currently, this repo contains implementations of the rerankers for [CovidQA](htt 0. Install [Anserini](https://github.com/castorini/anserini). +# A simple reranking example +The code below exemplifies how to score two documents for a given query using a T5 reranker from [Document Ranking with a Pretrained +Sequence-to-Sequence Model](https://arxiv.org/pdf/2003.06713.pdf). +```python +import torch +from transformers import AutoTokenizer, T5ForConditionalGeneration +from pygaggle.model import T5BatchTokenizer +from pygaggle.rerank.base import Query, Text +from pygaggle.rerank.transformer import T5Reranker + +model_name = 'castorini/monot5-base-msmarco' +tokenizer_name = 't5-base' +batch_size = 8 + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +model = T5ForConditionalGeneration.from_pretrained(model_name) +model = model.to(device).eval() + +tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) +tokenizer = T5BatchTokenizer(tokenizer, batch_size) +reranker = T5Reranker(model, tokenizer) + +query = Query('what causes low liver enzymes') + +correct_doc = Text('Reduced production of liver enzymes may indicate dysfunction of the liver. This article explains the causes and symptoms of low liver enzymes. Scroll down to know how the production of the enzymes can be accelerated.') + +wrong_doc = Text('Elevated liver enzymes often indicate inflammation or damage to cells in the liver. Inflamed or injured liver cells leak higher than normal amounts of certain chemicals, including liver enzymes, into the bloodstream, elevating liver enzymes on blood tests.') + +documents = [correct_doc, wrong_doc] + +scores = [result.score for result in reranker.rerank(query, documents)] +# scores = [-0.1782158613204956, -0.36637523770332336] +``` + # Evaluations ## Additional Instructions