In [2]:
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import head_view

from smdebug.pytorch import Hook
from smdebug.core.reduction_config import ReductionConfig
from smdebug.core.save_config import SaveConfig

utils.logging.set_verbosity_error()  # Suppress standard warnings

In [3]:
smdhook = Hook(out_dir='./smdebugger/',
            save_config=SaveConfig(save_interval=1),
            save_all=True)

[2022-11-10 21:03:58.824 ip-172-31-2-142:326041 INFO hook.py:201] tensorboard_dir has not been set for the hook. SMDebug will not be exporting tensorboard summaries.
[2022-11-10 21:03:58.827 ip-172-31-2-142:326041 INFO hook.py:254] Saving to ./smdebugger/
[2022-11-10 21:03:58.828 ip-172-31-2-142:326041 INFO state_store.py:77] The checkpoint config file /opt/ml/input/config/checkpointconfig.json does not exist.


In [4]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModel.from_pretrained("distilbert-base-uncased", output_attentions=True)
smdhook.register_module(model)

[2022-11-10 21:04:18.576 ip-172-31-2-142:326041 INFO hook.py:560] name:embeddings.word_embeddings.weight count_params:23440896
[2022-11-10 21:04:18.578 ip-172-31-2-142:326041 INFO hook.py:560] name:embeddings.position_embeddings.weight count_params:393216
[2022-11-10 21:04:18.578 ip-172-31-2-142:326041 INFO hook.py:560] name:embeddings.LayerNorm.weight count_params:768
[2022-11-10 21:04:18.579 ip-172-31-2-142:326041 INFO hook.py:560] name:embeddings.LayerNorm.bias count_params:768
[2022-11-10 21:04:18.579 ip-172-31-2-142:326041 INFO hook.py:560] name:transformer.layer.0.attention.q_lin.weight count_params:589824
[2022-11-10 21:04:18.580 ip-172-31-2-142:326041 INFO hook.py:560] name:transformer.layer.0.attention.q_lin.bias count_params:768
[2022-11-10 21:04:18.581 ip-172-31-2-142:326041 INFO hook.py:560] name:transformer.layer.0.attention.k_lin.weight count_params:589824
[2022-11-10 21:04:18.581 ip-172-31-2-142:326041 INFO hook.py:560] name:transformer.layer.0.attention.k_lin.bias count

In [7]:
inputs = tokenizer.encode("The cat sat on the mat", return_tensors='pt')
outputs = model(inputs)
attention = outputs[-1]  # Output includes attention weights when output_attentions=True
tokens = tokenizer.convert_ids_to_tokens(inputs[0]) 
head_view(attention, tokens)

<IPython.core.display.Javascript object>

In [37]:
inputs

tensor([[  101,  1996,  4937,  2938,  2006,  1996, 13523,   102]])

In [None]:
'''
attentions in
'transformer.layer.0.attention_output_1'

input tokens in 
'word_embeddings_input_0'

get tokenizer to convert input back
vocab = tokenizer.vocab
untoken = {j:i for i,j in vocab.items()}


store the tokenizer as a json file and read this at the start of the plugin
'''

In [39]:
vocab = tokenizer.vocab

In [40]:
untoken = {j:i for i,j in vocab.items()}

In [49]:
untoken[102]

'[SEP]'

In [8]:
from smdebug.core.reader import EventFileReader

In [51]:
reader = EventFileReader('./smdebugger/events/000000000000/000000000000_worker_0.tfevents')

In [52]:
tensors = list(reader.read_tensors())

In [53]:
len(tensors)

279

In [57]:
[i for i in tensors if 'word_embeddings_input_0' in i[0]]

[('embeddings.word_embeddings_input_0',
  0,
  array([[  101,  1996,  4937,  2938,  2006,  1996, 13523,   102]]),
  <ModeKeys.GLOBAL: 4>,
  0)]

In [28]:
[i[0] for i in tensors if i[2].shape==(1, 12, 8, 8)]

['transformer.layer.0.attention.dropout_input_0',
 'transformer.layer.0.attention.dropout_output_0',
 'transformer.layer.0.attention_output_1',
 'transformer.layer.0_output_0',
 'transformer.layer.1.attention.dropout_input_0',
 'transformer.layer.1.attention.dropout_output_0',
 'transformer.layer.1.attention_output_1',
 'transformer.layer.1_output_0',
 'transformer.layer.2.attention.dropout_input_0',
 'transformer.layer.2.attention.dropout_output_0',
 'transformer.layer.2.attention_output_1',
 'transformer.layer.2_output_0',
 'transformer.layer.3.attention.dropout_input_0',
 'transformer.layer.3.attention.dropout_output_0',
 'transformer.layer.3.attention_output_1',
 'transformer.layer.3_output_0',
 'transformer.layer.4.attention.dropout_input_0',
 'transformer.layer.4.attention.dropout_output_0',
 'transformer.layer.4.attention_output_1',
 'transformer.layer.4_output_0',
 'transformer.layer.5.attention.dropout_input_0',
 'transformer.layer.5.attention.dropout_output_0',
 'transformer.

In [30]:
outputs = [i for i in tensors if 'attention_output_1' in i[0]]

In [33]:
attention[0]

tensor([[[[6.3869e-02, 1.1991e-01, 9.5766e-02, 7.6306e-02, 9.5222e-02,
           1.5218e-01, 1.1428e-01, 2.8247e-01],
          [3.5801e-01, 1.0320e-01, 6.0095e-02, 5.7474e-02, 1.1596e-01,
           1.1652e-01, 6.7821e-02, 1.2092e-01],
          [2.0004e-01, 3.3206e-02, 6.3609e-02, 1.7562e-01, 3.0221e-02,
           3.8151e-02, 2.6589e-01, 1.9327e-01],
          [2.5242e-01, 6.9496e-02, 1.9354e-01, 7.1771e-02, 4.7177e-02,
           6.5658e-02, 1.7099e-01, 1.2895e-01],
          [4.0801e-01, 8.8196e-02, 5.9198e-02, 6.2429e-02, 1.2339e-01,
           8.0026e-02, 6.4638e-02, 1.1411e-01],
          [3.9437e-01, 1.2084e-01, 5.3797e-02, 5.5683e-02, 1.0843e-01,
           1.1435e-01, 5.6914e-02, 9.5617e-02],
          [8.8113e-02, 2.8448e-02, 3.5049e-01, 2.5977e-01, 1.9738e-02,
           2.9143e-02, 1.0683e-01, 1.1747e-01],
          [2.0127e-01, 1.3804e-01, 4.0319e-02, 3.8544e-02, 9.7579e-02,
           1.5626e-01, 6.4869e-02, 2.6313e-01]],

         [[9.5100e-01, 1.1581e-02, 7.9453e-03,

In [34]:
outputs[0]

('transformer.layer.0.attention_output_1',
 0,
 array([[[[6.38691485e-02, 1.19910367e-01, 9.57656801e-02,
           7.63061941e-02, 9.52220336e-02, 1.52176768e-01,
           1.14283741e-01, 2.82466084e-01],
          [3.58006090e-01, 1.03195086e-01, 6.00947104e-02,
           5.74744530e-02, 1.15964971e-01, 1.16521776e-01,
           6.78205118e-02, 1.20922327e-01],
          [2.00041905e-01, 3.32055353e-02, 6.36089817e-02,
           1.75616726e-01, 3.02206706e-02, 3.81510742e-02,
           2.65886128e-01, 1.93269074e-01],
          [2.52417296e-01, 6.94959015e-02, 1.93541437e-01,
           7.17712864e-02, 4.71770987e-02, 6.56584278e-02,
           1.70991302e-01, 1.28947303e-01],
          [4.08011675e-01, 8.81964490e-02, 5.91980740e-02,
           6.24285080e-02, 1.23394787e-01, 8.00255910e-02,
           6.46377876e-02, 1.14107177e-01],
          [3.94367516e-01, 1.20836005e-01, 5.37972525e-02,
           5.56833372e-02, 1.08430594e-01, 1.14353657e-01,
           5.69143184e-02