In [838]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import itertools as it

In [1103]:
base_alpha = 0.25

def set_color(obj, color):
    face, txt = obj
    face.set_alpha(base_alpha)
    face.set_color(color)
    txt.set_alpha(1.0)

def make_grad(obj):
    for o in obj:
        set_color(o, 'r')
        o[1].set_text('$\\nabla$' + o[1].get_text())

def hide(obj):
    face, txt = obj
    face.set_alpha(0.0)
    face.set_color('w')
    txt.set_alpha(0.0)

def show_solid(solid):
    for face in solid:
        face.set_alpha(0.25)

def dim_solid(solid):
    for face in solid:
        face.set_alpha(0.2)

def flash_solid(solids, steps):
    for solid in solids:
        dim_solid(solid)
    yield
    for _ in range(steps):
        for solid in solids:
            show_solid(solid)
        yield
        for solid in solids:
            dim_solid(solid)
        yield
    for solid in solids:
        hide_solid(solid)
    yield

def hide_solid(solid):
    for face in solid:
        face.set_alpha(0.0)

def fwd(ys, flops):
    for f in flops:
        show_solid(f)
    for y in ys:
        set_color(y, 'g')
    yield
    for f in flops:
        hide_solid(f)

def bwd_act(xs, flops):
    for f in flops:
        show_solid(f)
    make_grad(xs)
    yield
    for f in flops:
        hide_solid(f)

def bwd_weight(ys, w, flops, xs=None):
    for i in range(len(ys) + 1):
        if i < len(ys):
            for f in flops[i]:
                show_solid(f)
            if i == len(ys) - 1:
                make_grad(w)
            yield
        if i > 0:
            for f in flops[i-1]:
                hide_solid(f)
            for y in ys[i-1]:
                hide(y)
            if xs is not None:
                for x in xs[i-1]:
                    hide(x)

def make_face(ax, vals1, vals2, depth, dim1, dim2, name=None, color='b', alpha=base_alpha):
    dim3 = 3 - dim1 - dim2
    val1, val2 = np.meshgrid(vals1, vals2)
    val3 = np.full_like(val1, depth)
    arg = np.zeros((3, *val1.shape))
    arg[[dim1, dim2, dim3]] = val1, val2, val3
    face = ax.plot_surface(*arg, color=color, alpha=alpha, linewidth=2, edgecolor=color)
    if name is None:
        return face
    else:
        textpos = np.zeros(3)
        textpos[[dim1, dim2, dim3]] = np.mean(vals1), np.mean(vals2), depth
        txt = ax.text(*textpos, name, color='k', fontsize=10, ha='center', va='center', alpha=(alpha > 0))
        return face, txt

def make_rectangular_solid(ax, vals, dims, color='y'):
    faces = []
    for dim_idx in range(3):
        other_dims = list({0, 1, 2} - {dim_idx})
        for i in range(2):
            faces.append(make_face(ax, vals[other_dims[0]], vals[other_dims[1]], vals[dim_idx][i],
                                   dims[other_dims[0]], dims[other_dims[1]], None, color=color, alpha=0.0))
    return faces

def draw_arrow(ax, dimvals, val1s, val2s, label, dim, otherdim1, otherdim2):
    tail = np.zeros(3)
    vec = np.zeros(3)
    tail[[dim, otherdim1, otherdim2]] = dimvals[0], val1s[0], val2s[0]
    vec[[dim, otherdim1, otherdim2]] = dimvals[1], 0, 0
    ax.quiver(*tail, *vec, color='k', arrow_length_ratio=0.0, linewidth=2)
    text_pos = (tail + vec/2)
    text_pos[otherdim1] += val1s[1]
    text_pos[otherdim2] += val2s[1]
    ax.text(*text_pos, label, color='k', fontsize=10, ha='center', va='center')

In [1141]:
fig = plt.figure(figsize=(16, 16))
ax = fig.add_subplot(111, projection='3d')

d_model = 8
d_hidden = 2*d_model
B = 48
L = 1
dp = 1
pp = 1
tp = 1
microbatches = 1#4*pp
assert L % pp == 0
assert L == 1 or tp == 1
pipeline_intvl = L // pp
show_compute_blocks = True

L_by_hidden_ax = 0
B_ax = 1
L_by_d_model_ax = 2

dp_gpu_delta = 16
pp_gpu_delta = 12
tp_gpu_delta = 16

pp_comm_steps = 2
dp_comm_steps = 2
tp_comm_steps = 2

class Pipeline:
    def __init__(self, ax, d_model, d_hidden, tp, B, B_start, B_next_start, wait_for_dp, microbatches, L, pipeline_intvl, show_axes=False):
        self.pipeline_intvl = pipeline_intvl
        B_end = B_start + B
        if B_next_start is not None:
            B_next_end = B_next_start + B
        self.dp_comms = B_next_start is not None
        self.tp_comms = tp > 1
        self.wait_for_dp = wait_for_dp
        self.microbatches = microbatches

        self.residual_activations = []
        self.terminal_residual_activations = [[] for _ in range(L//pipeline_intvl)]
        self.input_tp_comms = []
        self.output_tp_comms = []
        self.hidden_activations = []
        self.weights1 = []
        self.weights2 = []
        self.flops1m = []
        self.flops2m = []
        self.w1_dp_comms = []
        self.w2_dp_comms = []
        self.pp_comms = [[] for _ in range(L//pipeline_intvl)]

        layer_L_by_d_hidden_start = 0
        for l in range(L):
            # # Alternating
            # lz = 0 # z level of hidden
            # lzrp = ((l + 1) % 2) - 1 # z bottom of previous residual
            # lzr = (l % 2) - 1 # z bottom of residual

            # Staircase
            lz = -l # z level of hidden
            lzrp = -l # z bottom of previous residual
            lzr = -l - 1 # z bottom of residual

            this_w1_dp_comms = []
            this_w2_dp_comms = []
            this_weights1 = []
            this_weights2 = []
            this_hidden_activations = [[] for _ in range(microbatches)]
            this_residual_activations = [[] for _ in range(microbatches)]
            this_flops1m = [[] for _ in range(microbatches)]
            this_flops2m = [[] for _ in range(microbatches)]
            this_terminal_residual_activations = [[] for _ in range(microbatches)]
            this_input_tp_comms = [[] for _ in range(microbatches)]
            this_output_tp_comms = [[] for _ in range(microbatches)]
            for tpi in range(tp):
                local_layer_L_by_d_hidden_start = layer_L_by_d_hidden_start + tpi * (tp_gpu_delta + d_hidden/tp)
                local_layer_L_by_d_hidden_end = local_layer_L_by_d_hidden_start + d_hidden/tp

                if l % pipeline_intvl == 0:
                    for m in range(microbatches):
                        mB_start = B_start + m * (B/microbatches)
                        mB_end = mB_start + B/microbatches
                        this_terminal_residual_activations[m].append(make_face(ax, (lzrp*d_model, (lzrp+1)*d_model), (mB_start, mB_end),
                                                                            local_layer_L_by_d_hidden_start, L_by_d_model_ax, B_ax, f'$x^{{{l-1}}}$', color='g', alpha=0.0))
                        if tpi > 0:
                            this_input_tp_comms[m].append(make_rectangular_solid(ax, [(local_layer_L_by_d_hidden_start-tp_gpu_delta-d_hidden/tp, local_layer_L_by_d_hidden_start),
                                                                                    (lzrp*d_model, (lzrp+1)*d_model), (mB_start, mB_end)],
                                                                                    [L_by_hidden_ax, L_by_d_model_ax, B_ax], color='k'))

                if self.dp_comms:
                    this_w1_dp_comms.append(make_rectangular_solid(ax, [(local_layer_L_by_d_hidden_start, local_layer_L_by_d_hidden_end),
                                                                        (lzrp*d_model, (lzrp+1)*d_model), (B_end, B_next_end)],
                                                                    [L_by_hidden_ax, L_by_d_model_ax, B_ax], color='k'))
                    this_w2_dp_comms.append(make_rectangular_solid(ax, [(local_layer_L_by_d_hidden_start, local_layer_L_by_d_hidden_end),
                                                                        (lzr*d_model, (lzr+1)*d_model), (B_end, B_next_end)],
                                                                    [L_by_hidden_ax, L_by_d_model_ax, B_ax], color='k'))
                this_weights1.append(make_face(ax, (local_layer_L_by_d_hidden_start, local_layer_L_by_d_hidden_end), (lzrp*d_model, (lzrp+1)*d_model),
                                                B_end, L_by_hidden_ax, L_by_d_model_ax, f'$W_1^{l}$'))
                this_weights2.append(make_face(ax, (local_layer_L_by_d_hidden_start, local_layer_L_by_d_hidden_end), (lzr*d_model, (lzr+1)*d_model),
                                                B_end, L_by_hidden_ax, L_by_d_model_ax, f'$W_2^{l}$'))

                for m in range(microbatches):
                    mB_start = B_start + m * (B/microbatches)
                    mB_end = mB_start + B/microbatches
                    this_hidden_activations[m].append(make_face(ax, (local_layer_L_by_d_hidden_start, local_layer_L_by_d_hidden_end), (mB_start, mB_end),
                                                                lz*d_model, L_by_hidden_ax, B_ax, f'$h^{l}$', color='g', alpha=0.0))
                    if tpi > 0:
                        this_output_tp_comms[m].append(make_rectangular_solid(ax, [(local_layer_L_by_d_hidden_start-tp_gpu_delta, local_layer_L_by_d_hidden_start+d_hidden/tp),
                                                                                (lz*d_model, (lz-1)*d_model), (mB_start, mB_end)],
                                                                                [L_by_hidden_ax, L_by_d_model_ax, B_ax], color='k'))
                    this_residual_activations[m].append(make_face(ax, (lzr*d_model, (lzr+1)*d_model), (mB_start, mB_end),
                                                            local_layer_L_by_d_hidden_start+d_hidden/tp, L_by_d_model_ax, B_ax, f'$x^{l}$', color='g', alpha=0.0))
                    if show_compute_blocks:
                        this_flops1m[m].append(make_rectangular_solid(ax, [(lzrp*d_model, (lzrp+1)*d_model),
                                                                        (local_layer_L_by_d_hidden_start, local_layer_L_by_d_hidden_end), (mB_start, mB_end)],
                                                                        [L_by_d_model_ax, L_by_hidden_ax, B_ax]))
                        this_flops2m[m].append(make_rectangular_solid(ax, [(lzr*d_model, (lzr+1)*d_model),
                                                                            (local_layer_L_by_d_hidden_start, local_layer_L_by_d_hidden_end), (mB_start, mB_end)],
                                                                            [L_by_d_model_ax, L_by_hidden_ax, B_ax]))
            if l % pipeline_intvl == 0:
                self.terminal_residual_activations[l//pipeline_intvl] = this_terminal_residual_activations
                self.input_tp_comms.append(this_input_tp_comms)
            self.w1_dp_comms.append(this_w1_dp_comms)
            self.w2_dp_comms.append(this_w2_dp_comms)
            self.weights1.append(this_weights1)
            self.weights2.append(this_weights2)
            self.hidden_activations.append(this_hidden_activations)
            self.output_tp_comms.append(this_output_tp_comms)
            self.residual_activations.append(this_residual_activations)
            self.flops1m.append(this_flops1m)
            self.flops2m.append(this_flops2m)
            layer_L_by_d_hidden_start += d_hidden + tp_gpu_delta*(tp-1)
            if ((l + 1) % pipeline_intvl == 0) and (l != L - 1):
                next_layer_L_by_d_hidden_start = layer_L_by_d_hidden_start + pp_gpu_delta
                for m in range(microbatches):
                    mB_start = B_start + m * (B/microbatches)
                    mB_end = mB_start + B/microbatches
                    self.pp_comms[l//pipeline_intvl].append([make_rectangular_solid(ax, [(layer_L_by_d_hidden_start, next_layer_L_by_d_hidden_start),
                                                                                        (lzr*d_model, (lzr+1)*d_model), (mB_start, mB_end)],
                                                                                    [L_by_hidden_ax, L_by_d_model_ax, B_ax], color='k')])
                layer_L_by_d_hidden_start = next_layer_L_by_d_hidden_start

        if show_axes:
            # draw_arrow(ax, (0, L*d_hidden), (d_model, 3), (B, 3), 'Layer', L_by_hidden_ax, L_by_d_model_ax, B_ax)
            draw_arrow(ax, (0, d_hidden/tp), (0, -2), (0, -2), '$d_2$', L_by_hidden_ax, L_by_d_model_ax, B_ax)
            draw_arrow(ax, (B_start, B_end), (0, -2), (d_model, 2), 'Batch', B_ax, L_by_hidden_ax, L_by_d_model_ax)
            draw_arrow(ax, (0, d_model), (0, -2), (0, -2), '$d_1$', L_by_d_model_ax, L_by_hidden_ax, B_ax)

    def microbatch_frame_gen_fwd(self, m, microbatches):
        for _ in range(5):
            yield
        steps_per_stage = 2*self.pipeline_intvl + 1

        # Activation fwd pass.
        for _ in range(m*steps_per_stage):
            yield
        for l0 in range(0, L, self.pipeline_intvl):
            if l0 > 0:
                yield from flash_solid(self.pp_comms[(l0-1)//self.pipeline_intvl][m], pp_comm_steps)
            for x in self.terminal_residual_activations[l0//self.pipeline_intvl][m]:
                set_color(x, 'g')
            if self.tp_comms:
                yield from flash_solid(self.input_tp_comms[l0//self.pipeline_intvl][m], tp_comm_steps)
            yield
            for l in range(l0, l0 + self.pipeline_intvl):
                yield from fwd(self.hidden_activations[l][m], self.flops1m[l][m])
                yield from fwd(self.residual_activations[l][m], self.flops2m[l][m])
                if self.tp_comms:
                    yield from flash_solid(self.output_tp_comms[l][m], tp_comm_steps)
        for _ in range((microbatches-m-1)*steps_per_stage):
            yield
        for _ in range(5):
            yield

    def microbatch_frame_gen_bwd(self, m, microbatches):
        # Activation bwd pass.
        for _ in range(m*(2*self.pipeline_intvl + 1)):
            yield
        for l0 in range(L - self.pipeline_intvl, -self.pipeline_intvl, -self.pipeline_intvl):
            make_grad(self.residual_activations[l0 + self.pipeline_intvl - 1][m])
            if self.tp_comms:
                yield from flash_solid(self.output_tp_comms[l0 + self.pipeline_intvl - 1][m], tp_comm_steps)
            yield
            for l in range(l0 + self.pipeline_intvl - 1, l0, -1):
                yield from bwd_act(self.hidden_activations[l][m], self.flops2m[l][m])
                if self.tp_comms:
                    yield from flash_solid(self.output_tp_comms[l-1][m], tp_comm_steps)
                yield from bwd_act(self.residual_activations[l-1][m], self.flops1m[l][m])
            yield from bwd_act(self.hidden_activations[l0][m], self.flops2m[l0][m])
            yield from bwd_act(self.terminal_residual_activations[l0//self.pipeline_intvl][m], self.flops1m[l0][m])
            if self.tp_comms:
                yield from flash_solid(self.input_tp_comms[l0//self.pipeline_intvl][m], tp_comm_steps)
            if l0 > 0:
                yield from flash_solid(self.pp_comms[(l0-1)//self.pipeline_intvl][m], pp_comm_steps)
                    
        for _ in range((microbatches-m-1)*(2*self.pipeline_intvl + 1)):
            yield
        yield

    def gen_bwd_weight(self, l0):
        # Weight bwd pass.
        for l in range(l0 + self.pipeline_intvl - 1, l0, -1):
            yield from bwd_weight(self.residual_activations[l], self.weights2[l], self.flops2m[l])
            yield from bwd_weight(self.hidden_activations[l], self.weights1[l], self.flops1m[l])
        yield from bwd_weight(self.residual_activations[l0], self.weights2[l0], self.flops2m[l0])
        yield from bwd_weight(self.hidden_activations[l0], self.weights1[l0], self.flops1m[l0], self.terminal_residual_activations[l0//self.pipeline_intvl])
        if self.dp_comms:
            for l in range(l0 + self.pipeline_intvl - 1, l0 - 1, -1):
                yield from flash_solid(self.w1_dp_comms[l], dp_comm_steps)
                yield from flash_solid(self.w2_dp_comms[l], dp_comm_steps)
        elif self.wait_for_dp:
            for _ in range(self.pipeline_intvl*(4*dp_comm_steps + 4)):
                yield
        for _ in range(15):
            yield

    def frame_gen(self):
        fwd_gen = zip(*[self.microbatch_frame_gen_fwd(m, self.microbatches) for m in range(self.microbatches)])
        bwd_act_gen = zip(*[self.microbatch_frame_gen_bwd(m, self.microbatches) for m in range(self.microbatches)])
        bwd_weight_gen = zip(*[self.gen_bwd_weight(l0) for l0 in range(0, L, self.pipeline_intvl)])
        return it.chain(fwd_gen, bwd_act_gen, bwd_weight_gen)

pipelines = []
for dpi in range(dp - 1, -1, -1):
    pipelines.append(Pipeline(ax, d_model, d_hidden, tp,
                              B/dp, dpi*(B/dp+dp_gpu_delta), (dpi+1)*(B/dp+dp_gpu_delta) if dpi < dp - 1 else None,
                              dp > 1,
                              microbatches,
                              L, pipeline_intvl,
                              show_axes=(dpi==0)))

ax.set_aspect('equal')
ax.grid(False)
ax.set_xticks([]) 
ax.set_yticks([])
ax.set_zticks([])

update_gen = zip(*[pipeline.frame_gen() for pipeline in pipelines])
def update(_):
    return None

ani = FuncAnimation(fig, update, frames=update_gen)
ani.save('batch.mp4', fps=3.75, extra_args=['-vcodec', 'libx264'])


  ani = FuncAnimation(fig, update, frames=update_gen)
