In [1]:
import os
os.chdir('..')

%load_ext autoreload
%autoreload 2

In [2]:
from functools import partial

import torch
from transformers import BertTokenizer, BertForNextSentencePrediction


def get_probability_of_next_sentence(tokenizer, model, text1, text2):
    text1_toks = ["[CLS]"] + tokenizer.tokenize(text1) + ["[SEP]"]
    text2_toks = tokenizer.tokenize(text2) + ["[SEP]"]
    text = text1_toks+text2_toks
    indexed_tokens = tokenizer.convert_tokens_to_ids(text)
    segments_ids = [0]*len(text1_toks) + [1]*len(text2_toks)

    tokens_tensor = torch.tensor([indexed_tokens])
    segments_tensors = torch.tensor([segments_ids])

    model.eval()
    prediction = model(tokens_tensor, token_type_ids=segments_tensors)
    prediction=prediction[0] # tuple to tensor
    softmax = torch.nn.Softmax(dim=1)
    prediction_sm = softmax(prediction)

    return prediction_sm[0]

I0128 00:28:57.202524 140120516495104 file_utils.py:35] PyTorch version 1.0.1.post2 available.
W0128 00:28:57.979238 140120516495104 __init__.py:28] To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [3]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')

partial_get_probability_of_next_sentence = partial(get_probability_of_next_sentence, tokenizer, model)

I0128 00:28:58.603835 140120516495104 tokenization_utils.py:398] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/gabrielamelo/.cache/torch/transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
I0128 00:28:59.222384 140120516495104 configuration_utils.py:185] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/gabrielamelo/.cache/torch/transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c
I0128 00:28:59.224655 140120516495104 configuration_utils.py:199] Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "ini

In [4]:
text1 = "How old are you?"
text2 = "The Eiffel Tower is in Paris"
prediction = partial_get_probability_of_next_sentence(text1, text2)
print(prediction)

tensor([4.1673e-04, 9.9958e-01], grad_fn=<SelectBackward>)


In [5]:
text1 = "How old are you?"
text2 = "I am 22 years old"
prediction = partial_get_probability_of_next_sentence(text1, text2)
print(prediction)
print(prediction[0])

tensor([9.9999e-01, 9.6342e-06], grad_fn=<SelectBackward>)
tensor(1.0000, grad_fn=<SelectBackward>)


In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertForNextSentencePrediction.from_pretrained('bert-base-multilingual-cased')

partial_get_probability_of_next_sentence = partial(get_probability_of_next_sentence, tokenizer, model)

I0128 00:29:02.621064 140120516495104 tokenization_utils.py:398] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-vocab.txt from cache at /home/gabrielamelo/.cache/torch/transformers/96435fa287fbf7e469185f1062386e05a075cadbf6838b74da22bf64b080bc32.99bcd55fc66f4f3360bc49ba472b940b8dcf223ea6a345deb969d607ca900729
I0128 00:29:03.348755 140120516495104 configuration_utils.py:185] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json from cache at /home/gabrielamelo/.cache/torch/transformers/45629519f3117b89d89fd9c740073d8e4c1f0a70f9842476185100a8afe715d1.83b0fa3d7f1ac0e113ad300189a938c6f14d0588a4200f30eef109d0a047c484
I0128 00:29:03.349600 140120516495104 configuration_utils.py:199] Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {


In [7]:
text1 = "Quantos anos você tem?"
text2 = "A Torre Eiffel fica em Paris"
prediction = partial_get_probability_of_next_sentence(text1, text2)
print(prediction)

tensor([0.8567, 0.1433], grad_fn=<SelectBackward>)


In [8]:
text1 = "Quantos anos você tem?"
text2 = "Eu tenho 22 anos"
prediction = partial_get_probability_of_next_sentence(text1, text2)
print(prediction)

tensor([0.9411, 0.0589], grad_fn=<SelectBackward>)


In [9]:
def get_sentence_breaks(first_sentence, second_sentence):
    for i in range(len(first_sentence.split())):
        if first_sentence.split()[i] != second_sentence.split()[i]:  # noqaE226
            break
    return i

In [10]:
def test_get_sentence_breaks():
    first_sentence = 'The city councilmen refused the demonstrators a permit because the city councilmen feared violence.'
    second_sentence = 'The city councilmen refused the demonstrators a permit because the demonstrators feared violence.'
    i = get_sentence_breaks(first_sentence, second_sentence)
    assert ' '.join(first_sentence.split()[:i]) == \
        'The city councilmen refused the demonstrators a permit because the'
    assert ' '.join(second_sentence.split()[:i]) == \
        'The city councilmen refused the demonstrators a permit because the'
    assert ' '.join(first_sentence.split()[i:]) == \
        'city councilmen feared violence.'
    assert ' '.join(second_sentence.split()[i:]) == \
        'demonstrators feared violence.'
    
    first_sentence = 'Os vereadores recusaram a autorização aos manifestantes porque os vereadores temiam a violência.'
    second_sentence = 'Os vereadores recusaram a autorização aos manifestantes porque os manifestantes temiam a violência.'
    
    i = get_sentence_breaks(first_sentence, second_sentence)
    assert ' '.join(first_sentence.split()[:i]) == \
        'Os vereadores recusaram a autorização aos manifestantes porque os'
    assert ' '.join(second_sentence.split()[:i]) == \
        'Os vereadores recusaram a autorização aos manifestantes porque os'
    assert ' '.join(first_sentence.split()[i:]) == \
        'vereadores temiam a violência.'
    assert ' '.join(second_sentence.split()[i:]) == \
        'manifestantes temiam a violência.'
    
test_get_sentence_breaks()