In [5]:
import numpy as np 
import pandas as pd 
import os
import apex

Tutorial taken from: https://pypi.org/project/pytorch-pretrained-bert/#usage
This notebook showcases how BERT can be used to predict a missing word in a sentence. You can also look at the tokenized text to see how bert handles tokens

In [3]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows
import logging
logging.basicConfig(level=logging.INFO)

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenized input
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)

# Mask a token that we will try to predict back with `BertForMaskedLM`
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
assert tokenized_text == ['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']

# Convert token to vocabulary indices
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
# Define sentence A and B indices associated to 1st and 2nd sentences (see paper)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# Convert inputs to PyTorch tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/curtis/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


In [6]:
# Load pre-trained model (weights)
model = BertModel.from_pretrained('bert-base-uncased')
model.half().eval()

# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.half().to('cuda')

# Predict hidden states features for each layer
with torch.no_grad():
    encoded_layers, _ = model(tokens_tensor, segments_tensors)
# We have a hidden states for each of the 12 layers in model bert-base-uncased
assert len(encoded_layers) == 12

INFO:pytorch_pretrained_bert.file_utils:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz not found in cache, downloading to /tmp/tmpw07_m1qv

  0%|          | 0/407873900 [00:00<?, ?B/s][A
  0%|          | 34816/407873900 [00:00<32:29, 209205.05B/s][A
  0%|          | 173056/407873900 [00:00<25:15, 268958.41B/s][A
  0%|          | 295936/407873900 [00:00<19:22, 350719.36B/s][A
  0%|          | 382976/407873900 [00:00<15:57, 425541.23B/s][A
  0%|          | 557056/407873900 [00:00<12:34, 539873.83B/s][A
  0%|          | 731136/407873900 [00:00<09:59, 678932.79B/s][A
  0%|          | 887808/407873900 [00:00<08:18, 815666.47B/s][A
  0%|          | 1044480/407873900 [00:00<07:08, 948781.06B/s][A
  0%|          | 1218560/407873900 [00:01<06:10, 1097232.67B/s][A
  0%|          | 1392640/407873900 [00:01<05:34, 1213859.74B/s][A
  0%|          | 1549312/407873900 [00:01<05:16, 1282499.78B/s][A
  0%|          | 1698816/407873900 [00:01<05:09, 1310741.14B/

  7%|▋         | 28603392/407873900 [00:30<06:45, 935603.62B/s][A
  7%|▋         | 28723200/407873900 [00:30<06:55, 913136.10B/s][A
  7%|▋         | 28879872/407873900 [00:30<06:19, 998723.55B/s][A
  7%|▋         | 28984320/407873900 [00:30<06:15, 1008701.58B/s][A
  7%|▋         | 29158400/407873900 [00:31<05:36, 1123795.02B/s][A
  7%|▋         | 29277184/407873900 [00:31<06:12, 1015747.27B/s][A
  7%|▋         | 29385728/407873900 [00:31<07:25, 849841.10B/s] [A
  7%|▋         | 29506560/407873900 [00:31<07:25, 848685.81B/s][A
  7%|▋         | 29597696/407873900 [00:31<08:15, 763518.93B/s][A
  7%|▋         | 29767680/407873900 [00:31<07:27, 844403.32B/s][A
  7%|▋         | 29872128/407873900 [00:31<08:32, 737513.83B/s][A
  7%|▋         | 29993984/407873900 [00:32<07:37, 826184.74B/s][A
  7%|▋         | 30133248/407873900 [00:32<06:53, 913492.57B/s][A
  7%|▋         | 30237696/407873900 [00:32<07:23, 852109.71B/s][A
  7%|▋         | 30376960/407873900 [00:32<06:44, 932865.3

 12%|█▏        | 50058240/407873900 [01:02<11:26, 520872.76B/s][A
 12%|█▏        | 50169856/407873900 [01:02<10:35, 563142.07B/s][A
 12%|█▏        | 50309120/407873900 [01:02<08:56, 665978.36B/s][A
 12%|█▏        | 50385920/407873900 [01:02<08:37, 690205.10B/s][A
 12%|█▏        | 50500608/407873900 [01:02<07:56, 749789.49B/s][A
 12%|█▏        | 50605056/407873900 [01:02<07:37, 781713.72B/s][A
 12%|█▏        | 50688000/407873900 [01:03<09:22, 634603.31B/s][A
 12%|█▏        | 50795520/407873900 [01:03<08:17, 718301.16B/s][A
 12%|█▏        | 50877440/407873900 [01:03<09:51, 603350.34B/s][A
 12%|█▏        | 50948096/407873900 [01:03<09:26, 630042.45B/s][A
 13%|█▎        | 51022848/407873900 [01:03<09:13, 645183.13B/s][A
 13%|█▎        | 51127296/407873900 [01:03<08:29, 699875.26B/s][A
 13%|█▎        | 51214336/407873900 [01:03<08:00, 742355.61B/s][A
 13%|█▎        | 51301376/407873900 [01:03<08:55, 665403.63B/s][A
 13%|█▎        | 51373056/407873900 [01:04<10:58, 541344.93B/s

 18%|█▊        | 73287680/407873900 [01:36<09:06, 611682.23B/s][A
 18%|█▊        | 73374720/407873900 [01:36<09:05, 613429.64B/s][A
 18%|█▊        | 73461760/407873900 [01:36<08:51, 628837.26B/s][A
 18%|█▊        | 73548800/407873900 [01:36<08:54, 625698.43B/s][A
 18%|█▊        | 73653248/407873900 [01:36<08:21, 667090.11B/s][A
 18%|█▊        | 73740288/407873900 [01:36<07:51, 708967.29B/s][A
 18%|█▊        | 73812992/407873900 [01:36<09:28, 587401.34B/s][A
 18%|█▊        | 73896960/407873900 [01:37<09:43, 572532.70B/s][A
 18%|█▊        | 73966592/407873900 [01:37<10:20, 538086.04B/s][A
 18%|█▊        | 74023936/407873900 [01:37<11:16, 493817.71B/s][A
 18%|█▊        | 74088448/407873900 [01:37<11:50, 469569.12B/s][A
 18%|█▊        | 74175488/407873900 [01:37<10:33, 526653.29B/s][A
 18%|█▊        | 74245120/407873900 [01:37<09:57, 558324.94B/s][A
 18%|█▊        | 74314752/407873900 [01:37<10:10, 546395.37B/s][A
 18%|█▊        | 74401792/407873900 [01:37<10:41, 519631.58B/s

 23%|██▎       | 94510080/407873900 [02:12<09:28, 551179.82B/s][A
 23%|██▎       | 94571520/407873900 [02:12<09:12, 566957.49B/s][A
 23%|██▎       | 94632960/407873900 [02:12<10:02, 519772.90B/s][A
 23%|██▎       | 94689280/407873900 [02:12<09:53, 527335.24B/s][A
 23%|██▎       | 94745600/407873900 [02:12<12:40, 411991.66B/s][A
 23%|██▎       | 94803968/407873900 [02:13<13:21, 390829.83B/s][A
 23%|██▎       | 94873600/407873900 [02:13<12:12, 427037.85B/s][A
 23%|██▎       | 94920704/407873900 [02:13<14:01, 371895.06B/s][A
 23%|██▎       | 94962688/407873900 [02:13<17:09, 303951.00B/s][A
 23%|██▎       | 94998528/407873900 [02:13<19:15, 270797.21B/s][A
 23%|██▎       | 95030272/407873900 [02:13<22:17, 233928.38B/s][A
 23%|██▎       | 95057920/407873900 [02:13<21:47, 239326.18B/s][A
 23%|██▎       | 95084544/407873900 [02:14<24:58, 208775.66B/s][A
 23%|██▎       | 95117312/407873900 [02:14<27:55, 186718.63B/s][A
 23%|██▎       | 95138816/407873900 [02:14<32:08, 162151.15B/s

 28%|██▊       | 112483328/407873900 [02:45<18:22, 267889.81B/s][A
 28%|██▊       | 112542720/407873900 [02:45<15:25, 318980.36B/s][A
 28%|██▊       | 112629760/407873900 [02:46<12:52, 382052.21B/s][A
 28%|██▊       | 112699392/407873900 [02:46<11:22, 432235.72B/s][A
 28%|██▊       | 112785408/407873900 [02:46<09:47, 502591.93B/s][A
 28%|██▊       | 112847872/407873900 [02:46<10:52, 451917.21B/s][A
 28%|██▊       | 112903168/407873900 [02:46<11:21, 432869.75B/s][A
 28%|██▊       | 112977920/407873900 [02:46<10:12, 481139.20B/s][A
 28%|██▊       | 113064960/407873900 [02:46<09:27, 519448.15B/s][A
 28%|██▊       | 113152000/407873900 [02:46<08:35, 571887.79B/s][A
 28%|██▊       | 113256448/407873900 [02:47<07:51, 624351.39B/s][A
 28%|██▊       | 113360896/407873900 [02:47<07:07, 689509.37B/s][A
 28%|██▊       | 113447936/407873900 [02:47<06:53, 712213.53B/s][A
 28%|██▊       | 113534976/407873900 [02:47<07:04, 692641.49B/s][A
 28%|██▊       | 113607680/407873900 [02:47<07:2

 33%|███▎      | 134737920/407873900 [03:18<20:15, 224792.89B/s][A
 33%|███▎      | 134772736/407873900 [03:18<18:30, 246025.90B/s][A
 33%|███▎      | 134807552/407873900 [03:18<19:05, 238335.44B/s][A
 33%|███▎      | 134859776/407873900 [03:18<17:33, 259253.11B/s][A
 33%|███▎      | 134912000/407873900 [03:18<16:01, 283825.39B/s][A
 33%|███▎      | 134964224/407873900 [03:18<15:14, 298454.99B/s][A
 33%|███▎      | 134999040/407873900 [03:18<14:46, 307773.78B/s][A
 33%|███▎      | 135051264/407873900 [03:19<15:40, 290221.04B/s][A
 33%|███▎      | 135120896/407873900 [03:19<13:37, 333816.97B/s][A
 33%|███▎      | 135157760/407873900 [03:19<14:42, 308990.51B/s][A
 33%|███▎      | 135191552/407873900 [03:19<14:38, 310314.15B/s][A
 33%|███▎      | 135225344/407873900 [03:19<14:47, 307287.52B/s][A
 33%|███▎      | 135277568/407873900 [03:19<13:31, 335803.31B/s][A
 33%|███▎      | 135329792/407873900 [03:19<13:16, 342165.09B/s][A
 33%|███▎      | 135365632/407873900 [03:19<15:3

 38%|███▊      | 155592704/407873900 [03:48<04:04, 1032683.54B/s][A
 38%|███▊      | 155731968/407873900 [03:48<04:18, 975467.82B/s] [A
 38%|███▊      | 155906048/407873900 [03:49<03:59, 1053942.48B/s][A
 38%|███▊      | 156016640/407873900 [03:49<04:10, 1003466.90B/s][A
 38%|███▊      | 156121088/407873900 [03:49<04:42, 890882.01B/s] [A
 38%|███▊      | 156215296/407873900 [03:49<05:05, 824382.97B/s][A
 38%|███▊      | 156306432/407873900 [03:49<05:23, 777715.60B/s][A
 38%|███▊      | 156393472/407873900 [03:49<05:19, 786405.86B/s][A
 38%|███▊      | 156475392/407873900 [03:49<05:19, 786193.27B/s][A
 38%|███▊      | 156567552/407873900 [03:50<05:36, 747494.12B/s][A
 38%|███▊      | 156672000/407873900 [03:50<05:40, 737153.90B/s][A
 38%|███▊      | 156776448/407873900 [03:50<05:31, 756340.03B/s][A
 38%|███▊      | 156880896/407873900 [03:50<05:10, 809280.56B/s][A
 38%|███▊      | 156967936/407873900 [03:50<05:03, 826224.46B/s][A
 39%|███▊      | 157054976/407873900 [03:50

 44%|████▍     | 178501632/407873900 [04:19<05:02, 758260.95B/s][A
 44%|████▍     | 178581504/407873900 [04:19<05:29, 694827.37B/s][A
 44%|████▍     | 178693120/407873900 [04:19<05:04, 752244.48B/s][A
 44%|████▍     | 178772992/407873900 [04:19<05:20, 715156.44B/s][A
 44%|████▍     | 178884608/407873900 [04:19<05:16, 722397.15B/s][A
 44%|████▍     | 179006464/407873900 [04:20<05:41, 669284.26B/s][A
 44%|████▍     | 179076096/407873900 [04:20<07:28, 510321.08B/s][A
 44%|████▍     | 179197952/407873900 [04:20<06:18, 603609.33B/s][A
 44%|████▍     | 179302400/407873900 [04:20<05:32, 687747.52B/s][A
 44%|████▍     | 179406848/407873900 [04:20<05:00, 761141.36B/s][A
 44%|████▍     | 179494912/407873900 [04:20<05:08, 740940.35B/s][A
 44%|████▍     | 179615744/407873900 [04:20<04:55, 771379.19B/s][A
 44%|████▍     | 179702784/407873900 [04:21<05:15, 723886.67B/s][A
 44%|████▍     | 179842048/407873900 [04:21<04:35, 828014.96B/s][A
 44%|████▍     | 179934208/407873900 [04:21<04:4

 50%|█████     | 204456960/407873900 [04:51<12:50, 263990.28B/s][A
 50%|█████     | 204526592/407873900 [04:51<12:12, 277648.69B/s][A
 50%|█████     | 204561408/407873900 [04:51<13:10, 257299.05B/s][A
 50%|█████     | 204595200/407873900 [04:51<13:31, 250635.44B/s][A
 50%|█████     | 204621824/407873900 [04:51<17:30, 193561.97B/s][A
 50%|█████     | 204644352/407873900 [04:51<17:11, 196939.83B/s][A
 50%|█████     | 204666880/407873900 [04:52<22:05, 153341.73B/s][A
 50%|█████     | 204685312/407873900 [04:52<27:38, 122530.07B/s][A
 50%|█████     | 204718080/407873900 [04:52<26:44, 126606.85B/s][A
 50%|█████     | 204735488/407873900 [04:52<29:41, 114039.42B/s][A
 50%|█████     | 204752896/407873900 [04:52<30:15, 111912.08B/s][A
 50%|█████     | 204770304/407873900 [04:53<30:36, 110603.20B/s][A
 50%|█████     | 204787712/407873900 [04:53<29:23, 115188.00B/s][A
 50%|█████     | 204805120/407873900 [04:53<27:14, 124208.60B/s][A
 50%|█████     | 204839936/407873900 [04:53<23:1

 55%|█████▌    | 225174528/407873900 [05:23<03:51, 788322.09B/s][A
 55%|█████▌    | 225259520/407873900 [05:24<04:41, 649531.70B/s][A
 55%|█████▌    | 225416192/407873900 [05:24<03:53, 779793.66B/s][A
 55%|█████▌    | 225511424/407873900 [05:24<03:51, 788688.68B/s][A
 55%|█████▌    | 225607680/407873900 [05:24<03:52, 785544.53B/s][A
 55%|█████▌    | 225694720/407873900 [05:24<04:02, 752361.61B/s][A
 55%|█████▌    | 225776640/407873900 [05:24<04:47, 633236.71B/s][A
 55%|█████▌    | 225848320/407873900 [05:24<04:43, 641306.89B/s][A
 55%|█████▌    | 225917952/407873900 [05:25<05:20, 568218.48B/s][A
 55%|█████▌    | 225980416/407873900 [05:25<06:44, 449302.76B/s][A
 55%|█████▌    | 226042880/407873900 [05:25<06:48, 445503.74B/s][A
 55%|█████▌    | 226129920/407873900 [05:25<06:23, 473393.41B/s][A
 55%|█████▌    | 226199552/407873900 [05:25<06:20, 477722.60B/s][A
 55%|█████▌    | 226268160/407873900 [05:25<05:47, 523253.56B/s][A
 55%|█████▌    | 226338816/407873900 [05:25<05:3

 61%|██████▏   | 250530816/407873900 [05:54<02:16, 1154311.67B/s][A
 61%|██████▏   | 250651648/407873900 [05:54<02:33, 1025590.98B/s][A
 61%|██████▏   | 250779648/407873900 [05:55<02:26, 1073914.99B/s][A
 62%|██████▏   | 250892288/407873900 [05:55<02:29, 1049750.64B/s][A
 62%|██████▏   | 251039744/407873900 [05:55<02:16, 1145279.43B/s][A
 62%|██████▏   | 251162624/407873900 [05:55<02:17, 1143261.85B/s][A
 62%|██████▏   | 251319296/407873900 [05:55<02:06, 1236350.05B/s][A
 62%|██████▏   | 251448320/407873900 [05:55<02:16, 1149124.98B/s][A
 62%|██████▏   | 251615232/407873900 [05:55<02:05, 1247408.13B/s][A
 62%|██████▏   | 251754496/407873900 [05:55<02:02, 1276903.88B/s][A
 62%|██████▏   | 251893760/407873900 [05:55<02:00, 1290481.98B/s][A
 62%|██████▏   | 252025856/407873900 [05:56<02:03, 1261031.60B/s][A
 62%|██████▏   | 252154880/407873900 [05:56<02:37, 991114.05B/s] [A
 62%|██████▏   | 252265472/407873900 [05:56<02:47, 927321.87B/s][A
 62%|██████▏   | 252366848/40787390

 70%|███████   | 287284224/407873900 [06:23<04:30, 445517.85B/s][A
 70%|███████   | 287353856/407873900 [06:24<04:30, 445719.12B/s][A
 70%|███████   | 287406080/407873900 [06:24<04:24, 455632.69B/s][A
 70%|███████   | 287453184/407873900 [06:24<05:49, 344934.85B/s][A
 71%|███████   | 287562752/407873900 [06:24<04:53, 410570.20B/s][A
 71%|███████   | 287613952/407873900 [06:24<04:43, 424088.28B/s][A
 71%|███████   | 287663104/407873900 [06:24<04:45, 420341.05B/s][A
 71%|███████   | 287719424/407873900 [06:24<04:36, 435119.98B/s][A
 71%|███████   | 287771648/407873900 [06:24<04:30, 444224.58B/s][A
 71%|███████   | 287823872/407873900 [06:25<04:24, 453915.61B/s][A
 71%|███████   | 287876096/407873900 [06:25<04:23, 454593.03B/s][A
 71%|███████   | 287945728/407873900 [06:25<04:12, 474536.75B/s][A
 71%|███████   | 287997952/407873900 [06:25<04:08, 483303.88B/s][A
 71%|███████   | 288067584/407873900 [06:25<03:58, 501691.54B/s][A
 71%|███████   | 288137216/407873900 [06:25<03:4

 75%|███████▌  | 306433024/407873900 [06:56<01:59, 851515.12B/s][A
 75%|███████▌  | 306572288/407873900 [06:56<01:46, 952510.49B/s][A
 75%|███████▌  | 306727936/407873900 [06:56<01:33, 1077966.30B/s][A
 75%|███████▌  | 306846720/407873900 [06:56<01:40, 1002918.02B/s][A
 75%|███████▌  | 306956288/407873900 [06:56<02:03, 814602.10B/s] [A
 75%|███████▌  | 307077120/407873900 [06:56<01:58, 851348.29B/s][A
 75%|███████▌  | 307216384/407873900 [06:56<01:59, 842012.77B/s][A
 75%|███████▌  | 307373056/407873900 [06:56<01:50, 910699.13B/s][A
 75%|███████▌  | 307512320/407873900 [06:57<01:40, 1000042.34B/s][A
 75%|███████▌  | 307634176/407873900 [06:57<01:39, 1008136.84B/s][A
 75%|███████▌  | 307790848/407873900 [06:57<01:28, 1127401.27B/s][A
 76%|███████▌  | 307959808/407873900 [06:57<01:19, 1252400.36B/s][A
 76%|███████▌  | 308094976/407873900 [06:57<01:18, 1276289.18B/s][A
 76%|███████▌  | 308260864/407873900 [06:57<01:18, 1268620.57B/s][A
 76%|███████▌  | 308451328/407873900 [0

 82%|████████▏ | 334119936/407873900 [07:25<00:59, 1231848.69B/s][A
 82%|████████▏ | 334268416/407873900 [07:25<00:56, 1296902.25B/s][A
 82%|████████▏ | 334402560/407873900 [07:25<01:01, 1185298.23B/s][A
 82%|████████▏ | 334564352/407873900 [07:25<00:58, 1247541.64B/s][A
 82%|████████▏ | 334755840/407873900 [07:25<00:53, 1354090.12B/s][A
 82%|████████▏ | 334947328/407873900 [07:25<00:49, 1477477.50B/s][A
 82%|████████▏ | 335141888/407873900 [07:25<00:45, 1592354.45B/s][A
 82%|████████▏ | 335312896/407873900 [07:26<00:44, 1615670.16B/s][A
 82%|████████▏ | 335486976/407873900 [07:26<00:43, 1649353.23B/s][A
 82%|████████▏ | 335678464/407873900 [07:26<00:42, 1689464.83B/s][A
 82%|████████▏ | 335904768/407873900 [07:26<00:41, 1749194.97B/s][A
 82%|████████▏ | 336131072/407873900 [07:26<00:38, 1846831.78B/s][A
 82%|████████▏ | 336319488/407873900 [07:26<00:39, 1800521.01B/s][A
 83%|████████▎ | 336531456/407873900 [07:26<00:39, 1817164.71B/s][A
 83%|████████▎ | 336757760/4078739

 89%|████████▉ | 362312704/407873900 [07:55<01:35, 479163.42B/s][A
 89%|████████▉ | 362416128/407873900 [07:55<01:20, 565025.34B/s][A
 89%|████████▉ | 362496000/407873900 [07:55<01:15, 599186.43B/s][A
 89%|████████▉ | 362607616/407873900 [07:56<01:05, 688808.91B/s][A
 89%|████████▉ | 362713088/407873900 [07:56<01:02, 722949.51B/s][A
 89%|████████▉ | 362817536/407873900 [07:56<01:01, 732432.17B/s][A
 89%|████████▉ | 362956800/407873900 [07:56<00:54, 823197.26B/s][A
 89%|████████▉ | 363078656/407873900 [07:56<00:54, 827662.83B/s][A
 89%|████████▉ | 363168768/407873900 [07:56<00:58, 764379.23B/s][A
 89%|████████▉ | 363270144/407873900 [07:56<01:03, 705851.86B/s][A
 89%|████████▉ | 363357184/407873900 [07:57<01:06, 665332.79B/s][A
 89%|████████▉ | 363427840/407873900 [07:57<01:05, 676889.99B/s][A
 89%|████████▉ | 363513856/407873900 [07:57<01:12, 612293.42B/s][A
 89%|████████▉ | 363618304/407873900 [07:57<01:10, 623921.86B/s][A
 89%|████████▉ | 363687936/407873900 [07:57<01:0

 93%|█████████▎| 381061120/407873900 [08:28<00:45, 592410.79B/s][A
 93%|█████████▎| 381165568/407873900 [08:28<00:44, 596793.69B/s][A
 93%|█████████▎| 381270016/407873900 [08:28<00:39, 675568.18B/s][A
 94%|█████████▎| 381391872/407873900 [08:28<00:35, 743017.21B/s][A
 94%|█████████▎| 381513728/407873900 [08:29<00:32, 822395.55B/s][A
 94%|█████████▎| 381635584/407873900 [08:29<00:31, 840992.11B/s][A
 94%|█████████▎| 381773824/407873900 [08:29<00:27, 938651.74B/s][A
 94%|█████████▎| 381896704/407873900 [08:29<00:29, 888119.98B/s][A
 94%|█████████▎| 382035968/407873900 [08:29<00:27, 947039.67B/s][A
 94%|█████████▎| 382157824/407873900 [08:29<00:26, 983495.14B/s][A
 94%|█████████▎| 382297088/407873900 [08:29<00:25, 999586.71B/s][A
 94%|█████████▍| 382418944/407873900 [08:29<00:24, 1042121.11B/s][A
 94%|█████████▍| 382558208/407873900 [08:30<00:23, 1076500.60B/s][A
 94%|█████████▍| 382697472/407873900 [08:30<00:22, 1137562.83B/s][A
 94%|█████████▍| 382819328/407873900 [08:30<0

 99%|█████████▉| 404753408/407873900 [08:58<00:03, 979206.38B/s][A
 99%|█████████▉| 404852736/407873900 [08:58<00:03, 914851.55B/s][A
 99%|█████████▉| 404945920/407873900 [08:58<00:03, 918069.20B/s][A
 99%|█████████▉| 405084160/407873900 [08:58<00:02, 976076.58B/s][A
 99%|█████████▉| 405206016/407873900 [08:58<00:02, 1007681.81B/s][A
 99%|█████████▉| 405327872/407873900 [08:59<00:02, 956873.58B/s] [A
 99%|█████████▉| 405467136/407873900 [08:59<00:02, 1035090.51B/s][A
 99%|█████████▉| 405573632/407873900 [08:59<00:02, 1032449.79B/s][A
 99%|█████████▉| 405679104/407873900 [08:59<00:02, 871404.91B/s] [A
 99%|█████████▉| 405832704/407873900 [08:59<00:02, 942601.59B/s][A
100%|█████████▉| 405937152/407873900 [08:59<00:02, 952782.70B/s][A
100%|█████████▉| 406059008/407873900 [08:59<00:01, 981613.45B/s][A
100%|█████████▉| 406180864/407873900 [08:59<00:01, 1021414.38B/s][A
100%|█████████▉| 406286336/407873900 [09:00<00:01, 1028093.38B/s][A
100%|█████████▉| 406403072/407873900 [09:

In [18]:
encoded_layers[0].size()

torch.Size([1, 14, 768])

In [7]:
# Load pre-trained model (weights)
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.half().eval()

# If you have a GPU, put everything on cuda
tokens_tensor = tokens_tensor.to('cuda')
segments_tensors = segments_tensors.to('cuda')
model.half().to('cuda')

# Predict all tokens
with torch.no_grad():
    predictions = model(tokens_tensor, segments_tensors)

# confirm we were able to predict 'henson'
predicted_index = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'henson'

INFO:pytorch_pretrained_bert.modeling:loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /home/curtis/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
INFO:pytorch_pretrained_bert.modeling:extracting archive file /home/curtis/.pytorch_pretrained_bert/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmpi0_4gzpw
INFO:pytorch_pretrained_bert.modeling:Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

INFO:pytorch_pretrained_bert.modeling:Weights from pretrai