In [237]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

In [638]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

d_model = 8
d_hidden = 2*d_model
B = 48
L = 6

L_by_hidden_ax = 0
B_ax = 1
L_by_d_model_ax = 2

def make_face(ax, vals1, vals2, depth, dim1, dim2, name=None, color='b', alpha=0.25):
    dim3 = 3 - dim1 - dim2
    val1, val2 = np.meshgrid(vals1, vals2)
    val3 = np.full_like(val1, depth)
    arg = np.zeros((3, *val1.shape))
    arg[[dim1, dim2, dim3]] = val1, val2, val3
    face = ax.plot_surface(*arg, color=color, alpha=alpha, linewidth=2, edgecolor=color)
    if name is None:
        return face
    else:
        textpos = np.zeros(3)
        textpos[[dim1, dim2, dim3]] = np.mean(vals1), np.mean(vals2), depth
        txt = ax.text(*textpos, name, color='k', fontsize=12, ha='center', va='center', alpha=4*alpha)
        return face, txt

def make_rectangular_solid(ax, vals, dims):
    faces = []
    return faces
    for dim_idx in range(3):
        other_dims = list({0, 1, 2} - {dim_idx})
        for i in range(2):
            faces.append(make_face(ax, vals[other_dims[0]], vals[other_dims[1]], vals[dim_idx][i],
                                   dims[other_dims[0]], dims[other_dims[1]], None, color='y', alpha=0.0))
    return faces

def draw_arrow(ax, dimvals, val1s, val2s, label, dim, otherdim1, otherdim2):
    tail = np.zeros(3)
    vec = np.zeros(3)
    tail[[dim, otherdim1, otherdim2]] = dimvals[0], val1s[0], val2s[0]
    vec[[dim, otherdim1, otherdim2]] = dimvals[1], 0, 0
    ax.quiver(*tail, *vec, color='k', arrow_length_ratio=0.0, linewidth=2)
    text_pos = (tail + vec/2)
    text_pos[otherdim1] += val1s[1]
    text_pos[otherdim2] += val2s[1]
    ax.text(*text_pos, label, color='k', fontsize=12, ha='center', va='center')

residual_activations = []
hidden_activations = []
weights1 = []
weights2 = []
flops1 = []
flops2 = []
for l in range(L+1):
    if l < L:
        weights1.append(make_face(ax, (l*d_hidden, (l+1)*d_hidden), (-l*d_model, (1-l)*d_model), B, L_by_hidden_ax, L_by_d_model_ax, f'$W_1^{l}$'))
        weights2.append(make_face(ax, (l*d_hidden, (l+1)*d_hidden), ((-1-l)*d_model, -l*d_model), B, L_by_hidden_ax, L_by_d_model_ax, f'$W_2^{l}$'))
    residual_activations.append(make_face(ax, (-l*d_model, (1-l)*d_model), (0, B), l*d_hidden, L_by_d_model_ax, B_ax, f'$x^{l}$', color='g', alpha=0.0))
    if l < L:
        hidden_activations.append(make_face(ax, (l*d_hidden, (l+1)*d_hidden), (0, B), -l*d_model, L_by_hidden_ax, B_ax, f'$h^{l}$', color='g', alpha=0.0))
        flops1.append(make_rectangular_solid(ax, [(-l*d_model, (1-l)*d_model), (l*d_hidden, (l+1)*d_hidden), (0, B)], [L_by_d_model_ax, L_by_hidden_ax, B_ax]))
        flops2.append(make_rectangular_solid(ax, [(-(1+l)*d_model, -l*d_model), (l*d_hidden, (l+1)*d_hidden), (0, B)], [L_by_d_model_ax, L_by_hidden_ax, B_ax]))
        
# draw_arrow(ax, (0, L*d_hidden), (d_model, 3), (B, 3), 'Layer', L_by_hidden_ax, L_by_d_model_ax, B_ax)
draw_arrow(ax, (0, d_hidden), (0, -2), (0, -2), '$d_{hidden}$', L_by_hidden_ax, L_by_d_model_ax, B_ax)
draw_arrow(ax, (0, B), (0, -3), (d_model, 3), 'Batch', B_ax, L_by_hidden_ax, L_by_d_model_ax)
draw_arrow(ax, (0, d_model), (0, -4), (0, -2), '$d_{model}$', L_by_d_model_ax, L_by_hidden_ax, B_ax)


ax.set_aspect('equal')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])

def set_color(obj, color):
    face, txt = obj
    face.set_alpha(0.25)
    face.set_color(color)
    txt.set_alpha(1.0)

def make_grad(obj):
    set_color(obj, 'r')
    obj[1].set_text('$\\nabla$' + obj[1].get_text())

def hide(obj):
    face, txt = obj
    face.set_alpha(0.0)
    face.set_color('w')
    txt.set_alpha(0.0)

def show_flops(flops):
    for face in flops:
        face.set_alpha(0.1)

def hide_flops(flops):
    for face in flops:
        face.set_alpha(0.0)

def fwd(y, flops):
    show_flops(flops)
    # yield
    set_color(y, 'g')
    yield
    hide_flops(flops)
    # yield

def bwd_act(x, flops):
    show_flops(flops)
    # yield
    make_grad(x)
    yield
    hide_flops(flops)
    # yield

def bwd_weight(y, w, flops):
    show_flops(flops)
    # yield
    make_grad(w)
    yield
    hide_flops(flops)
    hide(y)
    # yield

def update_gen():
    yield
    set_color(residual_activations[0], 'g')
    yield
    for l in range(L):
        yield from fwd(hidden_activations[l], flops1[l])
        yield from fwd(residual_activations[l+1], flops2[l])

    make_grad(residual_activations[L])
    yield
    for l in range(L-1, -1, -1):
        yield from bwd_act(hidden_activations[l], flops2[l])
        yield from bwd_weight(residual_activations[l+1], weights2[l], flops2[l])
        yield from bwd_act(residual_activations[l], flops1[l])
        yield from bwd_weight(hidden_activations[l], weights1[l], flops1[l])
    hide(residual_activations[0])
    yield
    for i in range(10):
        yield
update_gen = update_gen()
def update(frame):
    return next(update_gen)

ani = FuncAnimation(fig, update, frames=range(6*L+12))
ani.save('batch.mp4', fps=30, extra_args=['-vcodec', 'libx264'])


In [None]:
    hide(y)
    yield

def update_gen():
    yield
    set_color(residual_activations[0], 'g')
    yield
    for l in range(L):
        yield from fwd(hidden_activations[l], residual_activations[l], weights1[l], flops1[l])
        yield from fwd(residual_activations[l+1], hidden_activations[l], weights2[l], flops2[l])

    make_grad(residual_activations[L])
    yield
    for l in range(L-1, -1, -1):
        yield from bwd_act(hidden_activations[l], residual_activations[l+1], weights2[l], flops2[l])
        yield from bwd_weight(hidden_activations[l], residual_activations[l+1], weights2[l], flops2[l])
        yield from bwd_act(residual_activations[l], hidden_activations[l], weights1[l], flops1[l])
        yield from bwd_weight(residual_activations[l], hidden_activations[l], weights1[l], flops1[l])
    hide(residual_activations[0])
    yield
    for i in range(10):
        yield
update_gen = update_gen()
def update(frame):
    return next(update_gen)

ani = FuncAnimation(fig, update, frames=range(18*L+12))
ani.save('batch.mp4', fps=30, extra_args=['-vcodec', 'libx264'])
