Skip to content

Commit

Permalink
[Docs] Fix unused parameters (#1288)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhouzaida committed Aug 3, 2023
1 parent 5c5ec8b commit a54e814
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 1 deletion.
105 changes: 104 additions & 1 deletion docs/en/common_usage/debug_tricks.md
@@ -1,6 +1,6 @@
# Debug Tricks

## Set the Dataset's length
## Set the Dataset's Length

During the process of debugging code, sometimes it is necessary to train for several epochs, such as debugging the validation process or checking whether the checkpoint saving meets expectations. However, if the dataset is too large, it may take a long time to complete one epoch, in which case the length of the dataset can be set. Note that only datasets inherited from [BaseDataset](mmengine.dataset.BaseDataset) support this feature, and the usage of BaseDataset can be found in the [BaseDataset](../advanced_tutorials/basedataset.md).

Expand Down Expand Up @@ -49,3 +49,106 @@ As we can see, the number of iterations has changed to `313`. Compared to before
02/20 14:44:59 - mmengine - INFO - Epoch(train) [1][200/313] lr: 1.0000e-01 eta: 0:23:18 time: 0.0143 data_time: 0.0002 memory: 214 loss: 2.0424
02/20 14:45:01 - mmengine - INFO - Epoch(train) [1][300/313] lr: 1.0000e-01 eta: 0:20:39 time: 0.0143 data_time: 0.0003 memory: 214 loss: 1.814
```

## Find Unused Parameters

When using multiple GPUs training, if model's parameters are involved in forward computation but are not used in producing loss, the program may throw the following error:

```
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by
making sure all `forward` function outputs participate in calculating loss.
```

Let's take the model defined in [5 minutes to get started with MMEngine](../get_started/15_minutes.md) as an example:

```python
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
```

Modify it to:

```python
class MMResNet50(BaseModel):

def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
self.param = nn.Parameter(torch.ones(1))

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
# self.param is involved in the forward computation,
# but y is not involved in the loss calculation
y = self.param + x
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
```

Start training with two GPUs:

```bash
torchrun --nproc-per-node 2 examples/distributed_training.py --launcher pytorch
```

The program will throw the error mentioned above.

This issue can be resolved by setting `find_unused_parameters=True`:

```python
cfg = dict(
model_wrapper_cfg=dict(
type='MMDistributedDataParallel', find_unused_parameters=True)
)
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
launcher=args.launcher,
cfg=cfg,
)
runner.train()
```

Restart training, and you can see that the program trains normally and prints logs.

However, setting `find_unused_parameters=True` will slow down the program, so we want to find these parameters and analyze why they did not participate in the loss calculation.

This can be done by setting `detect_anomalous_params=True` to print the unused parameters.

```python
cfg = dict(
model_wrapper_cfg=dict(
type='MMDistributedDataParallel',
find_unused_parameters=True,
detect_anomalous_params=True),
)
```

Restart training, and you can see that the log prints the parameters not involved in the loss calculation.

```
08/03 15:04:42 - mmengine - ERROR - mmengine/logging/logger.py - print_log - 323 - module.param with shape torch.Size([1]) is not in the computational graph
```

Once these parameters are found, we can analyze why they did not participate in the loss calculation.

```{important}
`find_unused_parameters=True` and `detect_anomalous_params=True` should only be set when debugging.
```
102 changes: 102 additions & 0 deletions docs/zh_cn/common_usage/debug_tricks.md
Expand Up @@ -49,3 +49,105 @@ python tools/train.py configs/resnet/resnet18_8xb16_cifar10.py
02/20 14:44:59 - mmengine - INFO - Epoch(train) [1][200/313] lr: 1.0000e-01 eta: 0:23:18 time: 0.0143 data_time: 0.0002 memory: 214 loss: 2.0424
02/20 14:45:01 - mmengine - INFO - Epoch(train) [1][300/313] lr: 1.0000e-01 eta: 0:20:39 time: 0.0143 data_time: 0.0003 memory: 214 loss: 1.814
```

## 检查不参与 loss 计算的参数

使用多卡训练时,当模型的参数参与了前向计算,但没有参与 loss 的计算,程序会抛出下面的错误:

```
RuntimeError: Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing loss. You can enable unused parameter detection by passing the keyword argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`, and by
making sure all `forward` function outputs participate in calculating loss.
```

我们以[15 分钟上手 MMEngine](../get_started/15_minutes.md) 中定义的模型为例:

```python
class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
```

将其修改为下面的代码:

```python
class MMResNet50(BaseModel):

def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()
self.param = nn.Parameter(torch.ones(1))

def forward(self, imgs, labels, mode):
x = self.resnet(imgs)
# self.param 参与了前向计算,但 y 没有参与 loss 的计算
y = self.param + x
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
```

使用两张卡启动训练

```bash
torchrun --nproc-per-node 2 examples/distributed_training.py --launcher pytorch
```

程序会抛出上面提到的错误。

我们可以通过设置 `find_unused_parameters=True` 来解决这个问题,

```python
cfg = dict(
model_wrapper_cfg=dict(
type='MMDistributedDataParallel', find_unused_parameters=True)
)
runner = Runner(
model=MMResNet50(),
work_dir='./work_dir',
train_dataloader=train_dataloader,
optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
train_cfg=dict(by_epoch=True, max_epochs=2, val_interval=1),
val_dataloader=val_dataloader,
val_cfg=dict(),
val_evaluator=dict(type=Accuracy),
launcher=args.launcher,
cfg=cfg,
)
runner.train()
```

重新启动训练,可以看到程序可以正常训练并打印日志。

但是,设置 `find_unused_parameters=True` 会让程序变慢,因此,我们希望找出这些参数并分析它们没有参与 loss 计算的原因。

可以通过设置 `detect_anomalous_params=True` 来打印未被使用的参数。

```python
cfg = dict(
model_wrapper_cfg=dict(
type='MMDistributedDataParallel',
find_unused_parameters=True,
detect_anomalous_params=True),
)
```

重新启动训练,可以看到日志中打印了未参与 loss 计算的参数。

```
08/03 15:04:42 - mmengine - ERROR - mmengine/logging/logger.py - print_log - 323 - module.param with shape torch.Size([1]) is not in the computational graph
```

在找到这些参数后,我们可以分析为什么这些参数没有参与 loss 的计算。

```{important}
只应在调试时设置 `find_unused_parameters=True` 和 `detect_anomalous_params=True`。
```

0 comments on commit a54e814

Please sign in to comment.