### Implementations of collective communications (Tree & Ring)

In [1]:
import os
import torch
import torch.distributed as dist

In [2]:
class NCCL:
    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size

        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '12355'

        dist.init_process_group(
            backend='nccl',
            rank=rank,
            world_size=world_size
        )

        torch.cuda.set_device(self.rank)

    def all_reduce_ring(self, tensor: torch.tensor, op=dist.ReduceOp.SUM):
        if self.world_size == 1:
            return tensor
    
        send_rank = (self.rank + 1) % self.world_size
        recv_rank = (self.rank - 1 + self.world_size) % self.world_size

        chunks = torch.chunk(tensor, chunks=self.world_size)
        recv_buffer = torch.zeros_like(chunks[0], device=tensor.device)

        # reduce scatter
        for step in range(self.world_size - 1):
            send_chunk_idx = (self.rank - step) % self.world_size
            recv_chunk_idx = (self.rank - step - 1 + self.world_size) % self.world_size

            send_chunk = chunks[send_chunk_idx]
            recv_chunk = chunks[recv_chunk_idx]

            if recv_chunk.numel() != recv_buffer.numel():
                recv_buffer = torch.zeros_like(recv_chunk, device=tensor.device)

            send_handle = dist.isend(send_chunk, dst=send_rank)
            recv_handle = dist.irecv(recv_chunk, src=recv_rank)

            send_handle.wait()
            recv_handle.wait()

            if op == dist.ReduceOp.SUM:
                recv_chunk.add_(recv_buffer)

        
        # all gather
        for step in range(self.world_size - 1):
            send_chunk_idx = (self.rank - step - 1 + self.world_size) % self.world_size
            recv_chunk_idx = (self.rank - step) % self.world_size

            if recv_chunk.numel() != recv_buffer.numel():
                recv_buffer = torch.zeros_like(recv_chunk, device=tensor.device)

            send_handle = dist.isend(send_chunk, dst=send_rank)
            recv_handle = dist.irecv(recv_chunk, src=recv_rank)

            send_handle.wait()
            recv_handle.wait()

            recv_chunk.copy_(recv_buffer)

        return tensor
    
    def broadcast_mst(self, tensor: torch.tensor, root: int):
        if self.world_size == 1:
            return 

        left_child = 2 * self.rank + 1
        right_child = 2 * self.rank + 2

        if self.rank != root:
            parent = (self.rank - 1) // 2
            dist.recv(tensor, src=parent)

        if left_child < self.world_size:
            dist.send(tensor, dst=left_child)
        if right_child < self.world_size:
            dist.send(tensor, dst=right_child)

        return tensor
    
    def scatter_mst(self, tensor: torch.tensor, root: int):
        if self.world_size == 1:
            return tensor
        
        def get_parent(rank):
            return (rank - 1) // 2
        
        def get_children(rank):
            left_child = 2 * rank + 1
            right_child = 2 * rank + 2

            children = []
            if left_child < self.world_size:
                children.append(left_child)
            if right_child < self.world_size:
                children.append(right_child)

            return children
        
        def get_subtree_ranks(rank):
            ranks = [rank]
            children = get_children(rank)

            for child in children:
                ranks.extend(get_subtree_ranks(child))

            return ranks
        
        chunks = torch.chunk(tensor, chunks=self.world_size, dim=-1)

        if self.rank != root:
            parent = get_parent(self.rank)
            subtree_ranks = get_subtree_ranks(self.rank)
            subtree_size = len(subtree_ranks)

            chunks = torch.zeros(chunks[0].numel() * subtree_size, device=tensor.device)
            dist.recv(chunks, src=parent)
        
        children = get_children(self.rank)
        my_chunk = chunks[self.rank]

        for child in children:
            subtree_ranks = get_subtree_ranks(child)
            send_chunks = torch.cat([chunks[r] for r in subtree_ranks], dim=-1)
            dist.send(send_chunks, dst=child)

        return my_chunk
    
    def gather_mst(self, tensor: torch.tensor, root: int):
        if self.world_size == 1:
            return tensor
        
        def get_parent(rank):
            return (rank - 1) // 2
        
        def get_children(rank):
            left_child = 2 * rank + 1
            right_child = 2 * rank + 2

            children = []
            if left_child < self.world_size:
                children.append(left_child)
            if right_child < self.world_size:
                children.append(right_child)

            return children
        
        def get_subtree_size(rank):
            size = 1
            children = get_children(rank)

            for child in children:
                size += get_subtree_size(child)

            return size
        
        children = get_children(self.rank)
        child_data = []

        for child in children:
            size = get_subtree_size(child)
            recv_buffer = torch.zeros_like(tensor.numel() * size, device=tensor.device)
            dist.recv(recv_buffer, src=child)
            child_data.append(recv_buffer)

        data = None
        if len(child_data) == 0:
            data = tensor
        elif len(child_data) == 1:
            data = torch.concat([child_data[0], tensor])
        else:
            data = torch.concat([child_data[0], tensor, child_data[1]])
        
        if self.rank != root:
            parent = get_parent(self.rank)
            dist.send(data, dst=parent)
            return None
        else:
            return data
        
    def reduce_tree(self, tensor: torch.Tensor, root: int, op=dist.ReduceOp.SUM):
        if self.world_size == 1:
            return tensor

        def get_parent(rank):
            return (rank - 1) // 2
        
        def get_children(rank):
            left_child = 2 * rank + 1
            right_child = 2 * rank + 2

            children = []
            if left_child < self.world_size:
                children.append(left_child)
            if right_child < self.world_size:
                children.append(right_child)

            return children

        children = get_children(self.rank)
        recv_buffer = torch.zeros_like(tensor, dtype=torch.float32)

        for child in children:
            dist.recv(recv_buffer, src=child)
            
            if op == dist.ReduceOp.SUM:
                tensor.add_(recv_buffer)

        if self.rank != root:
            parent = get_parent(self.rank)
            dist.send(tensor, dst=parent)
            return None
        else:
            return tensor