From 81de7acff979ba10d2643502222e15b75ef30e8a Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Fri, 30 Dec 2022 15:43:35 +0800 Subject: [PATCH] first commit for dreamfusion --- mmagic/datasets/__init__.py | 4 +- mmagic/datasets/dummy_dataset.py | 30 + mmagic/engine/hooks/__init__.py | 3 +- mmagic/engine/hooks/dreamfusion_hook.py | 55 ++ mmagic/models/editors/__init__.py | 4 +- mmagic/models/utils/tensor_utils.py | 16 +- mmedit/models/editors/dreamfusion/__init__.py | 10 + mmedit/models/editors/dreamfusion/activate.py | 22 + mmedit/models/editors/dreamfusion/camera.py | 438 +++++++++++++++ .../models/editors/dreamfusion/dreamfusion.py | 271 +++++++++ mmedit/models/editors/dreamfusion/renderer.py | 512 ++++++++++++++++++ .../dreamfusion/stable_diffusion_wrapper.py | 122 +++++ .../editors/dreamfusion/vanilla_nerf.py | 309 +++++++++++ 13 files changed, 1790 insertions(+), 6 deletions(-) create mode 100644 mmagic/datasets/dummy_dataset.py create mode 100644 mmagic/engine/hooks/dreamfusion_hook.py create mode 100644 mmedit/models/editors/dreamfusion/__init__.py create mode 100644 mmedit/models/editors/dreamfusion/activate.py create mode 100644 mmedit/models/editors/dreamfusion/camera.py create mode 100644 mmedit/models/editors/dreamfusion/dreamfusion.py create mode 100644 mmedit/models/editors/dreamfusion/renderer.py create mode 100644 mmedit/models/editors/dreamfusion/stable_diffusion_wrapper.py create mode 100644 mmedit/models/editors/dreamfusion/vanilla_nerf.py diff --git a/mmagic/datasets/__init__.py b/mmagic/datasets/__init__.py index 240c68b2e5..cfb8070588 100644 --- a/mmagic/datasets/__init__.py +++ b/mmagic/datasets/__init__.py @@ -6,6 +6,7 @@ from .comp1k_dataset import AdobeComp1kDataset from .controlnet_dataset import ControlNetDataset from .dreambooth_dataset import DreamBoothDataset +from .dummy_dataset import DummyDataset from .grow_scale_image_dataset import GrowScaleImgDataset from .imagenet_dataset import ImageNet from .mscoco_dataset import MSCoCoDataset @@ -19,5 +20,6 @@ 'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset', 'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset', 'MSCoCoDataset', 'ControlNetDataset', 'DreamBoothDataset', - 'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset' + 'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset', + 'DummyDataset' ] diff --git a/mmagic/datasets/dummy_dataset.py b/mmagic/datasets/dummy_dataset.py new file mode 100644 index 0000000000..e8e874b4dc --- /dev/null +++ b/mmagic/datasets/dummy_dataset.py @@ -0,0 +1,30 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy + +from torch.utils.data import Dataset + +from mmedit.registry import DATASETS + + +@DATASETS.register_module() +class DummyDataset(Dataset): + + def __init__(self, max_length=100, batch_size=None, sample_kwargs=None): + super().__init__() + self.max_length = max_length + self.sample_kwargs = sample_kwargs + self.batch_size = batch_size + + def __len__(self): + return self.max_length + + def __getitem__(self, index): + data_dict = dict() + input_dict = dict() + if self.batch_size is not None: + input_dict['num_batches'] = self.batch_size + if self.sample_kwargs is not None: + input_dict['sample_kwargs'] = deepcopy(self.sample_kwargs) + + data_dict['inputs'] = input_dict + return data_dict diff --git a/mmagic/engine/hooks/__init__.py b/mmagic/engine/hooks/__init__.py index 8435afa9a5..7c47f0e9cb 100644 --- a/mmagic/engine/hooks/__init__.py +++ b/mmagic/engine/hooks/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .dreamfusion_hook import DreamFusionTrainingHook from .ema import ExponentialMovingAverageHook from .iter_time_hook import IterTimerHook from .pggan_fetch_data_hook import PGGANFetchDataHook @@ -9,5 +10,5 @@ __all__ = [ 'ReduceLRSchedulerHook', 'BasicVisualizationHook', 'VisualizationHook', 'ExponentialMovingAverageHook', 'IterTimerHook', 'PGGANFetchDataHook', - 'PickleDataHook' + 'PickleDataHook', 'DreamFusionTrainingHook' ] diff --git a/mmagic/engine/hooks/dreamfusion_hook.py b/mmagic/engine/hooks/dreamfusion_hook.py new file mode 100644 index 0000000000..b5e9145184 --- /dev/null +++ b/mmagic/engine/hooks/dreamfusion_hook.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import random + +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper + +from mmedit.registry import HOOKS + + +@HOOKS.register_module() +class DreamFusionTrainingHook(Hook): + + def __init__(self, albedo_iters: int): + super().__init__() + self.albedo_iters = albedo_iters + + self.shading_test = 'albedo' + self.ambident_ratio_test = 1.0 + + def set_shading_and_ambient(self, runner, shading: str, + ambient_ratio: str) -> None: + model = runner.model + if is_model_wrapper(model): + model = model.module + renderer = model.renderer + if is_model_wrapper(renderer): + renderer = renderer.module + renderer.set_shading(shading) + renderer.set_ambient_ratio(ambient_ratio) + + def after_train_iter(self, runner, batch_idx: int, *args, + **kwargs) -> None: + if batch_idx < self.albedo_iters or self.albedo_iters == -1: + shading = 'albedo' + ambient_ratio = 1.0 + else: + rand = random.random() + if rand > 0.8: # NOTE: this should be 0.75 in paper + shading = 'albedo' + ambient_ratio = 1.0 + elif rand > 0.4: # NOTE: this should be 0.75 * 0.5 = 0.325 + shading = 'textureless' + ambient_ratio = 0.1 + else: + shading = 'lambertian' + ambient_ratio = 0.1 + self.set_shading_and_ambient(runner, shading, ambient_ratio) + + def before_test(self, runner) -> None: + self.set_shading_and_ambient(runner, self.shading_test, + self.ambident_ratio_test) + + def before_val(self, runner) -> None: + self.set_shading_and_ambient(runner, self.shading_test, + self.ambident_ratio_test) diff --git a/mmagic/models/editors/__init__.py b/mmagic/models/editors/__init__.py index 95499b9d53..1c7589c96b 100644 --- a/mmagic/models/editors/__init__.py +++ b/mmagic/models/editors/__init__.py @@ -18,6 +18,7 @@ from .dim import DIM from .disco_diffusion import ClipWrapper, DiscoDiffusion from .dreambooth import DreamBooth +from .dreamfusion import DreamFusion from .edsr import EDSRNet from .edvr import EDVR, EDVRNet from .eg3d import EG3D @@ -86,8 +87,7 @@ 'StyleGAN1', 'StyleGAN2', 'StyleGAN3', 'BigGAN', 'DCGAN', 'ProgressiveGrowingGAN', 'SinGAN', 'AblatedDiffusionModel', 'DiscoDiffusion', 'IDLossModel', 'PESinGAN', 'MSPIEStyleGAN2', - 'StyleGAN3Generator', 'InstColorization', 'NAFBaseline', 'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal', 'DenoisingUnet', 'ClipWrapper', 'EG3D', 'Restormer', 'SwinIRNet', 'StableDiffusion', - 'ControlStableDiffusion', 'DreamBooth', 'TextualInversion' + 'ControlStableDiffusion', 'DreamBooth', 'TextualInversion', 'DreamFusion' ] diff --git a/mmagic/models/utils/tensor_utils.py b/mmagic/models/utils/tensor_utils.py index c50c418092..b7595195ad 100644 --- a/mmagic/models/utils/tensor_utils.py +++ b/mmagic/models/utils/tensor_utils.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch @@ -38,14 +40,24 @@ def get_unknown_tensor(trimap, unknown_value=128 / 255): return weight -def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: +def normalize_vecs(vectors: torch.Tensor, + clamp_eps: Optional[float] = None) -> torch.Tensor: """Normalize vector with it's lengths at the last dimension. If `vector` is two-dimension tensor, this function is same as L2 normalization. Args: vector (torch.Tensor): Vectors to be normalized. + eps (float, optional): If passed, the min value will be clamped to + this value before calculate the square root. Defaults to None. Returns: torch.Tensor: Vectors after normalization. """ - return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + if clamp_eps is None: + return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) + assert clamp_eps >= 0, ( + f'\'clamp_eps\' must not less than 0. But receive \'{clamp_eps}\'.') + + return vectors / torch.sqrt( + torch.clamp( + torch.sum(vectors * vectors, -1, keepdim=True), min=clamp_eps)) diff --git a/mmedit/models/editors/dreamfusion/__init__.py b/mmedit/models/editors/dreamfusion/__init__.py new file mode 100644 index 0000000000..8e205064df --- /dev/null +++ b/mmedit/models/editors/dreamfusion/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .camera import DreamFusionCamera +from .dreamfusion import DreamFusion +from .renderer import DreamFusionRenderer +from .stable_diffusion_wrapper import StableDiffusionWrapper + +__all__ = [ + 'DreamFusion', 'DreamFusionRenderer', 'DreamFusionCamera', + 'StableDiffusionWrapper' +] diff --git a/mmedit/models/editors/dreamfusion/activate.py b/mmedit/models/editors/dreamfusion/activate.py new file mode 100644 index 0000000000..547f8a2e8b --- /dev/null +++ b/mmedit/models/editors/dreamfusion/activate.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from torch.autograd import Function +from torch.cuda.amp import custom_bwd, custom_fwd + + +class _trunc_exp(Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float) + def forward(ctx, x): + ctx.save_for_backward(x) + return torch.exp(x) + + @staticmethod + @custom_bwd + def backward(ctx, g): + x = ctx.saved_tensors[0] + return g * torch.exp(x.clamp(max=15)) + + +trunc_exp = _trunc_exp.apply diff --git a/mmedit/models/editors/dreamfusion/camera.py b/mmedit/models/editors/dreamfusion/camera.py new file mode 100644 index 0000000000..1a4edf4ce9 --- /dev/null +++ b/mmedit/models/editors/dreamfusion/camera.py @@ -0,0 +1,438 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import math +import random +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION + +from mmedit.models.utils import normalize_vecs +from mmedit.registry import MODULES + +DeviceType = Optional[Union[str, int]] +VectorType = Optional[Union[list, torch.Tensor]] + + +def get_view_direction(thetas, phis, overhead, front): + # NOTE: thetas and phis is inverse with ours + # phis [B,]; thetas: [B,] + # front = 0 [0, front) + # side (left) = 1 [front, 180) + # back = 2 [180, 180+front) + # side (right) = 3 [180+front, 360) + # top = 4 [0, overhead] + # bottom = 5 [180-overhead, 180] + res = torch.zeros(thetas.shape[0], dtype=torch.long) + # first determine by phis + res[(phis < front)] = 0 + res[(phis >= front) & (phis < np.pi)] = 1 + res[(phis >= np.pi) & (phis < (np.pi + front))] = 2 + res[(phis >= (np.pi + front))] = 3 + # override by thetas + res[thetas <= overhead] = 4 + res[thetas >= (np.pi - overhead)] = 5 + return res + + +@MODULES.register_module() +class DreamFusionCamera(object): + + def __init__(self, + horizontal_mean, + vertical_mean, + horizontal_std, + vertical_std, + fov_mean, + fov_std, + radius_mean, + radius_std, + uniform_sphere_rate=0.5, + jitter_pose=False): + + self.horizontal_mean = horizontal_mean + self.vertical_mean = vertical_mean + self.horizontal_std = horizontal_std + self.vertical_std = vertical_std + self.look_at = torch.FloatTensor([0, 0, 0]) + self.up = torch.FloatTensor([0, 1, 0]) + self.radius_mean = radius_mean + self.radius_std = radius_std + self.fov_mean = fov_mean + self.fov_std = fov_std + + self.uniform_sphere_rate = uniform_sphere_rate + self.jitter_pose = jitter_pose + + def _sample_in_range(self, mean: float, std: float, batch_size: int, + sampling_statregy) -> torch.Tensor: + """Sample value with specific mean and std. + + Args: + mean (float): Mean of the sampled value. + std (float): Standard deviation of the sampled value. + batch_size (int): The batch size of the sampled result. + + Returns: + torch.Tensor: Sampled results. + """ + if sampling_statregy.upper() == 'UNIFORM': + return (torch.rand((batch_size, 1)) - 0.5) * 2 * std + mean + elif sampling_statregy.upper() == 'GAUSSIAN': + return torch.randn((batch_size, 1)) * std + mean + else: + raise ValueError( + 'Only support \'Uniform\' sampling and \'Gaussian\' sampling ' + 'currently. If you want to implement your own sampling ' + 'method, you can overwrite \'_sample_in_range\' function by ' + 'yourself.') + + def sample_intrinsic( + self, + # fov: Optional[float] = None, + fov_mean=None, + fov_std=None, + focal: Optional[float] = None, + device: Optional[DeviceType] = None, + batch_size: int = 1) -> torch.Tensor: + """Sample intrinsic matrix. + + Args: + fov (Optional[float], optional): FOV (field of view) in degree. If + not passed, :attr:`self.fov` will be used. Defaults to None. + focal (Optional[float], optional): Focal in pixel. If + not passed, :attr:`self.focal` will be used. Defaults to None. + batch_size (int): The batch size of the output. Defaults to 1. + device (DeviceType, optional): Device to put the intrinstic + matrix. If not passed, :attr:`self.device` will be used. + Defaults to None. + + Returns: + torch.Tensor: Intrinsic matrix. + """ + fov_mean = self.fov_mean if fov_mean is None else fov_mean + fov_std = self.fov_std if fov_std is None else fov_std + assert not ((fov_mean is None) ^ (fov_std is None)) + + if fov_mean is not None and fov_std is not None: + fov = self._sample_in_range(fov_mean, fov_std, batch_size, + 'uniform') + else: + fov = None + + # 1. check if foc and focal is both passed + assert (fov is None) or (focal is None), ( + '\'fov\' and focal should not be passed at the same time.') + # 2. if fov and focal is neither not passed, use initialized ones. + if fov is None and focal is None: + # do not use self.fov since fov is not defined + # fov = self.fov if fov is None else fov + focal = self.focal if focal is None else focal + + if fov is None and focal is None: + raise ValueError( + '\'fov\', \'focal\', \'self.fov\' and \'self.focal\' should ' + 'not be None neither.') + + if fov is not None: + intrinstic = self.fov_to_intrinsic(fov, device) + else: + intrinstic = self.focal_to_instrinsic(focal, device) + return intrinstic[None, ...].repeat(batch_size, 1, 1) + + def focal_to_instrinsic(self, + focal: Optional[float] = None, + device: DeviceType = None) -> torch.Tensor: + """Calculate intrinsic matrix from focal. + + Args: + focal (Optional[float], optional): Focal in degree. If + not passed, :attr:`self.focal` will be used. Defaults to None. + device (DeviceType, optional): Device to put the intrinsic + matrix. If not passed, :attr:`self.device` will be used. + Defaults to None. + + Returns: + torch.Tensor: Intrinsic matrix. + """ + focal = self.focal if focal is None else focal + assert focal is not None, ( + '\'focal\' and \'self.focal\' should not be None at the ' + 'same time.') + # device = self.device if device is None else device + # intrinsics = [[focal, 0, self.center_x], [0, focal, self.center_y], + # [0, 0, 1]] + intrinsics = [[focal, 0, 0.5], [0, focal, 0.5], [0, 0, 1]] + intrinsics = torch.tensor(intrinsics, device=device) + return intrinsics + + def fov_to_intrinsic(self, + fov: Optional[float] = None, + device: DeviceType = None) -> torch.Tensor: + """Calculate intrinsic matrix from FOV (field of view). + + Args: + fov (Optional[float], optional): FOV (field of view) in degree. If + not passed, :attr:`self.fov` will be used. Defaults to None. + device (DeviceType, optional): Device to put the intrinstic + matrix. If not passed, :attr:`self.device` will be used. + Defaults to None. + + Returns: + torch.Tensor: Intrinsic matrix. + """ + fov = self.fov if fov is None else fov + assert fov is not None, ( + '\'fov\' and \'self.fov\' should not be None at the same time.') + # focal = float(self.H / (math.tan(fov * math.pi / 360))) + focal = float(1 / (math.tan(fov * math.pi / 360))) + intrinsics = [[focal, 0, 0.5], [0, focal, 0.5], [0, 0, 1]] + intrinsics = torch.tensor(intrinsics, device=device) + return intrinsics + + def sample_theta(self, mean: float, std: float, + batch_size: int) -> torch.Tensor: + """Sampling the theta (yaw). + + Args: + mean (float): Mean of theta. + std (float): Standard deviation of theta. + batch_size (int): Target batch size of theta. + + Returns: + torch.Tensor: Sampled theta. + """ + h = self._sample_in_range(mean, std, batch_size, 'uniform') + return h + + def sample_phi(self, mean: float, std: float, + batch_size: int) -> torch.Tensor: + """Sampling the phi (pitch). Unlike sampling theta, we uniformly sample + phi on cosine space to release a spherical uniform sampling. + + Args: + mean (float): Mean of phi. + std (float): Standard deviation of phi. + batch_size (int): Target batch size of phi. + + Returns: + torch.Tensor: Sampled phi. + """ + v = self._sample_in_range(mean, std, batch_size, 'uniform') + + # return a angular uniform sampling with `1-self.uniform_sphere_rate` + if random.random() < (1 - self.uniform_sphere_rate): + return v + + v = torch.clamp(v, 1e-5, math.pi - 1e-5) + + v = v / math.pi + if digit_version(TORCH_VERSION) <= digit_version('1.6.0'): + import numpy as np + phi = torch.from_numpy(np.arccos((1 - 2 * v).numpy())) + else: + phi = torch.arccos(1 - 2 * v) + return phi + + def sample_camera2world(self, + h_mean: Optional[float] = None, + v_mean: Optional[float] = None, + h_std: Optional[float] = None, + v_std: Optional[float] = None, + look_at: VectorType = None, + up: VectorType = None, + r_mean=None, + r_std=None, + batch_size: int = 1, + device: Optional[str] = None, + return_pose=False): + + # parse input + h_mean = self.horizontal_mean if h_mean is None else h_mean + v_mean = self.vertical_mean if v_mean is None else v_mean + h_std = self.horizontal_std if h_std is None else h_std + v_std = self.vertical_std if v_std is None else v_std + r_mean = self.radius_mean if r_mean is None else r_mean + r_std = self.radius_std if r_std is None else r_std + + look_at = self.look_at if look_at is None else look_at + if not isinstance(look_at, torch.FloatTensor): + look_at = torch.FloatTensor(look_at) + look_at = look_at.to(device) + up = self.up if up is None else up + if not isinstance(up, torch.FloatTensor): + up = torch.FloatTensor(up) + up = up.to(device) + + radius = self._sample_in_range(r_mean, r_std, batch_size, + 'Uniform').to(device) + + theta = self.sample_theta(h_mean, h_std, batch_size).to(device) + phi = self.sample_phi(v_mean, v_std, batch_size).to(device) + # construct camera origin + camera_origins = torch.zeros((batch_size, 3), device=device) + + camera_origins[:, 0:1] = radius * torch.sin(phi) * torch.cos(math.pi - + theta) + camera_origins[:, 2:3] = radius * torch.sin(phi) * torch.sin(math.pi - + theta) + camera_origins[:, 1:2] = radius * torch.cos(phi) + + # add noise jitter to camera origins, look_at and up + if self.jitter_pose: + camera_origins = camera_origins + ( + torch.rand_like(camera_origins) * 0.2 - 0.1) + look_at = look_at + torch.randn_like(camera_origins) * 0.2 + + # calculate forward vector and camer2world + forward_vectors = normalize_vecs(look_at - camera_origins) + camera2world = create_cam2world_matrix(forward_vectors, camera_origins, + up) + + if return_pose: + # NOTE: phi shape like [bz, 1], squeeze manually + if phi.ndim == 2: + phi = phi[:, 0] + if theta.ndim == 2: + theta = theta[:, 0] + pose_index = get_view_direction( + phi, theta, overhead=3.141 / 6, front=3.141 / 3) + return camera2world, pose_index + return camera2world + + def interpolation(self, + num_frames: int, + batch_size: int, + device='cuda') -> Tuple[list, list, list]: + """Interpolation camera-to-world matrix in theta. + + Args: + num_frames (int): _description_ + batch_size (int): _description_ + device (str, optional): _description_. Defaults to 'cuda'. + + Returns: + Tuple[list, list, list]: _description_ + """ + # circle pose from sd-dreamfusion + tmp_flag = self.uniform_sphere_rate + self.uniform_sphere_rate = -1 + + cam2world_list, pose_list, intrinsic_list = [], [], [] + # intrinsic are same, across interpolation + intrinsic = self.sample_intrinsic( + fov_std=0, batch_size=batch_size, device=device) + for idx in range(num_frames): + # NOTE: >>> follow sd-dreamfusion + theta = (idx / num_frames) * 2 * 3.141 + phi = 3.141 / 3 # 60 degree + radius = (self.radius_mean + self.radius_std) * 1.2 + # NOTE: <<< follow sd-dreamfusion + + cam2world, pose = self.sample_camera2world( + h_mean=theta, + h_std=0, + v_mean=phi, + v_std=0, + r_mean=radius * 1.2, + r_std=0, + batch_size=batch_size, + return_pose=True, + device=device) + + pose_list.append(pose) + cam2world_list.append(cam2world) + intrinsic_list.append(intrinsic) + + self.uniform_sphere_rate = tmp_flag + + return cam2world_list, pose_list, intrinsic_list + + +def create_cam2world_matrix(forward_vector: torch.Tensor, origin: torch.Tensor, + up: torch.Tensor) -> torch.Tensor: + """Calculate camera-to-world matrix from camera's forward vector, world + origin and world up direction. The calculation is performed in right-hand + coordinate system and the returned matrix is in homogeneous coordinates + (shape like (bz, 4, 4)). + + Args: + forward_vector (torch.Tensor): The forward vector of the camera. + origin (torch.Tensor): The origin of the world coordinate. + up (torch.Tensor): The up direction of the world coordinate. + + Returns: + torch.Tensor: Camera-to-world matrix. + """ + + forward_vector = normalize_vecs(forward_vector) + up_vector = up.type(torch.float).expand_as(forward_vector) + right_vector = -normalize_vecs( + torch.cross(up_vector, forward_vector, dim=-1)) + up_vector = normalize_vecs( + torch.cross(forward_vector, right_vector, dim=-1)) + + rotation_matrix = torch.eye( + 4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], + 1, 1) + rotation_matrix[:, :3, :3] = torch.stack( + (right_vector, up_vector, forward_vector), axis=-1) + + translation_matrix = torch.eye( + 4, device=origin.device).unsqueeze(0).repeat(forward_vector.shape[0], + 1, 1) + translation_matrix[:, :3, 3] = origin + cam2world = (translation_matrix @ rotation_matrix)[:, :, :] + assert (cam2world.shape[1:] == (4, 4)) + return cam2world + + +def circle_poses(device, + radius=1.25, + theta=60, + phi=0, + return_dirs=False, + angle_overhead=30, + angle_front=60): + import numpy as np + theta = np.deg2rad(theta) + phi = np.deg2rad(phi) + angle_overhead = np.deg2rad(angle_overhead) + angle_front = np.deg2rad(angle_front) + + thetas = torch.FloatTensor([theta]).to(device) + phis = torch.FloatTensor([phi]).to(device) + + centers = torch.stack([ + radius * torch.sin(thetas) * torch.sin(phis), + radius * torch.cos(thetas), + radius * torch.sin(thetas) * torch.cos(phis), + ], + dim=-1) # [B, 3] + + # lookat + forward_vector = -safe_normalize(centers) + up_vector = torch.FloatTensor([0, -1, 0]).to(device).unsqueeze(0) + right_vector = safe_normalize( + torch.cross(forward_vector, up_vector, dim=-1)) + up_vector = safe_normalize( + torch.cross(right_vector, forward_vector, dim=-1)) + + poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0) + poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), + dim=-1) + poses[:, :3, 3] = centers + + if return_dirs: + dirs = get_view_direction(thetas, phis, angle_overhead, angle_front) + else: + dirs = None + + return poses, dirs + + +# TODO: replace with ours later +def safe_normalize(x, eps=1e-20): + return x / torch.sqrt( + torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) diff --git a/mmedit/models/editors/dreamfusion/dreamfusion.py b/mmedit/models/editors/dreamfusion/dreamfusion.py new file mode 100644 index 0000000000..692b1f17da --- /dev/null +++ b/mmedit/models/editors/dreamfusion/dreamfusion.py @@ -0,0 +1,271 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy +from typing import Optional, Tuple + +import torch +from mmengine.model import BaseModel +from mmengine.utils import ProgressBar +from torch import Tensor + +from mmedit.models.editors.eg3d.ray_sampler import sample_rays +from mmedit.registry import MODELS, MODULES +from mmedit.structures import EditDataSample, PixelData + + +@MODELS.register_module() +class DreamFusion(BaseModel): + + def __init__(self, + diffusion, + renderer, + camera, + resolution, + data_preprocessor, + test_resolution=None, + text: Optional[str] = None, + negative='', + dir_text=True, + suppress_face=False, + guidance_scale=100, + loss_config=dict()): + super().__init__(data_preprocessor) + # NOTE: dreamfusion do not need data preprocessor + self.diffusion = MODULES.build(diffusion) + self.renderer = MODULES.build(renderer) + self.camera = MODULES.build(camera) + + self.guidance_scale = guidance_scale + self.text = text + self.negative = negative + self.suppress_face = suppress_face + + self.dir_text = dir_text + + # >>> loss configs + self.loss_config = deepcopy(loss_config) + self.weight_entropy = loss_config.get('weight_entropy', 1e-4) + self.weight_opacity = loss_config.get('weight_opacity', 0) + self.weight_orient = loss_config.get('weight_orient', 1e-2) + self.weight_smooth = loss_config.get('weight_smooth', 0) + + self.resolutoin = resolution + if test_resolution is None: + self.test_resolution = resolution + else: + self.test_resolution = test_resolution + + self.prepare_text_embeddings() + + @property + def device(self) -> torch.device: + """Get current device of the model. + + Returns: + torch.device: The current device of the model. + """ + return next(self.parameters()).device + + def sample_rays_and_pose(self, num_batches): + cam2world_matrix, pose = self.camera.sample_camera2world( + batch_size=num_batches, return_pose=True) + intrinsics = self.camera.sample_intrinsic(batch_size=num_batches) + rays_o, rays_d = sample_rays( + cam2world_matrix, intrinsics, resolution=self.resolutoin) + return rays_o, rays_d, pose + + def prepare_text_embeddings(self): + assert self.text is not None, ('\'text\' must be defined in configs ' + 'or passed by command line args.') + + if not self.dir_text: + self.text_z = self.diffusion.get_text_embeds([self.text], + [self.negative]) + else: + self.text_z = [] + for d in ['front', 'side', 'back', 'side', 'overhead', 'bottom']: + # construct dir-encoded text + text = f'{self.text}, {d} view' + + negative_text = f'{self.negative}' + + # explicit negative dir-encoded text + if self.suppress_face: + if negative_text != '': + negative_text += ', ' + if d == 'back': + negative_text += 'face' + elif d == 'side': + negative_text += 'face' + elif d == 'overhead': + negative_text += 'face' + elif d == 'bottom': + negative_text += 'face' + + text_z = self.diffusion.get_text_embeds([text], + [negative_text]) + self.text_z.append(text_z) + + def label_fn(self, num_batches: int = 1) -> Tuple[Tensor, Tensor]: + """Label sampling function for DreamFusion model.""" + # sample random conditional from camera + assert self.camera is not None, ( + '\'camera\' is not defined for \'EG3D\'.') + camera2world = self.camera.sample_camera2world(batch_size=num_batches) + intrinsics = self.camera.sample_intrinsic(batch_size=num_batches) + + return camera2world, intrinsics + + def forward(self, inputs, data_samples=None, mode=None): + # TODO: how to design a better sampler, and what should we return + num_batches = inputs['num_batches'][0] + render_kwargs = inputs.get('render_kwargs', dict()) + # TODO: sample a random input, do not support parse input from + # data_samples currently + cam2world, intrinsic = self.label_fn(num_batches) + + rays_o, rays_d = sample_rays( + cam2world, intrinsic, resolution=self.test_resolution) + + # TODO: how can we support other shading mode (e.g., normal)? + rgb, depth, _ = self.batchify_render(rays_o, rays_d, render_kwargs) + B, H, W = 1, self.test_resolution, self.test_resolution + pred_rgb = rgb.reshape(B, H, W, 3) + pred_rgb = pred_rgb.permute(0, 3, 1, 2) + + pred_depth = depth.reshape(B, H, W, 1).permute(0, 3, 1, 2) + pred_depth = torch.cat([pred_depth] * 3, dim=1) + pred_depth = (pred_depth - depth.min()) / (depth.max() - depth.min()) + + output = [ + EditDataSample( + fake_img=PixelData(data=pred_rgb[0]), + depth=PixelData(data=pred_depth[0])) + ] + + return output + + @torch.no_grad() + def interpolation(self, + num_images: int, + num_batches: int = 1, + show_pbar: bool = True): + + assert hasattr(self, 'camera'), ('Camera must be defined.') + assert num_batches == 1, ( + 'DreamFusion only support \'num_batches\' as 1.') + cam2world_list, pose_list, intrinsic_list = self.camera.interpolation( + num_images, num_batches) + + output_list = [] + if show_pbar: + pbar = ProgressBar(num_images) + + for cam2world, pose, intrinsic in zip(cam2world_list, pose_list, + intrinsic_list): + + rays_o, rays_d = sample_rays( + cam2world, intrinsic, resolution=self.test_resolution) + + rgb, depth, weight = self.batchify_render(rays_o, rays_d) + B, H, W = 1, self.test_resolution, self.test_resolution + pred_rgb = rgb.reshape(B, H, W, 3) + pred_rgb = pred_rgb.permute(0, 3, 1, 2) + pred_depth = depth.reshape(B, H, W, 1) + output_list.append(dict(rgb=pred_rgb, depth=pred_depth)) + + if show_pbar: + pbar.update(1) + + if show_pbar: + print('\n') + + return output_list + + def batchify_render(self, rays_o, rays_d, render_kwarge=dict()): + # NOTE: can we implement this function with a decorator? + # If we wrap the renderer, the grad function in train step will not be + # released + + B, N = rays_o.shape[:2] + depth = torch.empty((B, N, 1), device=self.device) + image = torch.empty((B, N, 3), device=self.device) + weights_sum = torch.empty((B, N, 1), device=self.device) + + max_ray_batch = 4096 + for b in range(B): + head = 0 + while head < N: + tail = min(head + max_ray_batch, N) + rgb_, depth_, weight_ = self.renderer( + rays_o[b:b + 1, head:tail], + rays_d[b:b + 1, head:tail], + render_kwarge, + ) + + depth[b:b + 1, head:tail] = depth_ + weights_sum[b:b + 1, head:tail] = weight_ + image[b:b + 1, head:tail] = rgb_ + head += max_ray_batch + + return image, depth, weights_sum + + def train_step(self, data, optim_wrapper): + + # data preprocessor + num_batches = data['inputs']['num_batches'][0] + rays_o, rays_d, pose = self.sample_rays_and_pose( + num_batches=num_batches) + + B = num_batches + # N = self.resolutoin**2 + H = W = self.resolutoin + + # forward nerf + rgb, depth, weight, loss_dict = self.renderer(rays_o, rays_d) + pred_rgb = rgb.reshape(B, H, W, 3).permute(0, 3, 1, 2) + pred_rgb = pred_rgb.contiguous() # [1, 3, H, W] + + # forward diffusion + if self.dir_text: + text_z = self.text_z[pose] + else: + text_z = self.text_z + # encode pred_rgb to latents, + # use train_step to avoid interface conflict + self.diffusion.module.train_step_( + text_z, pred_rgb, guidance_scale=self.guidance_scale) + + pred_ws = weight.reshape(B, 1, H, W) + + # NOTE: not used in stable-dreamfusion + if self.weight_opacity > 0: + loss_opacity = (pred_ws**2).mean() * self.weight_opacity + loss_dict['loss_opacity'] = loss_opacity + + # NOTE: author use this to replace opacity one + if self.weight_entropy > 0: + alphas = (pred_ws).clamp(1e-5, 1 - 1e-5) + # alphas = alphas ** 2 # skewed entropy, favors 0 over 1 + loss_entropy = (-alphas * torch.log2(alphas) - + (1 - alphas) * torch.log2(1 - alphas)).mean() + loss_entropy = loss_entropy * self.weight_entropy + loss_dict['loss_entropy'] = loss_entropy + + if 'loss_orient' in loss_dict: + loss_orient = loss_dict['loss_orient'] * self.weight_orient + loss_dict['loss_orient'] = loss_orient + + if 'loss_smooth' in loss_dict: + loss_smooth = loss_dict['loss_smooth'] * self.weight_smooth + loss_dict['loss_smooth'] = loss_smooth + + loss, log_vars = self.parse_losses(loss_dict) + optim_wrapper['renderer'].update_params(loss) + + return log_vars + + def test_step(self, data): + return self.forward(data) + + def val_step(self, data): + return self.forward(data) diff --git a/mmedit/models/editors/dreamfusion/renderer.py b/mmedit/models/editors/dreamfusion/renderer.py new file mode 100644 index 0000000000..f7da9cad6d --- /dev/null +++ b/mmedit/models/editors/dreamfusion/renderer.py @@ -0,0 +1,512 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from mmengine import print_log +from mmengine.model import BaseModule +from mmengine.utils import digit_version +from mmengine.utils.dl_utils import TORCH_VERSION + +# TODO: move these utils to mmedit.models.utils folder +from mmedit.models.editors.eg3d.eg3d_utils import (get_ray_limits_box, + inverse_transform_sampling, + linspace_batch) +from mmedit.models.utils import normalize_vecs +from mmedit.registry import MODULES +from .vanilla_nerf import NeRFNetwork + + +@MODULES.register_module() +class DreamFusionRenderer(BaseModule): + """Renderer for EG3D. This class samples render points on each input ray + and interpolate the triplane feature corresponding to the points' + coordinates. Then, predict each point's RGB feature and density (sigma) by + a neural network and calculate the RGB feature of each ray by integration. + Different from typical NeRF models, the decoder of EG3DRenderer takes + triplane feature of each points as input instead of positional encoding of + the coordinates. + + Args: + decoder_cfg (dict): The config to build neural renderer. + ray_start (float): The start position of all rays. + ray_end (float): The end position of all rays. + box_warp (float): The side length of the cube spanned by the triplanes. + The box is axis-aligned, centered at the origin. The range of each + axis is `[-box_warp/2, box_warp/2]`. If `box_warp=1.8`, it has + vertices at the range of axis is `[-0.9, 0.9]`. Defaults to 1. + depth_resolution (int): Resolution of depth, as well as the number of + points per ray. Defaults to 64. + depth_resolution_importance (int): Resolution of depth in hierarchical + sampling. Defaults to 64. + clamp_mode (str): The clamp mode for density predicted by nerural + renderer. Defaults to 'softplus'. + white_back (bool): Whether render a white background. Defaults to True. + """ + + def __init__( + self, + # bound, + decoder_cfg: dict, + ray_start: float, + ray_end: float, + # NOTE: bound / 2, set as 2, different from eg3d + box_warp: float = 2, + depth_resolution: int = 64, + depth_resolution_importance: int = 64, + # density_noise: float = 0, # NOTE: no use + clamp_mode: Optional[str] = None, + white_back: bool = True): + super().__init__() + + self.decoder = NeRFNetwork(**decoder_cfg) + + self.ray_start = ray_start + self.ray_end = ray_end + self.box_warp = box_warp + self.depth_resolution = depth_resolution + self.depth_resolution_importance = depth_resolution_importance + + self.clamp_mode = clamp_mode + self.white_back = white_back + + if self.white_back and self.decoder.bg_radius: + print_log( + 'Background network in NeRF decoder will be not used ' + 'since \'white_back\' is set as True.', 'current') + + # init value for shding and ambient_ratio + self.shading = 'albedo' + self.ambient_ratio = 1.0 + + def set_shading(self, shading): + self.shading = shading + + def set_ambient_ratio(self, ambient_ratio): + self.ambient_ratio = ambient_ratio + + def get_value(self, + target: str, + render_kwargs: Optional[dict] = None) -> Any: + """Get value of target field. + + Args: + target (str): The key of the target field. + render_kwargs (Optional[dict], optional): The input key word + arguments dict. Defaults to None. + + Returns: + Any: The default value of target field. + """ + if render_kwargs is None: + return getattr(self, target) + return render_kwargs.get(target, getattr(self, target)) + + def forward( + self, + ray_origins: torch.Tensor, + ray_directions: torch.Tensor, + render_kwargs: Optional[dict] = dict() + ) -> Tuple[torch.Tensor]: + """Render 2D RGB feature, weighed depth and weights with the passed + triplane features and rays. 'weights' denotes `w` in Equation 5 of the + NeRF's paper. + + Args: + planes (torch.Tensor): The triplane features shape like + (bz, 3, TriPlane_feat, TriPlane_res, TriPlane_res). + ray_origins (torch.Tensor): The original of each ray to render, + shape like (bz, NeRF_res * NeRF_res, 3). + ray_directions (torch.Tensor): The direction vector of each ray to + render, shape like (bz, NeRF_res * NeRF_res, 3). + render_kwargs (Optional[dict], optional): The specific kwargs for + rendering. Defaults to None. + + Returns: + Tuple[torch.Tensor]: Renderer RGB feature, weighted depths and + weights. + """ + ray_start = self.get_value('ray_start', render_kwargs) + ray_end = self.get_value('ray_end', render_kwargs) + box_warp = self.get_value('box_warp', render_kwargs) + depth_resolution = self.get_value('depth_resolution', render_kwargs) + depth_resolution_importance = self.get_value( + 'depth_resolution_importance', render_kwargs) + shading = self.get_value('shading', render_kwargs) + ambient_ratio = self.get_value('ambient_ratio', render_kwargs) + + if ray_start == ray_end == 'auto': + ray_start, ray_end = get_ray_limits_box( + ray_origins, ray_directions, box_side_length=box_warp) + is_ray_valid = ray_end > ray_start + if torch.any(is_ray_valid).item(): + ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() + ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() + elif ray_start == ray_end == 'sphere': + radius = ray_origins.norm(dim=-1, keepdim=True) + ray_start = radius - self.box_warp / 2 # [B, N, 1] + ray_end = radius + self.box_warp / 2 + else: + assert (isinstance(ray_start, float) and isinstance( + ray_end, float)), ( + '\'ray_start\' and \'ray_end\' must be both float type or ' + f'both \'auto\'. But receive {ray_start} and {ray_end}.') + assert ray_start < ray_end, ( + '\'ray_start\' must less than \'ray_end\'.') + + # Create stratified depth samples + depths_coarse, depth_delta = self.sample_stratified( + ray_origins, + ray_start, + ray_end, + depth_resolution, + perturb=self.training) + + # get light direction + if 'light_d' in render_kwargs: + light_d = render_kwargs['light_d'] + else: + # select random light direction + light_d = ( + ray_origins[0, 0] + + torch.randn(3, device=ray_start.device, dtype=torch.float)) + light_d = normalize_vecs(light_d, clamp_eps=1e-20) + + batch_size, num_rays, samples_per_ray, _ = depths_coarse.shape + shape_prefix = [batch_size, num_rays, samples_per_ray] + + # Coarse Pass + # [B, Res^2, N_points, 3] + sample_coordinates = ray_origins[ + ..., None, :] + depths_coarse * ray_directions[..., None, :] + + # NOTE: add clip here + sample_coordinates = torch.clamp(sample_coordinates, + -self.box_warp / 2, self.box_warp / 2) + out = self.neural_rendering(sample_coordinates, mode='density') + densities_coarse = out['sigma'].reshape(*shape_prefix, -1) + + # Fine Pass + N_importance = depth_resolution_importance + if N_importance is not None and N_importance > 0: + # update shape prefix + shape_prefix_fine = (batch_size, num_rays, N_importance) + # shape_prefix[-1] = samples_per_ray + N_importance + with torch.no_grad(): + _, _, weights = self.volume_rendering( + None, + densities_coarse, + depths_coarse, + depth_delta, + mode='padding') + depths_fine = self.sample_importance(depths_coarse, weights, + N_importance) + # [B, Res^2, N_points, 3] + sample_coordinates_fine = ray_origins[ + ..., None, :] + depths_fine * ray_directions[..., None, :] + + out = self.neural_rendering( + sample_coordinates_fine, mode='density') + densities_fine = out['sigma'] + densities_fine = densities_fine.reshape(*shape_prefix_fine, -1) + sort_index = self.sort_depth(depths_coarse, depths_fine) + all_depths = self.unify_samples(depths_coarse, depths_fine, + sort_index) + all_densities = self.unify_samples(densities_coarse, + densities_fine, sort_index) + all_coord = self.unify_samples( + sample_coordinates.reshape(*shape_prefix, -1), + sample_coordinates_fine.reshape(*shape_prefix_fine, -1), + sort_index) + all_coord = torch.clamp(all_coord, -self.box_warp / 2, + self.box_warp / 2) + all_ray_directions = ray_directions[:, :, + None, :].expand_as(all_coord) + out_final = self.neural_rendering( + all_coord, light_d, ambient_ratio, shading, mode='full') + all_colors = out_final['color'] + all_colors = all_colors.reshape(batch_size, num_rays, + samples_per_ray + N_importance, -1) + # Aggregate + rgb_final, depth_final, weights = self.volume_rendering( + all_colors, + all_densities, + all_depths, + depth_delta, + mode='padding') + else: + # TODO: bugs here, fix later + all_colors = out['color'] + rgb_final, depth_final, weights = self.volume_rendering( + all_colors, + densities_coarse, + depths_coarse, + depth_delta, + mode='padding') + out_final = out + + weights_sum = weights.sum(2) + if self.white_back: + bg_color = 1 + else: + bg_color = self.decoder.forward_bg(ray_directions.reshape(-1, 3)) + bg_color = bg_color[None, ...] + rgb_final = rgb_final + (1 - weights_sum) * bg_color + + # NOTE: we have to calculate some loss terms in renderer, maybe we + # can find some better way to tackle this + if self.training: + loss_dict = dict() + # orientation loss + normals = out_final.get('normal', None) + if normals is not None: + normals = normals.view(batch_size * num_rays, -1, 3) + # NOTE: just a dirty way to reshape + ray_d = all_ray_directions.view(batch_size * num_rays, -1, 3) + weights_ = weights.view(batch_size * num_rays, -1) + loss_orient = weights_.detach() * ( + normals * ray_d).sum(-1).clamp(min=0)**2 + loss_dict['loss_orient'] = loss_orient.sum(-1).mean() + + # surface normal smoothness + normals_perturb = self.decoder.normal( + all_coord + torch.randn_like(all_coord) * 1e-2).view( + batch_size * num_rays, -1, 3) + loss_smooth = (normals - normals_perturb).abs() + loss_dict['loss_smooth'] = loss_smooth.mean() + + return rgb_final, depth_final, weights.sum(2), loss_dict + + return rgb_final, depth_final, weights.sum(2) + + def sample_stratified(self, ray_origins: torch.Tensor, + ray_start: Union[float, torch.Tensor], + ray_end: Union[float, + torch.Tensor], depth_resolution: int, + perturb: bool) -> torch.Tensor: + """Return depths of approximately uniformly spaced samples along rays. + + Args: + ray_origins (torch.Tensor): The original of each ray, shape like + (bz, NeRF_res * NeRF_res, 3). Only used to provide + device and shape info. + ray_start (Union[float, torch.Tensor]): The start position of rays. + If a float is passed, all rays will have the same start + distance. + ray_end (Union[float, torch.Tensor]): The end position of rays. If + a float is passed, all rays will have the same end distance. + depth_resolution (int): Resolution of depth, as well as the number + of points per ray. + + Returns: + torch.Tensor: The sampled coarse depth shape like + (bz, NeRF_res * NeRF_res, N_depth, 1). If padding is True, the + shape will be (bz, NeRF_res * NeRF_res, N_depth+1, 1) + ---> return padding and depths + """ + N, M, _ = ray_origins.shape + if isinstance(ray_start, torch.Tensor): + # perform linspace for batch of tensor + depths = linspace_batch(ray_start, ray_end, depth_resolution) + depths = depths.permute(1, 2, 0, 3) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + if perturb: + # NOTE: this is different from EG3D + # depths += torch.rand_like(depths) * depth_delta[..., None] + depths += (torch.rand_like(depths) - 0.5) * depth_delta[..., + None] + return depths, depth_delta[..., None] + else: + depths = torch.linspace( + ray_start, + ray_end, + depth_resolution, + device=ray_origins.device) + depths = depths.reshape(1, 1, depth_resolution, 1) + depths = depths.repeat(N, M, 1, 1) + depth_delta = (ray_end - ray_start) / (depth_resolution - 1) + if perturb: + # NOTE: this is different from EG3D + # depths += torch.rand_like(depths) * depth_delta + depths += (torch.rand_like(depths) - 0.5) * depth_delta[..., + None] + return depths, depth_delta * torch.ones(N, M, 1, 1) + + def neural_rendering(self, + sample_coordinates: torch.Tensor, + light_d: Optional[torch.Tensor] = None, + ratio: Optional[float] = None, + shading: str = 'albedo', + mode='density') -> dict: + """Predict RGB features (or albedo) and densities of the coordinates by + neural renderer model. + + Args: + sample_coordinates (torch.Tensor): Coordinates of the sampling + points, shape like (bz, N_depth * NeRF_res * NeRF_res, 1). + light_d (torch.Tensor): The direction vector of light. + ratio (float, optional): The ambident ratio in shading. + mode (str): The forward mode of the neural renderer model. + Supported choices are 'density' and 'full'. Defaults to + 'density'. + + Returns: + dict: A dict contains RGB features ('rgb'), densities ('sigma') + and normal. + """ + xyzs = sample_coordinates.reshape(-1, 3) + if mode == 'density': + out = self.decoder.density(xyzs) + else: + out = self.decoder( + xyzs, light_d, ambient_ratio=ratio, shading=shading) + return out + + def unify_samples(self, target_1: torch.Tensor, target_2: torch.Tensor, + indices: torch.Tensor) -> torch.Tensor: + """Unify two input tensor to one with the passed indice. + + Args: + target_1 (torch.Tensor): The first tensor to unify. + target_2 (torch.Tensor): The second tensor to unify. + indices (torch.Tensor): The index of the element in the first and + the second tensor after unifing. + + Returns: + torch.Tensor: The unified tensor. + """ + all_targets = torch.cat([target_1, target_2], dim=-2) + all_targets = torch.gather( + all_targets, -2, indices.expand(-1, -1, -1, all_targets.shape[-1])) + return all_targets + + def sort_depth(self, depths_c: torch.Tensor, + depths_f: torch.Tensor) -> torch.Tensor: + """Sort the coarse depth and fine depth, and return a indices tensor. + + Returns: + torch.Tensor: The index of the depth in the first and the second + tensor after unifing. + """ + all_depths = torch.cat([depths_c, depths_f], dim=-2) + + _, indices = torch.sort(all_depths, dim=-2) + return indices + + def volume_rendering(self, + colors: torch.Tensor, + densities: torch.Tensor, + depths: torch.Tensor, + depths_delta: Optional[torch.Tensor] = None, + mode='mid') -> Tuple[torch.Tensor]: + """Volume rendering. + + Args: + colors (torch.Tensor): Color feature for each points. Shape like + (bz, N_points, N_depth, N_feature). + densities (torch.Tensor): Density for each points. Shape like + (bz, N_points, N_depth, 1). + depths (torch.Tensor): Depths for each points. Shape like + (bz, N_points, N_depth, 1). + depths_delta (torch.Tensor, optional): The distance between two + points on each ray. Shape like (bz, N_points, 1, 1) + mode (str): The volume rendering mode. Supported choices are + 'padding' and 'mid'. If mode is 'padding', the distance + between the last two render points will be set as + 'depths_delta'. Otherwise, will calculate the color, + depths and density of middle points of original render points, + and then conduct volume rendering upon the middle points. + Defaults to 'mid'. + + Returns: + Tuple[torch.Tensor]: A tuple of color feature + `(bz, N_points, N_feature)`, weighted depth + `(bz, N_points, 1)` and weight + `(bz, N_points, N_depth-1, 1)`. + """ + # NOTE: density and depth must not be None, colors may be None + if mode == 'mid': + deltas = depths[:, :, 1:] - depths[:, :, :-1] + depths_ = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 + densities_ = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 + else: + # NOTE: dreamfusion is different from EG3D ones + assert depths_delta is not None + # depths_ = torch.cat([depths, depths_padding], dim=2) + deltas = depths[:, :, 1:] - depths[:, :, :-1] + deltas = torch.cat([deltas, depths_delta], dim=2) + depths_, densities_ = depths, densities + + # NOTE: do not use clamp for density + if self.clamp_mode == 'softplus': + # activation bias of -1 makes things initialize better + densities_ = F.softplus(densities_ - 1) + else: + assert self.clamp_mode is None, ( + f'{self.__class__.__name__} only supports \'softplus\' for ' + f'\'clamp_mode\' but receive \'{self.clamp_mode}\'.') + + density_delta = densities_ * deltas + + alpha = 1 - torch.exp(-density_delta) + + alpha_shifted = torch.cat( + [torch.ones_like(alpha[:, :, :1]), 1 - alpha + 1e-10], -2) + weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] + weight_total = weights.sum(2) + + composite_depth = torch.sum(weights * depths_, -2) / weight_total + + # clip the composite to min/max range of depths + if digit_version(TORCH_VERSION) < digit_version('1.8.0'): + composite_depth[torch.isnan(composite_depth)] = float('inf') + else: + composite_depth = torch.nan_to_num(composite_depth, float('inf')) + composite_depth = torch.clamp(composite_depth, torch.min(depths_), + torch.max(depths_)) + + # NOTE: move bg to forward, since we cannot forward bg decoder in + # volume rendering + if colors is not None: + if mode == 'mid': + colors_ = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 + else: + colors_ = colors + + composite_rgb = torch.sum(weights * colors_, -2) + else: + composite_rgb = None + + return composite_rgb, composite_depth, weights + + @torch.no_grad() + def sample_importance(self, z_vals: torch.Tensor, weights: torch.Tensor, + N_importance: int) -> torch.Tensor: + """Return depths of importance sampled points along rays. + + Args: + z_vals (torch.Tensor): Coarse Z value (depth). Shape like + (bz, N_points, N_depth, N_feature). + weights (torch.Tensor): Weights of the coarse samples. Shape like + (bz, N_points, N_depths-1, 1). + N_importance (int): Number of samples to resample. + """ + batch_size, num_rays, samples_per_ray, _ = z_vals.shape + z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) + # -1 to account for loss of 1 sample in MipRayMarcher + weights = weights.reshape(batch_size * num_rays, -1) + + # smooth weights as MipNeRF + # max(weights[:-1], weights[1:]) + weights = F.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1) + # 0.5 * (weights[:-1] + weights[1:]) + weights = F.avg_pool1d(weights, 2, 1).squeeze() + weights = weights + 0.01 # add resampling padding + + z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) + importance_z_vals = inverse_transform_sampling(z_vals_mid, + weights[:, 1:-1], + N_importance).detach() + importance_z_vals = importance_z_vals.reshape(batch_size, num_rays, + N_importance, 1) + return importance_z_vals diff --git a/mmedit/models/editors/dreamfusion/stable_diffusion_wrapper.py b/mmedit/models/editors/dreamfusion/stable_diffusion_wrapper.py new file mode 100644 index 0000000000..672a6efcbd --- /dev/null +++ b/mmedit/models/editors/dreamfusion/stable_diffusion_wrapper.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F + +from mmedit.models.editors.stable_diffusion import StableDiffusion +from mmedit.registry import MODULES + + +@MODULES.register_module() +class StableDiffusionWrapper(StableDiffusion): + """Stable diffusion wrapper for dreamfusion.""" + + def __init__(self, + diffusion_scheduler, + unet_cfg, + vae_cfg, + pretrained_ckpt_path, + requires_safety_checker=True, + unet_sample_size=64): + super().__init__(diffusion_scheduler, unet_cfg, vae_cfg, + pretrained_ckpt_path, requires_safety_checker, + unet_sample_size) + self.min_step = int(0.02 * self.scheduler.num_train_timesteps) + self.max_step = int(0.98 * self.scheduler.num_train_timesteps) + + def get_text_embeds(self, prompt, negative_prompt): + # prompt, negative_prompt: [str] + + # Tokenize text and get embeddings + text_input = self.tokenizer( + prompt, + padding='max_length', + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + + with torch.no_grad(): + text_embeddings = self.text_encoder( + text_input.input_ids.to(self.execution_device))[0] + + # Do the same for unconditional embeddings + uncond_input = self.tokenizer( + negative_prompt, + padding='max_length', + max_length=self.tokenizer.model_max_length, + return_tensors='pt') + + with torch.no_grad(): + uncond_embeddings = self.text_encoder( + uncond_input.input_ids.to(self.execution_device))[0] + + # Cat for final embeddings + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) + return text_embeddings + + def encode_imgs(self, imgs): + imgs = 2 * imgs - 1 + + posterior = self.vae.encode(imgs).latent_dist + latents = posterior.sample() * 0.18215 + + return latents + + # @torch.no_grad() + def decode_latents(self, latents): + # TODO: can we do this by super().decode(latents) ? + latents = 1 / 0.18215 * latents + + with torch.no_grad(): + imgs = self.vae.decode(latents).sample + + imgs = (imgs / 2 + 0.5).clamp(0, 1) + + return imgs + + def train_step_(self, text_embeddings, pred_rgb, guidance_scale=100): + + text_embeddings = text_embeddings.to(self.execution_device) + # interp to 512x512 to be fed into vae. + pred_rgb_512 = F.interpolate( + pred_rgb, (512, 512), mode='bilinear', align_corners=False) + + # timestep ~ U(0.02, 0.98) to avoid very high/low noise level + t = torch.randint( + self.min_step, + self.max_step + 1, [1], + dtype=torch.long, + device=self.execution_device) + + # encode image into latents with vae, requires grad! + latents = self.encode_imgs(pred_rgb_512) + + # predict the noise residual with unet, NO grad! + with torch.no_grad(): + # add noise + noise = torch.randn_like(latents) + latents_noisy = self.scheduler.add_noise(latents, noise, t) + # pred noise + latent_model_input = torch.cat([latents_noisy] * 2) + noise_pred = self.unet( + latent_model_input, t, + encoder_hidden_states=text_embeddings)['outputs'] + + # perform guidance (high scale from paper!) + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond) + + # w(t), sigma_t^2 + # w = (1 - self.alphas[t]) + w = (1 - self.scheduler.alphas_cumprod[t]) + # w = self.alphas[t] ** 0.5 * (1 - self.alphas[t]) + grad = w * (noise_pred - noise) + + grad = torch.nan_to_num(grad) + + # manually backward, since we omitted an item in grad and cannot + # simply autodiff. + latents.backward(gradient=grad, retain_graph=True) + + # TODO: return a loss term without grad + return 0 # dummy loss value diff --git a/mmedit/models/editors/dreamfusion/vanilla_nerf.py b/mmedit/models/editors/dreamfusion/vanilla_nerf.py new file mode 100644 index 0000000000..8434f63c55 --- /dev/null +++ b/mmedit/models/editors/dreamfusion/vanilla_nerf.py @@ -0,0 +1,309 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmedit.models.utils import normalize_vecs +from mmedit.registry import MODULES +from .activate import trunc_exp + +# from .utils import auto_batchicy + + +class FreqEncoder(nn.Module): + + def __init__(self, + input_dim, + max_freq_log2, + N_freqs, + log_sampling=True, + include_input=True, + periodic_fns=(torch.sin, torch.cos)): + + super().__init__() + + self.input_dim = input_dim + self.include_input = include_input + self.periodic_fns = periodic_fns + + self.output_dim = 0 + if self.include_input: + self.output_dim += self.input_dim + + self.output_dim += self.input_dim * N_freqs * len(self.periodic_fns) + + if log_sampling: + self.freq_bands = 2**torch.linspace(0, max_freq_log2, N_freqs) + else: + self.freq_bands = torch.linspace(2**0, 2**max_freq_log2, N_freqs) + + self.freq_bands = self.freq_bands.numpy().tolist() + + def forward(self, input, *args, **kwargs): + + out = [] + if self.include_input: + out.append(input) + + for i in range(len(self.freq_bands)): + freq = self.freq_bands[i] + for p_fn in self.periodic_fns: + out.append(p_fn(input * freq)) + + out = torch.cat(out, dim=-1) + + return out + + +class ResBlock(nn.Module): + + def __init__(self, dim_in, dim_out, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + + self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias) + self.norm = nn.LayerNorm(self.dim_out) + self.activation = nn.SiLU(inplace=True) + + if self.dim_in != self.dim_out: + self.skip = nn.Linear(self.dim_in, self.dim_out, bias=False) + else: + self.skip = None + + def forward(self, x): + # x: [B, C] + identity = x + + out = self.dense(x) + out = self.norm(out) + + if self.skip is not None: + identity = self.skip(identity) + + out += identity + out = self.activation(out) + + return out + + +class BasicBlock(nn.Module): + + def __init__(self, dim_in, dim_out, bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + + self.dense = nn.Linear(self.dim_in, self.dim_out, bias=bias) + self.activation = nn.ReLU(inplace=True) + + def forward(self, x): + # x: [B, C] + + out = self.dense(x) + out = self.activation(out) + + return out + + +class MLP(nn.Module): + + def __init__(self, + dim_in, + dim_out, + dim_hidden, + num_layers, + block, + bias=True): + super().__init__() + self.dim_in = dim_in + self.dim_out = dim_out + self.dim_hidden = dim_hidden + self.num_layers = num_layers + + net = [] + for idx in range(num_layers): + if idx == 0: + net.append(BasicBlock(self.dim_in, self.dim_hidden, bias=bias)) + elif idx != num_layers - 1: + net.append(block(self.dim_hidden, self.dim_hidden, bias=bias)) + else: + net.append(nn.Linear(self.dim_hidden, self.dim_out, bias=bias)) + + self.net = nn.Sequential(*net) + + def forward(self, x): + + for idx in range(self.num_layers): + x = self.net[idx](x) + + return x + + +class VanillaMLP(BaseModule): + + def __init__(self, + in_channels: int, + out_channels: int, + hidden_channels: int, + num_layers: int, + use_res_block: bool = False, + use_bias: bool = True): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.num_layers = num_layers + + net = [BasicBlock(in_channels, hidden_channels, bias=use_bias)] + block = ResBlock if use_res_block else BasicBlock + for _ in range(num_layers - 2): + net.append(block(hidden_channels, hidden_channels, bias=use_bias)) + net.append(nn.Linear(hidden_channels, out_channels, bias=use_bias)) + + self.net = nn.Sequential(*net) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +@MODULES.register_module('VanillaNeRF') +class NeRFNetwork(BaseModule): + + # TODO: optim n_freq and max_freq_log2 + + def __init__( + self, + num_layers: int = 4, # 5 in paper + hidden_dim: int = 96, # 128 in paper + bg_radius: float = 1.4, + num_layers_bg: int = 2, # 3 in paper + hidden_dim_bg: int = 64, # 64 in paper + n_freq: int = 6, + max_freq_log2: int = 5, + bg_freq: int = 4, + bg_max_freq_log2: int = 3, + init_cfg=None): + + super().__init__(init_cfg=init_cfg) + + self.bound = 1 + self.num_layers = num_layers + self.hidden_dim = hidden_dim + self.encoder = FreqEncoder( + input_dim=3, max_freq_log2=max_freq_log2, N_freqs=n_freq) + self.in_dim = self.encoder.output_dim + self.sigma_net = VanillaMLP( + self.in_dim, 4, hidden_dim, num_layers, use_res_block=True) + + # background network + self.bg_radius = bg_radius + if self.bg_radius > 0: + self.num_layers_bg = num_layers_bg + self.hidden_dim_bg = hidden_dim_bg + # multires = 4 + self.encoder_bg = FreqEncoder( + input_dim=3, max_freq_log2=bg_max_freq_log2, N_freqs=bg_freq) + self.in_dim_bg = self.encoder_bg.output_dim + self.bg_net = VanillaMLP(self.in_dim_bg, 3, hidden_dim_bg, + num_layers_bg) + else: + self.encoder_bg = self.bg_net = None + + def spatial_density_bias(self, x: torch.Tensor) -> torch.Tensor: + """Spatial density bias in Equal (9) in appendix.""" + + # Spatial density bias. + d = (x**2).sum(-1) + g = 5 * torch.exp(-d / (2 * 0.2**2)) + + return g + + def forward_bg(self, d: torch.Tensor) -> torch.Tensor: + """Forward functionfor the background network.""" + if self.bg_radius == 0: + return torch.rand_like(d) + + h = self.encoder_bg(d) # [N, C] + + h = self.bg_net(h) + + # sigmoid activation for rgb + rgbs = torch.sigmoid(h) + + return rgbs + + def forward_fg(self, x: torch.Tensor) -> torch.Tensor: + """Forward function for the foreground network.""" + # x: [N, 3], in [-bound, bound] + + # sigma + h = self.encoder(x, bound=self.bound) + h = self.sigma_net(h) + + sigma = trunc_exp(h[..., 0] + self.spatial_density_bias(x)) + albedo = torch.sigmoid(h[..., 1:]) + + return sigma, albedo + + def normal(self, x: torch.Tensor) -> torch.Tensor: + """Calculate the normal with density.""" + + with torch.enable_grad(): + x.requires_grad_(True) + sigma, _ = self.forward_fg(x) + normal = -torch.autograd.grad( + torch.sum(sigma), x, create_graph=True)[0] # [N, 3] + + normal = normalize_vecs(normal, clamp_eps=1e-20) + + return normal + + # @auto_batchicy(no_batchify_args='light_d') + def forward(self, + x: torch.Tensor, + light_d: Optional[torch.Tensor] = None, + ambient_ratio: float = 1, + shading: str = 'albedo') -> dict: + """The forward function.""" + # x: [N, 3], in [-bound, bound] + # d: [N, 3], view direction, nomalized in [-1, 1] + # l: [3], plane light direction, nomalized in [-1, 1] + # ratio: scalar, ambient ratio, 1 == no shading (albedo only), + # 0 == only shading (textureless) + + # NOTE: a dirty way to only get albedo in visualization step + if shading == 'albedo' or not self.training: + sigma, albedo = self.forward_fg(x) + output = dict(sigma=sigma, color=albedo) + else: + with torch.enable_grad(): + x.requires_grad_(True) + sigma, albedo = self.forward_fg(x) + # query gradient + normal = -torch.autograd.grad( + torch.sum(sigma), x, create_graph=True)[0] # [N, 3] + + normal = normalize_vecs(normal, clamp_eps=1e-20) + + if shading == 'normal': + color = (normal + 1) / 2 + else: + # lambertian shading + lambertian = ambient_ratio + (1 - ambient_ratio) * ( + normal @ light_d).clamp(min=0) # [N,] + if shading == 'textureless': + color = lambertian.unsqueeze(-1).repeat(1, 3) + else: # 'lambertian' + color = albedo * lambertian.unsqueeze(-1) + output = dict(sigma=sigma, color=color, normal=normal) + + return output + + def density(self, x): + # x: [N, 3], in [-bound, bound] + sigma, albedo = self.forward_fg(x) + return dict(sigma=sigma, color=albedo)