In [1]:
from bokeh.io import push_notebook, show, output_notebook
from bokeh.layouts import row
from bokeh.plotting import figure
from bokeh.models import LinearAxis, Range1d
output_notebook()

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:

def plot_head_pruning(perfs, times, title=''):
    plot = figure(
        y_range=[0, 0.5],
        x_axis_label='Ratio of pruned heads',
        y_axis_label='WER',
        width=800, height=400,
        title=title
    )
    plot.add_layout(LinearAxis(y_range_name="time", axis_label='Time (sec)'), 'right')
    plot.extra_y_ranges = {"time": Range1d(start=0, end=100)}

    ratios = [r / 10 for r in range(10)]

    plot.line(ratios, perfs, legend_label='WER', line_width=2, color='red')
    plot.line(ratios, times, legend_label='Time', line_width=2, y_range_name='time')

    show(plot)

In [4]:
# beam_size=5
perfs = [
    0.032, 0.032, 0.035, 0.038, 0.046, 0.062, 0.103, 0.165, 0.598, 0.884
]
times = [
    91.31, 90.66, 90.66, 94.36, 87.88, 78.45, 78.87, 79.71, 74.91, 56.98, 28.19
]
plot_head_pruning(perfs, times, title='Beam size: 5')

# beam_size=3
perfs = [
    0.03408399269628728, 0.034388314059646985, 0.03638542300669507, 0.040227480219111385, 0.04842513694461351, 0.06516281192939745, 0.10649345709068776, 0.17121880706025563, 0.5995701460742544, 0.8795458003651856
]
times = [
    69.9738, 69.5449, 67.6510, 60.0101, 61.1342, 62.1088, 59.7755, 62.6023, 36.6256, 27.8764
]
plot_head_pruning(perfs, times, title='Beam size: 3')


# beam_size=1
perfs = [
    0.046351947656725505, 0.045096622032866705, 0.047150791235544734, 0.05137325015216068, 0.06480143031040779, 0.08627510651247718, 0.14246043822276325, 0.21918746195982958, 0.6591410529519173, 0.8765025867315885
]
times = [
    44.4279, 44.8921, 43.1760, 38.7973, 39.5165, 40.1261, 37.5563, 39.2884, 21.2038, 16.1966
]
plot_head_pruning(perfs, times, title='Beam size: 1')


# beam_size=1, lm_weight=0
perfs = [
    0.04456405964698722, 0.04481132075471698, 0.046998630553864884, 0.05243837492391966, 0.06383140596469872, 0.08882379793061473, 0.15029671332927572, 0.23767498478393184, 0.6199786975045648, 0.8607919963481436
]
times = [
    21.681570053100586, 20.974687576293945, 19.74081039428711, 17.510927200317383, 17.081056594848633, 17.590673446655273, 15.915518760681152, 17.284034729003906, 9.869168281555176, 9.287040710449219
]
plot_head_pruning(perfs, times, title='Beam size: 1, no LM')


perfs = [
    0.04456405964698722, 0.04465916007303713, 0.04549604382227632, 0.04762629336579428, 0.054245283018867926, 0.07788724893487523, 0.1262363055386488, 0.34211427267194155, 0.6326650943396226, 0.8607919963481436
]
times = [
    20.26384735107422, 21.20014190673828, 19.76799774169922, 17.644161224365234, 17.260480880737305, 18.03213119506836, 16.215112686157227, 15.384403228759766, 9.348855972290039, 9.18283748626709
]
plot_head_pruning(perfs, times, title='Beam size: 1, no LM, None')

# L1 
perfs = [
    0.04456405964698722, 0.04446895922093731, 0.04732197200243457, 0.052495435179549604, 0.06430690809494827, 0.09306527693244065, 0.14689211807668898, 0.2505515824710895, 0.6199786975045648, 0.8607919963481436
]
times = [
    21.681570053100586, 20.974687576293945, 19.74081039428711, 17.510927200317383, 17.081056594848633, 17.590673446655273, 15.915518760681152, 17.284034729003906, 9.869168281555176, 9.287040710449219
]
plot_head_pruning(perfs, times, title='Beam size: 1, no LM, L1')

# prune layer-by-layer
times = [
    20.292512893676758, 21.21649169921875, 19.498455047607422, 15.736644744873047, 13.525660514831543, 11.547018051147461, 10.438002586364746, 11.876256942749023, 9.106951713562012, 11.318511962890625
]
perfs = [
    0.05030274802049371, 0.21055045871559633, 0.35036144578313255, 0.5752935171005615, 0.7029603612644255, 0.8498765432098765, 0.8890015205271161, 0.890748031496063, 0.8949554896142433, 0.9015369836695485
] # 100 samples only
plot_head_pruning(perfs, times, title='Beam size: 1, no LM, layer-by-layer')




In [16]:
times1 = [
    21.681570053100586, 20.974687576293945, 19.74081039428711, 17.510927200317383, 17.081056594848633, 17.590673446655273, 15.915518760681152, 17.284034729003906, 9.869168281555176, 9.287040710449219
]
times2 = [
    20.292512893676758, 21.21649169921875, 19.498455047607422, 15.736644744873047, 13.525660514831543, 11.547018051147461, 10.438002586364746, 11.876256942749023, 9.106951713562012, 11.318511962890625
]

plot = figure(
    x_axis_label='Ratio of pruned heads',
    y_axis_label='Latency',
)
xs = [i/10 for i in range(10)]
plot.line(xs, times1, color='blue', line_width=2, legend='Prune by Importance')
plot.line(xs, times2, color='red', line_width=2, legend='Prune by Layer')
show(plot)



In [19]:
perfs1 = [
    0.04456405964698722, 0.04481132075471698, 0.046998630553864884, 0.05243837492391966, 0.06383140596469872, 0.08882379793061473, 0.15029671332927572, 0.23767498478393184, 0.6199786975045648, 0.8607919963481436
]
perfs2 = [
    0.05030274802049371, 0.21055045871559633, 0.35036144578313255, 0.5752935171005615, 0.7029603612644255, 0.8498765432098765, 0.8890015205271161, 0.890748031496063, 0.8949554896142433, 0.9015369836695485
]
plot = figure(
    x_axis_label='Ratio of pruned heads',
    y_axis_label='WER',
)
xs = [i/10 for i in range(10)]
plot.line(xs, perfs1, color='blue', line_width=2, legend='Prune by Importance')
plot.line(xs, perfs2, color='red', line_width=2, legend='Prune by Layer')
show(plot)




In [9]:
import torch
from head_prune import rank_heads

pt = torch.load('head_grad.pt')
head_scores = pt['accumulator']['encoder']
head_ranks = rank_heads(head_scores, normalize='l2')


In [18]:
from bokeh_utils import plot_matrix

head_scores_ = head_scores / head_scores.pow(2).sum(-1, keepdim=True).sqrt()
head_scores_ = torch.tensor([
    sorted(hs) for hs in head_scores_
])

plot = plot_matrix(head_scores_, cell_size=50, include_color_bar=False, x_axis_label='Head', y_axis_label='Layer')
show(plot)