Skip to content

Commit

Permalink
SplatterPhongShader 1: Pull out common Shader functionality into Shad…
Browse files Browse the repository at this point in the history
…erBase

Summary: Most of the shaders copypaste exactly the same code into `__init__` and `to`. I will be adding a new shader in the next diff, so let's make it a bit easier.

Reviewed By: bottler

Differential Revision: D35767884

fbshipit-source-id: 0057e3e2ae3be4eaa49ae7e2bf3e4176953dde9d
  • Loading branch information
Krzysztof Chalupka authored and facebook-github-bot committed Apr 27, 2022
1 parent 9f443ed commit 96889de
Showing 1 changed file with 20 additions and 124 deletions.
144 changes: 20 additions & 124 deletions pytorch3d/renderer/mesh/shader.py
Expand Up @@ -32,22 +32,7 @@
# - sample colors from a texture map
# - apply per pixel lighting
# - blend colors across top K faces per pixel.


class HardPhongShader(nn.Module):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function hard assigns
the color of the closest face for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardPhongShader(device=torch.device("cuda:0"))
"""

class ShaderBase(nn.Module):
def __init__(
self,
device: Device = "cpu",
Expand All @@ -74,6 +59,21 @@ def to(self, device: Device):
self.lights = self.lights.to(device)
return self


class HardPhongShader(ShaderBase):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function hard assigns
the color of the closest face for each pixel.
To use the default values, simply initialize the shader with the desired
device e.g.
.. code-block::
shader = HardPhongShader(device=torch.device("cuda:0"))
"""

def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
Expand All @@ -97,7 +97,7 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
return images


class SoftPhongShader(nn.Module):
class SoftPhongShader(ShaderBase):
"""
Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function returns the
Expand All @@ -111,32 +111,6 @@ class SoftPhongShader(nn.Module):
shader = SoftPhongShader(device=torch.device("cuda:0"))
"""

def __init__(
self,
device: Device = "cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()

# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self

def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
Expand Down Expand Up @@ -164,7 +138,7 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
return images


class HardGouraudShader(nn.Module):
class HardGouraudShader(ShaderBase):
"""
Per vertex lighting - the lighting model is applied to the vertex colors and
the colors are then interpolated using the barycentric coordinates to
Expand All @@ -179,32 +153,6 @@ class HardGouraudShader(nn.Module):
shader = HardGouraudShader(device=torch.device("cuda:0"))
"""

def __init__(
self,
device: Device = "cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()

# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self

def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
Expand All @@ -231,7 +179,7 @@ def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tenso
return images


class SoftGouraudShader(nn.Module):
class SoftGouraudShader(ShaderBase):
"""
Per vertex lighting - the lighting model is applied to the vertex colors and
the colors are then interpolated using the barycentric coordinates to
Expand All @@ -246,32 +194,6 @@ class SoftGouraudShader(nn.Module):
shader = SoftGouraudShader(device=torch.device("cuda:0"))
"""

def __init__(
self,
device: Device = "cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()

# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self

def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
Expand Down Expand Up @@ -320,7 +242,7 @@ def TexturedSoftPhongShader(
)


class HardFlatShader(nn.Module):
class HardFlatShader(ShaderBase):
"""
Per face lighting - the lighting model is applied using the average face
position and the face normal. The blending function hard assigns
Expand All @@ -334,32 +256,6 @@ class HardFlatShader(nn.Module):
shader = HardFlatShader(device=torch.device("cuda:0"))
"""

def __init__(
self,
device: Device = "cpu",
cameras: Optional[TensorProperties] = None,
lights: Optional[TensorProperties] = None,
materials: Optional[Materials] = None,
blend_params: Optional[BlendParams] = None,
) -> None:
super().__init__()
self.lights = lights if lights is not None else PointLights(device=device)
self.materials = (
materials if materials is not None else Materials(device=device)
)
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()

# pyre-fixme[14]: `to` overrides method defined in `Module` inconsistently.
def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module
cameras = self.cameras
if cameras is not None:
self.cameras = cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
return self

def forward(self, fragments: Fragments, meshes: Meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
Expand Down

0 comments on commit 96889de

Please sign in to comment.