<a href="https://colab.research.google.com/github/beatobongco/tinybert_ranker/blob/main/TinyBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d8/f4/9f93f06dd2c57c7cd7aa515ffbf9fcfd8a084b92285732289f4a5696dd91/transformers-3.2.0-py3-none-any.whl (1.0MB)
[K     |▎                               | 10kB 21.8MB/s eta 0:00:01[K     |▋                               | 20kB 2.1MB/s eta 0:00:01[K     |█                               | 30kB 3.1MB/s eta 0:00:01[K     |█▎                              | 40kB 4.0MB/s eta 0:00:01[K     |█▋                              | 51kB 2.6MB/s eta 0:00:01[K     |██                              | 61kB 3.1MB/s eta 0:00:01[K     |██▎                             | 71kB 3.5MB/s eta 0:00:01[K     |██▋                             | 81kB 4.0MB/s eta 0:00:01[K     |███                             | 92kB 3.1MB/s eta 0:00:01[K     |███▎                            | 102kB 3.4MB/s eta 0:00:01[K     |███▋                            | 112kB 3.4MB/s eta 0:00:01[K     |███▉                            | 122kB 3.4M

In [None]:
!nvidia-smi

Thu Sep 24 09:33:42 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.66       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P0    60W / 149W |    769MiB / 11441MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import torch
import numpy as np
from typing import List, Tuple
from transformers import AutoTokenizer, AutoModelForSequenceClassification

class PTRanker:
  """Generic class for using pre-trained BERT-like models to rank"""
  def __init__(self, model_path):
    """
        model_path - actual path or huggingface model repo path
    """
    self.max_seq_len = 512
    self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    self.tokenizer = AutoTokenizer.from_pretrained(model_path)
    self.rerank_model = AutoModelForSequenceClassification.from_pretrained(model_path).to(self.device, non_blocking=True)

  def rank(self, query: str, choices: List[str], filter_results=False) -> Tuple[List[int], List[float]]:
    """Assigns relative ranks to each choice"""
    if len(choices) == 0:
        return [], []

    logits = self.get_logits(query, choices)
    scores = []
    all_scores = []
    index_map = []
    for i, logit in enumerate(logits):
        neg_logit = logit[0]
        score = logit[1]
        all_scores.append(score)
        if score > neg_logit or not filter_results:
            scores.append(score)
            index_map.append(i)
    sorted_indices = [index_map[i] for i in np.argsort(scores)[::-1]]
    return sorted_indices, [all_scores[i] for i in sorted_indices]

  def get_logits(self, query: str, choices: List[str]):
    """Get search ranking logits for query, choices"""
    input_ids, attention_mask, token_type_ids = self.encode(query, choices)

    with torch.no_grad():
        logits = self.rerank_model(input_ids,
                                    attention_mask=attention_mask,
                                    token_type_ids=token_type_ids)[0]
        logits = logits.detach().cpu().numpy()

        return logits
        
  def encode(self, query: str, choices: List[str]):
    """Encode query text"""
    inputs = [self.tokenizer.encode_plus(
        query, choice, add_special_tokens=True, return_token_type_ids=True, 
        max_length=self.max_seq_len, truncation=True
        ) for choice in choices]

    max_len = min(max(len(t['input_ids']) for t in inputs), self.max_seq_len)
    input_ids = [t['input_ids'][:max_len] +
                  [0] * (max_len - len(t['input_ids'][:max_len])) for t in inputs]
    attention_mask = [[1] * len(t['input_ids'][:max_len]) +
                      [0] * (max_len - len(t['input_ids'][:max_len])) for t in inputs]
    token_type_ids = [t['token_type_ids'][:max_len] +
                  [0] * (max_len - len(t['token_type_ids'][:max_len])) for t in inputs]

    input_ids = torch.tensor(input_ids).to(self.device, non_blocking=True)
    attention_mask = torch.tensor(attention_mask).to(self.device, non_blocking=True)
    token_type_ids = torch.tensor(token_type_ids).to(self.device, non_blocking=True)

    return input_ids, attention_mask, token_type_ids

model = PTRanker("nboost/pt-tinybert-msmarco")

Some weights of the model checkpoint at nboost/pt-tinybert-msmarco were not used when initializing BertForSequenceClassification: ['fit_dense.weight', 'fit_dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
choices = [
  """Daft Punk are a French electronic music duo formed in Paris in 1993 by Guy-Manuel de Homem-Christo and Thomas Bangalter.[5][6][7][8] They achieved popularity in the late 1990s as part of the French house movement; they also had success in the years following, combining elements of house music with funk, techno, disco, rock and synthpop.[2][6][7][9] They have worn ornate helmets and gloves to assume robot personas in most public appearances since 1999[10] and rarely grant interviews or appear on television. The duo were managed from 1996 to 2008 by Pedro Winter (also known as Busy P), the head of Ed Banger Records.""",
  """Tron (stylized as TRON) is a 1982 American science fiction action-adventure film written and directed by Steven Lisberger from a story by Lisberger and Bonnie MacBird. The film stars Jeff Bridges as a computer programmer who is transported inside the software world of a mainframe computer where he interacts with programs in his attempt to escape. Bruce Boxleitner, David Warner, Cindy Morgan, and Barnard Hughes star in supporting roles.""",
  """Joel Thomas Zimmerman (born January 5, 1981),[2] known professionally as Deadmau5 (stylized as deadmau5; pronounced "dead mouse"), is a Canadian electronic music producer, DJ, and musician. Zimmerman mainly produces progressive house music, though he also produces and DJs other genres of electronic music, including techno under the alias Testpilot. Zimmerman has received six Grammy Award nominations for his work.""",
  "Daft punk is cool",
  "A French electronic music duo formed in Paris in 1993 by Guy-Manuel de Homem-Christo and Thomas Bangalter.",
  "A French electronic music duo consisting of Gaspard Augé and Xavier de Rosnay."
]

In [None]:
def rerank(query, choices=choices, filter_results=False):
  print(f"Query: {query}, filter: {filter_results}")
  ranks, scores = model.rank(query, choices, filter_results=filter_results)
  for i, rs in enumerate(zip(ranks, scores)):
    rank, score = rs
    print(f"{i + 1}. ({score}) {choices[int(rank)]}")
  print("---")

In [None]:
rerank("Who is daft punk?", filter_results=True)
rerank("Who is deadmau5?")
rerank("Who is deadmau5?", filter_results=True)
rerank("sci-fi")
rerank("electronic music")

Query: Who is daft punk?, filter: True
1. (1.4061083793640137) Daft Punk are a French electronic music duo formed in Paris in 1993 by Guy-Manuel de Homem-Christo and Thomas Bangalter.[5][6][7][8] They achieved popularity in the late 1990s as part of the French house movement; they also had success in the years following, combining elements of house music with funk, techno, disco, rock and synthpop.[2][6][7][9] They have worn ornate helmets and gloves to assume robot personas in most public appearances since 1999[10] and rarely grant interviews or appear on television. The duo were managed from 1996 to 2008 by Pedro Winter (also known as Busy P), the head of Ed Banger Records.
2. (1.3816384077072144) Daft punk is cool
---
Query: Who is deadmau5?, filter: False
1. (1.3276888132095337) Joel Thomas Zimmerman (born January 5, 1981),[2] known professionally as Deadmau5 (stylized as deadmau5; pronounced "dead mouse"), is a Canadian electronic music producer, DJ, and musician. Zimmerman mainly

In [None]:
# Test reranking on 50 choices
%%time
rerank("Who is daft punk?", choices=(choices * 9)[:50])

Query: Who is daft punk?, filter: False
1. (1.4061083793640137) Daft Punk are a French electronic music duo formed in Paris in 1993 by Guy-Manuel de Homem-Christo and Thomas Bangalter.[5][6][7][8] They achieved popularity in the late 1990s as part of the French house movement; they also had success in the years following, combining elements of house music with funk, techno, disco, rock and synthpop.[2][6][7][9] They have worn ornate helmets and gloves to assume robot personas in most public appearances since 1999[10] and rarely grant interviews or appear on television. The duo were managed from 1996 to 2008 by Pedro Winter (also known as Busy P), the head of Ed Banger Records.
2. (1.4061083793640137) Daft Punk are a French electronic music duo formed in Paris in 1993 by Guy-Manuel de Homem-Christo and Thomas Bangalter.[5][6][7][8] They achieved popularity in the late 1990s as part of the French house movement; they also had success in the years following, combining elements of house mu