# 分布式训练
* 将数据与模型 分布到 单机多卡与多机多卡上

In [5]:
import torch

# 设备数
devie_count = torch.cuda.device_count()
# 获取某个设备
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 迁移模型或数据
data=torch.ones((3,3))
print(data.device)

device = torch.device('cuda:0')
data_gpu = data.to(device)
print(data_gpu.device)

net = torch.nn.Sequential(torch.nn.Linear(3,3))
net.to(device)


cpu
cuda:0


Sequential(
  (0): Linear(in_features=3, out_features=3, bias=True)
)

## 单机多卡
`torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)`
* model 模型
* device_ids GPU设备号
* output_device 输出结果的设备，默认为0，第一块卡

在模型推理过程中，数据被划分多个块，推送到不同的GPU，但是模型在每个GPU都会复制一份
```
class ASimpleNet(nn.Module):
    def __init__(self, layers=3):
        super(ASimpleNet, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(3, 3, bias=False) for i in range(layers)])
    def forward(self, x):
        print("forward batchsize is: {}".format(x.size()[0]))
        x = self.linears(x)
        x = torch.relu(x)
        return x
        
batch_size = 16
inputs = torch.randn(batch_size, 3)
labels = torch.randn(batch_size, 3)
inputs, labels = inputs.to(device), labels.to(device)
net = ASimpleNet()
net = nn.DataParallel(net)
net.to(device)
print("CUDA_VISIBLE_DEVICES :{}".format(os.environ["CUDA_VISIBLE_DEVICES"]))

for epoch in range(1):
    outputs = net(inputs)

# Get:
# CUDA_VISIBLE_DEVICES : 3, 2, 1, 0
# forward batchsize is: 4
# forward batchsize is: 4
# forward batchsize is: 4
# forward batchsize is: 4

```

## 多机多卡
### DP (dataparallel)
* 单进程控制多GPU，将输入的一个batch数据分成了n份，分别送到对应的GPU进行计算。
* 前向传播时，模型从主GPU复制到其他GPU，反向传播时，每个GPU上的梯度汇总到主GPU，主GPU求得梯度均值更新模型参数后，在把模型复制到其他GPU
* 主GPU承担梯度汇总和模型更新，以及下发任务，负载高  

### DDP (distributedDataParallel)
* 多进程控制GPU
* 数据加载采用分布式数据采集器，确保数据在各个进程没有重叠
* 反向传播时，各个GPU梯度计算完成后，以广播形式将梯度汇总平均，然后各个进程在各自的GPU上进行梯度更新，确保各个GPU上的模型参数保持一致
* 无需在GPU之间复制模型，DDP传输数据量更少，速度更快
* 也适应于单机多卡

#### DDP训练
  * group  进程组,默认一个组
  * world_size，全局进程个数
  * rank 进程序号,用于进程间通信，表示进程优先级，rank=0表示主节点
  * ![](images/distribution-train-2022-10-03-10-20-37.png)
    * `torch.distributed.init_process_group(backend, init_method=None,, world_size=-1, rank=-1, group_name='')`
    * 1初始化进程组 
      * backend 通信所用后端，可以使nccl(gpu)或者gloo(cpu)
      * init_method 指定进程的初始化方式，默认"env://",表示从环境变量初始化，也可以通过TCP方式共享文件系统
      * world_size 执行训练的所有进程数，一般表示多少个节点
      * rank 进程的编号，也是其优先级，表示当前节点的编号
      * group_name 进程组名字
    * 2模型并行化
    * `torch.nn.parallel.DistributedDataParallel(module, device_ids=None, output_device=None, dim=0）`
    * 3创建分布式数据采集器
    * `train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)`    
    * `data_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler)`
    * 

In [None]:

# imagenet demo https://github.com/pytorch/examples/blob/master/imagenet/main.py

if args.distributed: # 使用DDP
     if args.dist_url == "env://" and args.rank == -1:
         args.rank = int(os.environ["RANK"])
     if args.multiprocessing_distributed:
         # For multiprocessing distributed training, rank needs to be the
         # global rank among all the processes
         args.rank = args.rank * ngpus_per_node + gpu
     dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                             world_size=args.world_size, rank=args.rank)


In [None]:

if not torch.cuda.is_available():
    print('using CPU, this will be slow')
elif args.distributed:
    # For multiprocessing distributed, DistributedDataParallel constructor
    # should always set the single device scope, otherwise,
    # DistributedDataParallel will use all available devices.
    if args.gpu is not None:
        torch.cuda.set_device(args.gpu)
        model.cuda(args.gpu)
        # When using a single GPU per process and per
        # DistributedDataParallel, we need to divide the batch size
        # ourselves based on the total number of GPUs we have
        args.batch_size = int(args.batch_size / ngpus_per_node)
        args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
    else:
        model.cuda()
        # DistributedDataParallel will divide and allocate batch_size to all
        # available GPUs if device_ids are not set
        model = torch.nn.parallel.DistributedDataParallel(model)
elif args.gpu is not None:
    torch.cuda.set_device(args.gpu)
    model = model.cuda(args.gpu)
else:
    # DataParallel will divide and allocate batch_size to all available GPUs
    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()
    else:
        model = torch.nn.DataParallel(model).cuda()

In [None]:

if args.distributed:
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
    train_sampler = None
# 在建立dataloader过程中，如果sampler不是none，那么shuffle参数不应该被设置
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
    num_workers=args.workers, pin_memory=True, sampler=train_sample

#### 启动进程
* 为每个机器节点上的gpu启动进程，
* `torch.multiprocessing.spawn` 在一个节点启动该节点所有进程

In [None]:
# ngpus_per_node 每个节点的gpu数量
 ngpus_per_node = torch.cuda.device_count()
 if args.multiprocessing_distributed:
     # Since we have ngpus_per_node processes per node, the total world_size
     # needs to be adjusted accordingly
     args.world_size = ngpus_per_node * args.world_size
     # Use torch.multiprocessing.spawn to launch distributed processes: the
     # main_worker process function
     mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
 else:
     # Simply call main_worker function
     main_worker(args.gpu, ngpus_per_node, args)

In [None]:

# 只在主节点保存模型
if not args.multiprocessing_distributed or (args.multiprocessing_distributed
         and args.rank % ngpus_per_node == 0):
     save_checkpoint({
         'epoch': epoch + 1,
         'arch': args.arch,
         'state_dict': model.state_dict(),
         'best_acc1': best_acc1,
         'optimizer' : optimizer.state_dict(),
     }, is_best)