Skip to content

Commit

Permalink
first commit for dreamfusion
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed Aug 7, 2023
1 parent 5b5f895 commit 81de7ac
Show file tree
Hide file tree
Showing 13 changed files with 1,790 additions and 6 deletions.
4 changes: 3 additions & 1 deletion mmagic/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .comp1k_dataset import AdobeComp1kDataset
from .controlnet_dataset import ControlNetDataset
from .dreambooth_dataset import DreamBoothDataset
from .dummy_dataset import DummyDataset
from .grow_scale_image_dataset import GrowScaleImgDataset
from .imagenet_dataset import ImageNet
from .mscoco_dataset import MSCoCoDataset
Expand All @@ -19,5 +20,6 @@
'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset',
'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset',
'MSCoCoDataset', 'ControlNetDataset', 'DreamBoothDataset',
'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset'
'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset',
'DummyDataset'
]
30 changes: 30 additions & 0 deletions mmagic/datasets/dummy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy

from torch.utils.data import Dataset

from mmedit.registry import DATASETS


@DATASETS.register_module()
class DummyDataset(Dataset):

def __init__(self, max_length=100, batch_size=None, sample_kwargs=None):
super().__init__()
self.max_length = max_length
self.sample_kwargs = sample_kwargs
self.batch_size = batch_size

def __len__(self):
return self.max_length

def __getitem__(self, index):
data_dict = dict()
input_dict = dict()
if self.batch_size is not None:
input_dict['num_batches'] = self.batch_size
if self.sample_kwargs is not None:
input_dict['sample_kwargs'] = deepcopy(self.sample_kwargs)

data_dict['inputs'] = input_dict
return data_dict
3 changes: 2 additions & 1 deletion mmagic/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dreamfusion_hook import DreamFusionTrainingHook
from .ema import ExponentialMovingAverageHook
from .iter_time_hook import IterTimerHook
from .pggan_fetch_data_hook import PGGANFetchDataHook
Expand All @@ -9,5 +10,5 @@
__all__ = [
'ReduceLRSchedulerHook', 'BasicVisualizationHook', 'VisualizationHook',
'ExponentialMovingAverageHook', 'IterTimerHook', 'PGGANFetchDataHook',
'PickleDataHook'
'PickleDataHook', 'DreamFusionTrainingHook'
]
55 changes: 55 additions & 0 deletions mmagic/engine/hooks/dreamfusion_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random

from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper

from mmedit.registry import HOOKS


@HOOKS.register_module()
class DreamFusionTrainingHook(Hook):

def __init__(self, albedo_iters: int):
super().__init__()
self.albedo_iters = albedo_iters

self.shading_test = 'albedo'
self.ambident_ratio_test = 1.0

def set_shading_and_ambient(self, runner, shading: str,
ambient_ratio: str) -> None:
model = runner.model
if is_model_wrapper(model):
model = model.module
renderer = model.renderer
if is_model_wrapper(renderer):
renderer = renderer.module
renderer.set_shading(shading)
renderer.set_ambient_ratio(ambient_ratio)

def after_train_iter(self, runner, batch_idx: int, *args,
**kwargs) -> None:
if batch_idx < self.albedo_iters or self.albedo_iters == -1:
shading = 'albedo'
ambient_ratio = 1.0
else:
rand = random.random()
if rand > 0.8: # NOTE: this should be 0.75 in paper
shading = 'albedo'
ambient_ratio = 1.0
elif rand > 0.4: # NOTE: this should be 0.75 * 0.5 = 0.325
shading = 'textureless'
ambient_ratio = 0.1
else:
shading = 'lambertian'
ambient_ratio = 0.1
self.set_shading_and_ambient(runner, shading, ambient_ratio)

def before_test(self, runner) -> None:
self.set_shading_and_ambient(runner, self.shading_test,
self.ambident_ratio_test)

def before_val(self, runner) -> None:
self.set_shading_and_ambient(runner, self.shading_test,
self.ambident_ratio_test)
4 changes: 2 additions & 2 deletions mmagic/models/editors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .dim import DIM
from .disco_diffusion import ClipWrapper, DiscoDiffusion
from .dreambooth import DreamBooth
from .dreamfusion import DreamFusion
from .edsr import EDSRNet
from .edvr import EDVR, EDVRNet
from .eg3d import EG3D
Expand Down Expand Up @@ -86,8 +87,7 @@
'StyleGAN1', 'StyleGAN2', 'StyleGAN3', 'BigGAN', 'DCGAN',
'ProgressiveGrowingGAN', 'SinGAN', 'AblatedDiffusionModel',
'DiscoDiffusion', 'IDLossModel', 'PESinGAN', 'MSPIEStyleGAN2',
'StyleGAN3Generator', 'InstColorization', 'NAFBaseline',
'NAFBaselineLocal', 'NAFNet', 'NAFNetLocal', 'DenoisingUnet',
'ClipWrapper', 'EG3D', 'Restormer', 'SwinIRNet', 'StableDiffusion',
'ControlStableDiffusion', 'DreamBooth', 'TextualInversion'
'ControlStableDiffusion', 'DreamBooth', 'TextualInversion', 'DreamFusion'
]
16 changes: 14 additions & 2 deletions mmagic/models/utils/tensor_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

import torch


Expand Down Expand Up @@ -38,14 +40,24 @@ def get_unknown_tensor(trimap, unknown_value=128 / 255):
return weight


def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
def normalize_vecs(vectors: torch.Tensor,
clamp_eps: Optional[float] = None) -> torch.Tensor:
"""Normalize vector with it's lengths at the last dimension. If `vector` is
two-dimension tensor, this function is same as L2 normalization.
Args:
vector (torch.Tensor): Vectors to be normalized.
eps (float, optional): If passed, the min value will be clamped to
this value before calculate the square root. Defaults to None.
Returns:
torch.Tensor: Vectors after normalization.
"""
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
if clamp_eps is None:
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
assert clamp_eps >= 0, (
f'\'clamp_eps\' must not less than 0. But receive \'{clamp_eps}\'.')

return vectors / torch.sqrt(
torch.clamp(
torch.sum(vectors * vectors, -1, keepdim=True), min=clamp_eps))
10 changes: 10 additions & 0 deletions mmedit/models/editors/dreamfusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .camera import DreamFusionCamera
from .dreamfusion import DreamFusion
from .renderer import DreamFusionRenderer
from .stable_diffusion_wrapper import StableDiffusionWrapper

__all__ = [
'DreamFusion', 'DreamFusionRenderer', 'DreamFusionCamera',
'StableDiffusionWrapper'
]
22 changes: 22 additions & 0 deletions mmedit/models/editors/dreamfusion/activate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.autograd import Function
from torch.cuda.amp import custom_bwd, custom_fwd


class _trunc_exp(Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float)
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.exp(x)

@staticmethod
@custom_bwd
def backward(ctx, g):
x = ctx.saved_tensors[0]
return g * torch.exp(x.clamp(max=15))


trunc_exp = _trunc_exp.apply

0 comments on commit 81de7ac

Please sign in to comment.