# PyTorch 分布式实践

本文主要对 `PyTorch` 的[分布式 API ](https://pytorch.org/docs/stable/distributed.html#distributed-basics)做了一些非常简单的探索。

另外还有两个支持分布式的 `Accuracy` 和 `Mean Average Precision` 的实现，分别使用 `reduce` 和 `gather`。

## 环境

文章的所有代码均可实际运行，下面是文章写作时依赖的软件版本，你也完全可以尝试其他版本。

* Python==3.7.0
* torch==1.10.0

你也可以直接从 Colab 中实验本文的代码。

## Helper 方法

In [1]:
import os

def clear_log():
    with open("/tmp/pytorch-dist.log", "w") as f:
        f.write("")

os.environ['OMP_NUM_THREADS'] = "1"

In [2]:
%%file /tmp/helper.py

import torch.distributed as dist
import logging


def init_logger():
    assert dist.is_initialized()
    rank = dist.get_rank()
    role = "master" if rank == 0 else "slave"

    logging.basicConfig(filename="/tmp/pytorch-dist.log",
                        filemode='a',
                        format=f'{role.ljust(6)} ==> %(asctime)8s %(levelname)5s %(message)4s',
                        datefmt='%H:%M:%S',
                        level=logging.DEBUG)

    
def is_master():
    return dist.get_rank() == 0


def is_slave():
    return dist.get_rank() == 1

Overwriting /tmp/helper.py


## 获取环境信息

用于获取当前节点的环境信息，比如使用的 backend，当前节点的 rank，以及所有节点的数量等。

In [4]:
%%file /tmp/test.py

import torch
import logging
from helper import init_logger, is_master, is_slave

torch.distributed.init_process_group(backend="gloo")

init_logger()

logging.info(f"backend: {torch.distributed.get_backend()}")
logging.info(f"rank: {torch.distributed.get_rank()}")
logging.info(f"world size: {torch.distributed.get_world_size()}")

Overwriting /tmp/test.py


In [5]:
!echo > /tmp/pytorch-dist.log && torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log

master ==> 13:52:24  INFO backend: gloo
master ==> 13:52:24  INFO rank: 0
master ==> 13:52:24  INFO world size: 2
slave  ==> 13:52:24  INFO backend: gloo
slave  ==> 13:52:24  INFO rank: 1
slave  ==> 13:52:24  INFO world size: 2


## 分布式 Key-Value 存储

类似分布式的字典数据结构，但 Key Value 存储只允许存字符串，所以似乎一般是用来在节点之间同步一些元信息的。

当然非要存 `torch.Tensor` 也是可以的，只需要 base64 一下就可以了，下面的例子有展示。

In [7]:
%%file /tmp/test.py

import torch
import pickle
import time
from base64 import b64encode, b64decode
import torch.distributed as dist
import logging
from helper import init_logger, is_master, is_slave

dist.init_process_group(backend="gloo")
init_logger()

store = dist.TCPStore(
    host_name="127.0.0.1",
    port=5848,
    world_size=dist.get_world_size(),
    is_master=is_master(),
)

if is_slave():
    store.set("the_key", "Hi master")
    logging.info("slave has set the key")

if is_master():
    result = store.get("the_key")
    logging.info(f"master has get the value: {result}")
    
    b64tensor = b64encode(pickle.dumps(torch.rand(4))).decode()
    store.set("another_key", b64tensor)
    logging.info(f"master set a torch.Tensor object to the same key")
    time.sleep(2)  # should wait for the slave to read the data before quit

if is_slave():
    time.sleep(1)  # wait for master to set the torch.Tensor object
    store.get("another_key")
    tensor = pickle.loads(b64decode(store.get("another_key")))
    logging.info(f"slave got object from the same key {tensor}")

Overwriting /tmp/test.py


In [8]:
!echo > /tmp/pytorch-dist.log && torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log

slave  ==> 13:52:51  INFO slave has set the key
master ==> 13:52:51  INFO master has get the value: b'Hi master'
master ==> 13:52:51  INFO master set a torch.Tensor object to the same key
slave  ==> 13:52:52  INFO slave got object from the same key tensor([0.4105, 0.1806, 0.0714, 0.2835])


## 点到点数据传输

某个节点定向向另外一个节点发送数据。

点到点只能传 `torch.Tensor`，接收方需要提前创建好承载的 `torch.Tensor` 变量传入 `recv` 中，然后 `recv` 会做 inplace 修改。

In [10]:
%%file /tmp/test.py

import torch
import pickle
import time
from base64 import b64encode, b64decode
import torch.distributed as dist
import logging
from helper import init_logger, is_master, is_slave

dist.init_process_group(backend="gloo")
init_logger()

if is_master():
    res = torch.rand(4)
    dist.recv(res)
    logging.info(f"master has recv a tensor: {res}")
    time.sleep(2)

if is_slave():
    payload = torch.rand(4)
    time.sleep(1)  # wait until the master is ready
    res = dist.send(payload, 0)
    logging.info(f"slave has sent a tensor: {payload}")

Overwriting /tmp/test.py


In [11]:
!echo > /tmp/pytorch-dist.log && torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log

slave  ==> 13:53:03  INFO slave has sent a tensor: tensor([0.1837, 0.5282, 0.9892, 0.8568])
master ==> 13:53:03  INFO master has recv a tensor: tensor([0.1837, 0.5282, 0.9892, 0.8568])


## Collective functions

这是本文的主角，在实现一些分布式的功能时，用的最多的应该是这类方法了。

这类方法一般来说会同时操作所有的节点。

### broadcast

将某个 `torch.Tensor` 同步给全部其他节点

In [13]:
%%file /tmp/test.py

import time
import torch
import torch.distributed as dist
import logging
from helper import init_logger, is_master, is_slave

dist.init_process_group(backend="gloo")
init_logger()

if is_slave():
    tensor = torch.zeros(3)

if is_master():
    tensor = torch.rand(3)

dist.broadcast(tensor, src=0)

time.sleep(1)

logging.info(tensor)

Overwriting /tmp/test.py


In [14]:
!echo > /tmp/pytorch-dist.log && torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log

master ==> 13:53:15  INFO tensor([0.3944, 0.0291, 0.6223])
slave  ==> 13:53:15  INFO tensor([0.3944, 0.0291, 0.6223])


### broadcast_object_list

将某个包含 Python object 的数组，同步给所有其他节点

In [16]:
%%file /tmp/test.py

import time
import torch.distributed as dist
import logging
from helper import init_logger, is_master, is_slave

dist.init_process_group(backend="gloo")
init_logger()

if is_slave():
    objects = [None, None, None]

if is_master():
    objects = ["foo", 12, {1: 2}]

dist.broadcast_object_list(objects, src=0)

time.sleep(1)

logging.info(objects)

Overwriting /tmp/test.py


In [17]:
!echo > /tmp/pytorch-dist.log && torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log

master ==> 13:53:22  INFO ['foo', 12, {1: 2}]
slave  ==> 13:53:22  INFO ['foo', 12, {1: 2}]


### all_reduce

reduce 意为「聚合」。

该操作将对所有节点上的某个 `torch.Tensor` 进行某个聚合操作 `torch.distributed.ReduceOp`（可以是求均值、求和、求最小/最大值等），然后再将结果同步到所有的节点上。

典型的使用场景是梯度同步：

所有节点首先分别用自己的 batch 获得梯度，然后节点之间将各自获得梯度求均值，再将结果同步给所有的节点。

PyTorch 本身的 `torch.nn.parallel.DistributedDataParallel` 内部也是使用了该方法来同步梯度，因此在实际的训练中直接使用 `DistributedDataParallel` 就好，如果要用 `all_reduce` 自己实现分布式训练还有其他非常多的细节，感兴趣的同学可以查看 `DistributedDataParallel` 的源码。

In [19]:
%%file /tmp/test.py

import time
import torch
import torch.distributed as dist
import logging
from helper import init_logger, is_master, is_slave

dist.init_process_group(backend="gloo")
init_logger()

if is_master():
    tensor = torch.tensor([3, 4])

if is_slave():
    tensor = torch.tensor([5, 6])

logging.info(f"before reduce: {tensor}")

dist.all_reduce(tensor, op=dist.ReduceOp.SUM)

logging.info(f"after reduce: {tensor}")

Overwriting /tmp/test.py


In [None]:
!echo > /tmp/pytorch-dist.log && torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log

### reduce

对所有节点上的 `torch.Tensor` 做某个聚合操作，最后仅同步到单个节点。

典型的应用场景是某一些支持分布式的 `Metric` 的实现。

例如目前可用的节点有 2 个，要计算得分的样本有 100 个，我们可以把这 100 个样本分成两份，然后交给两个节点同步计算每个样本的得分，然后最后使用 `reduce` 操作对所有 100 个样本计算均值，再同步到 `master`（rank=0）的节点。再由 `master` 节点来打日志会绘制曲线。

下面会有一个分布式计算 `Accuracy` 的例子。

In [None]:
%%file /tmp/test.py

import time
import torch
import torch.distributed as dist
import logging
from helper import init_logger, is_master, is_slave

dist.init_process_group(backend="gloo")
init_logger()

if is_master():
    tensor = torch.tensor([3, 4])

if is_slave():
    tensor = torch.tensor([5, 6])

logging.info(f"before reduce: {tensor}")

dist.reduce(tensor, op=dist.ReduceOp.SUM, dst=0)

logging.info(f"after reduce: {tensor}")

In [None]:
!echo > /tmp/pytorch-dist.log && torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log

### all_gather

gather 意为「聚集」。

跟名称一样，将所有节点上的 `torch.Tensor` 搜集到一个 `list` 中，然后将该 list 同步给所有的节点。

`all_gather` 与之对应也有一个 `gather` 操作，其不同点仅仅在于最终结果是同步给单个节点还是全部节点（与 `reduce` 和 `all_reduce` 的区别相同）。

跟 `reduce` 操作的区别是，`gather` 操作仅仅是搜集张量，并不做任何的聚合操作。

`gather` 方法用于处理比 `reduce` 方法要求更苛刻的场景，例如在全部节点上的所有的张量无法仅仅通过 `torch.distributed.ReduceOp` 来得到结果，而是需要更加复杂的计算。下面会使用 `gather` 操作来实现一个跟 `pycocotools` 相兼容的分布式版本的 `mAP Metric`。

In [None]:
%%file /tmp/test.py

import torch
import torch.distributed as dist
import logging
from helper import init_logger, is_master, is_slave

dist.init_process_group(backend="gloo")
init_logger()

if is_master():
    tensor = torch.tensor([3, 4])

if is_slave():
    tensor = torch.tensor([5, 6])

tensor_list = [torch.zeros(2, dtype=torch.int64) for _ in range(2)]

dist.all_gather(tensor_list, tensor)

logging.info(f"gathered: {tensor_list}")

In [None]:
!echo > /tmp/pytorch-dist.log && torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log

### scatter

scatter 意为「分发」。

将某个节点上的一些张量（一个包含张量的 List），分别发个其余节点。接收张量的节点（也包含分发的节点自己）需要提前创建好用于接收张量的变量。

In [None]:
%%file /tmp/test.py

import torch
import torch.distributed as dist
import logging
from helper import init_logger, is_master, is_slave

dist.init_process_group(backend="gloo")
init_logger()
tensor = torch.tensor([0, 0]).float()

if is_master():
    tensors = [torch.ones(2), torch.ones(2) * 2]
else:
    tensors = None
dist.scatter(tensor, tensors, src=0)

logging.info(f"scattered: {tensor}")

In [None]:
!echo > /tmp/pytorch-dist.log && torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log

## 示例一：支持分布式的 Accuracy

Accuracy 应该是最简单的 metric 了，只需要用预测正确的数量（tp + tn）除以总的样本数即可。

在下面这个实现中，我们给 `Accuracy` 这个类定义了两个变量 `self._num_correct` 和 `self._num_samples`，分别用来记录当前正确预测的样本数量和当前已预测的总的样本数量。在分布式环境下，不同的节点会对仅属于该节点的数据计算这两个值。当全部数据预测完毕之后，调用 `compute` 方法获取最后的结果。在 `compute` 方法内部，会使用 `dist.reduce` 操作来将所有节点的 `self._num_correct` 和 `self._num_samples` 相加，然后把结果发送给 `rank==0` 的节点。

该实现也支持单进程模式，实际上只需要一条 `if` 语句就实现了，细节可以参考代码。

In [None]:
%%file /tmp/metric.py
import torch
import torch.distributed as dist


class Accuracy:
    """预测正确的样本数 / 总的样本数
    """

    def __init__(self):
        self.reset()

    def reset(self):
        self._num_correct = torch.tensor(0)
        self._num_samples = torch.tensor(0)

    def update(self, labels, logits):
        """计算每个 mini batch 的结果

        Args:
            labels(torch.Tensor): 标签，一维的向量
            logits(torch.Tensor): 模型输出值，二维矩阵
        """
        preds = torch.argmax(logits, dim=1)
        correct = (labels == preds).sum()
        self._num_correct += correct.item()
        self._num_samples += len(labels)

    def compute(self):
        """得到最终的结果，这里我们需要通过 `reduce` 方法，来将所有节点
        上的 self._num_correct 和 self._num_samples 这两个变量以
        `sum` 的方式聚合，然后再在主节点上做一个除法即可。
        """
        if dist.is_initialized():
            # 仅在分布式环境下做 reduce 操作
            dist.reduce(
                self._num_correct,
                dst=0,
                op=torch.distributed.ReduceOp.SUM
            )
            dist.reduce(
                self._num_samples,
                dst=0,
                op=torch.distributed.ReduceOp.SUM
            )

        return round((self._num_correct / self._num_samples).item(), 2)

下面分别在单进程和分布式环境下使用上面实现的这个 `Accuracy`，结果应该是一样的。

首先需要 mock 一些数据。

In [None]:
%%file /tmp/mock.py
import torch

labels = torch.tensor([2, 0, 2, 1, 0, 1])
logits = torch.tensor([
    [0.0266, 0.1719, 0.3055],
    [0.6886, 0.3978, 0.8176],
    [0.9230, 0.0197, 0.8395],
    [0.1785, 0.2670, 0.6084],
    [0.8448, 0.7177, 0.7288],
    [0.7748, 0.9542, 0.8573],
])

在分布式环境下计算得分：

In [None]:
%%file /tmp/test.py
import torch
import torch.distributed as dist
import logging
from helper import init_logger, is_master, is_slave
from mock import labels, logits
from metric import Accuracy

dist.init_process_group(backend="gloo")
init_logger()

labels = labels[dist.get_rank()::dist.get_world_size()]
logits = logits[dist.get_rank()::dist.get_world_size()]

metric = Accuracy()
metric.update(labels, logits)

result = metric.compute()

if is_master():
    logging.info(f"Accuracy is: {result}")

In [None]:
!echo > /tmp/pytorch-dist.log; torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log

----

现在我们用单进程的方式看看结果是否相同。

In [None]:
%%file /tmp/test.py
import torch
from mock import labels, logits
from metric import Accuracy


metric = Accuracy()
metric.update(labels, logits)
result = metric.compute()

print(f"Accuracy is: {result}")

In [None]:
!python /tmp/test.py

## 示例二：支持分布式的 COCO mAP

这是个相对复杂的例子，由于 COCO 的 mAP 计算方法不是按单个样本计算得分然后聚合结果的，因此没办法用到 `reduce` 操作。在这个场景中，我们使用 `gather` 操作来获取一些「中间结果」，然后最后在主进程中通过「中间结果」计算最后的得分。

一个值得注意的点是，我们应该尽可能得将计算放在各个节点上，而不是仅仅 gather 一下原始数据，然后计算全部交给主进程。因此在实现过程中，需要仔细分析「中间结果」到底是什么。

In [None]:
clear_log()

In [None]:
%%file /tmp/metric.py

import torch
import itertools
from collections import defaultdict
from torchvision.ops import box_iou
import torch.distributed as dist
from time import time
from torch.profiler import profile, record_function, ProfilerActivity
import pickle


class MeanAveragePrecision():

    def __init__(self):
        self.iou_thresholds = [round(iou.item(), 2) for iou in torch.arange(0.5, 0.99, 0.05)]
        self.rec_thresholds = torch.linspace(0, 1, 101)
        self.reset()

    def reset(self) -> None:
        self._cm: Dict[float, Dict[float, Dict[str, list]]] = {}

    def update(self, y, y_pred) -> None:
        """
        Args:
            y(torch.Tensor): 真值框，shape 为 (N, 5)，N 为该样本框的个数，5 的含义为 (x1, x2, y1, y2, class_number)
            y_pred(torch.Tensor): 预测框，shape 为 (M, 6)，M 为预测框的个数，6 的含义为 (x1, x2, y1, y2, confidence, class_number)
        """
        iou = box_iou(y_pred[:, :4], y[:, :4])
        categories = torch.cat((y[:, 4], y_pred[:, 5])).unique().tolist()
        for category in categories:
            if category not in self._cm:
                self._cm[category] = {iou: {"tp": [], "fp": [], "gt": [], "score": []} for iou in self.iou_thresholds}

        for iou_thres_item in self.iou_thresholds:
            valid_iou = torch.clone(iou)
            valid_iou[iou <= iou_thres_item] = 0
            for category in categories:
                class_index_gt = y[:, 4] == category
                class_index_dt = y_pred[:, 5] == category
                class_iou = valid_iou[:, class_index_gt][class_index_dt, :]

                if class_iou.shape[1] == 0:
                    # no ground truth of the category
                    n_gt = 0
                    tp = torch.tensor([False] * class_iou.shape[0])
                    fp = torch.tensor([True] * class_iou.shape[0])
                    score = y_pred[class_index_dt, 4]
                elif class_iou.shape[0] == 0:
                    # no predictions of the category
                    n_gt = class_iou.shape[1]
                    tp = torch.tensor([]).bool()
                    fp = torch.tensor([]).bool()
                    score = torch.tensor([])
                else:
                    class_iou[~(class_iou == class_iou.max(dim=0)[0])] = 0
                    class_iou[~(class_iou.T == class_iou.max(dim=1)[0]).T] = 0

                    n_gt = class_iou.shape[1]
                    tp = (class_iou != 0).any(dim=1)
                    fp = (class_iou == 0).all(dim=1)
                    score = y_pred[class_index_dt, 4]
                
                self._cm[category][iou_thres_item]["tp"].append(tp)
                self._cm[category][iou_thres_item]["fp"].append(fp)
                self._cm[category][iou_thres_item]["gt"].append(n_gt)
                self._cm[category][iou_thres_item]["score"].append(score)

    def compute(self) -> float:
        if dist.is_initialized():
            # 在分布式环境下，先调用 gather 搜集不同节点的数据
            cms = [None for _ in range(dist.get_world_size())]
            pickle.dump(self._cm, open(f"/home/featurize/{dist.get_rank()}.pkl", "wb"))
            dist.gather_object(self._cm, cms if dist.get_rank() == 0 else None, dst=0)

            if dist.get_rank() != 0:
                return None  # 主节点拿到了其他节点的结果，下面的逻辑只需要在主节点中进行

            # 把 cms merge 到 self._cm
            self.reset()
            for category in set(itertools.chain(*[cm.keys() for cm in cms])):
                if category not in self._cm:
                    self._cm[category] = {iou: {"tp": [], "fp": [], "gt": [], "score": []} for iou in self.iou_thresholds}
                for iou_thres in self.iou_thresholds:
                    for cm in cms:
                        if category not in cm:
                            continue
                        self._cm[category][iou_thres]["fp"].extend(cm[category][iou_thres]["fp"])
                        self._cm[category][iou_thres]["tp"].extend(cm[category][iou_thres]["tp"])
                        self._cm[category][iou_thres]["gt"].extend(cm[category][iou_thres]["gt"])
                        self._cm[category][iou_thres]["score"].extend(cm[category][iou_thres]["score"])

        results = []
        for _, cm in self._cm.items():
            category_pr = torch.ones(len(self.iou_thresholds), len(self.rec_thresholds)) * -1

            for idx, (_, cm_iou) in enumerate(cm.items()):
                n_gt = sum(cm_iou["gt"])
                if n_gt == 0:
                    # no ground truth of the class
                    continue
                scores = torch.cat(cm_iou["score"], dim=0)
                indx = torch.argsort(scores, descending=True)
                
                tp = torch.cat(cm_iou["tp"], dim=0)[indx].cumsum(dim=0)
                fp = torch.cat(cm_iou["fp"], dim=0)[indx].cumsum(dim=0)
                rc = tp / n_gt
                pr = tp / (fp + tp)

                for i in range(len(tp) - 1, 0, -1):
                    if pr[i] > pr[i - 1]:
                        pr[i - 1] = pr[i]

                inds = torch.searchsorted(rc, self.rec_thresholds)
                pr_at_recthres = torch.zeros(len(self.rec_thresholds))
                try:
                    for ri, pi in enumerate(inds):
                        pr_at_recthres[ri] = pr[pi]
                except:
                    pass
                category_pr[idx, :] = pr_at_recthres
            if torch.all(category_pr == -1):
                continue
            category_ap = category_pr[category_pr > -1].mean()
            results.append(category_ap)
        return round(torch.stack(results).mean().item(), 3)


Mock 两个样本：

In [None]:
%%file /tmp/mock.py
import torch

gts = [
    torch.tensor(
        [
            [126, 90, 523, 534, 6],
            [190, 304, 249, 369, 1],
            [435, 338, 451, 362, 1],
            [298, 334, 332, 367, 1],
            [174, 170, 203, 192, 1],
            [297, 160, 322, 180, 1],
            [121, 389, 127, 410, 1],
            [568, 316, 611, 404, 6],
            [91, 388, 104, 422, 1],
            [212, 168, 230, 188, 1],
            [78, 377, 97, 429, 1],
            [101, 397, 114, 429, 1],
            [113, 391, 126, 429, 1],
            [502, 315, 565, 384, 6],
        ]
    ),
    torch.tensor(
        [
            [256, 70, 353, 302, 1],
            [87, 184, 240, 324, 1],
            [87, 71, 153, 140, 1],
            [169, 74, 221, 130, 1],
            [387, 60, 475, 113, 1],
            [215, 76, 262, 129, 1],
            [301, 28, 363, 97, 39],
            [39, 0, 75, 14, 1],
            [138, 115, 191, 133, 15],
            [340, 97, 418, 118, 15],
            [213, 230, 241, 263, 40],
            [147, 103, 152, 118, 44],
            [17, 96, 67, 148, 1],
            [49, 80, 102, 142, 1],
            [86, 107, 90, 123, 44],
            [16, 141, 126, 323, 1],
        ]
    ),
]

preds = [
    torch.tensor(
        [
            [124, 90, 522, 534, 1, 6],
            [507, 316, 567, 385, 1, 6],
            [76, 375, 97, 427, 0.99, 1],
            [189, 307, 250, 368, 0.98, 1],
            [113, 383, 128, 430, 0.98, 1],
            [208, 21, 226, 43, 0.92, 18],
            [208, 21, 226, 43, 0.92, 18],
            [208, 21, 226, 43, 0.92, 16],
            [208, 21, 226, 43, 0.92, 16],
            [208, 21, 226, 43, 0.92, 16],
            [208, 21, 226, 43, 0.92, 16],
            [208, 21, 226, 43, 0.92, 16],
            [208, 21, 226, 44, 0.92, 16],
            [208, 21, 226, 43, 0.92, 16],
            [208, 21, 226, 43, 0.92, 16],
            [99, 397, 115, 429, 0.91, 1],
            [176, 166, 201, 192, 0.88, 1],
            [570, 320, 612, 404, 0.86, 6],
            [91, 390, 103, 430, 0.84, 1],
            [397, 178, 411, 198, 0.79, 1],
            [569, 339, 610, 403, 0.65, 3],
            [61, 354, 94, 368, 0.56, 3],
            [177, 177, 198, 192, 0.38, 77],
        ]
    ),
    torch.tensor(
        [
            [262, 71, 350, 297, 1, 1],
            [94, 185, 233, 320, 1, 1],
            [17, 139, 126, 324, 1, 1],
            [214, 231, 241, 263, 1, 40],
            [397, 60, 452, 114, 0.99, 1],
            [167, 74, 221, 133, 0.99, 1],
            [89, 72, 149, 140, 0.99, 1],
            [300, 28, 363, 98, 0.99, 39],
            [212, 72, 262, 129, 0.99, 1],
            [24, 94, 71, 147, 0.98, 1],
            [48, 81, 77, 137, 0.96, 1],
            [285, 91, 303, 119, 0.92, 40],
            [85, 106, 91, 124, 0.89, 44],
            [145, 102, 151, 119, 0.87, 44],
            [29, 0, 78, 14, 0.68, 1],
            [0, 92, 22, 146, 0.61, 1],
            [50, 156, 195, 323, 0.56, 1],
            [13, 0, 56, 15, 0.53, 1],
            [83, 0, 109, 12, 0.50, 1],
            [136, 116, 179, 135, 0.43, 15],
            [261, 80, 299, 131, 0.40, 1],
            [250, 105, 263, 121, 0.31, 40],
            [111, 118, 155, 140, 0.30, 31],
            [12, 87, 54, 146, 0.29, 1],
        ]
    ),
]

------

首先测试单进程的结果：

In [None]:
%%file /tmp/test.py

import torch
from metric import MeanAveragePrecision
from mock import gts, preds

metric = MeanAveragePrecision()
for gt, pred in zip(gts, preds):
    metric.update(gt, pred)
print(f"mAP: {metric.compute()}")

In [None]:
!python /tmp/test.py

-----

测试分布式的结果

In [None]:
%%file /tmp/test.py
import torch
import torch.distributed as dist
import logging
from helper import init_logger, is_master, is_slave
from mock import gts, preds
from metric import MeanAveragePrecision


dist.init_process_group(backend="gloo")
init_logger()

# master 算第一个样本，slave 算第二个样本。
gt = gts[0] if is_master() else gts[1]
pred = preds[0] if is_master() else preds[1]

metric = MeanAveragePrecision()
metric.update(gt, pred)

result = metric.compute()

if is_master():
    logging.info(f"mAP: {result}")

In [None]:
!echo /tmp/pytorch-dist.log && torchrun --nproc_per_node 2 --nnodes 1 /tmp/test.py && cat /tmp/pytorch-dist.log