In [1]:
import transformers
from transformers import TFBertForNextSentencePrediction, AutoTokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = TFBertForNextSentencePrediction.from_pretrained("bert-base-uncased")

All PyTorch model weights were used when initializing TFBertForNextSentencePrediction.

All the weights of TFBertForNextSentencePrediction were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForNextSentencePrediction for predictions without further training.


In [3]:
# sentences with connection
# prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
# next_sentence = "pizza is eaten with the use of a knife and fork. In casual settings, however, it is cut into wedges to be eaten while held in the hand."

In [4]:
# sentences with no connection
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
next_sentence = "The sky is blue due to the shorter wavelength of blue light."

In [5]:
encoding = tokenizer(prompt, next_sentence, return_tensors='tf')
print(encoding.input_ids)
print(encoding.token_type_ids)

tf.Tensor(
[[  101  1999  3304  1010 10733  2366  1999  5337 10906  1010  2107  2004
   2012  1037  4825  1010  2003  3591  4895 14540  6610  2094  1012   102
   1996  3712  2003  2630  2349  2000  1996  7820 19934  1997  2630  2422
   1012   102]], shape=(1, 38), dtype=int32)
tf.Tensor(
[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1
  1 1]], shape=(1, 38), dtype=int32)


In [6]:
print(tokenizer.cls_token, ':', tokenizer.cls_token_id)
print(tokenizer.sep_token, ':' , tokenizer.sep_token_id)

[CLS] : 101
[SEP] : 102


In [7]:
output = print(tokenizer.decode(encoding.input_ids[0]))

[CLS] in italy, pizza served in formal settings, such as at a restaurant, is presented unsliced. [SEP] the sky is blue due to the shorter wavelength of blue light. [SEP]


# Predict The Next Sentence using 'TFBertForNextSentencePrediction'

In [8]:
import tensorflow as tf
import keras
from keras.layers import Softmax

In [9]:
logits = model(encoding.input_ids, token_type_ids=encoding.token_type_ids)[0]
softmax = Softmax()
probabilities = softmax(logits)
print(probabilities)

tf.Tensor([[1.2606435e-04 9.9987388e-01]], shape=(1, 2), dtype=float32)


In [10]:
# 0 means next_sentence is a continuation of prompt and 1 means next_sentence is a random sentence.
print(tf.math.argmax(input=probabilities, axis=-1).numpy()) 

[1]


# Klue/BERT-base for Korean

In [12]:
tokenizer_k = AutoTokenizer.from_pretrained("klue/bert-base")
model_k = TFBertForNextSentencePrediction.from_pretrained("klue/bert-base")

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertForNextSentencePrediction: ['bert.embeddings.position_ids']
- This IS expected if you are initializing TFBertForNextSentencePrediction from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForNextSentencePrediction from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertForNextSentencePrediction were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForNextSentencePrediction for predictions without further training.


In [13]:
# sentences with connection
prompt = "2002년 월드컵 축구대회는 일본과 공동으로 개최되었던 세계적인 큰 잔치입니다."
next_sentence = "여행을 가보니 한국의 2002년 월드컵 축구대회의 준비는 완벽했습니다."

In [16]:
encoding_k = tokenizer_k(prompt, next_sentence, return_tensors='tf')
logits_k = model_k(encoding_k.input_ids, token_type_ids=encoding_k.token_type_ids)[0]

softmax = Softmax()
probabilities = softmax(logits_k)
print(probabilities)

tf.Tensor([[9.9988782e-01 1.1218969e-04]], shape=(1, 2), dtype=float32)


In [17]:
print(tf.math.argmax(input=probabilities, axis=-1).numpy())

[0]
