Skip to content

Commit

Permalink
[Enhanchment] Fix unit test of EG3D Render for torch < 1.8.0 (#1499)
Browse files Browse the repository at this point in the history
adopt torch.nan_to_num for torch < 1.8.0 + mock torch version in unit test to cover more lines
  • Loading branch information
LeoXing1996 committed Dec 5, 2022
1 parent a9bd113 commit fe7ac22
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
7 changes: 6 additions & 1 deletion mmedit/models/editors/eg3d/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION

from ..stylegan3.stylegan3_modules import FullyConnectedLayer
from .eg3d_utils import (get_ray_limits_box, inverse_transform_sampling,
Expand Down Expand Up @@ -431,7 +433,10 @@ def volume_rendering(self, colors: torch.Tensor, densities: torch.Tensor,
composite_depth = torch.sum(weights * depths_mid, -2) / weight_total

# clip the composite to min/max range of depths
composite_depth = torch.nan_to_num(composite_depth, float('inf'))
if digit_version(TORCH_VERSION) < digit_version('1.8.0'):
composite_depth[torch.isnan(composite_depth)] = float('inf')
else:
composite_depth = torch.nan_to_num(composite_depth, float('inf'))
composite_depth = torch.clamp(composite_depth, torch.min(depths),
torch.max(depths))

Expand Down
9 changes: 9 additions & 0 deletions tests/test_models/test_editors/test_eg3d/test_camera.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
from copy import deepcopy
from unittest import TestCase
from unittest.mock import patch

import torch
from mmengine.testing import assert_allclose
Expand Down Expand Up @@ -111,6 +112,14 @@ def test_sample_camera2world(self):
cam2world = camera.sample_camera2world()
self.assertEqual(cam2world.shape, (1, 4, 4))

mock_path = 'mmedit.models.editors.eg3d.camera.TORCH_VERSION'
with patch(mock_path, '1.6.0'):
print(torch.__version__)
cfg_ = deepcopy(self.default_cfg)
camera = BaseCamera(**cfg_)
cam2world = camera.sample_camera2world()
self.assertEqual(cam2world.shape, (1, 4, 4))

def test_sample_in_range(self):
cfg_ = deepcopy(self.default_cfg)
cfg_['sampling_strategy'] = 'unknow'
Expand Down
11 changes: 11 additions & 0 deletions tests/test_models/test_editors/test_eg3d/test_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,14 @@ def test_forward(self):
ray_directions,
render_kwargs=render_kwargs)
mock_func.assert_called_once()

# cover TORCH_VERSION < 1.8.0
mock_path = 'mmedit.models.editors.eg3d.renderer.TORCH_VERSION'
with patch(mock_path, '1.6.0'):
cfg_ = deepcopy(self.renderer_cfg)
renderer = EG3DRenderer(**cfg_)
renderer(
plane,
ray_origins,
ray_directions,
render_kwargs=render_kwargs)

0 comments on commit fe7ac22

Please sign in to comment.