diff --git a/.gitignore b/.gitignore index b41436af8a..6aa8edcbe4 100644 --- a/.gitignore +++ b/.gitignore @@ -153,3 +153,4 @@ batchscript-* *.zip work_dir work_dir/ +!tests/data/sd/* diff --git a/configs/_base_/datasets/pokemon_blip_xl.py b/configs/_base_/datasets/pokemon_blip_xl.py new file mode 100644 index 0000000000..e2fdb5ab65 --- /dev/null +++ b/configs/_base_/datasets/pokemon_blip_xl.py @@ -0,0 +1,23 @@ +pipeline = [ + dict( + type='LoadImageFromHuggingFaceDataset', key='img', + channel_order='rgb'), + dict(type='ResizeEdge', scale=1024), + dict(type='RandomCropXL', size=1024), + dict(type='FlipXL', keys=['img'], flip_ratio=0.5, direction='horizontal'), + dict(type='ComputeTimeIds'), + dict(type='PackInputs', keys=['merged', 'img', 'time_ids']), +] +dataset = dict( + type='HuggingFaceDataset', + dataset='lambdalabs/pokemon-blip-captions', + pipeline=pipeline) +train_dataloader = dict( + batch_size=1, + num_workers=2, + dataset=dataset, + sampler=dict(type='DefaultSampler', shuffle=True), +) + +val_dataloader = val_evaluator = None +test_dataloader = test_evaluator = None diff --git a/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl.py b/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl.py new file mode 100644 index 0000000000..455c9fb27c --- /dev/null +++ b/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl.py @@ -0,0 +1,36 @@ +# Use DiffuserWrapper! +stable_diffusion_xl_url = 'stabilityai/stable-diffusion-xl-base-1.0' +vae_url = 'madebyollin/sdxl-vae-fp16-fix' +unet = dict( + type='UNet2DConditionModel', + subfolder='unet', + from_pretrained=stable_diffusion_xl_url) +vae = dict(type='AutoencoderKL', from_pretrained=vae_url) + +diffusion_scheduler = dict( + type='DDPMScheduler', + from_pretrained=stable_diffusion_xl_url, + subfolder='scheduler') + +model = dict( + type='StableDiffusionXL', + dtype='fp16', + with_cp=True, + unet=unet, + vae=vae, + enable_xformers=False, + text_encoder_one=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_url, + subfolder='text_encoder'), + tokenizer_one=stable_diffusion_xl_url, + text_encoder_two=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_url, + subfolder='text_encoder_2'), + tokenizer_two=stable_diffusion_xl_url, + scheduler=diffusion_scheduler, + test_scheduler=diffusion_scheduler, + data_preprocessor=dict(type='DataPreprocessor', data_keys=None)) diff --git a/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl_lora.py b/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl_lora.py new file mode 100644 index 0000000000..1bf1eff1a3 --- /dev/null +++ b/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl_lora.py @@ -0,0 +1,39 @@ +# Use DiffuserWrapper! +stable_diffusion_xl_url = 'stabilityai/stable-diffusion-xl-base-1.0' +vae_url = 'madebyollin/sdxl-vae-fp16-fix' +unet = dict( + type='UNet2DConditionModel', + subfolder='unet', + from_pretrained=stable_diffusion_xl_url) +vae = dict(type='AutoencoderKL', from_pretrained=vae_url) + +diffusion_scheduler = dict( + type='DDPMScheduler', + from_pretrained=stable_diffusion_xl_url, + subfolder='scheduler') + +lora_config = dict(rank=8, target_modules=['to_q', 'to_k', 'to_v']) + +model = dict( + type='StableDiffusionXL', + dtype='fp16', + with_cp=True, + unet=unet, + vae=vae, + enable_xformers=False, + text_encoder_one=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_url, + subfolder='text_encoder'), + tokenizer_one=stable_diffusion_xl_url, + text_encoder_two=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_url, + subfolder='text_encoder_2'), + tokenizer_two=stable_diffusion_xl_url, + scheduler=diffusion_scheduler, + test_scheduler=diffusion_scheduler, + data_preprocessor=dict(type='DataPreprocessor', data_keys=None), + lora_config=lora_config) diff --git a/configs/_base_/schedules/sd_10e.py b/configs/_base_/schedules/sd_10e.py new file mode 100644 index 0000000000..cb017cdf27 --- /dev/null +++ b/configs/_base_/schedules/sd_10e.py @@ -0,0 +1,11 @@ +optim_wrapper = dict( + type='AmpOptimWrapper', + dtype='float16', + optimizer=dict(type='AdamW', lr=1e-5, weight_decay=1e-2), + clip_grad=dict(max_norm=1.0), + accumulative_counts=1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=10) +val_cfg = None +test_cfg = None diff --git a/configs/_base_/schedules/sdxl_10e.py b/configs/_base_/schedules/sdxl_10e.py new file mode 100644 index 0000000000..008adf94cf --- /dev/null +++ b/configs/_base_/schedules/sdxl_10e.py @@ -0,0 +1,16 @@ +optim_wrapper = dict( + type='AmpOptimWrapper', + dtype='float16', + optimizer=dict( + type='Adafactor', + lr=1e-5, + weight_decay=1e-2, + scale_parameter=False, + relative_step=False), + clip_grad=dict(max_norm=1.0), + accumulative_counts=1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=10) +val_cfg = None +test_cfg = None diff --git a/configs/_base_/sd_default_runtime.py b/configs/_base_/sd_default_runtime.py new file mode 100644 index 0000000000..ef66eae79b --- /dev/null +++ b/configs/_base_/sd_default_runtime.py @@ -0,0 +1,46 @@ +default_scope = 'mmagic' + +# configure for default hooks +default_hooks = dict( + # record time of every iteration. + timer=dict(type='IterTimerHook'), + # print log every 100 iterations. + logger=dict(type='LoggerHook', interval=100), + # save checkpoint per 10000 iterations + checkpoint=dict( + type='CheckpointHook', + interval=1, + by_epoch=True, + max_keep_ckpts=3, + save_optimizer=True)) + +# config for environment +env_cfg = dict( + # whether to enable cudnn benchmark. + cudnn_benchmark=True, + # set multi process parameters. + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters. + dist_cfg=dict(backend='nccl')) + +# set log level +log_level = 'INFO' +log_processor = dict(type='LogProcessor', by_epoch=True) + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = None + +# config for model wrapper +model_wrapper_cfg = dict( + type='MMSeparateDistributedDataParallel', + broadcast_buffers=False, + find_unused_parameters=False) + +# set visualizer +vis_backends = [dict(type='VisBackend')] +visualizer = dict(type='Visualizer', vis_backends=vis_backends) + +randomness = dict(seed=None, deterministic=False) diff --git a/configs/stable_diffusion_xl/README.md b/configs/stable_diffusion_xl/README.md index bc67a69ee9..e4c175d147 100644 --- a/configs/stable_diffusion_xl/README.md +++ b/configs/stable_diffusion_xl/README.md @@ -20,9 +20,11 @@ We present SDXL, a latent diffusion model for text-to-image synthesis. Compared ## Pretrained models -| Model | Task | Dataset | Download | -| :----------------------------------------------------------------: | :--------: | :-----: | :------: | -| [stable_diffusion_xl](./stable-diffusion_xl_ddim_denoisingunet.py) | Text2Image | - | - | +| Model | Task | Dataset | Download | +| :---------------------------------------------------------------------------------: | :--------: | :--------------------------------------------------------------------------------------: | :---------: | +| [stable_diffusion_xl](./stable-diffusion_xl_ddim_denoisingunet.py) | Text2Image | - | - | +| [stable_diffusion_xl_pokemon_blip](./stable-diffusion_xl_pokemon_blip.py) | Text2Image | [pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) | [model](<>) | +| [stable-diffusion_xl_lora_pokemon_blip](./stable-diffusion_xl_lora_pokemon_blip.py) | Text2Image | [pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) | [model](<>) | We use stable diffusion xl weights. This model has several weights including vae, unet and clip. diff --git a/configs/stable_diffusion_xl/metafile.yml b/configs/stable_diffusion_xl/metafile.yml index 4844b7daf6..2d3dc99e41 100644 --- a/configs/stable_diffusion_xl/metafile.yml +++ b/configs/stable_diffusion_xl/metafile.yml @@ -16,3 +16,19 @@ Models: - Dataset: '-' Metrics: {} Task: Text2Image +- Config: configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py + In Collection: Stable Diffusion XL + Name: stable-diffusion_xl_pokemon_blip + Results: + - Dataset: '[pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions)' + Metrics: {} + Task: Text2Image + Weights: <> +- Config: configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py + In Collection: Stable Diffusion XL + Name: stable-diffusion_xl_lora_pokemon_blip + Results: + - Dataset: '[pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions)' + Metrics: {} + Task: Text2Image + Weights: <> diff --git a/configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py b/configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py new file mode 100644 index 0000000000..ea51838184 --- /dev/null +++ b/configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py @@ -0,0 +1,24 @@ +_base_ = [ + '../_base_/models/stable_diffusion_xl/stable_diffusion_xl_lora.py', + '../_base_/datasets/pokemon_blip_xl.py', '../_base_/schedules/sd_10e.py', + '../_base_/sd_default_runtime.py' +] + +val_prompts = ['yoda pokemon'] * 4 + +model = dict(val_prompts=val_prompts) + +train_dataloader = dict(batch_size=4, num_workers=4) + +# hooks +custom_hooks = [ + dict( + type='VisualizationHook', + by_epoch=True, + interval=1, + fixed_input=True, + # visualize train dataset + vis_kwargs_list=dict(type='Data', name='fake_img'), + n_samples=1), + dict(type='LoRACheckpointToSaveHook') +] diff --git a/configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py b/configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py new file mode 100644 index 0000000000..0e4833f6b2 --- /dev/null +++ b/configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py @@ -0,0 +1,21 @@ +_base_ = [ + '../_base_/models/stable_diffusion_xl/stable_diffusion_xl.py', + '../_base_/datasets/pokemon_blip_xl.py', '../_base_/schedules/sdxl_10e.py', + '../_base_/sd_default_runtime.py' +] + +val_prompts = ['yoda pokemon'] * 4 + +model = dict(val_prompts=val_prompts) + +# hooks +custom_hooks = [ + dict( + type='VisualizationHook', + by_epoch=True, + interval=1, + fixed_input=True, + # visualize train dataset + vis_kwargs_list=dict(type='Data', name='fake_img'), + n_samples=1) +] diff --git a/mmagic/datasets/__init__.py b/mmagic/datasets/__init__.py index 80307867be..d25e6cdb21 100644 --- a/mmagic/datasets/__init__.py +++ b/mmagic/datasets/__init__.py @@ -7,6 +7,7 @@ from .controlnet_dataset import ControlNetDataset from .dreambooth_dataset import DreamBoothDataset from .grow_scale_image_dataset import GrowScaleImgDataset +from .hf_dataset import HuggingFaceDataset from .imagenet_dataset import ImageNet from .mscoco_dataset import MSCoCoDataset from .paired_image_dataset import PairedImageDataset @@ -19,5 +20,6 @@ 'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset', 'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset', 'MSCoCoDataset', 'ControlNetDataset', 'DreamBoothDataset', 'ViCoDataset', - 'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset' + 'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset', + 'HuggingFaceDataset' ] diff --git a/mmagic/datasets/hf_dataset.py b/mmagic/datasets/hf_dataset.py new file mode 100644 index 0000000000..1488423228 --- /dev/null +++ b/mmagic/datasets/hf_dataset.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import random +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +from mmengine.dataset import BaseDataset + +from mmagic.registry import DATASETS + + +@DATASETS.register_module() +class HuggingFaceDataset(BaseDataset): + """Huggingface Dataset for DreamBooth. + + Args: + dataset (str): Dataset name for Huggingface datasets. + image_column (str): Image column name. Defaults to 'image'. + caption_column (str): Caption column name. Defaults to 'text'. + csv (str): Caption csv file name when loading local folder. + Defaults to 'metadata.csv'. + cache_dir (str, optional): The directory where the downloaded datasets + will be stored.Defaults to None. + pipeline (list[dict | callable]): A sequence of data transforms. + """ + + def __init__(self, + dataset: str, + image_column: str = 'image', + caption_column: str = 'text', + csv: str = 'metadata.csv', + cache_dir: Optional[str] = None, + pipeline: List[Union[dict, Callable]] = []): + + self.dataset = dataset + self.image_column = image_column + self.caption_column = caption_column + self.csv = csv + self.cache_dir = cache_dir + + super().__init__(pipeline=pipeline) + + def load_data_list(self) -> list: + """Load data list from concept_dir and class_dir.""" + try: + from datasets import load_dataset + except BaseException: + raise ImportError( + 'HuggingFaceDreamBoothDataset requires datasets, please ' + 'install it by `pip install datasets`.') + + data_list = [] + + if Path(self.dataset).exists(): + # load local folder + data_file = os.path.join(self.dataset, self.csv) + dataset = load_dataset( + 'csv', data_files=data_file, cache_dir=self.cache_dir)['train'] + else: + # load huggingface online + dataset = load_dataset( + self.dataset, cache_dir=self.cache_dir)['train'] + + for i in range(len(dataset)): + caption = dataset[i][self.caption_column] + if isinstance(caption, str): + pass + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + caption = random.choice(caption) + else: + raise ValueError( + f'Caption column `{self.caption_column}` should contain' + ' either strings or lists of strings.') + + img = dataset[i][self.image_column] + if type(img) == str: + img = os.path.join(self.dataset, img) + + data_info = dict(img=img, prompt=caption) + data_list.append(data_info) + + return data_list diff --git a/mmagic/datasets/transforms/__init__.py b/mmagic/datasets/transforms/__init__.py index e6b483443b..a85f348116 100644 --- a/mmagic/datasets/transforms/__init__.py +++ b/mmagic/datasets/transforms/__init__.py @@ -20,7 +20,8 @@ GenerateFrameIndiceswithPadding, GenerateSegmentIndices) from .get_masked_image import GetMaskedImage -from .loading import (GetSpatialDiscountMask, LoadImageFromFile, LoadMask, +from .loading import (GetSpatialDiscountMask, LoadImageFromFile, + LoadImageFromHuggingFaceDataset, LoadMask, LoadPairedImageFromFile) from .matlab_like_resize import MATLABLikeResize from .normalization import Normalize, RescaleToZeroOne @@ -28,6 +29,7 @@ RandomJPEGCompression, RandomNoise, RandomResize, RandomVideoCompression) from .random_down_sampling import RandomDownSampling +from .sdxl import ComputeTimeIds, FlipXL, RandomCropXL, ResizeEdge from .trimap import (FormatTrimap, GenerateTrimap, GenerateTrimapWithDistTransform, TransformTrimap) from .values import CopyValues, SetValues @@ -49,5 +51,7 @@ 'GenerateTrimapWithDistTransform', 'CompositeFg', 'RandomLoadResizeBg', 'MergeFgAndBg', 'PerturbBg', 'RandomJitter', 'LoadPairedImageFromFile', 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'InstanceCrop', - 'Albumentations', 'AlbuCorruptFunction', 'PairedAlbuTransForms' + 'Albumentations', 'AlbuCorruptFunction', 'PairedAlbuTransForms', + 'LoadImageFromHuggingFaceDataset', 'RandomCropXL', 'FlipXL', + 'ComputeTimeIds', 'ResizeEdge' ] diff --git a/mmagic/datasets/transforms/loading.py b/mmagic/datasets/transforms/loading.py index 738a19f503..9c918f9560 100644 --- a/mmagic/datasets/transforms/loading.py +++ b/mmagic/datasets/transforms/loading.py @@ -6,6 +6,7 @@ import numpy as np from mmcv.transforms import BaseTransform from mmengine.fileio import get_file_backend, list_from_file +from PIL import Image from mmagic.registry import TRANSFORMS from mmagic.utils import (bbox2mask, brush_stroke_mask, get_irregular_mask, @@ -536,3 +537,89 @@ def transform(self, results: dict) -> dict: results[f'ori_{self.key}'] = ori_image return results + + +@TRANSFORMS.register_module() +class LoadImageFromHuggingFaceDataset(BaseTransform): + """Load a single image from corresponding paths. Required + Keys: + - [Key]_path + + New Keys: + - [KEY] + - ori_[KEY]_shape + - ori_[KEY] + + Args: + key (str): Keys in results to find corresponding path. + channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'. + Default: 'bgr'. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :func:``mmcv.imfrombytes`` for details. + candidates are 'cv2', 'turbojpeg', 'pillow', and 'tifffile'. + Defaults to None. + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__( + self, + key: str, + channel_order: str = 'bgr', + to_float32: bool = False, + save_original_img: bool = False, + ) -> None: + + self.key = key + self.channel_order = channel_order + self.save_original_img = save_original_img + + # convert + self.to_float32 = to_float32 + + def transform(self, results: dict) -> dict: + """Functions to load image or frames. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + Returns: + dict: The dict contains loaded image and meta information. + """ + + img = results[f'{self.key}'] + if type(img) == str: + img = Image.open(img) + + if self.channel_order == 'rgb': + img = img.convert('RGB') + img = np.array(img) + elif self.channel_order == 'bgr': + img = np.array(img) + img = img[..., ::-1] + + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + + if self.to_float32: + img = img.astype(np.float32) + + results[self.key] = img + results[f'ori_{self.key}_shape'] = img.shape + results[f'{self.key}_channel_order'] = self.channel_order + results[f'{self.key}_color_type'] = 'color' + if self.save_original_img: + results[f'ori_{self.key}'] = img.copy() + + return results + + def __repr__(self): + + repr_str = (f'{self.__class__.__name__}(' + f'key={self.key}, ' + f'channel_order={self.channel_order}, ' + f'to_float32={self.to_float32}, ' + f'save_original_img={self.save_original_img})') + + return repr_str diff --git a/mmagic/datasets/transforms/sdxl.py b/mmagic/datasets/transforms/sdxl.py new file mode 100644 index 0000000000..f6e9ffebc9 --- /dev/null +++ b/mmagic/datasets/transforms/sdxl.py @@ -0,0 +1,296 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Sequence, Union + +import mmcv +import numpy as np +from mmcv.transforms import BaseTransform, to_tensor + +from mmagic.datasets.transforms.aug_shape import Flip +from mmagic.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class RandomCropXL(BaseTransform): + """Random crop the given image. Required Keys: + + - [KEYS] + + Modified Keys: + - [KEYS] + + New Keys: + - [KEYS]_crop_bbox + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted + as (size[0], size[0]) + keys (str or list[str]): The images to be cropped. + """ + + def __init__(self, size: int, keys: Union[str, List[str]] = 'img'): + if not isinstance(size, Sequence): + size = (size, size) + self.size = size + + assert keys, 'Keys should not be empty.' + if not isinstance(keys, list): + keys = [keys] + self.keys = keys + + def transform(self, results: Dict) -> Dict: + """Call function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + assert all(results[self.keys[0]].size == results[k].size + for k in self.keys) + + h, w, _ = results[self.keys[0]].shape + + if h < self.size[0] or w < self.size[1]: + raise ValueError( + f'({h}, {w}) is smaller than crop size {self.size}.') + + # randomly choose top and left coordinates for img patch + top = np.random.randint(h - self.size[0] + 1) + left = np.random.randint(w - self.size[1] + 1) + + for key in self.keys: + results[key] = results[key][top:top + self.size[0], + left:left + self.size[1], ...] + results[f'{key}_crop_bbox'] = [ + top, left, top + self.size[0], left + self.size[1] + ] + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(size={self.size}, ' f'keys={self.keys})') + return repr_str + + +@TRANSFORMS.register_module() +class FlipXL(Flip): + """Flip the input data with a probability. + + The differences between FlipXL & Flip: + 1. Fix [KEYS]_crop_bbox. + + Required Keys: + - [KEYS] + - [KEYS]_crop_bbox + + Modified Keys: + - [KEYS] + - [KEYS]_crop_bbox + + Args: + keys (Union[str, List[str]]): The images to be flipped. + flip_ratio (float): The probability to flip the images. Default: 0.5. + direction (str): Flip images horizontally or vertically. Options are + "horizontal" | "vertical". Default: "horizontal". + """ + + def transform(self, results: Dict) -> Dict: + """transform function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + + flip = np.random.random() < self.flip_ratio + + if flip: + for key in self.keys: + mmcv.imflip_(results[key], self.direction) + h, w, _ = results[key].shape + if self.direction == 'horizontal': + results[f'{key}_crop_bbox'] = [ + results[f'{key}_crop_bbox'][0], + w - results[f'{key}_crop_bbox'][3], + results[f'{key}_crop_bbox'][2], + w - results[f'{key}_crop_bbox'][1] + ] + elif self.direction == 'vertical': + results[f'{key}_crop_bbox'] = [ + h - results[f'{key}_crop_bbox'][2], + results[f'{key}_crop_bbox'][1], + h - results[f'{key}_crop_bbox'][0], + results[f'{key}_crop_bbox'][3] + ] + + if 'flip_infos' not in results: + results['flip_infos'] = [] + + flip_info = dict( + keys=self.keys, + direction=self.direction, + ratio=self.flip_ratio, + flip=flip) + results['flip_infos'].append(flip_info) + + return results + + +@TRANSFORMS.register_module() +class ComputeTimeIds(BaseTransform): + """Load a single image from corresponding paths. Required Required Keys: + + - [Key] + - ori_[KEY]_shape + - [KEYS]_crop_bbox + + New Keys: + - time_ids + + Args: + key (str): Keys in results to find corresponding path. + Defaults to `img`. + """ + + def __init__( + self, + key: str = 'img', + ) -> None: + + self.key = key + + def transform(self, results: Dict) -> Dict: + """ + Args: + results (dict): The result dict. + + Returns: + dict: 'time_ids' key is added as original image shape. + """ + assert f'ori_{self.key}_shape' in results + assert f'{self.key}_crop_bbox' in results + target_size = list(results[self.key].shape)[:2] + time_ids = list(results[f'ori_{self.key}_shape'][:2] + ) + results[f'{self.key}_crop_bbox'][:2] + target_size + results['time_ids'] = to_tensor(time_ids) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(key={self.key})') + return repr_str + + +@TRANSFORMS.register_module() +class ResizeEdge(BaseTransform): + """Resize images along the specified edge. + + Required Keys: + - [KEYS] + + Modified Keys: + - [KEYS] + - [KEYS]_shape + + New Keys: + - keep_ratio + - scale_factor + - interpolation + + Args: + scale (int): The edge scale to resizing. + keys (str | list[str]): The image(s) to be resized. + edge (str): The edge to resize. Defaults to 'short'. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. + Defaults to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + Defaults to 'bilinear'. + """ + + def __init__(self, + scale: int, + keys: Union[str, List[str]] = 'img', + edge: str = 'short', + backend: str = 'cv2', + interpolation: str = 'bilinear') -> None: + assert keys, 'Keys should not be empty.' + keys = [keys] if not isinstance(keys, list) else keys + + allow_edges = ['short', 'long', 'width', 'height'] + assert edge in allow_edges, \ + f'Invalid edge "{edge}", please specify from {allow_edges}.' + + self.keys = keys + self.edge = edge + self.scale = scale + self.backend = backend + self.interpolation = interpolation + + def _resize_img(self, results: dict, key: str) -> None: + """Resize images with ``results['scale']``.""" + + img, w_scale, h_scale = mmcv.imresize( + results[key], + results['scale'], + interpolation=self.interpolation, + return_scale=True, + backend=self.backend) + results[key] = img + results[f'{key}_shape'] = img.shape[:2] + results['scale_factor'] = (w_scale, h_scale) + results['keep_ratio'] = True + results['interpolation'] = self.interpolation + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img', 'scale', 'scale_factor', + 'img_shape' keys are updated in result dict. + """ + for k in self.keys: + assert k in results, f'No {k} field in the input.' + + h, w = results[k].shape[:2] + if any([ + # conditions to resize the width + self.edge == 'short' and w < h, + self.edge == 'long' and w > h, + self.edge == 'width', + ]): + width = self.scale + height = int(self.scale * h / w) + else: + height = self.scale + width = int(self.scale * w / h) + results['scale'] = (width, height) + + self._resize_img(results, k) + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(scale={self.scale}, ' + repr_str += f'edge={self.edge}, ' + repr_str += f'backend={self.backend}, ' + repr_str += f'interpolation={self.interpolation})' + return repr_str diff --git a/mmagic/engine/hooks/__init__.py b/mmagic/engine/hooks/__init__.py index 8435afa9a5..e7de12ea1c 100644 --- a/mmagic/engine/hooks/__init__.py +++ b/mmagic/engine/hooks/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ema import ExponentialMovingAverageHook from .iter_time_hook import IterTimerHook +from .lora_checkpoint_to_save_hook import LoRACheckpointToSaveHook from .pggan_fetch_data_hook import PGGANFetchDataHook from .pickle_data_hook import PickleDataHook from .reduce_lr_scheduler_hook import ReduceLRSchedulerHook @@ -9,5 +10,5 @@ __all__ = [ 'ReduceLRSchedulerHook', 'BasicVisualizationHook', 'VisualizationHook', 'ExponentialMovingAverageHook', 'IterTimerHook', 'PGGANFetchDataHook', - 'PickleDataHook' + 'PickleDataHook', 'LoRACheckpointToSaveHook' ] diff --git a/mmagic/engine/hooks/lora_checkpoint_to_save_hook.py b/mmagic/engine/hooks/lora_checkpoint_to_save_hook.py new file mode 100644 index 0000000000..8e83a352b4 --- /dev/null +++ b/mmagic/engine/hooks/lora_checkpoint_to_save_hook.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import List + +from mmengine.hooks import Hook +from mmengine.registry import HOOKS + + +@HOOKS.register_module() +class LoRACheckpointToSaveHook(Hook): + """Pick up LoRA weights from checkpoint. + + Args: + lora_keys (List[str]): + """ + priority = 'VERY_LOW' + + def __init__(self, lora_keys: List[str] = ['lora_mapping']): + super().__init__() + self.lora_keys = lora_keys + + def before_save_checkpoint(self, runner, checkpoint: dict) -> None: + """ + Args: + runner (Runner): The runner of the training, validation or testing + process. + checkpoint (dict): Model's checkpoint. + """ + + new_ckpt = OrderedDict() + for k in checkpoint['state_dict'].keys(): + if any(key in k for key in self.lora_keys): + new_ckpt[k] = checkpoint['state_dict'][k] + + checkpoint['state_dict'] = new_ckpt diff --git a/mmagic/engine/hooks/visualization_hook.py b/mmagic/engine/hooks/visualization_hook.py index 48affeccfe..b4ff8c2b20 100644 --- a/mmagic/engine/hooks/visualization_hook.py +++ b/mmagic/engine/hooks/visualization_hook.py @@ -149,6 +149,7 @@ class VisualizationHook(Hook): If None is passed, all samples will be saved. Defaults to 100. show (bool): Whether to display the drawn image. Default to False. wait_time (float): The interval of show (s). Defaults to 0. + by_epoch (bool): Whether to visualize by epoch. Defaults to False. """ priority = 'NORMAL' @@ -182,10 +183,12 @@ def __init__(self, max_save_at_test: int = 100, test_vis_keys: Optional[Union[str, List[str]]] = None, show: bool = False, - wait_time: float = 0): + wait_time: float = 0, + by_epoch: bool = False): self._visualizer: Visualizer = Visualizer.get_current_instance() self.interval = interval + self.by_epoch = by_epoch self.vis_kwargs_list = deepcopy(vis_kwargs_list) if isinstance(self.vis_kwargs_list, dict): @@ -287,9 +290,22 @@ def after_train_iter(self, Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ - if self.every_n_inner_iters(batch_idx, self.interval): + if not self.by_epoch and self.every_n_inner_iters( + batch_idx, self.interval): self.vis_sample(runner, batch_idx, data_batch, outputs) + @master_only + def after_train_epoch(self, runner) -> None: + """Visualize samples after train iteration. + + Args: + runner (Runner): The runner of the training process. + """ + batch_idx = runner.epoch + if self.by_epoch and self.every_n_inner_iters(batch_idx, + self.interval): + self.vis_sample(runner, batch_idx, {}, None) + @torch.no_grad() def vis_sample(self, runner: Runner, diff --git a/mmagic/models/archs/lora.py b/mmagic/models/archs/lora.py index 066a13cbd4..86de9a5e3c 100644 --- a/mmagic/models/archs/lora.py +++ b/mmagic/models/archs/lora.py @@ -219,7 +219,7 @@ def forward_lora_mapping(self, x: Tensor) -> Tensor: mapping_out = self.scale * self.lora_mapping(x) return mapping_out - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, **kwargs) -> Tensor: """Forward and add LoRA mapping. Args: diff --git a/mmagic/models/data_preprocessors/data_preprocessor.py b/mmagic/models/data_preprocessors/data_preprocessor.py index bdea5f044b..a1113f3084 100644 --- a/mmagic/models/data_preprocessors/data_preprocessor.py +++ b/mmagic/models/data_preprocessors/data_preprocessor.py @@ -68,7 +68,7 @@ class DataPreprocessor(ImgDataPreprocessor): data sample. Only support with input data samples are `DataSamples`. Defaults to True. """ - _NON_IMAGE_KEYS = ['noise'] + _NON_IMAGE_KEYS = ['noise', 'time_ids'] _NON_CONCATENATE_KEYS = ['num_batches', 'mode', 'sample_kwargs', 'eq_cfg'] def __init__(self, diff --git a/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py b/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py index fa46f433e2..31bd03f39a 100644 --- a/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py +++ b/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py @@ -72,6 +72,9 @@ class StableDiffusionXL(BaseModel): force_zeros_for_empty_prompt (bool): Whether the negative prompt embeddings shall be forced to always be set to 0. Defaults to True. + with_cp (bool): Whether or not to use gradient + checkpointing to save memory at the expense of slower backward + pass. Defaults to False. init_cfg (dict, optional): The weight initialized config for :class:`BaseModule`. """ @@ -95,6 +98,7 @@ def __init__(self, val_prompts: Union[str, List[str]] = None, finetune_text_encoder: bool = False, force_zeros_for_empty_prompt: bool = True, + with_cp: bool = False, init_cfg: Optional[dict] = None): # TODO: support `from_pretrained` for this class @@ -150,6 +154,7 @@ def __init__(self, self.val_prompts = val_prompts self.lora_config = deepcopy(lora_config) self.force_zeros_for_empty_prompt = force_zeros_for_empty_prompt + self.with_cp = with_cp self.prepare_model() self.set_lora() @@ -165,6 +170,12 @@ def prepare_model(self): Move model to target dtype and disable gradient for some models. """ + if self.with_cp: + self.unet.enable_gradient_checkpointing() + if self.finetune_text_encoder: + self.text_encoder_one.gradient_checkpointing_enable() + self.text_encoder_two.gradient_checkpointing_enable() + self.vae.requires_grad_(False) print_log('Set VAE untrainable.', 'current') self.vae.to(self.dtype) @@ -989,7 +1000,8 @@ def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict): vae = self.vae.module if hasattr(self.vae, 'module') else self.vae with optim_wrapper.optim_context(self.unet): - image = inputs + image = inputs['img'] + time_ids = inputs['time_ids'] prompt = data_samples.prompt num_batches = image.shape[0] @@ -1032,7 +1044,7 @@ def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict): pooled_prompt_embeds) = self.encode_prompt_train( input_ids_one, input_ids_two) unet_added_conditions = { - 'time_ids': data['time_ids'], + 'time_ids': time_ids, 'text_embeds': pooled_prompt_embeds } diff --git a/requirements/optional.txt b/requirements/optional.txt index e44ef1c306..03fb832199 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,5 +1,6 @@ albumentations -e git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1#egg=clip +datasets imageio-ffmpeg==0.4.4 mmdet >= 3.0.0 open_clip_torch diff --git a/requirements/tests.txt b/requirements/tests.txt index 51e76ecd50..bc0deeda3b 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -9,6 +9,7 @@ controlnet_aux # pytest-runner # yapf coverage < 7.0.0 +datasets imageio-ffmpeg==0.4.4 interrogate mmdet >= 3.0.0 diff --git a/tests/data/sd/color.jpg b/tests/data/sd/color.jpg new file mode 100644 index 0000000000..2f19ebc6c6 Binary files /dev/null and b/tests/data/sd/color.jpg differ diff --git a/tests/data/sd/metadata.csv b/tests/data/sd/metadata.csv new file mode 100644 index 0000000000..c10a815652 --- /dev/null +++ b/tests/data/sd/metadata.csv @@ -0,0 +1,2 @@ +file_name,text +color.jpg,"a dog" \ No newline at end of file diff --git a/tests/data/sd/metadata2.csv b/tests/data/sd/metadata2.csv new file mode 100644 index 0000000000..458dfb8edd --- /dev/null +++ b/tests/data/sd/metadata2.csv @@ -0,0 +1,2 @@ +file_name,text +color.jpg,"a cat" \ No newline at end of file diff --git a/tests/test_datasets/test_hf_dataset.py b/tests/test_datasets/test_hf_dataset.py new file mode 100644 index 0000000000..5f16f6054a --- /dev/null +++ b/tests/test_datasets/test_hf_dataset.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import platform + +import pytest +from mmengine.testing import RunnerTestCase + +from mmagic.datasets import HuggingFaceDataset + + +@pytest.mark.skipif( + 'win' in platform.system().lower(), + reason='skip on windows due to limited RAM.') +class TestHFDataset(RunnerTestCase): + + def test_dataset_from_local(self): + data_root = osp.join(osp.dirname(__file__), '../../') + dataset_path = data_root + 'tests/data/sd' + dataset = HuggingFaceDataset( + dataset=dataset_path, image_column='file_name') + assert len(dataset) == 1 + + data = dataset[0] + assert data['prompt'] == 'a dog' + assert 'tests/data/sd/color.jpg' in data['img'] + + dataset = HuggingFaceDataset( + dataset='tests/data/sd', + image_column='file_name', + csv='metadata2.csv') + assert len(dataset) == 1 + + data = dataset[0] + assert data['prompt'] == 'a cat' + assert 'tests/data/sd/color.jpg' in data['img'] + + +def teardown_module(): + import gc + gc.collect() + globals().clear() + locals().clear() diff --git a/tests/test_datasets/test_paired_image_dataset.py b/tests/test_datasets/test_paired_image_dataset.py index 837c063cc0..333eb80fae 100644 --- a/tests/test_datasets/test_paired_image_dataset.py +++ b/tests/test_datasets/test_paired_image_dataset.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +from mmengine.registry import init_default_scope + from mmagic.datasets import PairedImageDataset from mmagic.utils import register_all_modules @@ -11,6 +13,7 @@ class TestPairedImageDataset(object): @classmethod def setup_class(cls): + init_default_scope('mmagic') cls.imgs_root = osp.join( osp.dirname(osp.dirname(__file__)), 'data/paired') cls.default_pipeline = [ diff --git a/tests/test_datasets/test_singan_dataset.py b/tests/test_datasets/test_singan_dataset.py index 24bfb8f744..907254b3ab 100644 --- a/tests/test_datasets/test_singan_dataset.py +++ b/tests/test_datasets/test_singan_dataset.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +from mmengine.registry import init_default_scope + from mmagic.datasets import SinGANDataset from mmagic.utils import register_all_modules @@ -11,6 +13,7 @@ class TestSinGANDataset(object): @classmethod def setup_class(cls): + init_default_scope('mmagic') cls.imgs_root = osp.join( osp.dirname(osp.dirname(__file__)), 'data/image/gt/baboon.png') cls.min_size = 25 diff --git a/tests/test_datasets/test_transforms/test_loading.py b/tests/test_datasets/test_transforms/test_loading.py index 6597a3f299..5bd72d0695 100644 --- a/tests/test_datasets/test_transforms/test_loading.py +++ b/tests/test_datasets/test_transforms/test_loading.py @@ -5,9 +5,12 @@ import numpy as np import pytest from mmengine.fileio.backends import LocalBackend +from PIL import Image from mmagic.datasets.transforms import (GetSpatialDiscountMask, - LoadImageFromFile, LoadMask) + LoadImageFromFile, + LoadImageFromHuggingFaceDataset, + LoadMask) def test_load_image_from_file(): @@ -298,6 +301,43 @@ def test_load_mask(self): results = loader(results) +def test_load_image_from_huggingface_dataset(): + + path_baboon = Path( + __file__).parent.parent.parent / 'data' / 'image' / 'gt' / 'baboon.png' + img_baboon = mmcv.imread(str(path_baboon), flag='color') + h, w, _ = img_baboon.shape + + # read gt image + # input path is Path object + results = dict(img=Image.open(str(path_baboon))) + config = dict(key='img') + image_loader = LoadImageFromHuggingFaceDataset(**config) + results = image_loader(results) + assert results['img'].shape == (h, w, 3) + assert results['ori_img_shape'] == (h, w, 3) + np.testing.assert_almost_equal(results['img'], img_baboon) + assert results['img_channel_order'] == 'bgr' + assert results['img_color_type'] == 'color' + + # test save_original_img + results = dict(img=Image.open(str(path_baboon))) + config = dict(key='img', save_original_img=True, channel_order='rgb') + image_loader = LoadImageFromHuggingFaceDataset(**config) + results = image_loader(results) + assert results['img'].shape == (h, w, 3) + assert results['ori_img_shape'] == (h, w, 3) + np.testing.assert_almost_equal(results['ori_img'], results['img']) + np.testing.assert_almost_equal(results['img'], img_baboon[..., ::-1]) + assert id(results['ori_img']) != id(results['img']) + assert results['img_channel_order'] == 'rgb' + assert results['img_color_type'] == 'color' + + assert image_loader.__repr__() == ( + image_loader.__class__.__name__ + '(key=img, channel_order=rgb,' + ' to_float32=False, save_original_img=True)') + + def teardown_module(): import gc gc.collect() diff --git a/tests/test_datasets/test_transforms/test_sdxl.py b/tests/test_datasets/test_transforms/test_sdxl.py new file mode 100644 index 0000000000..e2ecb3c92d --- /dev/null +++ b/tests/test_datasets/test_transforms/test_sdxl.py @@ -0,0 +1,206 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from unittest import TestCase + +import numpy as np +import torch +from PIL import Image + +from mmagic.registry import TRANSFORMS + + +class TestComputeTimeIds(TestCase): + + def test_register(self): + self.assertIn('ComputeTimeIds', TRANSFORMS) + + def test_transform(self): + img_path = osp.join(osp.dirname(__file__), '../../data/sd/color.jpg') + img = Image.open(img_path) + data = { + 'img': np.array(img), + 'ori_img_shape': [32, 32], + 'img_crop_bbox': [0, 0, 32, 32] + } + + # test transform + trans = TRANSFORMS.build(dict(type='ComputeTimeIds')) + data = trans(data) + self.assertIsInstance(data['time_ids'], torch.Tensor) + self.assertListEqual( + list(data['time_ids'].numpy()), + [32, 32, 0, 0, img.height, img.width]) + + assert trans.__repr__() == (trans.__class__.__name__ + '(key=img)') + + +class TestRandomCropXL(TestCase): + crop_size = 32 + + def test_register(self): + self.assertIn('RandomCropXL', TRANSFORMS) + + def test_transform(self): + img_path = osp.join(osp.dirname(__file__), '../../data/sd/color.jpg') + data = {'img': np.array(Image.open(img_path))} + + # test transform + trans = TRANSFORMS.build( + dict(type='RandomCropXL', size=self.crop_size)) + data = trans(data) + self.assertIn('img_crop_bbox', data) + assert len(data['img_crop_bbox']) == 4 + assert data['img'].shape[0] == data['img'].shape[1] == self.crop_size + upper, left, lower, right = data['img_crop_bbox'] + assert lower == upper + self.crop_size + assert right == left + self.crop_size + np.equal( + np.array(data['img']), + np.array(Image.open(img_path).crop((left, upper, right, lower)))) + + assert trans.__repr__() == ( + trans.__class__.__name__ + + f'(size={(self.crop_size, self.crop_size)},' + f" keys=['img'])") + + def test_transform_multiple_keys(self): + img_path = osp.join(osp.dirname(__file__), '../../data/sd/color.jpg') + data = { + 'img': np.array(Image.open(img_path)), + 'condition_img': np.array(Image.open(img_path)) + } + + # test transform + trans = TRANSFORMS.build( + dict( + type='RandomCropXL', + size=self.crop_size, + keys=['img', 'condition_img'])) + data = trans(data) + self.assertIn('img_crop_bbox', data) + assert len(data['img_crop_bbox']) == 4 + assert data['img'].shape[0] == data['img'].shape[1] == self.crop_size + upper, left, lower, right = data['img_crop_bbox'] + assert lower == upper + self.crop_size + assert right == left + self.crop_size + np.equal( + np.array(data['img']), + np.array(Image.open(img_path).crop((left, upper, right, lower)))) + np.equal(np.array(data['img']), np.array(data['condition_img'])) + + +class TestFlipXL(TestCase): + + def test_register(self): + self.assertIn('FlipXL', TRANSFORMS) + + def test_transform(self): + img_path = osp.join(osp.dirname(__file__), '../../data/sd/color.jpg') + data = { + 'img': np.array(Image.open(img_path)), + 'img_crop_bbox': [0, 0, 10, 10] + } + + # test transform + trans = TRANSFORMS.build( + dict(type='FlipXL', flip_ratio=1., keys=['img'])) + data = trans(data) + self.assertIn('img_crop_bbox', data) + assert len(data['img_crop_bbox']) == 4 + self.assertListEqual( + data['img_crop_bbox'], + [0, data['img'].shape[1] - 10, 10, data['img'].shape[1] - 0]) + + np.equal( + np.array(data['img']), + np.array(Image.open(img_path).transpose(Image.FLIP_LEFT_RIGHT))) + + assert trans.__repr__() == ( + trans.__class__.__name__ + + "(keys=['img'], flip_ratio=1.0, direction=horizontal)") + + # test transform p=0.0 + data = { + 'img': np.array(Image.open(img_path)), + 'img_crop_bbox': [0, 0, 10, 10] + } + trans = TRANSFORMS.build( + dict(type='FlipXL', flip_ratio=0., keys='img')) + data = trans(data) + self.assertIn('img_crop_bbox', data) + self.assertListEqual(data['img_crop_bbox'], [0, 0, 10, 10]) + + np.equal(np.array(data['img']), np.array(Image.open(img_path))) + + def test_transform_multiple_keys(self): + img_path = osp.join(osp.dirname(__file__), '../../data/sd/color.jpg') + data = { + 'img': np.array(Image.open(img_path)), + 'condition_img': np.array(Image.open(img_path)), + 'img_crop_bbox': [0, 0, 10, 10], + 'condition_img_crop_bbox': [0, 0, 10, 10] + } + + # test transform + trans = TRANSFORMS.build( + dict(type='FlipXL', flip_ratio=1., keys=['img', 'condition_img'])) + data = trans(data) + self.assertIn('img_crop_bbox', data) + assert len(data['img_crop_bbox']) == 4 + self.assertListEqual( + data['img_crop_bbox'], + [0, data['img'].shape[1] - 10, 10, data['img'].shape[1] - 0]) + + np.equal( + np.array(data['img']), + np.array(Image.open(img_path).transpose(Image.FLIP_LEFT_RIGHT))) + np.equal(np.array(data['img']), np.array(data['condition_img'])) + + +class TestResizeEdge(TestCase): + + def test_transform(self): + results = dict(img=np.random.randint(0, 256, (128, 256, 3), np.uint8)) + + # test resize short edge by default. + cfg = dict(type='ResizeEdge', scale=224) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 448, 3)) + + # test resize long edge. + cfg = dict(type='ResizeEdge', scale=224, edge='long') + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (112, 224, 3)) + + # test resize width. + cfg = dict(type='ResizeEdge', scale=224, edge='width') + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (112, 224, 3)) + + # test resize height. + cfg = dict(type='ResizeEdge', scale=224, edge='height') + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 448, 3)) + + # test invalid edge + with self.assertRaisesRegex(AssertionError, 'Invalid edge "hi"'): + cfg = dict(type='ResizeEdge', scale=224, edge='hi') + TRANSFORMS.build(cfg) + + def test_repr(self): + cfg = dict(type='ResizeEdge', scale=224, edge='height') + transform = TRANSFORMS.build(cfg) + self.assertEqual( + repr(transform), 'ResizeEdge(scale=224, edge=height, backend=cv2, ' + 'interpolation=bilinear)') + + +def teardown_module(): + import gc + gc.collect() + globals().clear() + locals().clear() diff --git a/tests/test_datasets/test_unpaired_image_dataset.py b/tests/test_datasets/test_unpaired_image_dataset.py index 4b7ff1bf82..694069d97f 100644 --- a/tests/test_datasets/test_unpaired_image_dataset.py +++ b/tests/test_datasets/test_unpaired_image_dataset.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +from mmengine.registry import init_default_scope + from mmagic.datasets import UnpairedImageDataset from mmagic.utils import register_all_modules @@ -11,6 +13,7 @@ class TestUnpairedImageDataset(object): @classmethod def setup_class(cls): + init_default_scope('mmagic') cls.imgs_root = osp.join( osp.dirname(osp.dirname(__file__)), 'data/unpaired') cls.default_pipeline = [ diff --git a/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py b/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py new file mode 100644 index 0000000000..3f18088848 --- /dev/null +++ b/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import platform + +import pytest +import torch +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from mmengine import Config +from mmengine.registry import MODELS +from mmengine.testing import RunnerTestCase + +from mmagic.engine.hooks import LoRACheckpointToSaveHook +from mmagic.models.archs import DiffusersWrapper +from mmagic.models.data_preprocessors import DataPreprocessor +from mmagic.models.editors import ClipWrapper, StableDiffusionXL +from mmagic.models.editors.stable_diffusion import AutoencoderKL + +stable_diffusion_xl_tiny_url = 'hf-internal-testing/tiny-stable-diffusion-xl-pipe' # noqa +lora_config = dict(target_modules=['to_q', 'to_k', 'to_v']) +diffusion_scheduler = dict( + type='EditDDIMScheduler', + variance_type='learned_range', + beta_end=0.012, + beta_schedule='scaled_linear', + beta_start=0.00085, + num_train_timesteps=1000, + set_alpha_to_one=False, + clip_sample=False) +model = dict( + type='StableDiffusionXL', + unet=dict( + type='UNet2DConditionModel', + subfolder='unet', + from_pretrained=stable_diffusion_xl_tiny_url), + vae=dict(type='AutoencoderKL', sample_size=64), + text_encoder_one=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_tiny_url, + subfolder='text_encoder'), + tokenizer_one=stable_diffusion_xl_tiny_url, + text_encoder_two=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_tiny_url, + subfolder='text_encoder_2'), + tokenizer_two=stable_diffusion_xl_tiny_url, + scheduler=diffusion_scheduler, + val_prompts=['a dog', 'a dog'], + lora_config=lora_config) + + +@pytest.mark.skipif( + 'win' in platform.system().lower(), + reason='skip on windows due to limited RAM.') +@pytest.mark.skipif( + torch.__version__ < '1.9.0', + reason='skip on torch<1.9 due to unsupported torch.concat') +class TestLoRACheckpointToSaveHook(RunnerTestCase): + + def setUp(self) -> None: + MODELS.register_module( + name='StableDiffusionXL', module=StableDiffusionXL) + MODELS.register_module(name='ClipWrapper', module=ClipWrapper) + + def gen_wrapped_cls(module, module_name): + return type( + module_name, (DiffusersWrapper, ), + dict( + _module_cls=module, + _module_name=module_name, + __module__=__name__)) + + wrapped_module = gen_wrapped_cls(UNet2DConditionModel, + 'UNet2DConditionModel') + MODELS.register_module( + name='UNet2DConditionModel', module=wrapped_module, force=True) + MODELS.register_module(name='AutoencoderKL', module=AutoencoderKL) + MODELS.register_module( + name='DataPreprocessor', module=DataPreprocessor) + return super().setUp() + + def tearDown(self): + MODELS.module_dict.pop('StableDiffusionXL') + MODELS.module_dict.pop('ClipWrapper') + MODELS.module_dict.pop('UNet2DConditionModel') + MODELS.module_dict.pop('AutoencoderKL') + MODELS.module_dict.pop('DataPreprocessor') + return super().tearDown() + + def test_init(self): + LoRACheckpointToSaveHook() + + def test_before_save_checkpoint(self): + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + runner.model = MODELS.build(Config(model)) + checkpoint = dict(state_dict=MODELS.build(Config(model)).state_dict()) + hook = LoRACheckpointToSaveHook() + hook.before_save_checkpoint(runner, checkpoint) + + for key in checkpoint['state_dict'].keys(): + assert 'lora_mapping' in key + + +def teardown_module(): + import gc + gc.collect() + globals().clear() + locals().clear() diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index 267af16bef..6aaae91ebd 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -7,6 +7,7 @@ import numpy as np import torch from mmengine import MessageHub +from mmengine.registry import init_default_scope from mmengine.testing import assert_allclose from mmengine.visualization import Visualizer from torch.utils.data.dataset import Dataset @@ -25,6 +26,7 @@ class TestBasicVisualizationHook(TestCase): def setUp(self) -> None: + init_default_scope('mmagic') input = torch.rand(2, 3, 32, 32) data_sample = DataSample( path_rgb='rgb.png', @@ -89,6 +91,7 @@ class TestVisualizationHook(TestCase): MessageHub.get_instance('test-gen-visualizer') def test_init(self): + init_default_scope('mmagic') hook = VisualizationHook( interval=10, vis_kwargs_list=dict(type='Noise')) self.assertEqual(hook.interval, 10) @@ -119,6 +122,7 @@ def test_vis_sample_with_gan_alias(self): runner.train_dataloader = MagicMock() runner.train_dataloader.batch_size = 10 + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=10, vis_kwargs_list=dict(type='GAN'), n_samples=9) mock_visualuzer = MagicMock() @@ -209,6 +213,7 @@ def __getitem__(self, index): runner.val_loop = MagicMock() runner.val_loop.dataloader = val_dataloader + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=10, vis_kwargs_list=[ @@ -341,6 +346,7 @@ def __getitem__(self, index): def test_after_val_iter(self): model = MagicMock() + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=10, n_samples=2, vis_kwargs_list=dict(type='GAN')) mock_visualuzer = MagicMock() @@ -366,6 +372,7 @@ def test_after_train_iter(self): runner.train_dataloader = MagicMock() runner.train_dataloader.batch_size = 10 + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=2, vis_kwargs_list=dict(type='GAN'), n_samples=9) mock_visualuzer = MagicMock() @@ -496,6 +503,7 @@ def train(self): runner.train_dataloader = MagicMock() runner.train_dataloader.batch_size = 2 + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=2, vis_kwargs_list=dict(type='GAN'), n_samples=3, n_row=8) mock_visualuzer = MagicMock() @@ -515,6 +523,7 @@ def train(self): def test_after_test_iter(self): model = MagicMock() + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=10, n_samples=2, @@ -593,6 +602,7 @@ def test_after_test_iter(self): hook.after_test_iter(runner, 42, [], outputs) # test max save time + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=10, n_samples=2, @@ -617,6 +627,28 @@ def test_after_test_iter(self): hook.after_test_iter(runner, 0, [], outputs) assert mock_visualuzer.add_datasample.call_count == 3 + def test_after_train_epoch(self): + model = MagicMock() + _ = Visualizer.get_instance('name1') + hook = VisualizationHook( + interval=1, + n_samples=1, + vis_kwargs_list=dict(type='Data', name='fake_img'), + by_epoch=True, + fixed_input=True) + hook.inputs_buffer = {'Data': ['dummy']} + mock_visualuzer = MagicMock() + mock_visualuzer.add_datasample = MagicMock() + hook._visualizer = mock_visualuzer + + runner = MagicMock() + runner.train_dataloader.batch_size = 1 + runner.model = model + runner.epoch = 5 + + hook.after_train_epoch(runner) + self.assertEqual(mock_visualuzer.add_datasample.call_count, 1) + def teardown_module(): import gc diff --git a/tests/test_models/test_data_preprocessors/test_data_preprocessor.py b/tests/test_models/test_data_preprocessors/test_data_preprocessor.py index 60d7dd20d9..bf97c4ccb0 100644 --- a/tests/test_models/test_data_preprocessors/test_data_preprocessor.py +++ b/tests/test_models/test_data_preprocessors/test_data_preprocessor.py @@ -476,6 +476,7 @@ def test_preprocess_dict_inputs(self): img_A=[torch.randint(0, 255, (3, 5, 5)) for _ in range(3)], img_B=[torch.randint(0, 255, (3, 5, 5)) for _ in range(3)], noise=[torch.randn(16) for _ in range(3)], + time_ids=[torch.randn(6) for _ in range(3)], num_batches=3, tensor=torch.randint(0, 255, (3, 4, 5, 5)), mode=['ema', 'ema', 'ema'], @@ -490,6 +491,7 @@ def test_preprocess_dict_inputs(self): target_B = (torch.stack(inputs['img_B']) - 127.5) / 127.5 target_B = target_B[:, [2, 1, 0]] target_noise = torch.stack(inputs['noise']) + target_time_ids = torch.stack(inputs['time_ids']) # no metainfo, parse as BGR, do conversion target_tensor = ((inputs['tensor'] - 127.5) / 127.5)[:, [2, 1, 0, 3]] @@ -497,6 +499,7 @@ def test_preprocess_dict_inputs(self): assert_allclose(outputs['img_A'], target_A) assert_allclose(outputs['img_B'], target_B) assert_allclose(outputs['noise'], target_noise) + assert_allclose(outputs['time_ids'], target_time_ids) assert_allclose(outputs['tensor'], target_tensor) self.assertEqual(outputs['mode'], 'ema') self.assertEqual(outputs['num_batches'], 3) @@ -769,21 +772,35 @@ def test_forward(self): img2 = torch.randn(3, 4, 4) noise1 = torch.randn(3, 4, 4) noise2 = torch.randn(3, 4, 4) + time_ids1 = torch.randn(6) + time_ids2 = torch.randn(6) target_input1 = (img1[[2, 1, 0], ...].clone() - 127.5) / 127.5 target_input2 = (img2[[2, 1, 0], ...].clone() - 127.5) / 127.5 data = dict(inputs=[ - dict(noise=noise1, img=img1, num_batches=2, mode='ema'), - dict(noise=noise2, img=img2, num_batches=2, mode='ema'), + dict( + noise=noise1, + img=img1, + num_batches=2, + mode='ema', + time_ids=time_ids1), + dict( + noise=noise2, + img=img2, + num_batches=2, + mode='ema', + time_ids=time_ids2), ]) data_preprocessor = DataPreprocessor(output_channel_order='RGB') data = data_preprocessor(data) self.assertEqual( list(data['inputs'].keys()), - ['noise', 'img', 'num_batches', 'mode']) + ['noise', 'img', 'num_batches', 'mode', 'time_ids']) assert_allclose(data['inputs']['noise'][0], noise1) assert_allclose(data['inputs']['noise'][1], noise2) + assert_allclose(data['inputs']['time_ids'][0], time_ids1) + assert_allclose(data['inputs']['time_ids'][1], time_ids2) assert_allclose(data['inputs']['img'][0], target_input1) assert_allclose(data['inputs']['img'][1], target_input2) self.assertEqual(data['inputs']['num_batches'], 2) diff --git a/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py index 64c6345ebe..fefe6cc8e8 100644 --- a/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py @@ -85,8 +85,10 @@ def test_stable_diffusion_xl_step(): # train step data = dict( - inputs=torch.ones([1, 3, 64, 64]), - time_ids=torch.zeros((1, 6)), + inputs={ + 'img': torch.ones([1, 3, 64, 64]), + 'time_ids': torch.zeros((1, 6)) + }, data_samples=[ DataSample(prompt='an insect robot preparing a delicious meal') ])