# Attention Visualization in Trax

For more information see the [tenso2tensor](https://trax-ml.readthedocs.io/en/latest/) visualization colab. All js tools are taken from the tensor2tensor version along with attention processing methods. The "viz" mode for a Trax model used in this colab [was added to Trax](https://github.com/google/trax/commit/e9a171379ef206a3e351b67cef91fe40bf37589c) with the attention visualization in mind. The colab re-uses some parts of the [Intro to Trax](https://github.com/google/trax/blob/master/trax/intro.ipynb) colab.



**General Setup**

Execute the following few cells (once) before running of visualization codes.

In [None]:
#@title
# Copyright 2020 Google LLC.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import numpy as np
import os
import IPython.display as display
import gin

In [None]:
#@title
# Import Trax

!pip install -q -U trax
import trax

[K     |████████████████████████████████| 368kB 2.8MB/s 
[K     |████████████████████████████████| 1.5MB 13.0MB/s 
[K     |████████████████████████████████| 2.6MB 20.1MB/s 
[K     |████████████████████████████████| 163kB 33.1MB/s 
[K     |████████████████████████████████| 194kB 19.4MB/s 
[K     |████████████████████████████████| 983kB 30.6MB/s 
[K     |████████████████████████████████| 655kB 56.6MB/s 
[K     |████████████████████████████████| 81kB 11.7MB/s 
[K     |████████████████████████████████| 5.3MB 45.0MB/s 
[K     |████████████████████████████████| 368kB 57.1MB/s 
[K     |████████████████████████████████| 307kB 55.8MB/s 
[K     |████████████████████████████████| 358kB 58.6MB/s 
[K     |████████████████████████████████| 1.1MB 59.0MB/s 
[K     |████████████████████████████████| 3.5MB 58.4MB/s 
[K     |████████████████████████████████| 778kB 59.4MB/s 
[K     |████████████████████████████████| 51kB 8.7MB/s 
[K     |████████████████████████████████| 51kB 8.6MB/s 
[K

In [None]:
#@title Some cool tooling for attention (make sure that you run the cell)
def resize(att_mat, max_length=None):
  """Normalize attention matrices and reshape as necessary."""
  for i, att in enumerate(att_mat):
    # Add extra batch dim for viz code to work.
    if att.ndim == 3:
      att = np.expand_dims(att, axis=0)
    if max_length is not None:
      # Sum across different attention values for each token.
      att = att[:, :, :max_length, :max_length]
      row_sums = np.sum(att, axis=2)
      # Normalize
      att /= row_sums[:, :, np.newaxis]
    att_mat[i] = att
  return att_mat


def _get_attention(inp_text, out_text, enc_atts, dec_atts, encdec_atts):
  """Compute representation of the attention ready for the d3 visualization.

  Args:
    inp_text: list of strings, words to be displayed on the left of the vis
    out_text: list of strings, words to be displayed on the right of the vis
    enc_atts: numpy array, encoder self-attentions
        [num_layers, batch_size, num_heads, enc_length, enc_length]
    dec_atts: numpy array, decoder self-attentions
        [num_layers, batch_size, num_heads, dec_length, dec_length]
    encdec_atts: numpy array, encoder-decoder attentions
        [num_layers, batch_size, num_heads, dec_length, enc_length]

  Returns:
    Dictionary of attention representations with the structure:
    {
      'all': Representations for showing all attentions at the same time.
      'inp_inp': Representations for showing encoder self-attentions
      'inp_out': Representations for showing encoder-decoder attentions
      'out_out': Representations for showing decoder self-attentions
    }
    and each sub-dictionary has structure:
    {
      'att': list of inter attentions matrices, one for each attention head
      'top_text': list of strings, words to be displayed on the left of the vis
      'bot_text': list of strings, words to be displayed on the right of the vis
    }
  """
  def get_full_attention(layer):
    """Get the full input+output - input+output attentions."""
    enc_att = enc_atts[layer][0]
    dec_att = dec_atts[layer][0]
    encdec_att = encdec_atts[layer][0]
    enc_att = np.transpose(enc_att, [0, 2, 1])
    dec_att = np.transpose(dec_att, [0, 2, 1])
    encdec_att = np.transpose(encdec_att, [0, 2, 1])
    # [heads, query_length, memory_length]
    enc_length = enc_att.shape[1]
    dec_length = dec_att.shape[1]
    num_heads = enc_att.shape[0]
    first = np.concatenate([enc_att, encdec_att], axis=2)
    second = np.concatenate(
        [np.zeros((num_heads, dec_length, enc_length)), dec_att], axis=2)
    full_att = np.concatenate([first, second], axis=1)
    return [ha.T.tolist() for ha in full_att]

  def get_inp_inp_attention(layer):
    att = np.transpose(enc_atts[layer][0], (0, 2, 1))
    return [ha.T.tolist() for ha in att]

  def get_out_inp_attention(layer):
    att = np.transpose(encdec_atts[layer][0], (0, 2, 1))
    return [ha.T.tolist() for ha in att]

  def get_out_out_attention(layer):
    att = np.transpose(dec_atts[layer][0], (0, 2, 1))
    return [ha.T.tolist() for ha in att]

  def get_attentions(get_attention_fn):
    num_layers = len(enc_atts)
    return [get_attention_fn(i) for i in range(num_layers)]

  attentions = {
      'all': {
          'att': get_attentions(get_full_attention),
          'top_text': inp_text + out_text,
          'bot_text': inp_text + out_text,
      },
      'inp_inp': {
          'att': get_attentions(get_inp_inp_attention),
          'top_text': inp_text,
          'bot_text': inp_text,
      },
      'inp_out': {
          'att': get_attentions(get_out_inp_attention),
          'top_text': inp_text,
          'bot_text': out_text,
      },
      'out_out': {
          'att': get_attentions(get_out_out_attention),
          'top_text': out_text,
          'bot_text': out_text,
      },
  }

  return attentions

In [None]:
#@title Some cool HTML and js stuff (make sure that you run the cell)
vis_html = """
  <span style="user-select:none">
    Layer: <select id="layer"></select>
    Attention: <select id="att_type">
      <option value="all">All</option>
      <option value="inp_inp">Input - Input</option>
      <option value="inp_out">Input - Output</option>
      <option value="out_out">Output - Output</option>
    </select>
  </span>
  <div id='vis'></div>
"""
def call_html():
  import IPython
  display.display(display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              "d3": "https://cdnjs.cloudflare.com/ajax/libs/d3/3.5.8/d3.min",
              jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
            },
          });
        </script>
        '''))
vis_js = """
/**
 * @fileoverview Transformer Visualization D3 javascript code.
 */

requirejs(['jquery', 'd3'],
function($, d3) {

var attention = window.attention;

const TEXT_SIZE = 15;
const BOXWIDTH = TEXT_SIZE * 8;
const BOXHEIGHT = TEXT_SIZE * 1.5;
const WIDTH = 2000;
const HEIGHT = attention.all.bot_text.length * BOXHEIGHT * 2 + 100;
const MATRIX_WIDTH = 150;
const head_colours = d3.scale.category10();
const CHECKBOX_SIZE = 20;

function lighten(colour) {
  var c = d3.hsl(colour);
  var increment = (1 - c.l) * 0.6;
  c.l += increment;
  c.s -= increment;
  return c;
}

function transpose(mat) {
  return mat[0].map(function(col, i) {
    return mat.map(function(row) {
      return row[i];
    });
  });
}

function zip(a, b) {
  return a.map(function (e, i) {
    return [e, b[i]];
  });
}


function renderVis(id, top_text, bot_text, attention_heads, config) {
  $(id).empty();
  var svg = d3.select(id)
            .append('svg')
            .attr("width", WIDTH)
            .attr("height", HEIGHT);

  var att_data = [];
  for (var i=0; i < attention_heads.length; i++) {
    var att_trans = transpose(attention_heads[i]);
    att_data.push(zip(attention_heads[i], att_trans));
  }

  renderText(svg, top_text, true, att_data, 0);
  renderText(svg, bot_text, false, att_data, MATRIX_WIDTH + BOXWIDTH);

  renderAttentionHighlights(svg, att_data);

  svg.append("g").classed("attention_heads", true);

  renderAttention(svg, attention_heads);

  draw_checkboxes(config, 0, svg, attention_heads);
}


function renderText(svg, text, is_top, att_data, left_pos) {
  var id = is_top ? "top" : "bottom";
  var textContainer = svg.append("svg:g")
                         .attr("id", id);

  textContainer.append("g").classed("attention_boxes", true)
               .selectAll("g")
               .data(att_data)
               .enter()
               .append("g")
               .selectAll("rect")
               .data(function(d) {return d;})
               .enter()
               .append("rect")
               .attr("x", function(d, i, j) {
                 return left_pos + box_offset(j);
               })
               .attr("y", function(d, i) {
                 return (+1) * BOXHEIGHT;
               })
               .attr("width", BOXWIDTH/active_heads())
               .attr("height", function() { return BOXHEIGHT; })
               .attr("fill", function(d, i, j) {
                  return head_colours(j);
                })
               .style("opacity", 0.0);


  var tokenContainer = textContainer.append("g").selectAll("g")
                                    .data(text)
                                    .enter()
                                    .append("g");

  tokenContainer.append("rect")
                .classed("background", true)
                .style("opacity", 0.0)
                .attr("fill", "lightgray")
                .attr("x", left_pos)
                .attr("y", function(d, i) {
                  return (i+1) * BOXHEIGHT;
                })
                .attr("width", BOXWIDTH)
                .attr("height", BOXHEIGHT);

  var theText = tokenContainer.append("text")
                              .text(function(d) { return d; })
                              .attr("font-size", TEXT_SIZE + "px")
                              .style("cursor", "default")
                              .style("-webkit-user-select", "none")
                              .attr("x", left_pos)
                              .attr("y", function(d, i) {
                                return (i+1) * BOXHEIGHT;
                              });

  if (is_top) {
    theText.style("text-anchor", "end")
           .attr("dx", BOXWIDTH - TEXT_SIZE)
           .attr("dy", TEXT_SIZE);
  } else {
    theText.style("text-anchor", "start")
           .attr("dx", + TEXT_SIZE)
           .attr("dy", TEXT_SIZE);
  }

  tokenContainer.on("mouseover", function(d, index) {
    textContainer.selectAll(".background")
                 .style("opacity", function(d, i) {
                   return i == index ? 1.0 : 0.0;
                 });

    svg.selectAll(".attention_heads").style("display", "none");

    svg.selectAll(".line_heads")  // To get the nesting to work.
       .selectAll(".att_lines")
       .attr("stroke-opacity", function(d) {
          return 1.0;
        })
       .attr("y1", function(d, i) {
        if (is_top) {
          return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);
        } else {
          return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);
        }
     })
     .attr("x1", BOXWIDTH)
     .attr("y2", function(d, i) {
       if (is_top) {
          return (i+1) * BOXHEIGHT + (BOXHEIGHT/2);
        } else {
          return (index+1) * BOXHEIGHT + (BOXHEIGHT/2);
        }
     })
     .attr("x2", BOXWIDTH + MATRIX_WIDTH)
     .attr("stroke-width", 2)
     .attr("stroke", function(d, i, j) {
        return head_colours(j);
      })
     .attr("stroke-opacity", function(d, i, j) {
      if (is_top) {d = d[0];} else {d = d[1];}
      if (config.head_vis[j]) {
        if (d) {
          return d[index];
        } else {
          return 0.0;
        }
      } else {
        return 0.0;
      }
     });


    function updateAttentionBoxes() {
      var id = is_top ? "bottom" : "top";
      var the_left_pos = is_top ? MATRIX_WIDTH + BOXWIDTH : 0;
      svg.select("#" + id)
         .selectAll(".attention_boxes")
         .selectAll("g")
         .selectAll("rect")
         .attr("x", function(d, i, j) { return the_left_pos + box_offset(j); })
         .attr("y", function(d, i) { return (i+1) * BOXHEIGHT; })
         .attr("width", BOXWIDTH/active_heads())
         .attr("height", function() { return BOXHEIGHT; })
         .style("opacity", function(d, i, j) {
            if (is_top) {d = d[0];} else {d = d[1];}
            if (config.head_vis[j])
              if (d) {
                return d[index];
              } else {
                return 0.0;
              }
            else
              return 0.0;

         });
    }

    updateAttentionBoxes();
  });

  textContainer.on("mouseleave", function() {
    d3.select(this).selectAll(".background")
                   .style("opacity", 0.0);

    svg.selectAll(".att_lines").attr("stroke-opacity", 0.0);
    svg.selectAll(".attention_heads").style("display", "inline");
    svg.selectAll(".attention_boxes")
       .selectAll("g")
       .selectAll("rect")
       .style("opacity", 0.0);
  });
}

function renderAttentionHighlights(svg, attention) {
  var line_container = svg.append("g");
  line_container.selectAll("g")
                .data(attention)
                .enter()
                .append("g")
                .classed("line_heads", true)
                .selectAll("line")
                .data(function(d){return d;})
                .enter()
                .append("line").classed("att_lines", true);
}

function renderAttention(svg, attention_heads) {
  var line_container = svg.selectAll(".attention_heads");
  line_container.html(null);
  for(var h=0; h<attention_heads.length; h++) {
    for(var a=0; a<attention_heads[h].length; a++) {
      for(var s=0; s<attention_heads[h][a].length; s++) {
        line_container.append("line")
        .attr("y1", (s+1) * BOXHEIGHT + (BOXHEIGHT/2))
        .attr("x1", BOXWIDTH)
        .attr("y2", (a+1) * BOXHEIGHT + (BOXHEIGHT/2))
        .attr("x2", BOXWIDTH + MATRIX_WIDTH)
        .attr("stroke-width", 2)
        .attr("stroke", head_colours(h))
        .attr("stroke-opacity", function() {
          if (config.head_vis[h]) {
            return attention_heads[h][a][s]/active_heads();
          } else {
            return 0.0;
          }
        }());
      }
    }
  }
}

// Checkboxes
function box_offset(i) {
  var num_head_above = config.head_vis.reduce(
      function(acc, val, cur) {return val && cur < i ? acc + 1: acc;}, 0);
  return num_head_above*(BOXWIDTH / active_heads());
}

function active_heads() {
  return config.head_vis.reduce(function(acc, val) {
    return val ? acc + 1: acc;
  }, 0);
}

function draw_checkboxes(config, top, svg, attention_heads) {
  var checkboxContainer = svg.append("g");
  var checkbox = checkboxContainer.selectAll("rect")
                                  .data(config.head_vis)
                                  .enter()
                                  .append("rect")
                                  .attr("fill", function(d, i) {
                                    return head_colours(i);
                                  })
                                  .attr("x", function(d, i) {
                                    return (i+1) * CHECKBOX_SIZE;
                                  })
                                  .attr("y", top)
                                  .attr("width", CHECKBOX_SIZE)
                                  .attr("height", CHECKBOX_SIZE);

  function update_checkboxes() {
    checkboxContainer.selectAll("rect")
                              .data(config.head_vis)
                              .attr("fill", function(d, i) {
      var head_colour = head_colours(i);
      var colour = d ? head_colour : lighten(head_colour);
      return colour;
    });
  }

  update_checkboxes();

  checkbox.on("click", function(d, i) {
    if (config.head_vis[i] && active_heads() == 1) return;
    config.head_vis[i] = !config.head_vis[i];
    update_checkboxes();
    renderAttention(svg, attention_heads);
  });

  checkbox.on("dblclick", function(d, i) {
    // If we double click on the only active head then reset
    if (config.head_vis[i] && active_heads() == 1) {
      config.head_vis = new Array(config.num_heads).fill(true);
    } else {
      config.head_vis = new Array(config.num_heads).fill(false);
      config.head_vis[i] = true;
    }
    update_checkboxes();
    renderAttention(svg, attention_heads);
  });
}

var config = {
  layer: 0,
  att_type: 'all',
};

function visualize() {
  var num_heads = attention['all']['att'][0].length;
  config.head_vis  = new Array(num_heads).fill(true);
  config.num_heads = num_heads;
  config.attention = attention;

  render();
}

function render() {
  var conf = config.attention[config.att_type];

  var top_text = conf.top_text;
  var bot_text = conf.bot_text;
  var attention = conf.att[config.layer];

  $("#vis svg").empty();
  renderVis("#vis", top_text, bot_text, attention, config);
}

$("#layer").empty();
for(var i=0; i<6; i++) {
  $("#layer").append($("<option />").val(i).text(i));
}

$("#layer").on('change', function(e) {
  config.layer = +e.currentTarget.value;
  render();
});

$("#att_type").on('change', function(e) {
  config.att_type = e.currentTarget.value;
  render();
});

$("button").on('click', visualize);

visualize();

});
"""

## 1. Run a pre-trained Transformer

* create a Transformer model in Trax with [trax.models.Transformer](https://trax-ml.readthedocs.io/en/latest/trax.models.html#trax.models.transformer.Transformer)
* initialize it from a file with pre-trained weights with [model.init_from_file](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Layer.init_from_file)
* tokenize your input sentence to input into the model with [trax.data.tokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.tokenize)
* decode from the Transformer with [trax.supervised.decoding.autoregressive_sample](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html#trax.supervised.decoding.autoregressive_sample)
* de-tokenize the decoded result to get the translation with [trax.data.detokenize](https://trax-ml.readthedocs.io/en/latest/trax.data.html#trax.data.tf_inputs.detokenize)


In [None]:
# Create a Transformer model.
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='predict')

# Initialize using pre-trained weights.
model.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                     weights_only=True)

# Tokenize a sentence.
sentence = 'It is nice to learn new things today!'
tokenized = list(trax.data.tokenize(iter([sentence]),  # Operates on streams.
                                    vocab_dir='gs://trax-ml/vocabs/',
                                    vocab_file='ende_32k.subword'))[0]

# Decode from the Transformer.
tokenized = tokenized[None, :]  # Add batch dimension.
tokenized_translation = trax.supervised.decoding.autoregressive_sample(
    model, tokenized, temperature=0.0)  # Higher temperature: more diverse results.

# De-tokenize,
tokenized_translation = tokenized_translation[0][:-1]  # Remove batch and EOS.
translation = trax.data.detokenize(tokenized_translation,
                                   vocab_dir='gs://trax-ml/vocabs/',
                                   vocab_file='ende_32k.subword')
print(translation)

Es ist schön, heute neue Dinge zu lernen!


In [None]:
tokenized, tokenized_translation

(array([[ 118,   16, 1902,    9, 3197,  141, 1059,  420,  207]]),
 array([ 168,   24, 9358,    2,  352,  367, 2427,   18, 3580,  207]))

## 2. Prepare the tokens for visualization

In [None]:
def decode(single_token):
  return trax.data.detokenize(single_token,
                              vocab_dir='gs://trax-ml/vocabs/',
                              vocab_file='ende_32k.subword')

In [None]:
def get_tokens_str(integers):
  token_strs = []
  for i in range(integers.shape[1]):
    token_strs.append(decode(integers[:,i]))
  return token_strs

In [None]:
tokenized_translation_with_start = np.array([0]+list(tokenized_translation), dtype=np.int64)
tokenized_translation_with_start = tokenized_translation_with_start[np.newaxis, ...]
tokenized_translation = np.array(tokenized_translation, dtype=np.int64)
tokenized_translation = tokenized_translation[np.newaxis, ...]

In [None]:
tokenized_str = get_tokens_str(tokenized)
tokenized_translation_str = get_tokens_str(tokenized_translation_with_start)

In [None]:
tokenized_str, tokenized_translation_str

(['It', 'is', 'nice', 'to', 'learn', 'new', 'things', 'today', '!'],
 ['<pad>',
  'Es',
  'ist',
  'schön',
  ', ',
  'heute',
  'neue',
  'Dinge',
  'zu',
  'lernen',
  '!'])

In [None]:
max_len = max(tokenized.shape[1], tokenized_translation.shape[1])

In [None]:
tokenized_translation_pad = np.zeros((1,max_len), dtype=np.int64)
tokenized_translation_pad[:,:tokenized_translation.shape[1]] = tokenized_translation

tokenized_pad = np.zeros((1,max_len), dtype=np.int64)
tokenized_pad[:,:tokenized.shape[1]] = tokenized

In [None]:
tokenized_translation_pad.shape, tokenized_pad.shape

((1, 10), (1, 10))

## 3. Create the same pre-trained model in the "viz" mode.

In [None]:
# Create a Transformer model in the "viz" mode
# Pre-trained model config in gs://trax-ml/models/translation/ende_wmt32k.gin
model_viz = trax.models.Transformer(
    input_vocab_size=33300,
    d_model=512, d_ff=2048,
    n_heads=8, n_encoder_layers=6, n_decoder_layers=6,
    max_len=2048, mode='viz')

# Initialize using pre-trained weights.
model_viz.init_from_file('gs://trax-ml/models/translation/ende_wmt32k.pkl.gz',
                     weights_only=True)

In [None]:
# We run the viz model because later we want to inspect its state
_ = model_viz((tokenized_pad, tokenized_translation_pad))

## 4. Find the attention weights (aka dots)

In [None]:
attention_weights = []
def attention_sublayers(layer):
  if 'Attention' in layer.name:
    print("Found layer {}".format(layer.name))
    attention_weights.append(layer.state)
  if layer.sublayers:
    for sublayer in layer.sublayers:
      attention_sublayers(sublayer)

In [None]:
attention_sublayers(model_viz)

Found layer PureAttention
Found layer PureAttention
Found layer PureAttention
Found layer PureAttention
Found layer PureAttention
Found layer PureAttention
Found layer DotProductCausalAttention
Found layer PureAttention
Found layer DotProductCausalAttention
Found layer PureAttention
Found layer DotProductCausalAttention
Found layer PureAttention
Found layer DotProductCausalAttention
Found layer PureAttention
Found layer DotProductCausalAttention
Found layer PureAttention
Found layer DotProductCausalAttention
Found layer PureAttention


In [None]:
len(attention_weights)

18

In [None]:
# Manually identification of layers would be difficult, hence we rely on attention_sublayers function
enc_atts = attention_weights[:6]
dec_atts = attention_weights[6::2] # these are the DotProductCausalAttention layers
encdec_atts = attention_weights[7::2] # these are the PureAttention layers starting from the 6th layer on

# Here we use a number of python utils inherited from tensor2tensor
enc_atts_res = resize(enc_atts)
dec_atts_res = resize(dec_atts)
encdec_atts_res = resize(encdec_atts)
attention_dict = _get_attention(tokenized_str, tokenized_translation_str, enc_atts_res, dec_atts_res, encdec_atts_res)
attention_json = json.dumps(attention_dict)

## 5. Display attention

In [None]:
call_html()
display.display(display.HTML(vis_html))
display.display(display.Javascript('window.attention = %s' % attention_json))
display.display(display.Javascript(vis_js))

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>