In [2]:
import numpy as np
import matplotlib.pyplot as plt

In [130]:
class LinearLayer:
    # A linear layer of shape (d_in × d_out) with (tp_in × tp_out)-degree tensor
    # parallelism, dp-degree data parallelism, and bytewidth precision.
    def __init__(self, d_in, d_out, tp_in, tp_out, dp, bytewidth):
        self.d_in = d_in
        self.d_out = d_out
        self.tp_in = tp_in
        self.tp_out = tp_out
        self.dp = dp
        self.bytewidth = bytewidth

        # Derived dimensions.
        self.N = d_in*d_out                             # number of parameters
        self.d_in_local = d_in/tp_in                    # local input dimension
        self.d_out_local = d_out/tp_out                 # local output dimension
        self.N_local = self.d_in_local*self.d_out_local # local number of parameters

class Device:
    def __init__(self, name, flop_per_sec_8bit, global_Bps, net_Bps,
                 base_util=0.8):
        self.name = name
        self.flop_per_sec_8bit = flop_per_sec_8bit
        self.global_Bps = global_Bps
        self.net_Bps = net_Bps
        self.base_util = base_util

    def __repr__(self):
        return self.name
    
    def all_reduce_communication(self, n, p):
        # The number of words this device has to send or receive to participate
        # in an all-reduce of a vector of local size n, with p - 1 other
        # devices.
        #
        # We assume a bandwidth-optimal reduce-scatter + multicast communication
        # pattern, where each device is responsible for reducing n/p words:
        #
        #  Step 1. Transmit p - 1 parts of size n/p to the devices responsible
        #          for reducing those parts.
        #
        #  Step 2. Receive p - 1 copies of the part of size n/p that I'm
        #          responsible for reducing from my peers.
        #
        #  Step 3. Reduce the p copies of my part and multicast the result to my
        #          peers (optimistically assumed to be free for simplicity).
        #
        #  Step 4. Receive p - 1 parts of size n/p that my peers were
        #          responsible for reducing.
        return 3*(p - 1)*n/p

    def linear_layer_fwd_bwd_secs(self, lin_layer, b):
        # Determine the time to compute a forward and backward pass through
        # lin_layer on a microbatch of b tokens. We take into account the time
        # of tensor FLOP, global memory IO, and tensor-parallel all-reduce
        # communication.
        #
        # Some assumptions (all "optimistic"):
        #  - Data movement and computation are fully overlappable, via the
        #    mechanism of multiple microbatches simultaneously in flight.
        #
        #  - Pipeline-parallel communication time is negligible (though deep
        #    pipelines may cause a bubble, which is modeled elsewhere).
        #
        #  - Pointwise operations such as activation functions and
        #    normalizations are negligible. So is attention and any
        #    communication it requires.
        #
        #  - b/dp is substantially smaller than d_in/tp_in and d_out/tp_out,
        #    where dp is the data parallel degree, tp_in is the input tensor
        #    parallel degree, and tp_out is the output tensor parallel degree.
        #    This is because this tends to be communication-optimal due to the
        #    significant communication costs of tensor parallelism compared to
        #    data parallelism. This in turn has a few implications:
        #
        #     1. The input and output matrix tends to be quite rectangular,
        #        which puts the on-chip data movement bottleneck at global
        #        memory rather than shared memory banks, L2 cache, or the
        #        SM-to-SM DSMEM network. This is because the SM- and warp-level
        #        tiles can still be large and approximately square-shaped, which
        #        removes these levels of the memory hierarchy as a bottleneck,
        #        however there isn't freedom to do this at the global level.
        #        Thus we do not model data movement on-chip except to/from
        #        global memory.
        #
        #     2. We assume activations and their gradients fit in L2 cache when
        #        they've recently been accessed (e.g. by a recent network
        #        receive, all-reduce, or the previous layer's matmul), but
        #        weights and their gradients cannot. Thus we only model global
        #        memory IO for weights and gradients, except on the backward
        #        pass where activations must be reloaded from global memory
        #        (where they are presumed to fit, avoiding the need for
        #        activation recomputation).
        #
        #  - The all-reduce communication pattern is a bandwidth-optimal
        #    reduce-scatter followed by a multicast, and we treat the multicast
        #    transmission as free for simplicity.
        #
        # We also don't worry about whether our parallelism degrees divide the
        # tensor dimensions.
        d_in_local, d_out_local = lin_layer.d_in_local, lin_layer.d_out_local
        N_local = lin_layer.N_local
        tp_in, tp_out, dp = lin_layer.tp_in, lin_layer.tp_out, lin_layer.dp
        bytewidth = lin_layer.bytewidth
        b_local = b/dp

        flop, global_io, net_io = 0, 0, 0

        # Forward pass.
        global_io += N_local                                                   # Load weights.
        flop      += 2*N_local*b_local                                         # Compute output activations: Y = WX.
        net_io    += self.all_reduce_communication(d_out_local*b_local, tp_in) # All-reduce activations (network IO).
        global_io += self.all_reduce_communication(d_out_local*b_local, tp_in) # All-reduce activations (global memory IO).

        # Backward pass (weight gradients).
        global_io += d_in_local*b_local                                        # Reload input activations from global memory.
        flop      += 2*N_local*b_local                                         # Compute weight gradients: dL/dW = (dL/dY)X^T.
        global_io += 2*N_local                                                 # Accumulate weight gradients (read + write).
    
        # Backward pass (activation gradients).
        global_io += N_local                                                   # Load weights.
        flop      += 2*N_local*b_local                                         # Compute input gradients: dL/dX = W^T(dL/dY).
        net_io    += self.all_reduce_communication(d_in_local*b_local, tp_out) # All-reduce input gradients (network IO).
        global_io += self.all_reduce_communication(d_in_local*b_local, tp_out) # All-reduce input gradients (global memory IO).

        # Total time. base_util due to thermal throttling is presumed to affect
        # FLOP/s and global memory B/s, but not net B/s.
        flop_secs = (flop*bytewidth)/(self.base_util*self.flop_per_sec_8bit)
        global_io_secs = (global_io*bytewidth)/(self.base_util*self.global_Bps)
        net_io_secs = (net_io*bytewidth)/self.net_Bps
        return np.maximum(flop_secs, np.maximum(global_io_secs, net_io_secs))

    def linear_layer_end_of_batch_secs(self, lin_layer):
        # Determine the time for the end-of-batch gradient all-reduce (we assume
        # the optimizer step itself is negligible).
        io = self.all_reduce_communication(lin_layer.N_local, lin_layer.dp)
        global_io_secs = (io*lin_layer.bytewidth)/(self.base_util*self.global_Bps)
        net_io_secs = (io*lin_layer.bytewidth)/self.net_Bps
        return np.maximum(global_io_secs, net_io_secs)

In [168]:
dev = Device('H100 SXM5', 1979e12, 3352e9, 900e9)
dev.linear_layer_fwd_bwd_secs(bytewidth=2, d_in=16384, d_out=5*16384, tp_in=8, tp_out=40, b=2048**2/128, dp=32)

3.2554072481051035e-05