Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TORCH] torch distributed test script #4

Open
kuizhiqing opened this issue Mar 23, 2023 · 2 comments
Open

[TORCH] torch distributed test script #4

kuizhiqing opened this issue Mar 23, 2023 · 2 comments

Comments

@kuizhiqing
Copy link
Owner

import torch
torch.distributed.init_process_group(backend="nccl", init_method="env://")
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
print(f"rank {rank} world_size {world_size}")
a = torch.tensor([1]).cuda()
torch.distributed.all_reduce(a)
print(f"rank {rank} world_size {world_size} {a}")
torch.distributed.barrier()
print(f"rank {rank} world_size {world_size}")
@kuizhiqing
Copy link
Owner Author

kuizhiqing commented Jul 11, 2023

import sys
import paddle
import paddle.distributed.fleet as fleet
import time


def get_rank(rank, mp_degree, pp_degree, dp_degree):
    mp_rank = rank % mp_degree
    assert (rank - mp_rank) % mp_degree == 0
    tmp = int((rank - mp_rank) / mp_degree)
    dp_rank = int(tmp / pp_degree)
    pp_rank = tmp % pp_degree
    assert dp_rank * pp_degree * mp_degree + pp_rank * mp_degree + mp_rank == rank
    return mp_rank, pp_rank, dp_rank


def main(mp_degree, pp_degree, dp_degree):
    fleet.init(is_collective=True)
    world_size = paddle.distributed.get_world_size()
    
    if dp_degree is None:
        assert world_size % (mp_degree * pp_degree) == 0
        dp_degree = int(world_size / (mp_degree * pp_degree))
    else:
        assert mp_degree * pp_degree * dp_degree == world_size
        
    items = [(rank,) + get_rank(rank, mp_degree, pp_degree, dp_degree) for rank in range(world_size)]
    
    all_groups = []
    my_groups = []
    for i in range(mp_degree):
        for j in range(pp_degree):
            ranks = [item[0] for item in items if item[1] == i and item[2] == j]
            group = paddle.distributed.new_group(ranks)
            all_groups.append(group)
            if paddle.distributed.get_rank() in ranks:
                my_groups.append(group)
    
    assert len(my_groups) == 1, len(my_groups)
    group = my_groups[0]
    
    nbytes = int(65 * 1024 * 1024 * 1024 * 4 / (1.0 * mp_degree * pp_degree))
    dtype = paddle.float32
    num = int(nbytes / 4)
    
    x = paddle.zeros([num], dtype=dtype)
    
    def allreduce_test(iteration):
        paddle.distributed.barrier()
        paddle.device.cuda.synchronize()
        start_t = time.time()
        for _ in range(iteration):
            paddle.distributed.all_reduce(x, group=group)
        paddle.device.cuda.synchronize()
        end_t = time.time()
        return (end_t - start_t) / iteration
    
    allreduce_test(10)
    print('Warmup ends')
    rank = paddle.distributed.get_rank()
    ranks = group.ranks
    local_device_num = paddle.device.cuda.device_count()
    local_rank = rank % local_device_num
    node_id = int(rank / local_device_num)
    ret = allreduce_test(50)
    
    print('ranks {} node_id {} local_rank {} {} : {}'.format(ranks, node_id, local_rank, num, ret))
    

if __name__ == "__main__":
    assert len(sys.argv) in [3, 4]
    mp_degree = int(sys.argv[1])
    pp_degree = int(sys.argv[2])
    if len(sys.argv) == 3:
        dp_degree = None
    else:
        dp_degree = int(sys.argv[3])
        
    main(mp_degree, pp_degree, dp_degree)

@kuizhiqing
Copy link
Owner Author

import filelog
import os
import time
import torch

torch.distributed.init_process_group(backend="nccl", init_method="env://")
world_size = torch.distributed.get_world_size()
rank = torch.distributed.get_rank()

local_device_num = torch.cuda.device_count()
local_rank = int(os.environ["LOCAL_RANK"])

torch.cuda.set_device(local_rank)

node_num = world_size // local_device_num 
node_ids = [0, 1]

torch.distributed.barrier()

# import nccl_env

groups = []
for i in range(local_device_num): 
    ranks = list([j * local_device_num + i for j in range(node_num)])  
    new_ranks = []
    for r in ranks:
        tmp_node_id = r // local_device_num   
        if tmp_node_id in node_ids:
            new_ranks.append(r) 
            
    group = torch.distributed.new_group(new_ranks)
    if rank in new_ranks:
        groups.append(group)

# if local_rank < 4: 
#    num = 5313036288    
# else:
#    num = 5245960192    
num = 5313036288    

dtype = torch.float32
x = torch.zeros([num // 4], dtype=dtype).cuda()

def allreduce_test(num): 
    torch.distributed.barrier()
    if len(groups) == 0:
         return
    # if local_rank not in [0, 2, 3, 4, 5, 6, 7]:
    #    return

    torch.cuda.synchronize()
    start_t = time.time()
    for _ in range(num):
        torch.distributed.all_reduce(x, group=groups[0])
    torch.cuda.synchronize()
    end_t = time.time()
    return (end_t - start_t) / num

allreduce_test(10)
print('Warmup ends')
ranks = torch.distributed.get_process_group_ranks(groups[0]) if len(groups) > 0 else None 
print('ranks {} node_id {} local_rank {} {} : {}'.format(ranks, rank // local_device_num, local_rank, num, allreduce_test(50)))

torch.distributed.barrier()
import os
import sys
import atexit  


class FileLogger(object):
    def __init__(self, path):
        self.fids = [open(path, 'w'), sys.stdout, sys.stderr] 

    def write(self, *args, **kwargs):
        for fid in self.fids:
            fid.write(*args, **kwargs)
        self.flush()

    def flush(self):
        for fid in self.fids:
            fid.flush()

    def close(self):
        self.fids[0].close() 


def redirect(): 
    rank = int(os.getenv('RANK', '0'))
    world_size = int(os.getenv('WORLD_SIZE', '1'))
    dir_path = 'torch_test_logs_{}'.format(world_size)
    file_name = 'log_{}_{}.log'.format(rank, world_size)
    os.makedirs(dir_path, exist_ok=True)
    logger = FileLogger(os.path.join(dir_path, file_name))
    sys.stdout = logger
    sys.stderr = logger
    atexit.register(lambda: logger.close())

    
redirect()

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant