In [237]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

In [728]:
base_alpha = 0.25

def set_color(obj, color):
    face, txt = obj
    face.set_alpha(base_alpha)
    face.set_color(color)
    txt.set_alpha(1.0)

def make_grad(obj):
    set_color(obj, 'r')
    obj[1].set_text('$\\nabla$' + obj[1].get_text())

def hide(obj):
    face, txt = obj
    face.set_alpha(0.0)
    face.set_color('w')
    txt.set_alpha(0.0)

def show_flops(flops):
    for face in flops:
        face.set_alpha(0.1)

def hide_flops(flops):
    for face in flops:
        face.set_alpha(0.0)

def fwd(y, flops):
    show_flops(flops)
    # yield
    set_color(y, 'g')
    yield
    hide_flops(flops)
    # yield

def bwd_act(x, flops):
    show_flops(flops)
    # yield
    make_grad(x)
    yield
    hide_flops(flops)
    # yield

def bwd_weight(y, w, flops):
    show_flops(flops)
    # yield
    make_grad(w)
    yield
    hide_flops(flops)
    hide(y)
    # yield

def make_face(ax, vals1, vals2, depth, dim1, dim2, name=None, color='b', alpha=base_alpha):
    dim3 = 3 - dim1 - dim2
    val1, val2 = np.meshgrid(vals1, vals2)
    val3 = np.full_like(val1, depth)
    arg = np.zeros((3, *val1.shape))
    arg[[dim1, dim2, dim3]] = val1, val2, val3
    face = ax.plot_surface(*arg, color=color, alpha=alpha, linewidth=2, edgecolor=color)
    if name is None:
        return face
    else:
        textpos = np.zeros(3)
        textpos[[dim1, dim2, dim3]] = np.mean(vals1), np.mean(vals2), depth
        txt = ax.text(*textpos, name, color='k', fontsize=10, ha='center', va='center', alpha=(alpha > 0))
        return face, txt

def make_rectangular_solid(ax, vals, dims):
    faces = []
    return faces
    for dim_idx in range(3):
        other_dims = list({0, 1, 2} - {dim_idx})
        for i in range(2):
            faces.append(make_face(ax, vals[other_dims[0]], vals[other_dims[1]], vals[dim_idx][i],
                                   dims[other_dims[0]], dims[other_dims[1]], None, color='y', alpha=0.0))
    return faces

def draw_arrow(ax, dimvals, val1s, val2s, label, dim, otherdim1, otherdim2):
    tail = np.zeros(3)
    vec = np.zeros(3)
    tail[[dim, otherdim1, otherdim2]] = dimvals[0], val1s[0], val2s[0]
    vec[[dim, otherdim1, otherdim2]] = dimvals[1], 0, 0
    ax.quiver(*tail, *vec, color='k', arrow_length_ratio=0.0, linewidth=2)
    text_pos = (tail + vec/2)
    text_pos[otherdim1] += val1s[1]
    text_pos[otherdim2] += val2s[1]
    ax.text(*text_pos, label, color='k', fontsize=10, ha='center', va='center')

In [740]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

d_model = 8
d_hidden = 2*d_model
B = 64
L = 6
dp = 2
tp = 1

L_by_hidden_ax = 0
B_ax = 1
L_by_d_model_ax = 2

class GPU:
    def __init__(self, ax, d_model, d_hidden_local, B, L, B_start, L_by_d_hidden_start, L_by_d_hidden_layer_step, show_axes=False):
        B_end = B_start + B

        self.residual_activations = []
        self.hidden_activations = []
        self.weights1 = []
        self.weights2 = []
        self.flops1 = []
        self.flops2 = []

        for l in range(L+1):
            layer_L_by_d_hidden_start = L_by_d_hidden_start + l*L_by_d_hidden_layer_step
            layer_L_by_d_hidden_end = layer_L_by_d_hidden_start + d_hidden_local
            if l < L:
                self.weights1.append(make_face(ax, (layer_L_by_d_hidden_start, layer_L_by_d_hidden_end), (-l*d_model, (1-l)*d_model),
                                               B_end, L_by_hidden_ax, L_by_d_model_ax, f'$W_1^{l}$'))
                self.weights2.append(make_face(ax, (layer_L_by_d_hidden_start, layer_L_by_d_hidden_end), ((-1-l)*d_model, -l*d_model),
                                               B_end, L_by_hidden_ax, L_by_d_model_ax, f'$W_2^{l}$'))
            self.residual_activations.append(make_face(ax, (-l*d_model, (1-l)*d_model), (B_start, B_end),
                                                       layer_L_by_d_hidden_start, L_by_d_model_ax, B_ax, f'$x^{l}$', color='g', alpha=0.0))
            if l < L:
                self.hidden_activations.append(make_face(ax, (layer_L_by_d_hidden_start, layer_L_by_d_hidden_end), (B_start, B_end),
                                                         -l*d_model, L_by_hidden_ax, B_ax, f'$h^{l}$', color='g', alpha=0.0))
                self.flops1.append(make_rectangular_solid(ax, [(-l*d_model, (1-l)*d_model), (layer_L_by_d_hidden_start, layer_L_by_d_hidden_end), (B_start, B_end)],
                                                          [L_by_d_model_ax, L_by_hidden_ax, B_ax]))
                self.flops2.append(make_rectangular_solid(ax, [(-(1+l)*d_model, -l*d_model), (layer_L_by_d_hidden_start, layer_L_by_d_hidden_end), (B_start, B_end)],
                                                          [L_by_d_model_ax, L_by_hidden_ax, B_ax]))

        if show_axes:
            # draw_arrow(ax, (0, L*d_hidden), (d_model, 3), (B, 3), 'Layer', L_by_hidden_ax, L_by_d_model_ax, B_ax)
            draw_arrow(ax, (L_by_d_hidden_start, L_by_d_hidden_start+d_hidden_local), (0, -2), (0, -2), '$d_{hidden}$', L_by_hidden_ax, L_by_d_model_ax, B_ax)
            draw_arrow(ax, (B_start, B_end), (0, -3), (d_model, 3), 'Batch', B_ax, L_by_hidden_ax, L_by_d_model_ax)
            draw_arrow(ax, (0, d_model), (0, -4), (0, -2), '$d_{model}$', L_by_d_model_ax, L_by_hidden_ax, B_ax)

    def frame_gen(self):
        yield
        set_color(self.residual_activations[0], 'g')
        yield
        for l in range(L):
            yield from fwd(self.hidden_activations[l], self.flops1[l])
            yield from fwd(self.residual_activations[l+1], self.flops2[l])

        make_grad(self.residual_activations[L])
        yield
        for l in range(L-1, -1, -1):
            yield from bwd_act(self.hidden_activations[l], self.flops2[l])
            yield from bwd_weight(self.residual_activations[l+1], self.weights2[l], self.flops2[l])
            yield from bwd_act(self.residual_activations[l], self.flops1[l])
            yield from bwd_weight(self.hidden_activations[l], self.weights1[l], self.flops1[l])
        hide(self.residual_activations[0])
        yield
        for i in range(10):
            yield

dp_gpu_delta = 16
tp_gpu_delta = 1
gpus = []
for dpi in range(dp - 1, -1, -1):
    for tpi in range(tp):
        gpus.append(GPU(ax, d_model, d_hidden/tp, B/dp, L,
                        dpi*(B/dp+dp_gpu_delta),
                        tpi*(d_hidden/tp+tp_gpu_delta), d_hidden+(tp-1)*tp_gpu_delta,
                        show_axes=(dpi==0 and tpi==0)))

ax.set_aspect('equal')
ax.grid(False)
ax.set_xticks([]) 
ax.set_yticks([])
ax.set_zticks([])

update_gen = zip(*[gpu.frame_gen() for gpu in gpus])
def update(frame):
    return next(update_gen)

ani = FuncAnimation(fig, update, frames=range(6*L+12))
ani.save('batch.mp4', fps=30, extra_args=['-vcodec', 'libx264'])
