# Bert Experiments

In [111]:
import torch
from pytorch_transformers import BertTokenizer, BertModel, BertForMaskedLM, BertConfig
import logging
logging.basicConfig(level=logging.INFO)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def batch_to_idx_tensor(text):
    assert isinstance(text, list), 'Must input a list of strings!'
    str_tokens = [tokenizer.tokenize('[CLS] ' + t + ' [SEP]') for t in text]
    indexed_tokens = [tokenizer.convert_tokens_to_ids(t) for t in str_tokens]
    seq_lens = torch.LongTensor(list(map(len, indexed_tokens)))
    seq_tensor = torch.zeros((len(indexed_tokens), seq_lens.max()), dtype=torch.long)
    attn_mask_tensor = torch.zeros((len(indexed_tokens), seq_lens.max()), dtype=torch.long)
    pad_idx = tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0]
    seq_tensor.fill_(pad_idx)

    for idx, (seq, seqlen) in enumerate(zip(indexed_tokens, seq_lens)):
        seq_tensor[idx, :seqlen] = torch.LongTensor(seq)
        attn_mask_tensor[idx, :seqlen] = 1

    scrm_seq_lens, scrm_idxs = seq_lens.sort(0, descending=True)
    scrm_seq_tensor = seq_tensor[scrm_idxs]
    scrm_str_tokens = [str_tokens[i.item()] for i in scrm_idxs]
    scrm_attn_mask = attn_mask_tensor[scrm_idxs]

    return scrm_seq_tensor, scrm_str_tokens, scrm_attn_mask, scrm_idxs

def str_to_idx_tensor(text, masked_words=None):
    assert isinstance(text, str), 'Must input a string!'
    # tokens_tensor = torch.tensor(indexed_tokens)
    text = '[CLS] ' + text + ' [SEP]'
    tokenized_text = tokenizer.tokenize(text)
    print(tokenized_text)

    if masked_words is not None:
        for i, tok in enumerate(tokenized_text):
            if tok in masked_words:
                tokenized_text[i] = '[MASK]'

    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])

    return tokens_tensor

INFO:pytorch_transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/hansonlu/.cache/torch/pytorch_transformers/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


In [7]:
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)

INFO:pytorch_transformers.modeling_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json from cache at /home/hansonlu/.cache/torch/pytorch_transformers/4dad0251492946e18ac39290fcfe91b89d370fee250efe9521476438fe8ca185.bf3b9ea126d8c0001ee8a1e8b92229871d06d36d8808208cc2449280da87785c
INFO:pytorch_transformers.modeling_utils:Model config {
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": true,
  "output_hidden_states": false,
  "torchscript": false,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

INFO:pytorch_transformers.file_utils:https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin not f

  6%|▌         | 26711040/440473133 [00:27<09:34, 719909.75B/s][A[A

  6%|▌         | 26924032/440473133 [00:27<07:45, 887534.95B/s][A[A

  6%|▌         | 27104256/440473133 [00:27<06:37, 1039831.39B/s][A[A

  6%|▌         | 27379712/440473133 [00:28<05:23, 1278581.43B/s][A[A

  6%|▋         | 27595776/440473133 [00:28<04:43, 1454059.72B/s][A[A

  6%|▋         | 27825152/440473133 [00:28<04:14, 1620554.98B/s][A[A

  6%|▋         | 28169216/440473133 [00:28<03:34, 1920712.47B/s][A[A

  6%|▋         | 28419072/440473133 [00:29<15:24, 445542.37B/s] [A[A

  6%|▋         | 28599296/440473133 [00:30<13:18, 515642.52B/s][A[A

  7%|▋         | 28873728/440473133 [00:30<10:17, 666041.81B/s][A[A

  7%|▋         | 29152256/440473133 [00:30<07:58, 859178.15B/s][A[A

  7%|▋         | 29430784/440473133 [00:30<06:22, 1075479.44B/s][A[A

  7%|▋         | 29742080/440473133 [00:30<05:10, 1323355.99B/s][A[A

  7%|▋         | 30069760/440473133 [00:30<04:17, 1593191.79B/s][A[

 12%|█▏        | 51533824/440473133 [00:54<13:41, 473667.90B/s][A[A

 12%|█▏        | 51745792/440473133 [00:54<10:50, 597785.01B/s][A[A

 12%|█▏        | 51975168/440473133 [00:54<08:39, 747710.25B/s][A[A

 12%|█▏        | 52155392/440473133 [00:54<07:09, 903215.60B/s][A[A

 12%|█▏        | 52352000/440473133 [00:54<06:05, 1062069.08B/s][A[A

 12%|█▏        | 52532224/440473133 [00:54<05:20, 1210706.33B/s][A[A

 12%|█▏        | 52728832/440473133 [00:54<04:43, 1367017.86B/s][A[A

 12%|█▏        | 52903936/440473133 [00:54<04:39, 1386182.65B/s][A[A

 12%|█▏        | 53138432/440473133 [00:54<04:25, 1460372.35B/s][A[A

 12%|█▏        | 53304320/440473133 [00:55<12:28, 517440.41B/s] [A[A

 12%|█▏        | 53427200/440473133 [00:56<13:09, 490176.03B/s][A[A

 12%|█▏        | 53597184/440473133 [00:56<10:52, 592757.47B/s][A[A

 12%|█▏        | 53777408/440473133 [00:56<09:00, 715239.35B/s][A[A

 12%|█▏        | 53974016/440473133 [00:56<07:37, 844869.66B/s][A[A



 17%|█▋        | 73501696/440473133 [01:16<10:06, 605343.90B/s][A[A

 17%|█▋        | 73651200/440473133 [01:16<08:38, 707050.42B/s][A[A

 17%|█▋        | 73928704/440473133 [01:17<06:42, 910618.09B/s][A[A

 17%|█▋        | 74109952/440473133 [01:17<05:44, 1063585.34B/s][A[A

 17%|█▋        | 74306560/440473133 [01:17<05:00, 1220073.80B/s][A[A

 17%|█▋        | 74617856/440473133 [01:17<04:05, 1489158.27B/s][A[A

 17%|█▋        | 74880000/440473133 [01:17<03:36, 1687865.48B/s][A[A

 17%|█▋        | 75191296/440473133 [01:17<03:09, 1928253.29B/s][A[A

 17%|█▋        | 75435008/440473133 [01:18<11:08, 545808.92B/s] [A[A

 17%|█▋        | 75748352/440473133 [01:18<08:39, 702447.30B/s][A[A

 17%|█▋        | 76108800/440473133 [01:19<06:47, 894717.63B/s][A[A

 17%|█▋        | 76314624/440473133 [01:19<06:14, 972857.87B/s][A[A

 17%|█▋        | 76567552/440473133 [01:19<05:07, 1185052.93B/s][A[A

 17%|█▋        | 76780544/440473133 [01:19<04:29, 1351586.37B/s][A[A

 23%|██▎       | 102552576/440473133 [01:46<03:45, 1501481.87B/s][A[A

 23%|██▎       | 102913024/440473133 [01:46<03:11, 1764288.76B/s][A[A

 23%|██▎       | 103195648/440473133 [01:48<11:07, 505091.44B/s] [A[A

 24%|██▎       | 103535616/440473133 [01:48<08:30, 659723.68B/s][A[A

 24%|██▎       | 103879680/440473133 [01:48<06:26, 870411.95B/s][A[A

 24%|██▎       | 104174592/440473133 [01:48<05:06, 1095731.54B/s][A[A

 24%|██▎       | 104496128/440473133 [01:48<04:05, 1365834.92B/s][A[A

 24%|██▍       | 104863744/440473133 [01:48<03:19, 1682850.06B/s][A[A

 24%|██▍       | 105168896/440473133 [01:49<08:30, 656858.35B/s] [A[A

 24%|██▍       | 105452544/440473133 [01:50<06:41, 834796.50B/s][A[A

 24%|██▍       | 105747456/440473133 [01:50<05:16, 1056599.56B/s][A[A

 24%|██▍       | 106107904/440473133 [01:50<04:09, 1339380.24B/s][A[A

 24%|██▍       | 106468352/440473133 [01:50<03:29, 1596026.40B/s][A[A

 24%|██▍       | 106779648/440473133 [01:50<02:58, 186

 30%|██▉       | 131297280/440473133 [02:16<03:33, 1449146.68B/s][A[A

 30%|██▉       | 131503104/440473133 [02:16<03:15, 1576804.30B/s][A[A

 30%|██▉       | 131691520/440473133 [02:17<06:35, 780560.40B/s] [A[A

 30%|██▉       | 131834880/440473133 [02:18<10:42, 480438.14B/s][A[A

 30%|██▉       | 132060160/440473133 [02:18<08:32, 601561.31B/s][A[A

 30%|███       | 132289536/440473133 [02:18<06:55, 742211.99B/s][A[A

 30%|███       | 132453376/440473133 [02:18<05:47, 886436.21B/s][A[A

 30%|███       | 132633600/440473133 [02:18<04:59, 1028150.43B/s][A[A

 30%|███       | 132797440/440473133 [02:18<04:28, 1144558.11B/s][A[A

 30%|███       | 132998144/440473133 [02:18<03:54, 1313887.52B/s][A[A

 30%|███       | 133164032/440473133 [02:18<03:43, 1377196.72B/s][A[A

 30%|███       | 133370880/440473133 [02:19<03:22, 1517126.01B/s][A[A

 30%|███       | 133544960/440473133 [02:19<03:18, 1543403.60B/s][A[A

 30%|███       | 133714944/440473133 [02:19<04:23, 1166

 33%|███▎      | 146543616/440473133 [02:32<04:58, 983687.34B/s][A[A

 33%|███▎      | 146674688/440473133 [02:32<04:38, 1053888.06B/s][A[A

 33%|███▎      | 146782208/440473133 [02:33<04:54, 997706.06B/s] [A[A

 33%|███▎      | 146920448/440473133 [02:33<04:51, 1006631.21B/s][A[A

 33%|███▎      | 147022848/440473133 [02:33<05:04, 962907.37B/s] [A[A

 33%|███▎      | 147133440/440473133 [02:33<05:00, 976812.95B/s][A[A

 33%|███▎      | 147232768/440473133 [02:33<05:23, 907853.20B/s][A[A

 33%|███▎      | 147330048/440473133 [02:33<05:30, 888052.04B/s][A[A

 33%|███▎      | 147428352/440473133 [02:33<05:47, 843028.77B/s][A[A

 33%|███▎      | 147543040/440473133 [02:33<05:27, 895475.74B/s][A[A

 34%|███▎      | 147641344/440473133 [02:33<05:22, 906967.24B/s][A[A

 34%|███▎      | 147739648/440473133 [02:34<05:16, 925929.54B/s][A[A

 34%|███▎      | 147837952/440473133 [02:34<05:12, 936048.73B/s][A[A

 34%|███▎      | 147936256/440473133 [02:34<05:10, 941427.88

 36%|███▋      | 160528384/440473133 [02:47<03:21, 1387936.80B/s][A[A

 36%|███▋      | 160683008/440473133 [02:47<03:15, 1429818.11B/s][A[A

 37%|███▋      | 160827392/440473133 [02:47<03:15, 1431286.75B/s][A[A

 37%|███▋      | 160994304/440473133 [02:47<03:07, 1491008.73B/s][A[A

 37%|███▋      | 161144832/440473133 [02:47<03:19, 1401625.38B/s][A[A

 37%|███▋      | 161321984/440473133 [02:48<09:52, 471512.10B/s] [A[A

 37%|███▋      | 161427456/440473133 [02:48<09:42, 478676.96B/s][A[A

 37%|███▋      | 161551360/440473133 [02:49<08:14, 563756.00B/s][A[A

 37%|███▋      | 161730560/440473133 [02:49<06:32, 709664.18B/s][A[A

 37%|███▋      | 161848320/440473133 [02:49<06:09, 753303.26B/s][A[A

 37%|███▋      | 162010112/440473133 [02:49<05:22, 862937.97B/s][A[A

 37%|███▋      | 162223104/440473133 [02:49<04:38, 1000017.69B/s][A[A

 37%|███▋      | 162403328/440473133 [02:49<04:01, 1151104.89B/s][A[A

 37%|███▋      | 162567168/440473133 [02:49<03:40, 12613

 41%|████      | 180332544/440473133 [03:08<04:30, 962414.87B/s][A[A

 41%|████      | 180442112/440473133 [03:08<04:41, 923610.99B/s][A[A

 41%|████      | 180574208/440473133 [03:08<04:16, 1015019.74B/s][A[A

 41%|████      | 180679680/440473133 [03:08<04:33, 948393.55B/s] [A[A

 41%|████      | 180802560/440473133 [03:08<04:39, 927668.11B/s][A[A

 41%|████      | 180950016/440473133 [03:08<04:53, 884208.61B/s][A[A

 41%|████      | 181097472/440473133 [03:09<04:41, 921982.91B/s][A[A

 41%|████      | 181244928/440473133 [03:09<04:41, 922134.64B/s][A[A

 41%|████      | 181376000/440473133 [03:09<04:16, 1012089.06B/s][A[A

 41%|████      | 181490688/440473133 [03:09<04:20, 993836.15B/s] [A[A

 41%|████      | 181621760/440473133 [03:09<04:02, 1069230.57B/s][A[A

 41%|████▏     | 181732352/440473133 [03:09<04:02, 1067872.30B/s][A[A

 41%|████▏     | 181841920/440473133 [03:09<04:00, 1074248.47B/s][A[A

 41%|████▏     | 181965824/440473133 [03:09<03:54, 110130

 44%|████▍     | 194917376/440473133 [03:24<11:37, 352110.03B/s] [A[A

 44%|████▍     | 195038208/440473133 [03:24<09:52, 413905.17B/s][A[A

 44%|████▍     | 195204096/440473133 [03:24<08:10, 499639.84B/s][A[A

 44%|████▍     | 195351552/440473133 [03:24<06:39, 613681.36B/s][A[A

 44%|████▍     | 195499008/440473133 [03:25<06:11, 659308.70B/s][A[A

 44%|████▍     | 195679232/440473133 [03:25<05:40, 718228.40B/s][A[A

 44%|████▍     | 195810304/440473133 [03:25<04:54, 830443.85B/s][A[A

 44%|████▍     | 195918848/440473133 [03:25<05:49, 700507.03B/s][A[A

 44%|████▍     | 196009984/440473133 [03:25<07:52, 517437.77B/s][A[A

 45%|████▍     | 196083712/440473133 [03:26<08:07, 501355.15B/s][A[A

 45%|████▍     | 196149248/440473133 [03:26<07:47, 522746.16B/s][A[A

 45%|████▍     | 196212736/440473133 [03:26<08:07, 500981.60B/s][A[A

 45%|████▍     | 196301824/440473133 [03:26<07:57, 511726.59B/s][A[A

 45%|████▍     | 196367360/440473133 [03:26<07:38, 531989.67B/s

 46%|████▌     | 201511936/440473133 [03:40<14:00, 284329.56B/s][A[A

 46%|████▌     | 201544704/440473133 [03:40<13:50, 287705.76B/s][A[A

 46%|████▌     | 201577472/440473133 [03:40<13:53, 286737.10B/s][A[A

 46%|████▌     | 201610240/440473133 [03:40<15:00, 265192.13B/s][A[A

 46%|████▌     | 201659392/440473133 [03:40<13:39, 291566.31B/s][A[A

 46%|████▌     | 201692160/440473133 [03:40<14:28, 274991.85B/s][A[A

 46%|████▌     | 201741312/440473133 [03:41<14:10, 280778.88B/s][A[A

 46%|████▌     | 201774080/440473133 [03:41<13:35, 292882.10B/s][A[A

 46%|████▌     | 201823232/440473133 [03:41<12:22, 321454.16B/s][A[A

 46%|████▌     | 201857024/440473133 [03:41<14:00, 283957.03B/s][A[A

 46%|████▌     | 201888768/440473133 [03:41<13:55, 285453.70B/s][A[A

 46%|████▌     | 201937920/440473133 [03:41<12:23, 320890.36B/s][A[A

 46%|████▌     | 201972736/440473133 [03:41<12:47, 310729.74B/s][A[A

 46%|████▌     | 202006528/440473133 [03:41<12:31, 317466.52B/s]

 47%|████▋     | 206246912/440473133 [03:55<18:00, 216790.85B/s][A[A

 47%|████▋     | 206296064/440473133 [03:55<16:22, 238320.26B/s][A[A

 47%|████▋     | 206328832/440473133 [03:56<15:30, 251609.37B/s][A[A

 47%|████▋     | 206361600/440473133 [03:56<16:34, 235455.34B/s][A[A

 47%|████▋     | 206394368/440473133 [03:56<15:29, 251911.12B/s][A[A

 47%|████▋     | 206427136/440473133 [03:56<17:04, 228348.15B/s][A[A

 47%|████▋     | 206459904/440473133 [03:56<15:55, 244899.57B/s][A[A

 47%|████▋     | 206492672/440473133 [03:56<16:16, 239712.45B/s][A[A

 47%|████▋     | 206525440/440473133 [03:56<16:19, 238910.43B/s][A[A

 47%|████▋     | 206550016/440473133 [03:57<21:36, 180477.17B/s][A[A

 47%|████▋     | 206590976/440473133 [03:57<19:20, 201560.82B/s][A[A

 47%|████▋     | 206623744/440473133 [03:57<18:43, 208098.59B/s][A[A

 47%|████▋     | 206656512/440473133 [03:57<18:13, 213758.46B/s][A[A

 47%|████▋     | 206689280/440473133 [03:57<18:02, 215884.23B/s]

 48%|████▊     | 211686400/440473133 [04:11<23:49, 160100.54B/s][A[A

 48%|████▊     | 211719168/440473133 [04:11<21:27, 177734.74B/s][A[A

 48%|████▊     | 211751936/440473133 [04:11<19:58, 190886.39B/s][A[A

 48%|████▊     | 211772416/440473133 [04:11<20:12, 188579.51B/s][A[A

 48%|████▊     | 211801088/440473133 [04:12<19:30, 195299.95B/s][A[A

 48%|████▊     | 211833856/440473133 [04:12<18:21, 207612.82B/s][A[A

 48%|████▊     | 211866624/440473133 [04:12<16:28, 231282.18B/s][A[A

 48%|████▊     | 211899392/440473133 [04:12<16:12, 235135.48B/s][A[A

 48%|████▊     | 211932160/440473133 [04:12<18:33, 205305.91B/s][A[A

 48%|████▊     | 211981312/440473133 [04:12<16:32, 230159.92B/s][A[A

 48%|████▊     | 212014080/440473133 [04:12<16:58, 224231.45B/s][A[A

 48%|████▊     | 212038656/440473133 [04:13<17:24, 218616.86B/s][A[A

 48%|████▊     | 212063232/440473133 [04:13<19:16, 197581.13B/s][A[A

 48%|████▊     | 212084736/440473133 [04:13<19:19, 196951.50B/s]

 50%|█████     | 221812736/440473133 [04:27<07:17, 499745.25B/s][A[A

 50%|█████     | 221942784/440473133 [04:27<06:16, 581180.58B/s][A[A

 50%|█████     | 222106624/440473133 [04:27<05:12, 699068.41B/s][A[A

 50%|█████     | 222238720/440473133 [04:27<04:28, 813865.90B/s][A[A

 50%|█████     | 222401536/440473133 [04:27<04:00, 907736.23B/s][A[A

 51%|█████     | 222614528/440473133 [04:27<03:31, 1030327.82B/s][A[A

 51%|█████     | 222843904/440473133 [04:28<02:58, 1219835.56B/s][A[A

 51%|█████     | 222996480/440473133 [04:28<02:50, 1273981.85B/s][A[A

 51%|█████     | 223253504/440473133 [04:28<02:34, 1407403.85B/s][A[A

 51%|█████     | 223463424/440473133 [04:28<02:22, 1526188.28B/s][A[A

 51%|█████     | 223728640/440473133 [04:28<02:07, 1702668.46B/s][A[A

 51%|█████     | 223974400/440473133 [04:28<01:57, 1841194.67B/s][A[A

 51%|█████     | 224269312/440473133 [04:28<01:47, 2012649.78B/s][A[A

 51%|█████     | 224487424/440473133 [04:30<09:06, 39490

 56%|█████▋    | 248288256/440473133 [04:53<01:45, 1826144.10B/s][A[A

 56%|█████▋    | 248529920/440473133 [04:55<06:21, 503285.75B/s] [A[A

 56%|█████▋    | 248733696/440473133 [04:55<04:54, 650154.14B/s][A[A

 57%|█████▋    | 248976384/440473133 [04:55<03:59, 799221.84B/s][A[A

 57%|█████▋    | 249176064/440473133 [04:55<03:16, 974514.16B/s][A[A

 57%|█████▋    | 249451520/440473133 [04:55<02:39, 1200327.44B/s][A[A

 57%|█████▋    | 249697280/440473133 [04:55<02:15, 1408333.58B/s][A[A

 57%|█████▋    | 249926656/440473133 [04:55<02:02, 1558343.32B/s][A[A

 57%|█████▋    | 250188800/440473133 [04:55<01:49, 1737449.59B/s][A[A

 57%|█████▋    | 250409984/440473133 [04:57<06:18, 502354.15B/s] [A[A

 57%|█████▋    | 250570752/440473133 [04:57<05:02, 627377.14B/s][A[A

 57%|█████▋    | 250745856/440473133 [04:57<04:18, 733987.64B/s][A[A

 57%|█████▋    | 250991616/440473133 [04:57<03:24, 928321.06B/s][A[A

 57%|█████▋    | 251188224/440473133 [04:57<02:52, 109832

 63%|██████▎   | 276807680/440473133 [05:25<05:29, 496376.22B/s] [A[A

 63%|██████▎   | 276936704/440473133 [05:25<05:11, 524757.96B/s][A[A

 63%|██████▎   | 277156864/440473133 [05:25<04:04, 667155.50B/s][A[A

 63%|██████▎   | 277387264/440473133 [05:25<03:12, 847791.64B/s][A[A

 63%|██████▎   | 277615616/440473133 [05:25<02:36, 1039569.59B/s][A[A

 63%|██████▎   | 277812224/440473133 [05:26<02:14, 1209407.72B/s][A[A

 63%|██████▎   | 278090752/440473133 [05:26<01:51, 1455705.68B/s][A[A

 63%|██████▎   | 278302720/440473133 [05:26<01:41, 1603614.46B/s][A[A

 63%|██████▎   | 278514688/440473133 [05:26<01:33, 1729160.89B/s][A[A

 63%|██████▎   | 278726656/440473133 [05:26<01:37, 1660922.40B/s][A[A

 63%|██████▎   | 278920192/440473133 [05:26<01:34, 1700923.97B/s][A[A

 63%|██████▎   | 279110656/440473133 [05:26<01:34, 1700202.52B/s][A[A

 63%|██████▎   | 279294976/440473133 [05:26<01:37, 1660838.38B/s][A[A

 63%|██████▎   | 279499776/440473133 [05:26<01:35, 167

 68%|██████▊   | 298931200/440473133 [05:48<02:02, 1157933.65B/s][A[A

 68%|██████▊   | 299063296/440473133 [05:48<01:57, 1202148.20B/s][A[A

 68%|██████▊   | 299193344/440473133 [05:48<01:59, 1181791.52B/s][A[A

 68%|██████▊   | 299340800/440473133 [05:48<01:52, 1255624.39B/s][A[A

 68%|██████▊   | 299471872/440473133 [05:48<01:52, 1255782.37B/s][A[A

 68%|██████▊   | 299602944/440473133 [05:48<01:52, 1249078.91B/s][A[A

 68%|██████▊   | 299750400/440473133 [05:48<01:49, 1281339.88B/s][A[A

 68%|██████▊   | 299897856/440473133 [05:48<01:46, 1318114.84B/s][A[A

 68%|██████▊   | 300045312/440473133 [05:49<01:44, 1341928.30B/s][A[A

 68%|██████▊   | 300180480/440473133 [05:49<01:46, 1316866.08B/s][A[A

 68%|██████▊   | 300340224/440473133 [05:49<01:41, 1374984.91B/s][A[A

 68%|██████▊   | 300479488/440473133 [05:49<01:42, 1367810.91B/s][A[A

 68%|██████▊   | 300617728/440473133 [05:49<02:45, 842699.23B/s] [A[A

 68%|██████▊   | 300728320/440473133 [05:49<02:48, 

 72%|███████▏  | 314954752/440473133 [06:04<02:08, 979672.49B/s][A[A

 72%|███████▏  | 315134976/440473133 [06:05<01:57, 1063326.88B/s][A[A

 72%|███████▏  | 315282432/440473133 [06:05<01:48, 1159050.97B/s][A[A

 72%|███████▏  | 315413504/440473133 [06:05<01:51, 1117686.25B/s][A[A

 72%|███████▏  | 315593728/440473133 [06:05<01:46, 1175659.61B/s][A[A

 72%|███████▏  | 315773952/440473133 [06:05<01:36, 1295624.01B/s][A[A

 72%|███████▏  | 315912192/440473133 [06:05<01:38, 1260122.64B/s][A[A

 72%|███████▏  | 316068864/440473133 [06:05<01:36, 1287432.25B/s][A[A

 72%|███████▏  | 316216320/440473133 [06:05<01:33, 1331812.93B/s][A[A

 72%|███████▏  | 316363776/440473133 [06:05<01:32, 1344350.17B/s][A[A

 72%|███████▏  | 316511232/440473133 [06:06<01:32, 1343510.17B/s][A[A

 72%|███████▏  | 316675072/440473133 [06:06<01:31, 1349654.59B/s][A[A

 72%|███████▏  | 316811264/440473133 [06:06<03:26, 600192.47B/s] [A[A

 72%|███████▏  | 316914688/440473133 [06:07<05:14, 3

 75%|███████▌  | 331303936/440473133 [06:21<01:16, 1426960.81B/s][A[A

 75%|███████▌  | 331453440/440473133 [06:21<01:15, 1434738.92B/s][A[A

 75%|███████▌  | 331597824/440473133 [06:21<01:16, 1430609.69B/s][A[A

 75%|███████▌  | 331742208/440473133 [06:21<01:16, 1420028.95B/s][A[A

 75%|███████▌  | 331895808/440473133 [06:22<01:15, 1432612.64B/s][A[A

 75%|███████▌  | 332040192/440473133 [06:22<01:15, 1427141.78B/s][A[A

 75%|███████▌  | 332183552/440473133 [06:23<04:46, 378268.88B/s] [A[A

 75%|███████▌  | 332288000/440473133 [06:23<04:04, 442738.06B/s][A[A

 75%|███████▌  | 332383232/440473133 [06:23<03:40, 489601.28B/s][A[A

 75%|███████▌  | 332502016/440473133 [06:23<03:04, 585364.27B/s][A[A

 76%|███████▌  | 332600320/440473133 [06:23<02:43, 659416.43B/s][A[A

 76%|███████▌  | 332731392/440473133 [06:23<02:19, 774934.44B/s][A[A

 76%|███████▌  | 332846080/440473133 [06:23<02:09, 829202.02B/s][A[A

 76%|███████▌  | 333009920/440473133 [06:24<01:55, 932118

 79%|███████▉  | 349590528/440473133 [06:41<02:18, 658310.05B/s][A[A

 79%|███████▉  | 349787136/440473133 [06:41<01:53, 796967.52B/s][A[A

 79%|███████▉  | 349948928/440473133 [06:41<01:36, 940046.04B/s][A[A

 79%|███████▉  | 350131200/440473133 [06:41<01:25, 1061595.30B/s][A[A

 80%|███████▉  | 350269440/440473133 [06:41<01:20, 1114368.21B/s][A[A

 80%|███████▉  | 350524416/440473133 [06:41<01:11, 1263846.75B/s][A[A

 80%|███████▉  | 350737408/440473133 [06:42<01:02, 1436755.95B/s][A[A

 80%|███████▉  | 350966784/440473133 [06:42<00:56, 1582921.65B/s][A[A

 80%|███████▉  | 351212544/440473133 [06:42<00:51, 1742650.48B/s][A[A

 80%|███████▉  | 351407104/440473133 [06:43<02:42, 547818.87B/s] [A[A

 80%|███████▉  | 351549440/440473133 [06:43<02:20, 633857.11B/s][A[A

 80%|███████▉  | 351720448/440473133 [06:43<02:01, 731295.77B/s][A[A

 80%|███████▉  | 351917056/440473133 [06:43<01:38, 900904.02B/s][A[A

 80%|███████▉  | 352130048/440473133 [06:43<01:25, 103456

 83%|████████▎ | 365990912/440473133 [06:58<01:11, 1040583.52B/s][A[A

 83%|████████▎ | 366105600/440473133 [06:58<01:18, 948936.74B/s] [A[A

 83%|████████▎ | 366236672/440473133 [06:59<01:19, 928102.09B/s][A[A

 83%|████████▎ | 366400512/440473133 [06:59<01:14, 990297.78B/s][A[A

 83%|████████▎ | 366564352/440473133 [06:59<01:10, 1043046.62B/s][A[A

 83%|████████▎ | 366728192/440473133 [06:59<01:07, 1085019.37B/s][A[A

 83%|████████▎ | 366892032/440473133 [06:59<01:04, 1132915.88B/s][A[A

 83%|████████▎ | 367007744/440473133 [06:59<01:05, 1122808.28B/s][A[A

 83%|████████▎ | 367154176/440473133 [06:59<01:01, 1193288.43B/s][A[A

 83%|████████▎ | 367276032/440473133 [06:59<01:02, 1176999.29B/s][A[A

 83%|████████▎ | 367432704/440473133 [07:00<00:57, 1264639.55B/s][A[A

 83%|████████▎ | 367562752/440473133 [07:00<00:59, 1216290.14B/s][A[A

 83%|████████▎ | 367711232/440473133 [07:00<00:58, 1249351.54B/s][A[A

 84%|████████▎ | 367842304/440473133 [07:00<00:57, 12

 88%|████████▊ | 385799168/440473133 [07:19<01:54, 477736.41B/s][A[A

 88%|████████▊ | 385898496/440473133 [07:19<01:43, 525127.13B/s][A[A

 88%|████████▊ | 386126848/440473133 [07:19<01:21, 669295.88B/s][A[A

 88%|████████▊ | 386323456/440473133 [07:19<01:05, 831575.62B/s][A[A

 88%|████████▊ | 386552832/440473133 [07:19<00:53, 1006529.45B/s][A[A

 88%|████████▊ | 386749440/440473133 [07:20<00:45, 1176103.11B/s][A[A

 88%|████████▊ | 386978816/440473133 [07:20<00:38, 1372912.87B/s][A[A

 88%|████████▊ | 387175424/440473133 [07:20<00:35, 1508280.50B/s][A[A

 88%|████████▊ | 387421184/440473133 [07:20<00:31, 1702018.20B/s][A[A

 88%|████████▊ | 387627008/440473133 [07:20<00:37, 1394164.09B/s][A[A

 88%|████████▊ | 387800064/440473133 [07:20<00:39, 1348450.71B/s][A[A

 88%|████████▊ | 387978240/440473133 [07:20<00:38, 1371530.17B/s][A[A

 88%|████████▊ | 388158464/440473133 [07:20<00:35, 1475782.65B/s][A[A

 88%|████████▊ | 388322304/440473133 [07:21<00:35, 1477

 93%|█████████▎| 410440704/440473133 [07:44<00:30, 972949.33B/s][A[A

 93%|█████████▎| 410647552/440473133 [07:44<00:25, 1156706.11B/s][A[A

 93%|█████████▎| 410883072/440473133 [07:44<00:22, 1330322.85B/s][A[A

 93%|█████████▎| 411112448/440473133 [07:45<00:19, 1496507.88B/s][A[A

 93%|█████████▎| 411390976/440473133 [07:45<00:16, 1733720.21B/s][A[A

 93%|█████████▎| 411606016/440473133 [07:46<00:54, 526887.77B/s] [A[A

 93%|█████████▎| 411800576/440473133 [07:46<00:44, 647775.85B/s][A[A

 94%|█████████▎| 412079104/440473133 [07:46<00:34, 822741.97B/s][A[A

 94%|█████████▎| 412275712/440473133 [07:46<00:28, 981358.98B/s][A[A

 94%|█████████▎| 412587008/440473133 [07:46<00:22, 1222626.03B/s][A[A

 94%|█████████▎| 412800000/440473133 [07:46<00:20, 1373819.81B/s][A[A

 94%|█████████▍| 413176832/440473133 [07:46<00:16, 1673670.78B/s][A[A

 94%|█████████▍| 413422592/440473133 [07:47<00:45, 595513.37B/s] [A[A

 94%|█████████▍| 413601792/440473133 [07:48<00:41, 6502

 98%|█████████▊| 431019008/440473133 [08:06<00:16, 569065.39B/s][A[A

 98%|█████████▊| 431199232/440473133 [08:06<00:13, 686948.18B/s][A[A

 98%|█████████▊| 431330304/440473133 [08:06<00:11, 799806.74B/s][A[A

 98%|█████████▊| 431510528/440473133 [08:06<00:09, 941291.67B/s][A[A

 98%|█████████▊| 431690752/440473133 [08:06<00:08, 1078476.22B/s][A[A

 98%|█████████▊| 431870976/440473133 [08:07<00:07, 1111402.82B/s][A[A

 98%|█████████▊| 432133120/440473133 [08:07<00:06, 1272421.91B/s][A[A

 98%|█████████▊| 432411648/440473133 [08:07<00:05, 1517391.14B/s][A[A

 98%|█████████▊| 432601088/440473133 [08:08<00:18, 434348.46B/s] [A[A

 98%|█████████▊| 432739328/440473133 [08:08<00:18, 424516.11B/s][A[A

 98%|█████████▊| 432935936/440473133 [08:09<00:13, 542248.13B/s][A[A

 98%|█████████▊| 433156096/440473133 [08:09<00:10, 700677.44B/s][A[A

 98%|█████████▊| 433312768/440473133 [08:09<00:08, 825610.15B/s][A[A

 98%|█████████▊| 433492992/440473133 [08:09<00:07, 984949.2

In [8]:
model.eval()



In [90]:
text = 'burma has put five cities on a security alert after religious unrest involving buddhists and moslems in the northern city of mandalay , an informed source said wednesday.' 
text1 = 'police arrested five anti-nuclear protesters friday after they sought to disrupt loading of a french antarctic research and supply vessel , a spokesman for the protesters said .'
text2 = 'turkmen president gurbanguly berdymukhammedov will begin a two-day visit to russia , his country \'s main energy partner , on monday for trade talks , the kremlin press office said .'
text3 = 'israel \'s new government barred yasser arafat from flying to the west bank to meet with former prime minister shimon peres on thursday , a move palestinian officials said violated the israel-plo peace accords .'
toks = str_to_idx_tensor(text)

with torch.no_grad():
    outputs = model(toks)

['[CLS]', 'burma', 'has', 'put', 'five', 'cities', 'on', 'a', 'security', 'alert', 'after', 'religious', 'unrest', 'involving', 'buddhist', '##s', 'and', 'mo', '##sle', '##ms', 'in', 'the', 'northern', 'city', 'of', 'mandal', '##ay', ',', 'an', 'informed', 'source', 'said', 'wednesday', '.', '[SEP]']
tensor([[  101, 22883,  3549,  2343, 19739, 28483,  3070,  5313,  2100,  2022,
         17460, 12274, 15256, 20058,  3527,  2615,  2097,  4088,  1037,  2048,
          1011,  2154,  3942,  2000,  3607,  1010,  2010,  2406,  1005,  1055,
          2364,  2943,  4256,  1010,  2006,  6928,  2005,  3119,  7566,  1010,
          1996,  1047, 28578,  4115,  2811,  2436,  2056,  1012,   102],
        [  101,  3956,  1005,  1055,  2047,  2231, 15605,  8038, 18116, 19027,
         27753,  2013,  3909,  2000,  1996,  2225,  2924,  2000,  3113,  2007,
          2280,  3539,  2704, 11895,  8202, 23976,  2015,  2006,  9432,  1010,
          1037,  2693,  9302,  4584,  2056, 14424,  1996,  3956,  1011, 

In [68]:
# print(outputs[0].shape) 
# (batch_size, sequence_length, hidden_size) 
# Sequence of hidden-states at the output of the last layer of the model.

# print(outputs[1].shape)
# pooler_output, probably not relevant to our usage

# print(len(outputs[2])) 
# attentions of all 12 layers
# print(outputs[2][0].shape) 
# (batch_size, num_heads, sequence_length, sequence_length)
# the third dimension sums up to one
# print(torch.sum(outputs[2][0], dim=3))
tokens = tokenizer.convert_ids_to_tokens(toks.tolist()[0])
cum_attn = []
layers = [0, 8, 9, 10, 11]
for l in layers:
    print('---- Layer', l)
    layer = outputs[2][l]
    summed = layer.sum(dim=2).sum(dim=1).view(-1)
    summed = (summed / summed.sum(dim=0)).tolist()
    cum_attn.append(summed)
    sorted_by_attn = sorted(list(zip(summed, tokens)), key=lambda p: p[0], reverse=True)
    print(sorted_by_attn[:10])
print('---- total')
sum_attn = torch.tensor(cum_attn).sum(dim=0)
sum_attn = (sum_attn / sum_attn.sum(dim=0)).tolist()
sorted_by_attn = sorted(list(zip(sum_attn, tokens)), key=lambda p: p[0], reverse=True)
print(sorted_by_attn[:10])

---- Layer 0
[(0.08032415807247162, '[CLS]'), (0.04123055562376976, 'palestinian'), (0.03168165683746338, 'israel'), (0.0302461925894022, 'thursday'), (0.02916713058948517, 'violated'), (0.02887265384197235, 'minister'), (0.02798117883503437, 'israel'), (0.026993650943040848, '##fat'), (0.025710370391607285, 'accord'), (0.0246791560202837, '[SEP]')]
---- Layer 9
[(0.4811535179615021, '[SEP]'), (0.0437617227435112, '.'), (0.037608399987220764, '[CLS]'), (0.029492225497961044, 'israel'), (0.02293320745229721, 'palestinian'), (0.021103298291563988, '##fat'), (0.020826423540711403, 'barred'), (0.01810401678085327, 'thursday'), (0.018085891380906105, '##sser'), (0.017743797972798347, 'pere')]
---- Layer 10
[(0.3741353750228882, '.'), (0.19038298726081848, ','), (0.059014927595853806, '[CLS]'), (0.050655387341976166, '[SEP]'), (0.02185649797320366, 'palestinian'), (0.020123891532421112, 'pere'), (0.01932804472744465, 'thursday'), (0.019161606207489967, 'israel'), (0.013851759023964405, 'flyi

In [112]:
text = 'burma has put five cities on a security alert after religious unrest involving buddhists and moslems in the northern city of mandalay , an informed source said wednesday.' 
text1 = 'police arrested five anti-nuclear protesters friday after they sought to disrupt loading of a french antarctic research and supply vessel , a spokesman for the protesters said .'
text2 = 'turkmen president gurbanguly berdymukhammedov will begin a two-day visit to russia , his country \'s main energy partner , on monday for trade talks , the kremlin press office said .'
text3 = 'israel \'s new government barred yasser arafat from flying to the west bank to meet with former prime minister shimon peres on thursday , a move palestinian officials said violated the israel-plo peace accords .'
batch = [text, text1, text2, text3]
batch_toks, str_toks, attn_mask, scrm_idxs = batch_to_idx_tensor(batch)

print(batch_toks)
print(attn_mask)
print([len(s) for s in str_toks])

with torch.no_grad():
    outputs = model(batch_toks, attention_mask=attn_mask)

print(outputs[2][0].shape)

['[CLS]', 'burma', 'has', 'put', 'five', 'cities', 'on', 'a', 'security', 'alert', 'after', 'religious', 'unrest', 'involving', 'buddhist', '##s', 'and', 'mo', '##sle', '##ms', 'in', 'the', 'northern', 'city', 'of', 'mandal', '##ay', ',', 'an', 'informed', 'source', 'said', 'wednesday', '.', '[SEP]']
tensor([[  101, 22883,  3549,  2343, 19739, 28483,  3070,  5313,  2100,  2022,
         17460, 12274, 15256, 20058,  3527,  2615,  2097,  4088,  1037,  2048,
          1011,  2154,  3942,  2000,  3607,  1010,  2010,  2406,  1005,  1055,
          2364,  2943,  4256,  1010,  2006,  6928,  2005,  3119,  7566,  1010,
          1996,  1047, 28578,  4115,  2811,  2436,  2056,  1012,   102],
        [  101,  3956,  1005,  1055,  2047,  2231, 15605,  8038, 18116, 19027,
         27753,  2013,  3909,  2000,  1996,  2225,  2924,  2000,  3113,  2007,
          2280,  3539,  2704, 11895,  8202, 23976,  2015,  2006,  9432,  1010,
          1037,  2693,  9302,  4584,  2056, 14424,  1996,  3956,  1011, 

In [138]:
from collections import defaultdict as DD

attended_words = DD(int)

layers = torch.tensor([0, 9, 10, 11])

attn = torch.stack(outputs[2]).index_select(0, layers) # [4, 4, 12, 49, 49]
summed = attn.sum(dim=3).sum(dim=2).sum(dim=0).view(attn.shape[1], attn.shape[4])
summed = summed / summed.sum(dim=1, keepdim=True)
_, topk_idxs = summed.topk(10, sorted=False)
print(topk_idxs)

attended_word_tensor = torch.zeros(tokenizer.vocab_size)

split_batch_toks = batch_toks.split(1, dim=0)
split_topk_idxs = topk_idxs.split(1, dim=0)
print(split_batch_toks)
for idxs, toks in zip(topk_idxs, split_batch_toks):
    attended_toks = toks.view(-1)[idxs.view(-1)]
    attended_word_tensor[attended_toks] += 1
print(attended_word_tensor.nonzero().view(-1))

non_zero_idxs = attended_word_tensor.nonzero().view(-1)
counts = attended_word_tensor[non_zero_idxs].tolist()
toks = tokenizer.convert_ids_to_tokens(non_zero_idxs.tolist())
attended_words.update(zip(toks, counts))
print(attended_words)
print(to)
# offset = (torch.arange(batch_toks.shape[0]) * batch_toks.shape[1]).unsqueeze(1)
# attended_idxs = topk_idxs + offset
# attended_tok_idxs = batch_toks.take(attended_idxs)
# print(attended_tok_idxs)
# attended_word_tensor[attended_tok_idxs] += 1
# print(attended_word_tensor[102])

# summed = (summed / summed.sum(dim=0)).tolist()
# cum_attn.append(summed)
# sorted_by_attn = sorted(list(zip(summed, tokens)), key=lambda p: p[0], reverse=True)
# print(sorted_by_attn[:10])

tensor([[16, 24,  1,  3,  2, 31,  0, 33, 47, 48],
        [25, 10,  6, 28,  1, 32, 45, 44, 29,  0],
        [25, 11,  2, 12, 32,  1, 14,  0, 34, 33],
        [ 1,  8, 28, 17,  6, 31, 30,  7, 18,  0]])
(tensor([[  101, 22883,  3549,  2343, 19739, 28483,  3070,  5313,  2100,  2022,
         17460, 12274, 15256, 20058,  3527,  2615,  2097,  4088,  1037,  2048,
          1011,  2154,  3942,  2000,  3607,  1010,  2010,  2406,  1005,  1055,
          2364,  2943,  4256,  1010,  2006,  6928,  2005,  3119,  7566,  1010,
          1996,  1047, 28578,  4115,  2811,  2436,  2056,  1012,   102]]), tensor([[  101,  3956,  1005,  1055,  2047,  2231, 15605,  8038, 18116, 19027,
         27753,  2013,  3909,  2000,  1996,  2225,  2924,  2000,  3113,  2007,
          2280,  3539,  2704, 11895,  8202, 23976,  2015,  2006,  9432,  1010,
          1037,  2693,  9302,  4584,  2056, 14424,  1996,  3956,  1011, 20228,
          2080,  3521, 15802,  2015,  1012,   102,     0,     0,     0]]), tensor([[  101, 