In [838]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import itertools as it

In [907]:
base_alpha = 0.25

def set_color(obj, color):
    face, txt = obj
    face.set_alpha(base_alpha)
    face.set_color(color)
    txt.set_alpha(1.0)

def make_grad(obj):
    set_color(obj, 'r')
    obj[1].set_text('$\\nabla$' + obj[1].get_text())

def hide(obj):
    face, txt = obj
    face.set_alpha(0.0)
    face.set_color('w')
    txt.set_alpha(0.0)

def show_solid(solid):
    for face in solid:
        face.set_alpha(0.25)

def dim_solid(solid):
    for face in solid:
        face.set_alpha(0.2)

def flash_solid(solid, steps):
    dim_solid(solid)
    yield
    for _ in range(steps):
        show_solid(solid)
        yield
        dim_solid(solid)
        yield
    hide_solid(solid)
    yield

def hide_solid(solid):
    for face in solid:
        face.set_alpha(0.0)

def fwd(y, flop):
    show_solid(flop)
    set_color(y, 'g')
    yield
    hide_solid(flop)

def bwd_act(x, flop):
    show_solid(flop)
    make_grad(x)
    yield
    hide_solid(flop)

def bwd_weight(ys, w, flops, xs=None):
    for i in range(len(ys) + 1):
        if i < len(ys):
            show_solid(flops[i])
            if i == len(ys) - 1:
                make_grad(w)
            yield
        if i > 0:
            hide_solid(flops[i-1])
            hide(ys[i-1])
            if xs is not None:
                hide(xs[i-1])

def make_face(ax, vals1, vals2, depth, dim1, dim2, name=None, color='b', alpha=base_alpha):
    dim3 = 3 - dim1 - dim2
    val1, val2 = np.meshgrid(vals1, vals2)
    val3 = np.full_like(val1, depth)
    arg = np.zeros((3, *val1.shape))
    arg[[dim1, dim2, dim3]] = val1, val2, val3
    face = ax.plot_surface(*arg, color=color, alpha=alpha, linewidth=2, edgecolor=color)
    if name is None:
        return face
    else:
        textpos = np.zeros(3)
        textpos[[dim1, dim2, dim3]] = np.mean(vals1), np.mean(vals2), depth
        txt = ax.text(*textpos, name, color='k', fontsize=10, ha='center', va='center', alpha=(alpha > 0))
        return face, txt

def make_rectangular_solid(ax, vals, dims, color='y'):
    faces = []
    # return faces
    for dim_idx in range(3):
        other_dims = list({0, 1, 2} - {dim_idx})
        for i in range(2):
            faces.append(make_face(ax, vals[other_dims[0]], vals[other_dims[1]], vals[dim_idx][i],
                                   dims[other_dims[0]], dims[other_dims[1]], None, color=color, alpha=0.0))
    return faces

def draw_arrow(ax, dimvals, val1s, val2s, label, dim, otherdim1, otherdim2):
    tail = np.zeros(3)
    vec = np.zeros(3)
    tail[[dim, otherdim1, otherdim2]] = dimvals[0], val1s[0], val2s[0]
    vec[[dim, otherdim1, otherdim2]] = dimvals[1], 0, 0
    ax.quiver(*tail, *vec, color='k', arrow_length_ratio=0.0, linewidth=2)
    text_pos = (tail + vec/2)
    text_pos[otherdim1] += val1s[1]
    text_pos[otherdim2] += val2s[1]
    ax.text(*text_pos, label, color='k', fontsize=10, ha='center', va='center')

In [956]:
fig = plt.figure(figsize=(16, 16))
ax = fig.add_subplot(111, projection='3d')

d_model = 8
d_hidden = 2*d_model
B = 48
L = 6
dp = 2
pp = 3
microbatches = 4*pp
assert L % pp == 0
pipeline_intvl = L // pp

L_by_hidden_ax = 0
B_ax = 1
L_by_d_model_ax = 2

dp_gpu_delta = 16
pp_gpu_delta = 12

pp_comm_steps = 2
dp_comm_steps = 2

class Pipeline:
    def __init__(self, ax, d_model, d_hidden, B, B_start, B_next_start, wait_for_dp, microbatches, L, pipeline_intvl, show_axes=False):
        self.pipeline_intvl = pipeline_intvl
        B_end = B_start + B
        if B_next_start is not None:
            B_next_end = B_next_start + B
        self.dp_comms = B_next_start is not None
        self.wait_for_dp = wait_for_dp
        self.microbatches = microbatches

        self.residual_activations = [[] for _ in range(L)]
        self.terminal_residual_activations = [[] for _ in range(L//pipeline_intvl)]
        self.hidden_activations = [[] for _ in range(L)]
        self.weights1 = []
        self.weights2 = []
        self.flops1m = [[] for _ in range(L)]
        self.flops2m = [[] for _ in range(L)]
        self.w1_dp_comms = []
        self.w2_dp_comms = []
        self.pp_comms = [[] for _ in range(L//pipeline_intvl)]

        layer_L_by_d_hidden_start = 0
        for l in range(L):
            lz = 0 # z level of hidden
            lzrp = ((l + 1) % 2) - 1 # z bottom of previous residual
            lzr = (l % 2) - 1 # z bottom of residual

            if l % pipeline_intvl == 0:
                for m in range(microbatches):
                    mB_start = B_start + m * (B/microbatches)
                    mB_end = mB_start + B/microbatches
                    self.terminal_residual_activations[l//pipeline_intvl].append(make_face(ax, (lzrp*d_model, (lzrp+1)*d_model), (mB_start, mB_end),
                                                                           layer_L_by_d_hidden_start, L_by_d_model_ax, B_ax, f'$x^{{{l-1}}}$', color='g', alpha=0.0))

            layer_L_by_d_hidden_end = layer_L_by_d_hidden_start + d_hidden
            if self.dp_comms:
                self.w1_dp_comms.append(make_rectangular_solid(ax, [(layer_L_by_d_hidden_start, layer_L_by_d_hidden_end),
                                                                    (lzrp*d_model, (lzrp+1)*d_model), (B_end, B_next_end)],
                                                                [L_by_hidden_ax, L_by_d_model_ax, B_ax], color='k'))
                self.w2_dp_comms.append(make_rectangular_solid(ax, [(layer_L_by_d_hidden_start, layer_L_by_d_hidden_end),
                                                                    (lzr*d_model, (lzr+1)*d_model), (B_end, B_next_end)],
                                                                [L_by_hidden_ax, L_by_d_model_ax, B_ax], color='k'))
            self.weights1.append(make_face(ax, (layer_L_by_d_hidden_start, layer_L_by_d_hidden_end), (lzrp*d_model, (lzrp+1)*d_model),
                                            B_end, L_by_hidden_ax, L_by_d_model_ax, f'$W_1^{l}$'))
            self.weights2.append(make_face(ax, (layer_L_by_d_hidden_start, layer_L_by_d_hidden_end), (lzr*d_model, (lzr+1)*d_model),
                                            B_end, L_by_hidden_ax, L_by_d_model_ax, f'$W_2^{l}$'))
            for m in range(microbatches):
                mB_start = B_start + m * (B/microbatches)
                mB_end = mB_start + B/microbatches
                self.hidden_activations[l].append(make_face(ax, (layer_L_by_d_hidden_start, layer_L_by_d_hidden_end), (mB_start, mB_end),
                                                            lz*d_model, L_by_hidden_ax, B_ax, f'$h^{l}$', color='g', alpha=0.0))
                self.residual_activations[l].append(make_face(ax, (lzr*d_model, (lzr+1)*d_model), (mB_start, mB_end),
                                                        layer_L_by_d_hidden_start+d_hidden, L_by_d_model_ax, B_ax, f'$x^{l}$', color='g', alpha=0.0))
                self.flops1m[l].append(make_rectangular_solid(ax, [(lzrp*d_model, (lzrp+1)*d_model),
                                                                  (layer_L_by_d_hidden_start, layer_L_by_d_hidden_end), (mB_start, mB_end)],
                                                                [L_by_d_model_ax, L_by_hidden_ax, B_ax]))
                self.flops2m[l].append(make_rectangular_solid(ax, [(lzr*d_model, (lzr+1)*d_model),
                                                                    (layer_L_by_d_hidden_start, layer_L_by_d_hidden_end), (mB_start, mB_end)],
                                                                    [L_by_d_model_ax, L_by_hidden_ax, B_ax]))
            layer_L_by_d_hidden_start += d_hidden
            if ((l + 1) % pipeline_intvl == 0) and (l != L - 1):
                next_layer_L_by_d_hidden_start = layer_L_by_d_hidden_start + pp_gpu_delta
                for m in range(microbatches):
                    mB_start = B_start + m * (B/microbatches)
                    mB_end = mB_start + B/microbatches
                    self.pp_comms[l//pipeline_intvl].append(make_rectangular_solid(ax, [(layer_L_by_d_hidden_start, next_layer_L_by_d_hidden_start),
                                                                                        (lzr*d_model, (lzr+1)*d_model), (mB_start, mB_end)],
                                                                                    [L_by_hidden_ax, L_by_d_model_ax, B_ax], color='k'))
                layer_L_by_d_hidden_start = next_layer_L_by_d_hidden_start

        if show_axes:
            # draw_arrow(ax, (0, L*d_hidden), (d_model, 3), (B, 3), 'Layer', L_by_hidden_ax, L_by_d_model_ax, B_ax)
            draw_arrow(ax, (0, d_hidden), (0, -2), (0, -2), '$d_{hidden}$', L_by_hidden_ax, L_by_d_model_ax, B_ax)
            draw_arrow(ax, (B_start, B_end), (0, -3), (d_model, 3), 'Batch', B_ax, L_by_hidden_ax, L_by_d_model_ax)
            draw_arrow(ax, (0, d_model), (0, -4), (0, -2), '$d_{model}$', L_by_d_model_ax, L_by_hidden_ax, B_ax)

    def microbatch_frame_gen_fwd(self, m, microbatches):
        yield
        steps_per_stage = 2*self.pipeline_intvl + 1

        # Activation fwd pass.
        for _ in range(m*steps_per_stage):
            yield
        for l0 in range(0, L, self.pipeline_intvl):
            if l0 > 0:
                yield from flash_solid(self.pp_comms[(l0-1)//self.pipeline_intvl][m], pp_comm_steps)
            set_color(self.terminal_residual_activations[l0//self.pipeline_intvl][m], 'g')
            yield
            for l in range(l0, l0 + self.pipeline_intvl):
                yield from fwd(self.hidden_activations[l][m], self.flops1m[l][m])
                yield from fwd(self.residual_activations[l][m], self.flops2m[l][m])
        for _ in range((microbatches-m-1)*steps_per_stage):
            yield
        yield

    def microbatch_frame_gen_bwd(self, m, microbatches):
        # Activation bwd pass.
        for _ in range(m*(2*self.pipeline_intvl + 1)):
            yield
        for l0 in range(L - self.pipeline_intvl, -self.pipeline_intvl, -self.pipeline_intvl):
            make_grad(self.residual_activations[l0 + self.pipeline_intvl - 1][m])
            yield
            for l in range(l0 + self.pipeline_intvl - 1, l0, -1):
                yield from bwd_act(self.hidden_activations[l][m], self.flops2m[l][m])
                yield from bwd_act(self.residual_activations[l-1][m], self.flops1m[l][m])
            yield from bwd_act(self.hidden_activations[l0][m], self.flops2m[l0][m])
            yield from bwd_act(self.terminal_residual_activations[l0//self.pipeline_intvl][m], self.flops1m[l0][m])
            if l0 > 0:
                yield from flash_solid(self.pp_comms[(l0-1)//self.pipeline_intvl][m], pp_comm_steps)
                    
        for _ in range((microbatches-m-1)*(2*self.pipeline_intvl + 1)):
            yield
        yield

    def gen_bwd_weight(self, l0):
        # Weight bwd pass.
        for l in range(l0 + self.pipeline_intvl - 1, l0, -1):
            yield from bwd_weight(self.residual_activations[l], self.weights2[l], self.flops2m[l])
            yield from bwd_weight(self.hidden_activations[l], self.weights1[l], self.flops1m[l])
        yield from bwd_weight(self.residual_activations[l0], self.weights2[l0], self.flops2m[l0])
        yield from bwd_weight(self.hidden_activations[l0], self.weights1[l0], self.flops1m[l0], self.terminal_residual_activations[l0//self.pipeline_intvl])
        if self.dp_comms:
            for l in range(l0 + self.pipeline_intvl - 1, l0 - 1, -1):
                yield from flash_solid(self.w1_dp_comms[l], dp_comm_steps)
                yield from flash_solid(self.w2_dp_comms[l], dp_comm_steps)
        elif self.wait_for_dp:
            for _ in range(self.pipeline_intvl*(4*dp_comm_steps + 4)):
                yield
        for _ in range(20):
            yield

    def frame_gen(self):
        fwd_gen = zip(*[self.microbatch_frame_gen_fwd(m, self.microbatches) for m in range(self.microbatches)])
        bwd_act_gen = zip(*[self.microbatch_frame_gen_bwd(m, self.microbatches) for m in range(self.microbatches)])
        bwd_weight_gen = zip(*[self.gen_bwd_weight(l0) for l0 in range(0, L, self.pipeline_intvl)])
        return it.chain(fwd_gen, bwd_act_gen, bwd_weight_gen)

pipelines = []
for dpi in range(dp - 1, -1, -1):
    pipelines.append(Pipeline(ax, d_model, d_hidden,
                              B/dp, dpi*(B/dp+dp_gpu_delta), (dpi+1)*(B/dp+dp_gpu_delta) if dpi < dp - 1 else None,
                              dp > 1,
                              microbatches,
                              L, pipeline_intvl,
                              show_axes=(dpi==0)))

ax.set_aspect('equal')
ax.grid(False)
ax.set_xticks([]) 
ax.set_yticks([])
ax.set_zticks([])

update_gen = zip(*[pipeline.frame_gen() for pipeline in pipelines])
def update(_):
    return None

ani = FuncAnimation(fig, update, frames=update_gen)
ani.save('batch.mp4', fps=15, extra_args=['-vcodec', 'libx264'])
