### Implementations of collective communications (Tree & Ring)

In [2]:
import torch
import torch.distributed as dist

In [3]:
class NCCL:
    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size

        dist.init_process_group(
            backend='nccl',
            rank=rank,
            world_size=world_size
        )

        torch.cuda.set_device(device=self.rank)

    def all_reduce_ring(self, tensor: torch.Tensor, op=dist.ReduceOp.SUM):
        if self.world_size == 1:
            return tensor
        
        send_rank = (self.rank + 1) % self.world_size
        recv_rank = (self.rank - 1 + self.world_size) % self.world_size

        chunk_size = tensor.numel() // self.world_size
        remainder = tensor.numel % self.world_size

        chunks = []
        start = 0
        for i in range(self.world_size):
            size = chunk_size + (1 if i < remainder else 0)
            chunks.append(tensor[start : start + size])
            start += size

        recv_buffer = torch.zeros_like(chunks[0], device=tensor.device)

        # reduce_scatter
        for step in range(self.world_size - 1):
            send_chunk_idx = (self.rank - step + self.world_size) % self.world_size
            recv_chunk_idx = (self.rank - step - 1 + self.world_size) % self.world_size

            send_chunk = chunks[send_chunk_idx]
            recv_chunk = chunks[recv_chunk_idx]

            if recv_chunk.numel() != recv_buffer.numel():
                recv_buffer = torch.zeros_like(recv_chunk, device=tensor.device)

            send_handle = dist.isend(send_chunk, dst=send_rank)
            recv_handle = dist.irecv(recv_buffer, src=recv_rank)

            send_handle.wait()
            recv_handle.wait()

            if op == dist.ReduceOp.SUM:
                recv_chunk.add_(recv_buffer)

        # all_gather
        for step in range(self.world_size - 1):
            send_chunk_idx = (self.rank - step - 1 + self.world_size) % self.world_size
            recv_chunk_idx = (self.rank - step + self.world_size) % self.world_size

            send_chunk = chunks[send_chunk_idx]
            recv_chunk = chunks[recv_chunk_idx]

            if recv_chunk.numel() != recv_buffer.numel():
                recv_buffer = torch.zeros_like(recv_chunk, device=tensor.device)

            send_handle = dist.isend(send_chunk, dst=send_rank)
            recv_handle = dist.irecv(recv_buffer, src=recv_rank)

            send_handle.wait()
            recv_handle.wait()

            if op == dist.ReduceOp.SUM:
                recv_chunk.copy_(recv_buffer)

        return tensor
    
    def broadcast_mst(self, tensor, root):
        if self.world_size == 1:
            return tensor
        
        left_child = 2 * self.rank + 1
        right_child = 2 * self.rank + 2

        if self.rank != root:
            parent = (self.rank - 1) // 2
            dist.recv(tensor, src=parent)

        if left_child < self.world_size:
            dist.send(tensor, dst=left_child)
        if right_child < self.world_size:
            dist.send(tensor, dst=right_child)

        return tensor
    
    def scatter_tree(self, tensor, root):
        if self.world_size == 1:
            return tensor
        
        def get_parent(rank):
            return (rank - 1) // 2
        
        def get_children(rank):
            left = 2 * rank + 1
            right = 2 * rank + 2

            children = []
            if left < self.world_size:
                children.append(left)
            if right < self.world_size:
                children.append(right)

            return children
        
        def get_subtree_ranks(rank):
            ranks = [rank]
            children = get_children(rank)

            for child in children:
                ranks.extend(get_subtree_ranks(child))

            return ranks
        
        chunk_size = tensor.numel() // self.world_size
        if self.rank == root:
            chunks = tensor.split(chunk_size)

            rank_to_chunk = {i: chunks[i] for i in range(self.world_size)}
            my_chunk = rank_to_chunk[self.rank]

            for child in get_children(self.rank):
                subtree_ranks = get_subtree_ranks(child)
                subtree_chunks = torch.cat([rank_to_chunk[r] for r in subtree_ranks])
                dist.send(subtree_chunks.contiguous(), dst=child)
        else:
            parent = get_parent(self.rank)
            subtree_ranks = get_subtree_ranks(self.rank)
            subtree_len = len(subtree_ranks)

            chunks_buffer = torch.zeros_like(chunk_size * subtree_len, device=tensor.device)
            dist.recv(chunks_buffer, src=parent)

            chunks = chunks_buffer.split(subtree_len)
            rank_to_chunk = {r: chunks[r] for r in subtree_ranks}

            my_chunk = rank_to_chunk[self.rank]

            for child in get_children(self.rank):
                subtree_ranks = get_subtree_ranks(child)
                subtree_chunks = torch.cat([rank_to_chunk[r] for r in subtree_ranks])
                dist.send(subtree_chunks.contiguous(), dst=child)

        return my_chunk
    
    def gather_mst(self, tensor, root):
        if self.world_size == 1:
            return tensor
        
        def get_parent(rank):
            return (rank - 1) // 2
        
        def get_children(rank):
            left = 2 * rank + 1
            right = 2 * rank + 2

            children = []
            if left < self.world_size:
                children.append(left)
            if right < self.world_size:
                children.append(right)

            return children
        
        def get_subtree_size(rank):
            size = 1
            children = get_children(rank)
            for child in children:
                size += get_subtree_size(child)

            return size
        
        chunk_size = tensor.numel()
        children = get_children(self.rank)
        child_data = []

        for child in children:
            subtree_size = get_subtree_size(child)
            buffer = torch.zeros_like(subtree_size * chunk_size, device=tensor.device)
            dist.recv(buffer, src=child)
            child_data.append(buffer)

        data = None
        if len(child_data) == 0:
            data = tensor
        elif len(child_data) == 1:
            data = torch.cat([child_data[0], tensor])
        else:
            data = torch.cat([child_data[0], tensor, child_data[1]])
        
        if self.rank != root:
            parent = get_parent(self.rank)
            dist.send(data, dst=parent)
            return None
        else:
            return data
        
    def reduce_tree(self, tensor: torch.Tensor, root: int, op=dist.ReduceOp.SUM):
        if self.world_size == 1:
            return tensor

        def get_parent(rank):
            return (rank - 1) // 2
        
        def get_children(rank):
            left_child = 2 * rank + 1
            right_child = 2 * rank + 2

            children = []
            if left_child < self.world_size:
                children.append(left_child)
            if right_child < self.world_size:
                children.append(right_child)

            return children

        children = get_children(self.rank)
        recv_buffer = torch.zeros_like(tensor, dtype=torch.float32)

        for child in children:
            dist.recv(recv_buffer, src=child)
            
            if op == dist.ReduceOp.SUM:
                tensor.add_(recv_buffer)

        if self.rank != root:
            parent = get_parent(self.rank)
            dist.send(tensor, dst=parent)
            return None
        else:
            return tensor