In [1]:
import numpy as np


# ============================================================
# Configuration (Modify for Experiments)
# ============================================================

K = 3                 # Pipeline stages
D = 4                 # Data parallel replicas
M = 6                 # Micro-batches
GRAD_SIZE = 120       # Must be divisible by D
COMM_MODE = "INT8"    # "FP16" or "INT8"


# ============================================================
# Quantization Utilities
# ============================================================

def quantize_int8(t):
    max_val = np.max(np.abs(t))
    scale = max_val / 127.0 if max_val != 0 else 1.0
    q = np.round(t / scale).astype(np.int8)
    return q, scale

def dequantize_int8(q, scale):
    return q.astype(np.float32) * scale


# ============================================================
# Device Definition
# ============================================================

class Device:
    def __init__(self, rank, stage):
        self.rank = rank
        self.stage = stage
        self.weight = np.ones(GRAD_SIZE, dtype=np.float32)
        self.local_grad = np.zeros(GRAD_SIZE, dtype=np.float32)
        self.comm_bytes = 0


# ============================================================
# Hybrid Parallel Simulator
# ============================================================

class HybridSimulator:

    def __init__(self):

        self.devices = []
        self.timeline = []
        self.logical_time = 0

        # Create K × D devices
        for dp in range(D):
            for stage in range(K):
                rank = dp * K + stage
                self.devices.append(Device(rank, stage))

        # Group by pipeline stage (Data Parallel groups)
        self.stage_groups = {
            s: [d for d in self.devices if d.stage == s]
            for s in range(K)
        }

    # --------------------------------------------------------
    def log(self, rank, event):
        self.timeline.append((self.logical_time, rank, event))
        self.logical_time += 1

    # --------------------------------------------------------
    # 1F1B Pipeline Scheduling
    # --------------------------------------------------------
    def run_pipeline(self):

        total_steps = M + K - 1

        for step in range(total_steps):

            for d in self.devices:

                stage = d.stage
                warmup = K - stage - 1

                # Forward Phase
                if step < M and step >= stage:
                    self.log(d.rank, f"F{step - stage}")

                # Backward Phase
                if step >= warmup:
                    mb = step - warmup
                    if 0 <= mb < M:
                        self.log(d.rank, f"B{mb}")

                        # Gradient accumulation
                        grad = np.random.randn(GRAD_SIZE).astype(np.float32)
                        d.local_grad += grad

    # --------------------------------------------------------
    # True Chunked Ring All-Reduce (Deterministic)
    # --------------------------------------------------------
    def ring_all_reduce(self, stage):

        group = self.stage_groups[stage]

        # Stack gradients (D × GRAD_SIZE)
        grads = np.stack([d.local_grad for d in group])
        chunks = np.split(grads, D, axis=1)

        # -------- Reduce-Scatter --------
        for step in range(D - 1):
            for i in range(D):
                src = i
                dst = (i + 1) % D
                chunk_idx = (i - step) % D

                send_chunk = chunks[chunk_idx][src]

                if COMM_MODE == "FP16":
                    payload = send_chunk.astype(np.float16)
                    group[src].comm_bytes += payload.nbytes
                    recv_chunk = payload.astype(np.float32)
                else:
                    q, scale = quantize_int8(send_chunk)
                    group[src].comm_bytes += q.nbytes
                    recv_chunk = dequantize_int8(q, scale)

                target_idx = (dst - step - 1) % D
                chunks[target_idx][dst] += recv_chunk

        # -------- All-Gather --------
        for step in range(D - 1):
            for i in range(D):
                src = i
                dst = (i + 1) % D
                chunk_idx = (i - step) % D

                send_chunk = chunks[chunk_idx][src]

                if COMM_MODE == "FP16":
                    payload = send_chunk.astype(np.float16)
                    group[src].comm_bytes += payload.nbytes
                    recv_chunk = payload.astype(np.float32)
                else:
                    q, scale = quantize_int8(send_chunk)
                    group[src].comm_bytes += q.nbytes
                    recv_chunk = dequantize_int8(q, scale)

                target_idx = (dst - step - 1) % D
                chunks[target_idx][dst] = recv_chunk

        # Average
        final = np.concatenate(chunks, axis=1) / D

        for i, d in enumerate(group):
            d.local_grad = final[i]

    # --------------------------------------------------------
    def optimizer_step(self, stage):

        for d in self.stage_groups[stage]:
            self.log(d.rank, "OPT")
            d.weight -= 0.01 * d.local_grad
            d.local_grad[:] = 0

    # --------------------------------------------------------
    def run(self):

        print(f"\nHybrid Parallel Simulation (K={K}, D={D})")
        print(f"Total Nodes = {K * D}")
        print(f"Communication Mode = {COMM_MODE}\n")

        # 1F1B Pipeline
        self.run_pipeline()

        # Data Parallel Synchronization
        for stage in range(K):
            self.ring_all_reduce(stage)
            self.optimizer_step(stage)

        print("Simulation Completed.\n")

        self.verify()
        self.report_comm()
        self.print_timeline()

    # --------------------------------------------------------
    def verify(self):

        print("Replica Consistency Check:")
        for stage in range(K):
            weights = [d.weight for d in self.stage_groups[stage]]
            diffs = [np.linalg.norm(weights[0] - w) for w in weights]
            print(f"  Stage {stage}: Max Weight Difference = {max(diffs)}")

    # --------------------------------------------------------
    def report_comm(self):

        total = sum(d.comm_bytes for d in self.devices)
        print("\nCommunication Statistics:")
        print(f"  Total Communication Bytes = {total}")

    # --------------------------------------------------------
    def print_timeline(self):

        print("\n================ EXECUTION TIMELINE ================\n")

        header = f"{'Time':<6} | {'Rank':<6} | {'Stage':<6} | Event"
        print(header)
        print("-" * len(header))

        for t, rank, event in self.timeline:
            stage = self.devices[rank].stage
            print(f"{t:<6} | {rank:<6} | {stage:<6} | {event}")

        print("\n====================================================\n")


# ============================================================
if __name__ == "__main__":
    HybridSimulator().run()



Hybrid Parallel Simulation (K=3, D=4)
Total Nodes = 12
Communication Mode = INT8

Simulation Completed.

Replica Consistency Check:
  Stage 0: Max Weight Difference = 0.00024210211995523423
  Stage 1: Max Weight Difference = 0.0003216343466192484
  Stage 2: Max Weight Difference = 0.0002757864131126553

Communication Statistics:
  Total Communication Bytes = 2160


Time   | Rank   | Stage  | Event
--------------------------------
0      | 0      | 0      | F0
1      | 2      | 2      | B0
2      | 3      | 0      | F0
3      | 5      | 2      | B0
4      | 6      | 0      | F0
5      | 8      | 2      | B0
6      | 9      | 0      | F0
7      | 11     | 2      | B0
8      | 0      | 0      | F1
9      | 1      | 1      | F0
10     | 1      | 1      | B0
11     | 2      | 2      | B1
12     | 3      | 0      | F1
13     | 4      | 1      | F0
14     | 4      | 1      | B0
15     | 5      | 2      | B1
16     | 6      | 0      | F1
17     | 7      | 1      | F0
18     | 7      | 1      