In [None]:
import json
import numpy as np
import altair as alt
import pandas as pd

In [None]:
def read_data(encoder_file, decoder_file):
    with open('data/' + encoder_file) as f:
        encoder = json.load(f)
    with open('data/' + decoder_file) as f:
        decoder = json.load(f)
    return encoder, decoder

# [PAPER] Attention evolution

In [None]:
'''
Tinput  = Input tokens (English)
Toutput = Output tokens (French)
Ba      = Batch size
Be      = Beam dimension
S       = Nr of states
E       = Embedding dimensions
H       = Heads
'''

encoder_carrot, decoder_carrot = read_data('EncoderOutWrite_carrot.json', 'DecoderOutWrite_carrot.json')

encoder_out_carrot = np.array(encoder_carrot['encoder_out']) # 9 x 1 x 1024 = Tinput x B x E 
encoder_embedding_carrot = np.array(encoder_carrot['encoder_embedding']) # 1 x 9 x 1024 = Ba x Tinput x E
encoder_states_carrot = np.array(encoder_carrot['encoder_states']) # 6 x 9 x 1 x 1024 = S x Tinput x Ba x E

decoder_out_carrot = np.array(decoder_carrot['decoder_out']) # 9 x 1 x 44512 = Tinput x Be x E
decoder_attn_carrot = np.array(decoder_carrot['attn']) # 1 x 1 x 9 = Be x Ba x Tinput
decoder_inner_states_carrot = np.array(decoder_carrot['inner_states']) # 7 x 1 x 1 x 1024 = S x Ba x Be x E
decoder_full_attn_carrot = np.array(decoder_carrot['my_attn']) # 8 x 16 x 1 x 1 x 9 = Toutput x H x Ba x Be x Tinput

english_carrot = 'My m um e ats a car rot <eos>'.split(' ') # tokens input
french_carrot = 'Ma mère m ange une car otte <eos>'.split(' ') # tokens output
tag = 'carrot'

In [None]:
# Display fonts
font_size = 12
title_size = 16
legend_size = 16

# Data Preparation
attn_carrot = np.mean(decoder_full_attn_carrot[:,:,0,0,:], axis=1)
x, y = np.meshgrid(range(len(french_carrot)), range(len(english_carrot)))
z = np.transpose(attn_carrot)

source = pd.DataFrame({'x': x.ravel(),
                        'French': np.array([french_carrot[i] for i in x.ravel()]),
                        'y': y.ravel(),
                        'English':np.array([english_carrot[i] for i in y.ravel()]),
                        'Attention': z.ravel(),
                        'tag':tag})

# Plot
first = alt.Chart(source).mark_rect().encode(
    x=alt.X('French:O', sort=[french_carrot[i] for i in range(len(french_carrot))], 
            axis = alt.Axis(
                title="French token", 
                titleFontSize=title_size, 
                labelFontSize=font_size)),
    y=alt.Y('English:O', sort=[english_carrot[i] for i in range(len(english_carrot))], 
            axis=alt.Axis(
                title="English token", 
                titleFontSize=title_size, 
                labelFontSize=font_size)),
    color=alt.Color('Attention'),
    tooltip=['Attention'] # Hover
).properties(
    width=400,
    height=400
).interactive()

In [None]:
# Data preparation
x, y = np.meshgrid(range(len(french_carrot)-1), range(len(english_carrot)-1))
z = np.transpose(attn_carrot[:-1,:-1]) # Get rid of <EOS> token

source = pd.DataFrame({'x': x.ravel(),
                        'French': np.array([french_carrot[i] for i in x.ravel()]),
                        'y': y.ravel(),
                        'English':np.array([english_carrot[i] for i in y.ravel()]),
                        'Attention': z.ravel(),
                        'tag':tag})

# Plot
second = alt.Chart(source).mark_rect().encode(
    x=alt.X('French:O', sort=[french_carrot[i] for i in range(len(french_carrot))], 
            axis = alt.Axis(
                title="French token",
                titleFontSize = title_size, 
                labelFontSize=font_size)),
    y=alt.Y('English:O', sort=[english_carrot[i] for i in range(len(english_carrot))], 
            axis=alt.Axis(
                title="English token",
                titleFontSize = title_size, 
                labelFontSize=font_size)),
    color=alt.Color('Attention'),
    tooltip=['Attention'] # Hover
).properties(
    width=400,
    height=400
).interactive()

In [None]:
# Data preparation
x, y = np.meshgrid(range(len(french_carrot)-1), range(len(english_carrot)-1))
z = np.transpose(attn_carrot[:-1,:-1])

source = pd.DataFrame({'x': x.ravel(),
                        'French': np.array([french_carrot[i] for i in x.ravel()]),
                        'y': y.ravel(),
                        'English':np.array([english_carrot[i] for i in y.ravel()]),
                        'Attention': z.ravel(),
                        'tag':tag})

# Plot
third = alt.Chart(source).mark_rect().encode(
    x=alt.X('French:O', sort=[french_carrot[i] for i in range(len(french_carrot))], 
            axis = alt.Axis(
                title="French token", 
                titleFontSize=title_size, 
                labelFontSize=font_size)),
    y=alt.Y('English:O', sort=[english_carrot[i] for i in range(len(english_carrot))], 
            axis=alt.Axis(
                title="English token", 
                titleFontSize=title_size, 
                labelFontSize=font_size)),
    color=alt.Color('Attention', scale=alt.Scale(scheme='magma')),
    tooltip=['Attention'] # Hover
).properties(
    width=400,
    height=400,
).interactive()

In [None]:
# Merge all plots
(first & second & third
).resolve_scale(
    color='independent'
).configure_legend(
    titleFontSize=legend_size,
    orient='right'
) 

# [ATTENTION] Attention Plots

In [None]:
alt.data_transformers.disable_max_rows()

In [None]:
def attention_plotter(english_tokens, french_tokens, attention_matrix, tag):
    '''
    Input:
        - english_tokens - the input tokens [list]
        - french_tokens  -  the output tokens [list]
        - attention_matrix - the attention matrix (either from encoder or decoder) [np.array]
        - tag - a tag / word to identify the dataset and sentence [string]
    
    Output:
        - source - the processed dataset [pd.DataFrame]
        - chart  - the resulting plot [alt.Chart]
    '''
    
    # Data Preparation
    x, y = np.meshgrid(range(len(french_tokens)), range(len(english_tokens)))
    z = np.transpose(attention_matrix[:-1,:-1])

    source = pd.DataFrame({'x': x.ravel(),
                            'French': np.array([french_tokens[i] for i in x.ravel()]),
                            'y': y.ravel(),
                            'English':np.array([english_tokens[i] for i in y.ravel()]),
                            'Attention': z.ravel(),
                            'tag':tag})
    
    # Plot
    chart = alt.Chart(source).mark_rect().encode(
        x=alt.X('French:N', sort=[french_tokens[i] for i in range(len(french_tokens))], 
                axis = alt.Axis(title="French token")),
        y=alt.Y('English:N', sort=[english_tokens[i] for i in range(len(english_tokens))], 
                axis=alt.Axis(title="English token")),
        color=alt.Color('Attention', 
                        scale=alt.Scale(scheme='magma')),
        tooltip=['Attention']
    ).properties(
        width=400,
        height=400
    ).interactive()
    return source, chart

In [None]:
english_carrot = 'My m um e ats a car rot'.split(' ')
french_carrot = 'Ma mère m ange une car otte'.split(' ')

In [None]:
encoder_carrot_attn = np.array(encoder_carrot['encoder_attn'])
encoder_carrot_attn = encoder_carrot_attn[-1,0,:,:]

source_encoder_carrot, chart_encoder_carrot = attention_plotter(english_carrot, english_carrot, encoder_carrot_attn, 'carrot')
#chart_encoder_world

In [None]:
attention_decoder_carrot = np.mean(decoder_full_attn_carrot[:,:,0,0,:], axis=1)

source_decoder_carrot, chart_decoder_carrot = attention_plotter(english_carrot, french_carrot, attention_decoder_carrot, 'carrot')
#chart_decoder_carrot

# DRESS

Je veux juste une robe avec des manches bouffante

All i want is a dress with puffy sleaves

## encoder

In [None]:
encoder_dress, decoder_dress = read_data('EncoderOutWrite_dress.json', 'DecoderOutWrite_dress.json')

encoder_out_dress = np.array(encoder_dress['encoder_out']) 
encoder_embedding_dress = np.array(encoder_dress['encoder_embedding']) 
encoder_states_dress = np.array(encoder_dress['encoder_states'])

decoder_out_dress = np.array(decoder_dress['decoder_out'])
decoder_attn_dress = np.array(decoder_dress['attn'])
decoder_inner_states_dress = np.array(decoder_dress['inner_states'])
decoder_full_attn_dress = np.array(decoder_dress['my_attn'])

english_dress = 'All I want is a d ress with pu ff y sle aves'.split(' ')
french_dress = 'Je veux seulement une ro be à man ches bou ff antes'.split(' ')

In [None]:
encoder_dress_attn = np.array(encoder_dress['encoder_attn'])
encoder_dress_attn = encoder_dress_attn[-1,0,:,:]

source_encoder_dress, chart_encoder_dress = attention_plotter(english_dress, english_dress, encoder_dress_attn, 'dress')
#chart_encoder_dress

## decoder

In [None]:
attention_decoder_dress = np.mean(decoder_full_attn_dress[:,:,0,0,:], axis=1)

source_decoder_dress, chart_decoder_dress = attention_plotter(english_dress, french_dress, attention_decoder_dress, 'dress')
#chart_decoder_dress

# DAY

Tomorrow is a new a day with no mistakes in it... yet

Demain est une nouvelle journée sans erreur... pour l'instant

## encoder

In [None]:
encoder_day, decoder_day = read_data('EncoderOutWrite_day.json', 'DecoderOutWrite_day.json')

encoder_out_day = np.array(encoder_day['encoder_out'])
encoder_embedding_day = np.array(encoder_day['encoder_embedding']) 
encoder_states_day = np.array(encoder_day['encoder_states']) 

decoder_out_day = np.array(decoder_day['decoder_out']) 
decoder_attn_day = np.array(decoder_day['attn']) 
decoder_inner_states_day = np.array(decoder_day['inner_states']) 
decoder_full_attn_day = np.array(decoder_day['my_attn'])

english_day = "Tom orrow is a new day with no mistakes in it ... yet".split(' ')
french_day = "Dem ain est une nouvelle journée sans erreur ... pour l' instant".split(' ')

In [None]:
encoder_day_attn = np.array(encoder_day['encoder_attn'])
encoder_day_attn = encoder_day_attn[-1,0,:,:]

source_encoder_day, chart_encoder_day = attention_plotter(english_day, english_day, encoder_day_attn, 'day')
#chart_encoder_day

## decoder

In [None]:
attention_day = np.mean(decoder_full_attn_day[:,:,0,0,:], axis=1)

source_decoder_day, chart_decoder_day = attention_plotter(english_day, french_day, attention_day, 'day')
#chart_decoder_day

# journalism

## decoder

In [None]:
encoder_journalism, decoder_journalism = read_data('EncoderOutWrite_journalism.json', 'DecoderOutWrite_journalism.json')

decoder_full_attn_journalism = np.array(decoder_journalism['my_attn'])

attention_journalism = np.mean(decoder_full_attn_journalism[:,:,0,0,:], axis=1)

In [None]:
english_journalism = "sound journ- alism must defend the vo- ic- eless , not send them further into silence".split(' ')
french_journalism ="un journ- alisme sérieux doit défendre les sans - voix , ne pas les_ envoyer dans le silence".split(' ')


source_decoder_journalism, chart_decoder_journalism = attention_plotter(english_journalism, french_journalism, attention_journalism, 'journalism')
#chart_decoder_journalism

## encoder

In [None]:
encoder_journalism_attn = np.array(encoder_journalism['encoder_attn'])
encoder_journalism_attn = encoder_journalism_attn[-1,0,:,:]

In [None]:
source_encoder_journalism, chart_encoder_journalism = attention_plotter(english_journalism, english_journalism, encoder_journalism_attn, 'journalism')
chart_encoder_journalism

# BEE
If I wasn't a human girl I think I'd like to be a bee and live among the flowers

Si je n'étais pas une fille humaine, je pense que j'aimerais être une abeille et vivre parmi les fleurs.

## encoder

In [None]:
encoder_bee, decoder_bee = read_data('EncoderOutWrite_bee.json', 'DecoderOutWrite_bee.json')

encoder_out_bee = np.array(encoder_bee['encoder_out']) 
encoder_embedding_bee = np.array(encoder_bee['encoder_embedding']) 
encoder_states_bee = np.array(encoder_bee['encoder_states']) 

decoder_out_bee = np.array(decoder_bee['decoder_out']) 
decoder_attn_bee = np.array(decoder_bee['attn']) 
decoder_inner_states_bee = np.array(decoder_bee['inner_states'])
decoder_full_attn_bee = np.array(decoder_bee['my_attn'])

english_bee = "If I was n 't a human girl I_ think _I 'd like to be a_ be_ e and live among the flowers".split(' ')
french_bee = "Si je n' étais pas une fille humaine , je_ pense que j&apos; aimerais être une abe ille et vivre parmi les fleurs .".split(' ')

In [None]:
encoder_bee_attn = np.array(encoder_bee['encoder_attn'])
encoder_bee_attn = encoder_bee_attn[-1,0,:,:]

source_encoder_bee, chart_encoder_bee = attention_plotter(english_bee, english_bee, encoder_bee_attn, 'bee')
chart_encoder_bee

## decoder

In [None]:
attention_decoder_bee = np.mean(decoder_full_attn_bee[:,:,0,0,:], axis=1)

source_decoder_bee, chart_decoder_bee = attention_plotter(english_bee, french_bee, attention_decoder_bee, 'bee')
#chart_decoder_bee

# WORLD

In [None]:
encoder_world, decoder_world = read_data('EncoderOutWrite_world.json', 'DecoderOutWrite_world.json')

encoder_out_world = np.array(encoder_world['encoder_out']) 
encoder_embedding_world = np.array(encoder_world['encoder_embedding'])
encoder_states_world = np.array(encoder_world['encoder_states'])

decoder_out_world = np.array(decoder_world['decoder_out']) 
decoder_attn_world = np.array(decoder_world['attn']) 
decoder_inner_states_world = np.array(decoder_world['inner_states'])
decoder_full_attn_world = np.array(decoder_world['my_attn'])

english_world = 'life is short and the world is_ wide'.split(' ')
french_world = 'la vie est courte et le monde est vaste'.split(' ')


## encoder

In [None]:
encoder_world_attn = np.array(encoder_world['encoder_attn'])
encoder_world_attn = encoder_world_attn[-1,0,:,:]

source_encoder_world, chart_encoder_world = attention_plotter(english_world, english_world, encoder_world_attn, 'world')
#chart_encoder_world

## decoder

In [None]:
attention_decoder_world = np.mean(decoder_full_attn_world[:,:,0,0,:], axis=1)

source_decoder_world, chart_decoder_world = attention_plotter(english_world, french_world, attention_decoder_world, 'world')
chart_decoder_world

# ALL

In [None]:
frames_decoder = [source_decoder_carrot, source_decoder_dress, source_decoder_journalism, source_decoder_day, 
          source_decoder_bee, source_decoder_world]
source_decoder = pd.concat(frames_decoder)
source_decoder['idx'] = [i for i in range(len(source_decoder))]

In [None]:
input_dropdown = alt.binding_select(options=['carrot','dress', 'journalism', 'day', 'bee', 'world'])  
selection = alt.selection_single(fields=['tag'], bind=input_dropdown, name='sentence', init={'tag':'carrot'})
color = alt.condition(selection,
                    alt.Color('z:N', legend=None),
                    alt.value('black'))


att_dec = alt.Chart(source_decoder).mark_rect().encode(
    x=alt.X('French:N', sort=alt.SortField("idx", order="ascending"), axis = alt.Axis(
        minExtent= 59,
        orient= "top",
        title= "French token",
        titleAnchor= "start",
        titlePadding=-50,
        tickWidth=0, 
        offset=64, 
        labelAlign="left", 
        labelPadding=-65)),
    y=alt.Y('English:N', sort=alt.SortField('idx',order="ascending"), axis=alt.Axis(
                offset= 54, 
                orient='right',                                                                  
                tickBand= "extent", 
                title= "English token", 
                titlePadding=0,
                tickWidth=0, 
                labelPadding=-57, 
                labelAlign="left", 
                titleAnchor= "start", 
                titleY=316)),
    color=alt.Color('Attention', scale=alt.Scale(scheme='magma'), title="Attention"),
    tooltip=['Attention','English','French']
).properties(
    width=400,
    height=400,
    title = alt.TitleParams(
        text='ENCODER-DECODER ATTENTION',
        fontStyle='bold',
        fontWeight='bold',
        font = 'Raleway, sans-serif',
        anchor='start',
        dx=0,
        dy=30)
).interactive(
).add_selection(
    selection
).transform_filter(
    selection
)
    
#att_dec

In [None]:
frames_encoder = [source_encoder_carrot, source_encoder_dress, source_encoder_journalism, source_encoder_day, 
          source_encoder_bee, source_encoder_world]
source_encoder = pd.concat(frames_encoder)
source_encoder['idx'] = [i for i in range(len(source_encoder))]

In [None]:
att_enc = alt.Chart(source_encoder).mark_rect().encode(
    x=alt.X('French:N', sort=alt.SortField("idx", order="ascending"), 
            axis = alt.Axis(
        minExtent= 59,
        orient= "top",
        title= 'English token',
        titleAnchor= "start",
        titlePadding=-50,
        tickWidth=0, 
        offset=64, 
        labelAlign="left", 
        labelPadding=-65 
      )),
    y=alt.Y('English:N', sort=alt.SortField('idx',order="ascending"), 
            axis=alt.Axis(
                offset= 54, 
                tickBand= "extent", 
                title= "English token", 
                tickWidth=0, 
                labelPadding=-54, 
                labelAlign="right", 
                titleAnchor= "end", 
                titleY=316)),
    color=alt.Color('Attention', scale=alt.Scale(scheme='magma'), title="Attention"),
    tooltip=['Attention', 'English']
).properties(
    width=400,
    height=400,
    title = alt.TitleParams(
        text='SELF-ATTENTION',
        fontStyle='bold',
        fontWeight=900,
        anchor='start',
        dx=73,
        dy=30,
        ),
).interactive(
).add_selection(
    selection
).transform_filter(
    selection
)

In [None]:
attention_full = (att_enc | att_dec).configure_legend(
    orient='bottom'
)
attention_full

In [None]:
#attention_full.save('attention_original.html')

# Raleway font adaptation

This option was dismissed in the final interactive document, as it did not render it correctly when saving.

In [None]:
%%html
<style>
@import url('https://fonts.googleapis.com/css?family=Raleway');
</style>

In [None]:
def raleway():
    font = "Raleway, sans-serif"
    
    return {
        "config" : {
             "title": {'font': font, "fontStyle": "bold"},
             "axis": {
                  "labelFont": font,
                  "titleFont": font,
                  "xaxis":{"minExtent" : 100}
             },
             "header": {
                  "labelFont": font,
                  "titleFont": font,
                 "fontStyle": "bold",
             },
             "legend": {
                  "labelFont": font,
                  "titleFont": font,
                 "fontStyle": "bold",
             }
        },
    }

In [None]:
# register the custom theme under a chosen name
alt.themes.register('my_theme', raleway)# enable the newly registered theme
alt.themes.enable('my_theme')

In [None]:
input_dropdown = alt.binding_select(options=['dress', 'journalism', 'day', 'bee', 'world'])  
selection = alt.selection_single(fields=['tag'], bind=input_dropdown, name='sentence', init={'tag':'dress'})
color = alt.condition(selection,
                    alt.Color('z:N', legend=None),
                    alt.value('black'))


att_dec = alt.Chart(source_decoder).mark_rect().encode(
    x=alt.X('French:N', sort=alt.SortField("idx", order="ascending"), axis = alt.Axis(
        minExtent= 59,
        orient= "top",
        title= "French token",
        titleAnchor= "start",
        titlePadding=-50,
        tickWidth=0, 
        offset=58, 
        labelAlign="left", 
        labelPadding=-60)),
    y=alt.Y('English:N', sort=alt.SortField('idx',order="ascending"), axis=alt.Axis(
                offset= 56, 
                orient='right',                                                                  
                tickBand= "extent", 
                title= "English token", 
                titlePadding=0,
                tickWidth=0, 
                labelPadding=-58, 
                labelAlign="left", 
                titleAnchor= "start", 
                titleY=329)),
    color=alt.Color('Attention', scale=alt.Scale(scheme='magma'), title="Attention"),
    tooltip=['Attention','English','French']
).properties(
    width=400,
    height=400,
    title = alt.TitleParams(
        text='ENCODER-DECODER ATTENTION',
        fontStyle='bold',
        fontWeight='bold',
        font = 'Raleway, sans-serif',
        anchor='start',
        dx=0,
        dy=30,
)).interactive(
).add_selection(
    selection
).transform_filter(
    selection
)

In [None]:
att_enc = alt.Chart(source_encoder).mark_rect().encode(
    x=alt.X('French:N', sort=alt.SortField("idx", order="ascending"), 
            axis = alt.Axis(
        minExtent= 59,
        orient= "top",
        title= 'French token',
        titleAnchor= "start",
        titlePadding=-50,
        tickWidth=0, 
        offset=58, 
        labelAlign="left", 
        labelPadding=-60
      )),
    y=alt.Y('English:N', sort=alt.SortField('idx',order="ascending"), 
            axis=alt.Axis(
                offset= 54, 
                tickBand= "extent", 
                title= "English token", 
                tickWidth=0, 
                labelPadding=-54, 
                labelAlign="right", 
                titleAnchor= "end", 
                titleY=329)),
    color=alt.Color('Attention', scale=alt.Scale(scheme='magma'), title="Attention"),
    tooltip=['Attention', 'English', 'French']
).properties(
    width=400,
    height=400,
    title = alt.TitleParams(
        text='SELF-ATTENTION',
        fontStyle='bold',
        fontWeight=900,
        anchor='start',
        dx=73,
        dy=30,
        ),
).interactive(
).add_selection(
    selection
).transform_filter(
    selection
)

attention_full = (att_enc | att_dec).configure_legend(
    orient='bottom'
)

In [None]:
attention_full = (att_enc | att_dec).configure_legend(
    orient='bottom'
)

attention_full

In [None]:
#attention_full.save('attention_raleway.html')
