From 8851b95ea1e1698ed18b93cecb71e27d7a0370d2 Mon Sep 17 00:00:00 2001 From: okotaku Date: Tue, 19 Sep 2023 16:10:23 +0900 Subject: [PATCH 01/14] fix test --- .../editors/stable_diffusion_xl/stable_diffusion_xl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py b/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py index fa46f433e..892b2f0d2 100644 --- a/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py +++ b/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py @@ -643,7 +643,7 @@ def _encode_prompt(self, prompt, prompt_2, device, num_images_per_prompt, prompt_embeds_list.append(text_embeddings) - text_embeddings = torch.concat(prompt_embeds_list, dim=-1) + text_embeddings = torch.cat(prompt_embeds_list, dim=-1) # duplicate text embeddings for each generation per prompt, bs_embed, seq_len, _ = text_embeddings.shape @@ -702,7 +702,7 @@ def _encode_prompt(self, prompt, prompt_2, device, num_images_per_prompt, negative_prompt_embeds_list.append(negative_prompt_embeds) - negative_prompt_embeds = torch.concat( + negative_prompt_embeds = torch.cat( negative_prompt_embeds_list, dim=-1) bs_embed, seq_len, _ = text_embeddings.shape @@ -967,7 +967,7 @@ def encode_prompt_train(self, text_one, text_two): prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + prompt_embeds = torch.cat(prompt_embeds_list, dim=-1) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) return prompt_embeds, pooled_prompt_embeds From 6381d35b0baac7d0f6924522030fe42c6cfff291 Mon Sep 17 00:00:00 2001 From: okotaku Date: Tue, 19 Sep 2023 16:43:09 +0900 Subject: [PATCH 02/14] fix test --- .../editors/stable_diffusion_xl/stable_diffusion_xl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py b/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py index 892b2f0d2..fa46f433e 100644 --- a/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py +++ b/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py @@ -643,7 +643,7 @@ def _encode_prompt(self, prompt, prompt_2, device, num_images_per_prompt, prompt_embeds_list.append(text_embeddings) - text_embeddings = torch.cat(prompt_embeds_list, dim=-1) + text_embeddings = torch.concat(prompt_embeds_list, dim=-1) # duplicate text embeddings for each generation per prompt, bs_embed, seq_len, _ = text_embeddings.shape @@ -702,7 +702,7 @@ def _encode_prompt(self, prompt, prompt_2, device, num_images_per_prompt, negative_prompt_embeds_list.append(negative_prompt_embeds) - negative_prompt_embeds = torch.cat( + negative_prompt_embeds = torch.concat( negative_prompt_embeds_list, dim=-1) bs_embed, seq_len, _ = text_embeddings.shape @@ -967,7 +967,7 @@ def encode_prompt_train(self, text_one, text_two): prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) prompt_embeds_list.append(prompt_embeds) - prompt_embeds = torch.cat(prompt_embeds_list, dim=-1) + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) return prompt_embeds, pooled_prompt_embeds From 36a79974108e0a5b400a8811507571af5ae82be6 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 21 Sep 2023 15:00:38 +0900 Subject: [PATCH 03/14] support SDXL training --- .gitignore | 1 + configs/_base_/datasets/pokemon_blip_xl.py | 23 ++ .../stable_diffusion_xl.py | 36 +++ configs/_base_/schedules/sdxl_10e.py | 16 + configs/_base_/sd_default_runtime.py | 46 +++ .../stable-diffusion_xl_pokemon_blip.py | 21 ++ mmagic/datasets/__init__.py | 4 +- mmagic/datasets/hf_dataset.py | 84 +++++ mmagic/datasets/transforms/__init__.py | 8 +- mmagic/datasets/transforms/loading.py | 87 +++++ mmagic/datasets/transforms/sdxl.py | 296 ++++++++++++++++++ mmagic/engine/hooks/visualization_hook.py | 20 +- .../data_preprocessors/data_preprocessor.py | 2 +- .../stable_diffusion_xl.py | 16 +- requirements/optional.txt | 1 + requirements/tests.txt | 1 + tests/data/sd/color.jpg | Bin 0 -> 39779 bytes tests/data/sd/metadata.csv | 2 + tests/data/sd/metadata2.csv | 2 + tests/test_datasets/test_hf_dataset.py | 26 ++ .../test_transforms/test_loading.py | 42 ++- .../test_transforms/test_sdxl.py | 206 ++++++++++++ .../test_hooks/test_visualization_hook.py | 17 + .../test_data_preprocessor.py | 23 +- .../test_stable_diffusion_xl.py | 6 +- 25 files changed, 972 insertions(+), 14 deletions(-) create mode 100644 configs/_base_/datasets/pokemon_blip_xl.py create mode 100644 configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl.py create mode 100644 configs/_base_/schedules/sdxl_10e.py create mode 100644 configs/_base_/sd_default_runtime.py create mode 100644 configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py create mode 100644 mmagic/datasets/hf_dataset.py create mode 100644 mmagic/datasets/transforms/sdxl.py create mode 100644 tests/data/sd/color.jpg create mode 100644 tests/data/sd/metadata.csv create mode 100644 tests/data/sd/metadata2.csv create mode 100644 tests/test_datasets/test_hf_dataset.py create mode 100644 tests/test_datasets/test_transforms/test_sdxl.py diff --git a/.gitignore b/.gitignore index b41436af8..6aa8edcbe 100644 --- a/.gitignore +++ b/.gitignore @@ -153,3 +153,4 @@ batchscript-* *.zip work_dir work_dir/ +!tests/data/sd/* diff --git a/configs/_base_/datasets/pokemon_blip_xl.py b/configs/_base_/datasets/pokemon_blip_xl.py new file mode 100644 index 000000000..e2fdb5ab6 --- /dev/null +++ b/configs/_base_/datasets/pokemon_blip_xl.py @@ -0,0 +1,23 @@ +pipeline = [ + dict( + type='LoadImageFromHuggingFaceDataset', key='img', + channel_order='rgb'), + dict(type='ResizeEdge', scale=1024), + dict(type='RandomCropXL', size=1024), + dict(type='FlipXL', keys=['img'], flip_ratio=0.5, direction='horizontal'), + dict(type='ComputeTimeIds'), + dict(type='PackInputs', keys=['merged', 'img', 'time_ids']), +] +dataset = dict( + type='HuggingFaceDataset', + dataset='lambdalabs/pokemon-blip-captions', + pipeline=pipeline) +train_dataloader = dict( + batch_size=1, + num_workers=2, + dataset=dataset, + sampler=dict(type='DefaultSampler', shuffle=True), +) + +val_dataloader = val_evaluator = None +test_dataloader = test_evaluator = None diff --git a/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl.py b/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl.py new file mode 100644 index 000000000..455c9fb27 --- /dev/null +++ b/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl.py @@ -0,0 +1,36 @@ +# Use DiffuserWrapper! +stable_diffusion_xl_url = 'stabilityai/stable-diffusion-xl-base-1.0' +vae_url = 'madebyollin/sdxl-vae-fp16-fix' +unet = dict( + type='UNet2DConditionModel', + subfolder='unet', + from_pretrained=stable_diffusion_xl_url) +vae = dict(type='AutoencoderKL', from_pretrained=vae_url) + +diffusion_scheduler = dict( + type='DDPMScheduler', + from_pretrained=stable_diffusion_xl_url, + subfolder='scheduler') + +model = dict( + type='StableDiffusionXL', + dtype='fp16', + with_cp=True, + unet=unet, + vae=vae, + enable_xformers=False, + text_encoder_one=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_url, + subfolder='text_encoder'), + tokenizer_one=stable_diffusion_xl_url, + text_encoder_two=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_url, + subfolder='text_encoder_2'), + tokenizer_two=stable_diffusion_xl_url, + scheduler=diffusion_scheduler, + test_scheduler=diffusion_scheduler, + data_preprocessor=dict(type='DataPreprocessor', data_keys=None)) diff --git a/configs/_base_/schedules/sdxl_10e.py b/configs/_base_/schedules/sdxl_10e.py new file mode 100644 index 000000000..008adf94c --- /dev/null +++ b/configs/_base_/schedules/sdxl_10e.py @@ -0,0 +1,16 @@ +optim_wrapper = dict( + type='AmpOptimWrapper', + dtype='float16', + optimizer=dict( + type='Adafactor', + lr=1e-5, + weight_decay=1e-2, + scale_parameter=False, + relative_step=False), + clip_grad=dict(max_norm=1.0), + accumulative_counts=1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=10) +val_cfg = None +test_cfg = None diff --git a/configs/_base_/sd_default_runtime.py b/configs/_base_/sd_default_runtime.py new file mode 100644 index 000000000..6d05f8caa --- /dev/null +++ b/configs/_base_/sd_default_runtime.py @@ -0,0 +1,46 @@ +default_scope = 'mmagic' + +# configure for default hooks +default_hooks = dict( + # record time of every iteration. + timer=dict(type='IterTimerHook'), + # print log every 100 iterations. + logger=dict(type='LoggerHook', interval=100), + # save checkpoint per 10000 iterations + checkpoint=dict( + type='CheckpointHook', + interval=1, + by_epoch=True, + max_keep_ckpts=3, + save_optimizer=False)) + +# config for environment +env_cfg = dict( + # whether to enable cudnn benchmark. + cudnn_benchmark=True, + # set multi process parameters. + mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), + # set distributed parameters. + dist_cfg=dict(backend='nccl')) + +# set log level +log_level = 'INFO' +log_processor = dict(type='LogProcessor', by_epoch=True) + +# load from which checkpoint +load_from = None + +# whether to resume training from the loaded checkpoint +resume = None + +# config for model wrapper +model_wrapper_cfg = dict( + type='MMSeparateDistributedDataParallel', + broadcast_buffers=False, + find_unused_parameters=False) + +# set visualizer +vis_backends = [dict(type='VisBackend')] +visualizer = dict(type='Visualizer', vis_backends=vis_backends) + +randomness = dict(seed=None, deterministic=False) diff --git a/configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py b/configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py new file mode 100644 index 000000000..0e4833f6b --- /dev/null +++ b/configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py @@ -0,0 +1,21 @@ +_base_ = [ + '../_base_/models/stable_diffusion_xl/stable_diffusion_xl.py', + '../_base_/datasets/pokemon_blip_xl.py', '../_base_/schedules/sdxl_10e.py', + '../_base_/sd_default_runtime.py' +] + +val_prompts = ['yoda pokemon'] * 4 + +model = dict(val_prompts=val_prompts) + +# hooks +custom_hooks = [ + dict( + type='VisualizationHook', + by_epoch=True, + interval=1, + fixed_input=True, + # visualize train dataset + vis_kwargs_list=dict(type='Data', name='fake_img'), + n_samples=1) +] diff --git a/mmagic/datasets/__init__.py b/mmagic/datasets/__init__.py index 80307867b..d25e6cdb2 100644 --- a/mmagic/datasets/__init__.py +++ b/mmagic/datasets/__init__.py @@ -7,6 +7,7 @@ from .controlnet_dataset import ControlNetDataset from .dreambooth_dataset import DreamBoothDataset from .grow_scale_image_dataset import GrowScaleImgDataset +from .hf_dataset import HuggingFaceDataset from .imagenet_dataset import ImageNet from .mscoco_dataset import MSCoCoDataset from .paired_image_dataset import PairedImageDataset @@ -19,5 +20,6 @@ 'BasicConditionalDataset', 'UnpairedImageDataset', 'PairedImageDataset', 'ImageNet', 'CIFAR10', 'GrowScaleImgDataset', 'SinGANDataset', 'MSCoCoDataset', 'ControlNetDataset', 'DreamBoothDataset', 'ViCoDataset', - 'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset' + 'ControlNetDataset', 'SDFinetuneDataset', 'TextualInversionDataset', + 'HuggingFaceDataset' ] diff --git a/mmagic/datasets/hf_dataset.py b/mmagic/datasets/hf_dataset.py new file mode 100644 index 000000000..148842322 --- /dev/null +++ b/mmagic/datasets/hf_dataset.py @@ -0,0 +1,84 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import random +from pathlib import Path +from typing import Callable, List, Optional, Union + +import numpy as np +from mmengine.dataset import BaseDataset + +from mmagic.registry import DATASETS + + +@DATASETS.register_module() +class HuggingFaceDataset(BaseDataset): + """Huggingface Dataset for DreamBooth. + + Args: + dataset (str): Dataset name for Huggingface datasets. + image_column (str): Image column name. Defaults to 'image'. + caption_column (str): Caption column name. Defaults to 'text'. + csv (str): Caption csv file name when loading local folder. + Defaults to 'metadata.csv'. + cache_dir (str, optional): The directory where the downloaded datasets + will be stored.Defaults to None. + pipeline (list[dict | callable]): A sequence of data transforms. + """ + + def __init__(self, + dataset: str, + image_column: str = 'image', + caption_column: str = 'text', + csv: str = 'metadata.csv', + cache_dir: Optional[str] = None, + pipeline: List[Union[dict, Callable]] = []): + + self.dataset = dataset + self.image_column = image_column + self.caption_column = caption_column + self.csv = csv + self.cache_dir = cache_dir + + super().__init__(pipeline=pipeline) + + def load_data_list(self) -> list: + """Load data list from concept_dir and class_dir.""" + try: + from datasets import load_dataset + except BaseException: + raise ImportError( + 'HuggingFaceDreamBoothDataset requires datasets, please ' + 'install it by `pip install datasets`.') + + data_list = [] + + if Path(self.dataset).exists(): + # load local folder + data_file = os.path.join(self.dataset, self.csv) + dataset = load_dataset( + 'csv', data_files=data_file, cache_dir=self.cache_dir)['train'] + else: + # load huggingface online + dataset = load_dataset( + self.dataset, cache_dir=self.cache_dir)['train'] + + for i in range(len(dataset)): + caption = dataset[i][self.caption_column] + if isinstance(caption, str): + pass + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + caption = random.choice(caption) + else: + raise ValueError( + f'Caption column `{self.caption_column}` should contain' + ' either strings or lists of strings.') + + img = dataset[i][self.image_column] + if type(img) == str: + img = os.path.join(self.dataset, img) + + data_info = dict(img=img, prompt=caption) + data_list.append(data_info) + + return data_list diff --git a/mmagic/datasets/transforms/__init__.py b/mmagic/datasets/transforms/__init__.py index e6b483443..a85f34811 100644 --- a/mmagic/datasets/transforms/__init__.py +++ b/mmagic/datasets/transforms/__init__.py @@ -20,7 +20,8 @@ GenerateFrameIndiceswithPadding, GenerateSegmentIndices) from .get_masked_image import GetMaskedImage -from .loading import (GetSpatialDiscountMask, LoadImageFromFile, LoadMask, +from .loading import (GetSpatialDiscountMask, LoadImageFromFile, + LoadImageFromHuggingFaceDataset, LoadMask, LoadPairedImageFromFile) from .matlab_like_resize import MATLABLikeResize from .normalization import Normalize, RescaleToZeroOne @@ -28,6 +29,7 @@ RandomJPEGCompression, RandomNoise, RandomResize, RandomVideoCompression) from .random_down_sampling import RandomDownSampling +from .sdxl import ComputeTimeIds, FlipXL, RandomCropXL, ResizeEdge from .trimap import (FormatTrimap, GenerateTrimap, GenerateTrimapWithDistTransform, TransformTrimap) from .values import CopyValues, SetValues @@ -49,5 +51,7 @@ 'GenerateTrimapWithDistTransform', 'CompositeFg', 'RandomLoadResizeBg', 'MergeFgAndBg', 'PerturbBg', 'RandomJitter', 'LoadPairedImageFromFile', 'CenterCropLongEdge', 'RandomCropLongEdge', 'NumpyPad', 'InstanceCrop', - 'Albumentations', 'AlbuCorruptFunction', 'PairedAlbuTransForms' + 'Albumentations', 'AlbuCorruptFunction', 'PairedAlbuTransForms', + 'LoadImageFromHuggingFaceDataset', 'RandomCropXL', 'FlipXL', + 'ComputeTimeIds', 'ResizeEdge' ] diff --git a/mmagic/datasets/transforms/loading.py b/mmagic/datasets/transforms/loading.py index 738a19f50..9c918f956 100644 --- a/mmagic/datasets/transforms/loading.py +++ b/mmagic/datasets/transforms/loading.py @@ -6,6 +6,7 @@ import numpy as np from mmcv.transforms import BaseTransform from mmengine.fileio import get_file_backend, list_from_file +from PIL import Image from mmagic.registry import TRANSFORMS from mmagic.utils import (bbox2mask, brush_stroke_mask, get_irregular_mask, @@ -536,3 +537,89 @@ def transform(self, results: dict) -> dict: results[f'ori_{self.key}'] = ori_image return results + + +@TRANSFORMS.register_module() +class LoadImageFromHuggingFaceDataset(BaseTransform): + """Load a single image from corresponding paths. Required + Keys: + - [Key]_path + + New Keys: + - [KEY] + - ori_[KEY]_shape + - ori_[KEY] + + Args: + key (str): Keys in results to find corresponding path. + channel_order (str): Order of channel, candidates are 'bgr' and 'rgb'. + Default: 'bgr'. + imdecode_backend (str): The image decoding backend type. The backend + argument for :func:``mmcv.imfrombytes``. + See :func:``mmcv.imfrombytes`` for details. + candidates are 'cv2', 'turbojpeg', 'pillow', and 'tifffile'. + Defaults to None. + to_float32 (bool): Whether to convert the loaded image to a float32 + numpy array. If set to False, the loaded image is an uint8 array. + Defaults to False. + """ + + def __init__( + self, + key: str, + channel_order: str = 'bgr', + to_float32: bool = False, + save_original_img: bool = False, + ) -> None: + + self.key = key + self.channel_order = channel_order + self.save_original_img = save_original_img + + # convert + self.to_float32 = to_float32 + + def transform(self, results: dict) -> dict: + """Functions to load image or frames. + + Args: + results (dict): Result dict from :obj:``mmcv.BaseDataset``. + Returns: + dict: The dict contains loaded image and meta information. + """ + + img = results[f'{self.key}'] + if type(img) == str: + img = Image.open(img) + + if self.channel_order == 'rgb': + img = img.convert('RGB') + img = np.array(img) + elif self.channel_order == 'bgr': + img = np.array(img) + img = img[..., ::-1] + + if img.ndim == 2: + img = np.expand_dims(img, axis=2) + + if self.to_float32: + img = img.astype(np.float32) + + results[self.key] = img + results[f'ori_{self.key}_shape'] = img.shape + results[f'{self.key}_channel_order'] = self.channel_order + results[f'{self.key}_color_type'] = 'color' + if self.save_original_img: + results[f'ori_{self.key}'] = img.copy() + + return results + + def __repr__(self): + + repr_str = (f'{self.__class__.__name__}(' + f'key={self.key}, ' + f'channel_order={self.channel_order}, ' + f'to_float32={self.to_float32}, ' + f'save_original_img={self.save_original_img})') + + return repr_str diff --git a/mmagic/datasets/transforms/sdxl.py b/mmagic/datasets/transforms/sdxl.py new file mode 100644 index 000000000..f6e9ffebc --- /dev/null +++ b/mmagic/datasets/transforms/sdxl.py @@ -0,0 +1,296 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Sequence, Union + +import mmcv +import numpy as np +from mmcv.transforms import BaseTransform, to_tensor + +from mmagic.datasets.transforms.aug_shape import Flip +from mmagic.registry import TRANSFORMS + + +@TRANSFORMS.register_module() +class RandomCropXL(BaseTransform): + """Random crop the given image. Required Keys: + + - [KEYS] + + Modified Keys: + - [KEYS] + + New Keys: + - [KEYS]_crop_bbox + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. If provided a sequence of length 1, it will be interpreted + as (size[0], size[0]) + keys (str or list[str]): The images to be cropped. + """ + + def __init__(self, size: int, keys: Union[str, List[str]] = 'img'): + if not isinstance(size, Sequence): + size = (size, size) + self.size = size + + assert keys, 'Keys should not be empty.' + if not isinstance(keys, list): + keys = [keys] + self.keys = keys + + def transform(self, results: Dict) -> Dict: + """Call function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + assert all(results[self.keys[0]].size == results[k].size + for k in self.keys) + + h, w, _ = results[self.keys[0]].shape + + if h < self.size[0] or w < self.size[1]: + raise ValueError( + f'({h}, {w}) is smaller than crop size {self.size}.') + + # randomly choose top and left coordinates for img patch + top = np.random.randint(h - self.size[0] + 1) + left = np.random.randint(w - self.size[1] + 1) + + for key in self.keys: + results[key] = results[key][top:top + self.size[0], + left:left + self.size[1], ...] + results[f'{key}_crop_bbox'] = [ + top, left, top + self.size[0], left + self.size[1] + ] + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(size={self.size}, ' f'keys={self.keys})') + return repr_str + + +@TRANSFORMS.register_module() +class FlipXL(Flip): + """Flip the input data with a probability. + + The differences between FlipXL & Flip: + 1. Fix [KEYS]_crop_bbox. + + Required Keys: + - [KEYS] + - [KEYS]_crop_bbox + + Modified Keys: + - [KEYS] + - [KEYS]_crop_bbox + + Args: + keys (Union[str, List[str]]): The images to be flipped. + flip_ratio (float): The probability to flip the images. Default: 0.5. + direction (str): Flip images horizontally or vertically. Options are + "horizontal" | "vertical". Default: "horizontal". + """ + + def transform(self, results: Dict) -> Dict: + """transform function. + + Args: + results (dict): A dict containing the necessary information and + data for augmentation. + + Returns: + dict: A dict containing the processed data and information. + """ + + flip = np.random.random() < self.flip_ratio + + if flip: + for key in self.keys: + mmcv.imflip_(results[key], self.direction) + h, w, _ = results[key].shape + if self.direction == 'horizontal': + results[f'{key}_crop_bbox'] = [ + results[f'{key}_crop_bbox'][0], + w - results[f'{key}_crop_bbox'][3], + results[f'{key}_crop_bbox'][2], + w - results[f'{key}_crop_bbox'][1] + ] + elif self.direction == 'vertical': + results[f'{key}_crop_bbox'] = [ + h - results[f'{key}_crop_bbox'][2], + results[f'{key}_crop_bbox'][1], + h - results[f'{key}_crop_bbox'][0], + results[f'{key}_crop_bbox'][3] + ] + + if 'flip_infos' not in results: + results['flip_infos'] = [] + + flip_info = dict( + keys=self.keys, + direction=self.direction, + ratio=self.flip_ratio, + flip=flip) + results['flip_infos'].append(flip_info) + + return results + + +@TRANSFORMS.register_module() +class ComputeTimeIds(BaseTransform): + """Load a single image from corresponding paths. Required Required Keys: + + - [Key] + - ori_[KEY]_shape + - [KEYS]_crop_bbox + + New Keys: + - time_ids + + Args: + key (str): Keys in results to find corresponding path. + Defaults to `img`. + """ + + def __init__( + self, + key: str = 'img', + ) -> None: + + self.key = key + + def transform(self, results: Dict) -> Dict: + """ + Args: + results (dict): The result dict. + + Returns: + dict: 'time_ids' key is added as original image shape. + """ + assert f'ori_{self.key}_shape' in results + assert f'{self.key}_crop_bbox' in results + target_size = list(results[self.key].shape)[:2] + time_ids = list(results[f'ori_{self.key}_shape'][:2] + ) + results[f'{self.key}_crop_bbox'][:2] + target_size + results['time_ids'] = to_tensor(time_ids) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(key={self.key})') + return repr_str + + +@TRANSFORMS.register_module() +class ResizeEdge(BaseTransform): + """Resize images along the specified edge. + + Required Keys: + - [KEYS] + + Modified Keys: + - [KEYS] + - [KEYS]_shape + + New Keys: + - keep_ratio + - scale_factor + - interpolation + + Args: + scale (int): The edge scale to resizing. + keys (str | list[str]): The image(s) to be resized. + edge (str): The edge to resize. Defaults to 'short'. + backend (str): Image resize backend, choices are 'cv2' and 'pillow'. + These two backends generates slightly different results. + Defaults to 'cv2'. + interpolation (str): Interpolation method, accepted values are + "nearest", "bilinear", "bicubic", "area", "lanczos" for 'cv2' + backend, "nearest", "bilinear" for 'pillow' backend. + Defaults to 'bilinear'. + """ + + def __init__(self, + scale: int, + keys: Union[str, List[str]] = 'img', + edge: str = 'short', + backend: str = 'cv2', + interpolation: str = 'bilinear') -> None: + assert keys, 'Keys should not be empty.' + keys = [keys] if not isinstance(keys, list) else keys + + allow_edges = ['short', 'long', 'width', 'height'] + assert edge in allow_edges, \ + f'Invalid edge "{edge}", please specify from {allow_edges}.' + + self.keys = keys + self.edge = edge + self.scale = scale + self.backend = backend + self.interpolation = interpolation + + def _resize_img(self, results: dict, key: str) -> None: + """Resize images with ``results['scale']``.""" + + img, w_scale, h_scale = mmcv.imresize( + results[key], + results['scale'], + interpolation=self.interpolation, + return_scale=True, + backend=self.backend) + results[key] = img + results[f'{key}_shape'] = img.shape[:2] + results['scale_factor'] = (w_scale, h_scale) + results['keep_ratio'] = True + results['interpolation'] = self.interpolation + + def transform(self, results: Dict) -> Dict: + """Transform function to resize images. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img', 'scale', 'scale_factor', + 'img_shape' keys are updated in result dict. + """ + for k in self.keys: + assert k in results, f'No {k} field in the input.' + + h, w = results[k].shape[:2] + if any([ + # conditions to resize the width + self.edge == 'short' and w < h, + self.edge == 'long' and w > h, + self.edge == 'width', + ]): + width = self.scale + height = int(self.scale * h / w) + else: + height = self.scale + width = int(self.scale * w / h) + results['scale'] = (width, height) + + self._resize_img(results, k) + return results + + def __repr__(self): + """Print the basic information of the transform. + + Returns: + str: Formatted string. + """ + repr_str = self.__class__.__name__ + repr_str += f'(scale={self.scale}, ' + repr_str += f'edge={self.edge}, ' + repr_str += f'backend={self.backend}, ' + repr_str += f'interpolation={self.interpolation})' + return repr_str diff --git a/mmagic/engine/hooks/visualization_hook.py b/mmagic/engine/hooks/visualization_hook.py index 48affeccf..b4ff8c2b2 100644 --- a/mmagic/engine/hooks/visualization_hook.py +++ b/mmagic/engine/hooks/visualization_hook.py @@ -149,6 +149,7 @@ class VisualizationHook(Hook): If None is passed, all samples will be saved. Defaults to 100. show (bool): Whether to display the drawn image. Default to False. wait_time (float): The interval of show (s). Defaults to 0. + by_epoch (bool): Whether to visualize by epoch. Defaults to False. """ priority = 'NORMAL' @@ -182,10 +183,12 @@ def __init__(self, max_save_at_test: int = 100, test_vis_keys: Optional[Union[str, List[str]]] = None, show: bool = False, - wait_time: float = 0): + wait_time: float = 0, + by_epoch: bool = False): self._visualizer: Visualizer = Visualizer.get_current_instance() self.interval = interval + self.by_epoch = by_epoch self.vis_kwargs_list = deepcopy(vis_kwargs_list) if isinstance(self.vis_kwargs_list, dict): @@ -287,9 +290,22 @@ def after_train_iter(self, Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ - if self.every_n_inner_iters(batch_idx, self.interval): + if not self.by_epoch and self.every_n_inner_iters( + batch_idx, self.interval): self.vis_sample(runner, batch_idx, data_batch, outputs) + @master_only + def after_train_epoch(self, runner) -> None: + """Visualize samples after train iteration. + + Args: + runner (Runner): The runner of the training process. + """ + batch_idx = runner.epoch + if self.by_epoch and self.every_n_inner_iters(batch_idx, + self.interval): + self.vis_sample(runner, batch_idx, {}, None) + @torch.no_grad() def vis_sample(self, runner: Runner, diff --git a/mmagic/models/data_preprocessors/data_preprocessor.py b/mmagic/models/data_preprocessors/data_preprocessor.py index bdea5f044..a1113f308 100644 --- a/mmagic/models/data_preprocessors/data_preprocessor.py +++ b/mmagic/models/data_preprocessors/data_preprocessor.py @@ -68,7 +68,7 @@ class DataPreprocessor(ImgDataPreprocessor): data sample. Only support with input data samples are `DataSamples`. Defaults to True. """ - _NON_IMAGE_KEYS = ['noise'] + _NON_IMAGE_KEYS = ['noise', 'time_ids'] _NON_CONCATENATE_KEYS = ['num_batches', 'mode', 'sample_kwargs', 'eq_cfg'] def __init__(self, diff --git a/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py b/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py index fa46f433e..31bd03f39 100644 --- a/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py +++ b/mmagic/models/editors/stable_diffusion_xl/stable_diffusion_xl.py @@ -72,6 +72,9 @@ class StableDiffusionXL(BaseModel): force_zeros_for_empty_prompt (bool): Whether the negative prompt embeddings shall be forced to always be set to 0. Defaults to True. + with_cp (bool): Whether or not to use gradient + checkpointing to save memory at the expense of slower backward + pass. Defaults to False. init_cfg (dict, optional): The weight initialized config for :class:`BaseModule`. """ @@ -95,6 +98,7 @@ def __init__(self, val_prompts: Union[str, List[str]] = None, finetune_text_encoder: bool = False, force_zeros_for_empty_prompt: bool = True, + with_cp: bool = False, init_cfg: Optional[dict] = None): # TODO: support `from_pretrained` for this class @@ -150,6 +154,7 @@ def __init__(self, self.val_prompts = val_prompts self.lora_config = deepcopy(lora_config) self.force_zeros_for_empty_prompt = force_zeros_for_empty_prompt + self.with_cp = with_cp self.prepare_model() self.set_lora() @@ -165,6 +170,12 @@ def prepare_model(self): Move model to target dtype and disable gradient for some models. """ + if self.with_cp: + self.unet.enable_gradient_checkpointing() + if self.finetune_text_encoder: + self.text_encoder_one.gradient_checkpointing_enable() + self.text_encoder_two.gradient_checkpointing_enable() + self.vae.requires_grad_(False) print_log('Set VAE untrainable.', 'current') self.vae.to(self.dtype) @@ -989,7 +1000,8 @@ def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict): vae = self.vae.module if hasattr(self.vae, 'module') else self.vae with optim_wrapper.optim_context(self.unet): - image = inputs + image = inputs['img'] + time_ids = inputs['time_ids'] prompt = data_samples.prompt num_batches = image.shape[0] @@ -1032,7 +1044,7 @@ def train_step(self, data: List[dict], optim_wrapper: OptimWrapperDict): pooled_prompt_embeds) = self.encode_prompt_train( input_ids_one, input_ids_two) unet_added_conditions = { - 'time_ids': data['time_ids'], + 'time_ids': time_ids, 'text_embeds': pooled_prompt_embeds } diff --git a/requirements/optional.txt b/requirements/optional.txt index e44ef1c30..03fb83219 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,5 +1,6 @@ albumentations -e git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1#egg=clip +datasets imageio-ffmpeg==0.4.4 mmdet >= 3.0.0 open_clip_torch diff --git a/requirements/tests.txt b/requirements/tests.txt index 51e76ecd5..bc0deeda3 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -9,6 +9,7 @@ controlnet_aux # pytest-runner # yapf coverage < 7.0.0 +datasets imageio-ffmpeg==0.4.4 interrogate mmdet >= 3.0.0 diff --git a/tests/data/sd/color.jpg b/tests/data/sd/color.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2f19ebc6c6e867372f61dceadba4d66de46e31ab GIT binary patch literal 39779 zcmbTcWl$YY&^~%_2|2m}l68rOc8UL;okRF#l;d zSXdZXcsO|Y|2Ywm5aAJ!5a8hvKOrI^|0f?;D4&o~{=4~akpJxp4F>}Qhl~J^@IOud zKdJX#00tt|H>gAyC~^Qa1{4ei)cYWS^utcL|1s~w*#8wMXc$;Hc!Uo#K7BN3LjN!y z2Ij+PxDSgzTKj)I2f$*$VSZ&5fyYudMj&_o!WNvAk4Pa}*N3e#3!-ErLALXW^Q3=Wo_f?=I-I?Wol@%US zL>a-@8H=1P81ai}Qhr?@5(T>o2;0PE78!@~`!3b>f6)F1+5bCWA^*RS{a?WTZ(M5t z6d0(F$%DZF2mxM^e!m%)K*dABr}?i(Qk?yO$$kguz;BMs6PAC$p*+~XDWuABHJYJl z0ke|D=<0889p7T#Uxj3{=jVSOLlBDda|dN7Rfi+s@h1#wF6BVCQHcM@VI5iOw&j&? z^gl*kP%apI)i}Lo`F!cZFzV;EVxsmS5ayE`&EyhMJngf}^OBSe!~U9Lu6O_}9U1(7t( zfrpnLtAhfrhg_3T#^$;2YE(B$u3++Jy+m^Qi@K)o9Uo2wD1P=4mp`6Mvc@-|xzct_ zWkIl)an{sP2SSH!JLmc3>}F9$;l0ib*p^+^(64z&S}W58nLC=v*t22I^f*A53jgl^ zwvI^hHwz4Z_E`yzf_kh}KKt8B@9~1N94&DTZN{gwJCX-qqw6Q>vZ9;D|Hv^3niIg= zKjEkph`N_1_J{yv8h6DK)xg_MRdw2a14Fwg)>VJGR-2=9<+CSdvx+D1NAK?9bl&sxtSIMZ9wpnUl=wGAVC7YfjQqj`o2CThkoxFov5FLwR-l`Ww$< z@|yKC&a#fj5J}$U&#smFld`Vu2PA^xRiH7W668*jm94;4Eo9H_r zk)&T-tcyy$9a*kcQ4!xjeY35|8DSJ_0+A=50)QNwGG^5_Fr}256rinzG1>zoQc$f( zX9Jxb9;)zu^wYV&xgv(?df&W3@4h=DJvB;)i*qjp2kNU1daHwPr(^tYAO42F zE$1`v0&zvpSx33j+=YHlzJmzvtt=|)({{0?o@g|S6iqin`C{Wa>*$%*vIi^?Geuji zL7Z;kWVymr0~_8$_YN_G#6RB%QzHpNL?*RJP zaET4XN=@`+t{pT58zl6LVF_z|?PQh>?*1}h3Ej2M=2rLo2$9Nq|0iqC{V>svuk8r3jig!J`WH@~XMEo$6R0kS4=R{!N zdQZo`1mN~C#t2d+urx5EKplfVd`9nzg7?{!4=}Qqv-FS$pD6ybH&;S^)u8Dq!)U+I zmX)t|+OG_u0ZZktKvCO;-L|*>pn+H0F+o!Hq8$=IAIC{^ScVh4S53BljcFR^Nx{Y= zV@^c;YPBNjmQXg25SQ4=T-_R{#Y8|;d4Vn|hx88I$(1D1S_K2U-iq}zacb%eExdCg zyC%giGY7l9dPW!{Zb~!fDZ`)1Qd7zD0U6;bZsdWS`)X*vm)Vkbb2sq$j=U=e;V!dhp0HKo5K0b9M*KEEpz?0DDhFhfGqkEsyvDNjLqD~J6_ z7x!tjYG)Q5MPjggAwDafQ)^*b4i=LU3j$EYe^IN_Ux5^K|6CW7(HvY zeJ9pTwS z`zYX#=#k}aA_;pWO9}jw-n&}slwHfAjITEr{0M%NQex<&lS|ZfFMI*Byf?p^!2Zq& z4+qL`){;$l{|>Bn8LUZrpDKqkjp|Y+)M2cTv8Mt3_N`Zj5~qQGffrw^8tOn^3E7v{ z!$Lr?8jc}Ifpm@fBD6Winnag><5owCFg*N55^F#+t%3#%l)Rd$z9AH8q*=Osa`Aw2 zF)VV*tEmmTdO!j%=1C{pa5m>LpoAqh`LIJqdN7%*raB!Rs^D3ApM8;^^u`-1qr3RN z`=frA{PRiC^>9jbGA5i#%79gg<1czD!=f+pv)VhLPY6P7Ab4G-!=(6hrKkv|AssyX zO=gppM+@0~2M}zfQ6tkyM=>-UOsOiMA|c&9)Oe~kYMFn-wYti=fAGr%{kV#3;n}Vg zxw%IU&0Q+r`WTY7)4$A>D;h&aARl{MicpmFOT5j3L!;C%#yVMw3Ce+sv&zQWr7&Ua zZS3-D(nkk!ezzcLeoHLNJoJPU)m$(o+Mu-v2Dlx?v@C&RrU8Y|W?qVYU_qB+cOu3Q z1@7!Zsj1cY@w0!E=Q{w_ADn}VJ2B88Tx#v)^*r^L#&72ev)? z?kb0*b2*laJpv=%*AUMQ>sC$30AvPSM?X)WK@eZAm&pC9WjBE?rtg5WMI@n}WYR|7 z5c1?zP=H}-eV|THjq&lr?5EF>io2G>kM!81K%dy88IN)Y)FT04wfMN&a*+#?Y)o7G zoN6Jr-3)>Z=m70vfg8en`k&(rk4B9yYG#8vxiF-C<#`7ro+hslw!IGd8-lhW!dC1f zq;uV_{PNXkt-HFk%?8W258hD3@o2jSa{DfEbUdTO3`?D&u^?Vkd?JFITs-S5%2&ER zho*g_WYvX7cDd=)WEyawJ-!p+OBbs_nAIq{&yGo}BaGydVp^DPJXffG@t=!)rCOVT z9DJ%_A(+gQ!5V8T!z;gVlq5q>HcB`2o9)(x8cHG?FYUdUGJGb}p_ zHnLn!!3(gr@(@3s2A4}V%jbdD?8iW?DFuEF`ySr9-L!Y_#}XN%|5hT11WT91Ga*qGlrr<>V&R9pM1dO&7G76 zR#=0%piPppf!*zP6z{(F=nhtJcB1VhFtSNN2i#XO#JHaJWE#SD009uZVC_7cJ2)+WtI&o?UtBL-^D%oZe^uyW8S{upZzxMv?gvxTE;`WgmhhUjei(wBF)e z;K%Y7)IBE{j-X2IB*PSD)6F0kRdBvF+OJ_XV6>ZScaV+E-S7`5q%T&WXU`NO_oD)_ zKZIW|jdpT?q=udoG7?Lc2Yx#c#%5!XYy~~m6dXUGzWs&qu7gZ>(?SUgnw|_HQ1IhW z+}5t3#cYFA(N-qzxWG^cTUKvq1xZ1V=MpbA zBjiVWJ0L9{TSb`nrsDWXr80U-&-*Q!DTfpW*U@U4Q1~r8&^k(=P@`pzF+9NO zIef3vhxYg#P^~YS`r#Aaz)jvw%$=UaIvt)jly-siFP8nh3E}*fyAwhLV5~t-oy#P@ z;`1hXdbLdac}g?vOGL8rr(fZ)C>m|EC=O}2TH%r^>lr@BseLf5N1d5%!%XZp$_;$^ zVWn_h9W}>8%|alB1oZRi>T6|b$XCar6kzi)B(<{m!Ynwoex_?yXzFY`I#1(Vi-)G~zp zvN+aYb!1a0V$XT0dT95BuMIn{EiBtPbRaM%*MOdc3!9-_CUs0f04=>L&@3UW-N_Ct zIJ?He^=*98`JSWB{)OlDdjpvc7-uvVrtPaS7u~`&5XZFc zFog+gF_5RdURGZi|H67EHsCE3sg&t=nB6-61ON4JW3uTRfMuGZS%e7Evt4neI z(aFx#cbUc&(JIZLS(ub3mD14f$2{??_@!qars`i&*n-CM#J4)b_QNOXV~yudxFop0Wld^|Xu0w9f=@D&nb8;r zCFEN*&RUrmniCY_-6Ve@jP{L<6+=dJEm%7lN1wD`qOtrhQ9Wn&DP{;@)iWgdh(J>L z`1b=43h=L`?e6mly{AKAW z?cmu{4el-0<&yMh1AMY$v2+>jW?5n}d3yJA-vR2n3G&Gv%4BHMxDM=U#b5C;tpC=h zsZlnwv1;x%vy+mS{w}!VmFih;6mncx;fRC+(4O2BrWAcu&3yC%gGc zk>RXE_Uk&IqX!+76y$YZC=o7Y@_K2%$KCXHefK ztJ%LN!LMOnJRSOqjk}f}vxT)Zmu4F*!h5ZZCX19Q%%s-@S50?RT9qUb(fiCHX-_Nu z=h>Rm^PZ^*71duBqo4H~ofFgdTd^x0VcJvi`k^OfV1n4ZIW&eTt+~>MXJjZ}H2L-% zF}Gw~Oe9y)Pe@zJA30_{sm$As7HUxh&(DA{cX9NLdpNCQ!`eY&VD}2K{o|NMOS+;> zO7z-K1T}yW(JDl9!Bdf1o<)A~t1#z8o?d0Dh`yD`|o*h|C6I#F9i!8$*<pJAZdmPCTM}Bfw zZP2e(V`@Z#zih4KP2_{Ioj(E_J=s3PasDZ)_+~F7DE@F-32Rm>1UYoxxwodWmxSax zpl7!u4tKa4rK4-k?NF7{;y^;+r#(5EAifqUqh~#;r?}@8t&_gLjurvVvg>eVLRD&E z-Em6*Wp~@3ykCjkie13^aXm{C7xEnB)AMOTG9mN)X-o z!d~yIEMQ;v{ir`G=3{AsP?y%+Y^r;bJmYPTep3Ma@y|91^BJ*pz~Db zFuuj24Km}{V*cwU_Qtf8h>=3GQy&^xH#OX%&rJ%LK#3)O2sv@8=nyx5DVu%UiXb|p zo)N^PV(rZYK6+#?j&{MMkBVGp0N6zw^t z!2K1xKS>dnYxw3d#3CqvyIF^YPqH;vgz8#iE<4dUu6nPceaLlBcsX)XtCJWz7zeK< zDo^(041ZNRu*&oeQz)J91bMt{UtoY}266=6iXlW`Qjr;ACU^Tlt!BB7E*|4X?emKq z1}j*~3FZp9EML9c8mhJ@(<(oqCAXGl(BzuHXj`WtTR#tliMHIY zZ#|EPDFAJ%eJIWqKwg--o4l&l@~at{-&n)raH9HZ&`X|HT59qY<-+lXP6a>})tC)Eeb zx5fZw$Y6j_j6O(U@6-Tq@j$E=>;h4|(=8ku1CUa?;ggHw_x>9Qe!+TS3z*qP6|!Lt zYtF?CAl}N?ro7?s$Jw2&{a#u4XGRvgZ5l~x*G#F3y@vrTvDbyA*V)|`qcl@f{Z$)9 zWyC3+?v^y$2YVx~Wd~E z_?AuF8uv>QOhrJx&qMDQJMgVECf~hib~}7ntyojuB;X_tM>1&2hTJOs@Ii2%w0w*R zW4;uDEVQpU1Yd1WLaAaBU(_toiw1#bM3iS(ijpW2^5s<&6fyKc+UVd8DFRN>oD0Dr zzzYl{gsZ%-8C1;~P<%Au;+Z^OFh7ZrP1wVpZi-=77FFU)styd@K0wd|9uz0L+sJal zyV=|}@o&0jFL~%{tj@)hc$0BANE571VcZC zuntQR`=Y#-QWz>WY#1vOS^cQQrwvfzOz9K29zPC&5{sv+|8NgsnbXGP?8dQX=R!;( zO5S7o>mp}y-GdG6Y^C*oy#Sv6`&S6ZDUBrH5dmP`K8i~dF{X4>yPWGyb~clgVL;(_ z=qWG33G*HJb5(Mroa5T*bwbw&Eb31mpV<5h`IZ}+Oz9d6jDmtf6V?#tB%sFV{V98n zaDh^Xn8H~0-~%PBrEmnkhVl&yB;xfRk7Fuzs484W$Q!>GMG=Ok_UU22jL=g94L75h zL5lqFXCU1OKggdmhcC5{KUA&y2hh#f{VD%tJvVqh_W3yb?KAG-Twf)fqM00_DY-0) zZy})Q7Hrn}#rd)lN2vBHl9z>shYef{+Op zMwae+$hJh^eyQP&^M&kGWf8|@JjS_wUc0t^zMUna^sAGEe)4RZh%=D{l>H>aUx<2s zjn(zq2Hr**Tmwplppx(L!aQT+EqwxyCrVPvKfpQPl7_TPQ;kw%i(F}<2#Y3yv%zqw z@CiIqQjwC=Y>w+C+ES{0T+>bDMQX4SIw}{lA*AO}C(aRjv02uX)XOsIzIwk6-b5$H zN|c`As{sp--b2Mz(bzHLBrY8492d;?T%#$Hky&s?I*B(w8l*{fDTV-P&TfD_7Uv|> zhlPD>(LoKzgz8*+*jc@LKc`>*hI!mu3ZbXi3dc)3*{a3p(0S%%VUPLZli|053&$*m zv6CU%$~?UI-%3bK@Fxh|0%3~CRa7KoKI$}54R0Gk!NYjZ8cm@i&JLA_040~rK-mf( zt4FB1whURRQG>QhCBN-J+uJ|HuTVN~day6Daw$j*7JUckf!uP)T9wF?Rj{LY*0NLd za0jcg?~a*pVX~diM&B4td*J=}{LywT5()`joQns({y^ayyKQ@QHjub(>7C=~Ts|UG z|5aZ`Lnn)Z5fz;%*G#dp8aPSonQ?R$|vO1GdP!>UhmE@PqVH#xF<=pQ8L_r*{tdVyB6i2lQP%?gAb)F+jWA@S* zdXwt6k!QFXW4K4cQQWGcfS_K|1u6V1qMu=jm)JgTqzX6rd=o?ESr+!s3C^1a@oqgY zLb9hl^@KyOS@ruC0-*$clXiieF%2P^bbpF>MVkCS-upoUzmDvF-vF`DS|r zvjqX`!ka2wskq>Ua|7EWuDcY3a6XZ*4KD37KmjvxXnkwqJJU_)|FN`wkhZinL-BP( z_~uc}DTHPF8Z^^hC+s=^h8YPUU4-z^3}|`i_IYBoHrg`{PVsF>ubYV>t)+}dQT6pg z_IDCkXS?&kxmDnpenTXV+6&!KQ{R%6xB4|=RO~cb2iUZDvj#S3m6Z)$sjtlz<5`Pg z8OeMsTipb#seJ493G`i(B?*Zmp$YxFd+S1x%WGk4rzAM>d$aC5SQofknyWPg#Di#z z1tG+R zV~YXup0fQ!^1&sF@qze#C>y2V=~43@$+|UH0O5HUygn%gNEo;_HRGBQsHPi zcuCjgc0D7R0|dYr7xRiS3e8Q?r%#urD9qF^W`I@R?9ijN30{ZM^_ThV7NPW>mX@F0 zl#uz{CPc7rp=|@lR%w+rfkD~B_!)kLl1bubAy&Ql&D{&OpFQfVxTpiF;yLgFX}8M* z;#+NRKQ8h)ba?HT8CsogdaZ?ADjwMc$^#B13taQ$>L_(L?ptgC8a+5-NRaGWPM)Oj zt-%}aV(1b_)rx7j5cZ!(-Dk`1&F6KY;q041&?D`Q!*dEKLZC;LKak@M^HrvP=Kh0XT|V>7(^5p7{j?!!sxH~;hH4(B z?$JTYw1LjXzd}yrFcd%#u*vKG7|yT1S1n{Pdr__L;P9<+eZoZ?H-4>U&75|G8*y@S z_2ek}`gA^`snJ7#U%)0BEvxOS2PBcZ7Vn7SC%+_Q9!LNVQ<)eJVRRfwnfjB~KA0jobbCMRVsW^m{Irm*-Hs zR4jzp5NKTU&&B#t$Dt63yE$ix7I&KL`UYD$w9njJ$fw6e5c5?DbUdQ?(|`cHxLe5` zFtle+=djl(BS(pgj6CIl#6!j=hfY(j_fQ)%T{g_IDeb=TtcQ|SIBwS!D)K$4(}=V8 z<|mJm2;4)`jOAm6Kl{}WX^PfHS4z1~+;0^4n+aJA=fWGMb?~+NqRml`j-?7P^}}D2 zro|*(OkLjr&`)S}MJB$nNH8hnMkdf-Pg{PxJ;JhTiR$ z)ueeHVU8jOW*w>QCwgK5^T;4ZYP|6gN4iLc#{M0l&Ix7rXNdwX^ptJ;#;X)xA}0%N zp}+Va<585J0gwDYlgqWAr{}^+&Q;FbLa;v{a47>15re6oG`aE51=EE_=l3jABb`Hn zj){RS<#&%$3>yS>)Wxt|oN>O$HKYg5@k-Ar3iIhrZSypIEli6vOC@p1Up3(pt7FZJ%M*^sTfos8b_W z6Bl=VdsO~kp^$-1;hK!y-Vs%Qvs^y^@miw;(-O43qJp-u#s;xA>s59@s>T>~lkgnp z7w7q#%2`Dc&fxe_17i+fEAMpK4=#h% z>UxRh%AO(xfBpM)7)4xO)j=5gG}mC<&ectV1KExdU}Jq+Ug?3KU6<6OdkH0SuZ`~} z+3^ZtE`zu~2=%#C7w6sF_|}&ICEUjap1bVl`vDPVbK_iN#pg*cTpNpEZ+ymK5{lK+Ls8`AfHW%{wu`K1-Vv zni#7C1U#xK&D#_R^SN`K%X>W`rk!&|H(6BhE8t!-vJujRGlRk%oKPJQNkl;X_mZ;# zj|PgcCzqwD*8OV^u|0$cO5PvCUNec8*v@19d9c3swSuy0r!7Z*FAY+?q_}=GZ-4w5 z@cw2ch{Q|qJ=lQRt6!zS5yf5(XegTo6XhEPC)YzzX7M$_)4dW(F4KJtl82WTzvPXod^wY?L&07hC z7(B2?ZtIf}W(O-uw!E|)44W)nA|FP2fzcUTM4yjDhU_FV%5g5#D0ET1(-Y3?B3I_+ zv3GM%;mx>2sOx1w?ZhIoHcs?;z*!kcL%Hn9Pj;p1Fp{>bQrg${?;U_3WD?2PX^s&4 zT>!~_r#be~sSo>rO>~T53}0MSXYB{#_4iMQ0H`&XX5WwPjbpsT$oV02G|AH8)z`SZ z96AciKLvzx-{yy5f7F>~Krv1)tJcfsW|TsZjbB7;+9x#yNkDH^+eVsl>>fA#iRH|F z$(%~1$L9zNtaDv7AJIWb{pmKk8$)hbZlhS}?p|gEgshS?0ty3{IQ;*VY z)k5mq>y~;83Ba`>MRq(&y&rHk9*PSaWwLCF*G$$`kCuK1c)G%dSgj%%CKWj0MX4>% zcEf@51ygF9BPy(v*R@WjJl0agn8WB(FWZij?~PG*W_N4!NCepaSe-Z+gjIWua8lRGC9loq5PjY9y zX(B7k)b&H*Lnwfg68?Wa(q5IV5nnI<^Ud)KR`S&Xqgyj7`qUs5Grnuz!C_u)Q$>lJXaJ>n-5zd#A% z{b%|`rX(?N%mpT`c)MaI5B%7;(nMhRg-6}Sdl zblblbcH);>=X?!;jM=L^Uz?}9xQG#5;gB5ZVgHl2Pij)A(wjeIJ%A0op1bKaj8Ev3)?Tr0o z%jX}uwyrRa`3*P-KU0G2pmC7aK2!f5f~4F_2TeT1kD}!cA-8zKOCKnF0i5!2n!QpBv8fo%TUr6 zG}IZCUJ8LJYKIIlJ`(nb{N^C3!QUYHGO+L9ry8p`6Qp3HtZzm!+-L+-C^0#EI~W5f z;M_PS2=;#!6d%cOupixi%sOE$x-H`MRB`%1V~r>~_0Y&=j8XeRx$@t-iRkMqFg4mm zZQ1Ha53yH5;)6LI>n+ypT^TYDC`#X&Z!K^Luv;-A9bowVTa@T)=Q|R5Otqzv>hRig z&CY-c8;`okk|25+8xJXz{5CDpq&xaEutlYC%8x6Og5osp6;6b1%D%fX)dat#Us2|i zlHR>hh=ZX3`T>^R_IJQ1%N9xOSmxA1w=NZYu{1raO##E_*g$>1E# zd{8*JrkDa(X@dLn{HLbAM>6sZD)>x1<0ly6joZ02UutdN^ak^?o?EomaX0>AU< z_;m#maby;MxvzN7>LRou5?1lzCM;|RTx6qhcl3h+@L5fFY_jJ`d%oo3q7!va$ryt% zjdnC|2mNmv4VP3h6c)mw6}gOtvE{HRxhh4qi>{hin>y=L;k&{~xDH6W_&=jYUKoeW z6e)kF9j$F5iAJj;yH13#`R59SiDyZyl2drBZ!-0F*!O+|Z?GmMc)-^9JIgn5t|zO0 zrJv+7a}83nJXOsO^tSogvdqXP=QDxD7PjJ3PWhK6UbK~ByA+$~_!#k)^ zEOwzm(+dR+jwC#3>@oh-KD^iZwr6#zOgUT9M$z3MvTz4o01D~ZJ$1ELyl$b(a^EbH z>JhYaN;cNj0@pl;)CKzp>PP;^=hvW_YfFgmJD|H>2tyCJH`7#ZE7QU5VExG<1w|l= zh)LkHPS*<`xjYO}#o_PP4=#j-h`z*3d#=bn|< zH}Jj-(N1AzE1!C@&OdB@iG50%$+?Z4s%<65V^uy2;nTVAXO_LD4k5@NWV z2P%GP?Jnw7jQI2hVfXe*3AO@q1A64FbrRBW$l;d`wNT7R170=mP9>e&`PUZcm9k=| zf@(Q$_LLRY6Axnv6SlK3?eh=-+$~5U&_CWemR1=X5m1 zWItY;#Gyig@hiH+=cq*)`^E}H9O1T00WXQ)(EJH{2YBrU3erQe_#God0j8{s3zB0rbVPE##J7FOhag%H5@-4aJ`M{N zOK!t7LeL&uh#nAy`WwBMSRLfXL4G6qToHjg>1#_VE3)#NSajd`X_^Z?lnZed1kZ2B?2pu}d$<;&(9 z|BnG^36jjc4jsPsIA7g$3K&cwKOnA;-Pj8FyLR6qNim)!bB{E5`TLzs5%OpH3 zZGyBz@QhN=*6B+k$optHvsU_w$)^Cm-#NQ5?bB$eo#9a zxoQ`Tx7Ma#d0Ts-{DSyu0M@=H9LR+~_j7)$^exXvvZl;8_#rq<`!jhxz=A9gxxqC= z>TLNOX75Zfk`(QVU%q%1gK_hI@%5UL5TiI*k@WzX^Z*su6C&bnz@qCp9@yl5SdcSWGYj8!P3bf4H03!!zM4Av%T3VwXD!C|z}cvjH#L=`pey6L zWyo<=7xJYg1e@yTdY`44d3!cgYke(cqBmuzA^$52m||aOjc`P>8y6%>u79$IQ(o0I zwJ=O;+Vr(E!V}~4rsuR6*COS^=p*NSD6{J!iCK+)q#Y2J=x>Lc}p+YT4Yqo#i0N>NZ;7uHYQ`yxV$FW6mZH zgojetvr|)}21`K1J>Cj@fDhvuSGu|P=aND*PQ-NDC-c@Y4#(Bjr>PMdIx-1mG?Wkz zh-5wnR9BG}5VkqWOK~{XpeW=|5C(gHevJF2XU!sIZ*2__j}vnL@-Q!3HOl6dq<}yJ z4UdpEdW#L<{{FOVSc!dq>qu`YVsK7&!N13bh8va(eH2*OE5l$6rdui%)w6xE4;uh-x z?TxSy8;~?OJODedTNWF)O}Lczf@tXbC+2PTdUBddngQWpWR*&7hTL>Rf@(o_Ke@Ix zFz~wj^mOmW{nPOloz{+uOOj;z4hl%HC_D2O+~^od`ssuBM0*t>P^L1rVGdQ|7l0KPiIY}y=LVpyh zCkXp9+y;(plR#)*f6RZcMe^!VZfTu&BW4m{9%!yR`78d7QA0OMg?1BNaGUee8}*1U zQ+_JBG66TNf!k7BaVaORtL5c|P%}H#Q{FD&U%>uIrqp0h7(w8mAu_v^?Oq%v;frV? z!g!YXq|t5P&AI~`&smPVl)F!gZStxPd|(5RC^$VSEEQY&nPgz8^$JYm8PcTh2|-#1 zjSL%nyJk#TEe*Lak!qP}?ACe9pjcnjmev~M?h{KItvRH?8sV(|Ld}__s7Z-hO`;&g z33vy{KDFQdp}$GUxC7-1P!2XXch4tXG6xe9(CJ~M+^{p;O$+KbSwJ}APjwL;#nKJO z+1OIQ{(=`vd8}_0(h$7Qb#|xg-l7r?QQy)9o2CNr?91m3_u3xF|c> z)i$=DXC-GtOuT_Pz+|7tgn3%DSU@eimZLO#`WXSln=sm+M{RqvX>3WGp4aYlflozo zzPw8kE8WA_y*5=dSg>b>*=wKp|?qsrox$|bez~L z-ID|xz`?licTN6b!nKPlCFyO(vy3#9rP9S-r*~fWH%EOK+q7uOd0Z<>p8cCrX!LN` zB8{$r0ugaN8;WZy%<~1$<3r`L@xFC6a|@eT>-W-pM0ICyA=-Fb%uC`+wFB3_m#89P zT^rQv1J-?6C3W*DuXAUW;e3sTxBV64*jIBEPuIq5bRns;GaU_2oCYLj9S$?I;kiu& zJcRAbT<~YtrM1-NEKf>SF-Vgf?PqYpK&Yw+kEJnDbfg|S`f&&X_k(76$^DFbTFFu& z9}rVYaI~4i6z!%AE`esA@EwqA^Cb3A@>XbeeeU9gFn}AK3HwcO{X1j$!hwP71tNd& ztNpE8Z*|XF!$WjiPm5!i0U+@lU*1(oEQt<(5?L}p*o8P^i^k|uF=7>L@TdyCcdxKa zaX5@dW*ks@N2K9i(Z;0Jg{I5gCyAt9*&`O+Iftvf__;4Ac$tU04^qg3zKA()>SZjP&lcG4g+oF*?x5VNE)9|&o_X;R7vS*v~ z@4i%9f~XzrVygw21-O3=h1UI?=uJi<)l%$F{PO+PYvH>;rVzrb)Pnv6Bl;&|$wOcU zsd-%b+adl=j_|1h_l=cZA#L~t7p9jx%vS~JSDCSAjq0Ks^2jWrZWH-JvPZ`eZqh4) zIcy$EClfX2FHhPJ4+(*mfo8vxQZwn$*~a4%Met!s#m+R7-CvaA$wlj3M~}K~!9v#s z?||LCl7u}Tonv`VxH;g*fXn3`F$boPgr2zqmrAg4^KRq z>ngtS^RBwKlr`#(2lR`H%~g7zXyGz-;R3&|`4_tj-r)Oc zf`SKLtg^_GnQx1qI;~je532$(HcklOGau)rgMx$+f*2@&Nro*9+UCpu|w&{8r0sHGRSc^!3Is z2eZ=z-zqJKxioz)E2)cpleJqnT;rR5uDM)tL)`l=O#aNylyQ`EG~L2Y1i16s{1_y& zcviQn##e=^yDv9wCJS8%JyTp8>S(q;KQuW?rR)KvS$j_KRftg*s42u(e~PO<<$~qE zW6(#4t68*TqP`K)3>S?BPB4(~knoyWUULK&|7}_@$u#1GCyaP{Uiv+VWZVf_b^iso zrX0fZ6pjcR;NHcc79K%O#jU93XSnwcAPwxtzZTM2$k{XdwI+W+sXVNlUWZKRM45<& zcVERErtXZkh2zNhUMw z2Gh~D9$t$^#{xd2R)Kx|?*Q8_(q0k)8Me=OuByRK45eRv=GBE?(eH#ybGD z$A+z&1#D_>>HLLBO}|K!^LAaYR?ktZw|9i0!9zr8J=`lYaOd|sps^&OSf*s+dV7y4 zzi?2N5@+jsm)4ipfiNBkhA=9%9fz+$*iWbw1K-q9%uWyS9wo81m>%TnZI<(#F6owF)>L#~xF zYz!4?O8mp?DNPA&*v}suIoK1BViZ5_3yl|aAkIp6u6vcWvJH)FyIm(=uiJL-Qvzjd zHg@fDD@8cyRp|HdPm5Eos-x$gSyHYXlTB@xT6cTA2f-?k!$*jW*c=9*6YW81lT>zW zMXY$NukqU%yT6Lvlz#_n=6SE}3aRnAQQ{!cTMM~}8R3>=*4oF)cBhwx>6kUaJ@sS>;YCwnm~e~>kaLN zuEvrD#a~tfZc5D07Pe_h8#-gQwEV{wTd0>#0^zpNA&W(hksWnAfz2($+JZ!{lnYmk zOs4T__15V5e8s{=oGe_J(0_r#(lpRLM1ebh-T^pvm)&&W;}H_N1&6)sj0Ka0#Eb{lBTS`|2mHV=~?XsQ*;U39OX^?&lR}9%|!BFhHdvz%VSYg#QD|KsCR^ zO$E)P#_cxz6A#Ol574RWxlc?C;8*46f`4Mqhu%B*z2Yy46XM^%o*R})FRvCw`z_VV zs}me-h}z3en_}KUZ*c1a0hS?=U&|}Ox&HE*M57l6gLW@kD^}Fph`69QT@&c_Lo5~ar8bY~^UoIFTkjyLc z-X7uQ!+462p^0@V&*4_Jm!FyKR#fWKRVNgs`ktZx00fc!qdY_LOW>!)9|K=rc^WR0 z;y7gR#geYnhB@H{!khv#EwGyhoV#{v`uxgl^t(9Xa5lV=kXr$Sk(SB-0D!9JJY?6v z-|$I43EX%yNcfXw5yE01B zCJ4l0H5?P}PhaU?Jf|d&ML{!q$5v4(8`$}iK+71$0+<*Z#vtp}Uc&a-KNTg*~ zmNtQrSpCwnZD2irgjYHJw=3UT{{WoK(I|Ogad;G3H1*)2sqjHG~$fFw*_h66+1(1S0 zaQXZNVOi_aOK8^@C6(GsiDL4OS98RNkOw@ETC=Wdi+c-dHc}MwUQn@C)+Dl z-Qr=&`OTEor=6vq?n5SG3nHrmIa3?)>y;msVqWSq-NtX`jTA}um4*Q&bK9pN)AZd< z+SXg8flw>FgdiAl3djNHzD-fpF52NQ76_3H5-})S7z-yKQHJYDwPRQ&rC~nRi5m83 z{^4V>Sa1PW2kC-qium90X}mk3f5JE6Y4NGu+l0EAPt9!*8#o{H(ZgX+U^jLZspB6T z>E00V^tWepvoiTMaRteh_PXQO0dPs|E5Ws0a{E%hpHtN@Ota53NYXwsyM{B5zyu%C zv6P%5&bl-eN%KpgpKYL<5R>0J6)Ch9V#KyV=lzw(;&WH$duwz^(CuczqAHF-0CVnf z&!9CO#luD%pa8ho%sb;J9>H^xeNA9$dd$&G(pznI+CU`q86=-#M?S)`O2?xV3S7@W$?nH$MxBo)Z&e^cseg(ynq{iN2LqmOEV)X4o@o;QEIN8~d~+JeD!*G$Wg zm6HR5(;tw)0bf{ttf9I;X--TWGkF_eYxJ{Vvs{;X3A2IpO)q zJr7FqnC>nmoz2w1MhM-5)2~l@r>p94+(1ge%NmXk13vYiwUsw?U)3X7a!XWxkbFD% zk>P(E_-@<9ng#TCc9W7-<3DtQ0aNHfBO|xcwr>1Zn_5WFJ7?w3VPBlT0RI4JeRJTi zirL~aTN1;GHEXC}Fy75s8&E!k540FnORUOu(+&b9HGKqYQh zr##mmul!s$ZpyL@7Cq}YRh9HMbaXwN#2*#zH96R=;l_EvJlB_eNAc{J8{?L8bH^=Q z)LuK?Y|PSi9Q@qZpZrAevE17xIXvffXs;<)x>UX6p}g?~7&&fFT%6{pPvRJXI3VEl zuQPuiND18hgVZ%hzA=tu-X!RHXRUHMw0cp)Ji5TH!%$Iw&Y@hlhsAIMjnf5KQ; zU`viT0OFDl7ZBe(-ZfNj9@*o}DYW@pCY-7cfegI0XLyy_M_f$o(tQ%W5~wj~_Or$;kbGL*!47 znrlnr2Z^TZv|zNiEL;T(OgrJ;UL?v!eLu#G0E)bEIlFjAnaK?)H}?#u=mu0gfd3O6@WEaw|w0 zNK3PyqI?PaJNU9)2G7Kr?~Y>fE#tHh_={DFM8%0AA>LtYn8_PSiCL!FyNsStAlJ=5 zvQLI?d^M`c<4b=BUd^mCtbS#XmIJf_iDZazG)Dt)QsG!+2hP_0D*ch}^#1@7JQ|jz zd!h$2o4K!iIpUPJbGej=o&|N3xXA^iMBmo~C2Q$0ST$Ls&jS}-)4Y+n@IOHDPl9|u zGxl8Kn$1 zvlPZp9ZK%t_T>6)?wO&(WY)jfuoT)uNhE5;lpN#K^MW`&we3{PaMHA9ozEhkWr&5` zEPif$1^)noUh6vV#j78R{{RzyD{65zm##^yjWvA9nrTu+5D?+Q1@h#ZdN6Ilf}Di~ zzO2{Sll$yh9Wpv>=?(e3-?BP4Cl3ak;BuYK|%`1#*F8Q#ktRO z*E|P*sm)@gq}t*oW3oaJpO-l}9mY=`>%H*sA~c=Rl!eA$g2Oq<$?IHKgyOccw1(Ki z;yJb#cX5@rBcy8xWgoK zYDtk>Rb%7N_$Y_$5PmsNCy%XkTU|*;gw|(IOMFTlCuxy`?#9BBOBoo*FS~#gFUNng z*X;T6C*aq|yKCKE;^DN5T{alvlKSc8M@DjD`9k4QLn5|FKPe*%Bw_ymV(w*mEf(51 znUsu>!;0f|&x9Tv@gMKsODfxe-k=f)?l&F}UY$2L_nSP}niMfE)ATNks?kf6Q{xYb zW8)XZzYa&E>RPP&UXgif%PoYqQpF)0sYRX{89{A@(IhHP4t9`MzD@YM`zY96c%16K z3)9*`BqtS#^wWh z7cnye8bsLQ74Ytpqw057z8SnXw(*9#y8t8+sS1rMAln-Xu1C4yQ4^E13`~LvolMH29&z4)C9sdBqJT!Z+416{5ev^8PFK+xpaPi%Faxk4% zI8ll;iln9$lzgBsmQ;M9N%n6Z%X{|KO7nc-n4IOsV2abK8c zuyD*IkG9I+w8F)ro}mTVoh{%CyxKxvbesp0LV5f^tUW?oyLqlb>@Iw^#(q+YymS5H z=kcwXFW%beASlr7lHDd&U<-WA-TDukAB9oX;J3Qd^J6=C#vonUK^SSAeuxL@UM-V% zJ!N&dk+1lPwAp8m%51By>~X(jBbd-Sa%WfEK*jk{e+DWZualaR;&Esi{n zarHmStLl2hp-~{1h}~0j7Qn}ne@p{h^!np5Mz9DbCS-XIf16*2(% zUk8>|w!=5;Nl7#%O3vo;>#}{AZ6#6U$Zgy(Fn=uIR2N<>w6uo$>i*+wTO{5ajtptW zeLekao`&Xq%Bd)LohU+(NM{&O?&ejZp35 zvHn;ek7~PXeQ6@P%NSLVV@zL9{!%cl_kct=UlT0mOOA9a8RSC;D3;(H_Egn zMaplnFqY!tZz*ELQI)~RTsBYFr{kK}gGRS_)HK0P2^$7LALCf|`mnQBO^RMY&Q3!A z0Ea>J&05p1C!Lw%YyvT`@sCsNYF?=qDE|N`?Mp+pLLh;}hn4|P0D7NV#?de1;zUdvd7F+$KjB33c{icePdHmNdYnEqMHp#E=T_kH zS}+ghkDx#Odh!;V#EJt$X)*RHO)BcXDus6aieLNS)^o{cB&=)wvrTHxXp6=Xk%YM0 z>Orr6)#po}3hDPy0MOo9TTX=cEQ;UzKj~i$!{Sd9206EOBdEoFq40i9L&krzH->yh zV^p3iKe612n{L*MM~q+`mia<-7~-X6^)Qj0pGwG0c^}*+Zjt@@o-j%9p@}eZZV&5yj-779fVT|$XUa|3ZZ5Kt<2B6+1<=1DC z4VZPEPdgbUbPS{r2*LZkGAo(WJT+^r;`hRT5By%aqS>aJ@f#?@g;kM= zUJnbw!!9_%$>Xhc&Eq{Y>}FYBClUgH22xLMGoQ=SM0w z;D?OmSwZXQeY;gT;@0fJAKFTYAuc0AyNTQ~GDiciQcif{zUonJ-J{|?9gRO3c!ti~ zQ$)!NAYkb6k3*73#sIDt;kStN-CM?*pNS^Cgv^o%Eke3*zc1bc(C!Dm6nYAe;}3^5 zpBVfi)Z0QgHaF2)%OH3nQ77Cv$0O!FaC-DP1cUKc!r%BPx5GaL>JsU{8~kgfcza)p z`7Jdh)EvZuOl=OKfWorjvN&9jaf;x@<6(xY7}R!5`mP+w<(f`@%T~8PPxQNpp|>7d zsE97y7V4)11oz;#=UonqtjQtsOsp7!rP=e_sp+0^&3GSyemZDcu7huUx{MM>JWC`} zMIy9rr=zLP(x-w5BoIO9L9Ltr0FV9#&@|hPOIYz%oo8=rZqr=Nd3zjD!x+hF9h`+l z7&&8rG1iII!%8Y{OH0DRVrkxN3R6v2-tdB;;zsCOC!BG|TFNyk%{NWm9(HAvSIgDwD>K-m zSQL`z81xxED)6*n0?z?Kh2(9?`c*fMXx}lj5;30OR;HJ0bt{O_g2ZG7{c(<9B;fP(;Qo28Q^Y!4Wf>xQ7!|>eV$N6S11SUBk?d>H zr5L(=uVc@tbyJ;gesK7|!CE)NJ!ixk4xiw=JsJZbdnJU;btjhOZZ9;ke)034l)L<( z5aXqJ&CTmv>QcCnJZkqBQs5Qn7BvIcJplEuvwv$X3Jc2zw6gH_v8Y~<3~}i?6Wp+M1H2Zw7JOY7im7TGf68G9n&*Ak;2BkKkuVix{kSVVPGRkq!K%#oHsy`GZXrO zn(FSn38luy)^Cb_8VBs8&bBLebs_T~92A3VbpH20t!7_%AL0CWQR(vduf$hoX_;ed zZwn*|`^QuyJhvah10TE5Ptv|k;Ooh4W8ID-lht4DD||81=F#*~tUM8Um)C!Ej&>w? zKTMKw^r-dAmVzkdp6liErzKBNz~>x(71U~%zYIJvrQ0RFp0_-3Z*cb-&AdP^4hGeQ zvw?s}IIc5P@h8KqG!Nk0spV@_i45A~oELCN`D0Pgj^>lY#yUAA_lLEUdPTDWdzo!6 z5X%}n2ZdP}WMvm|8T9E@wVTU$tm0|+7Hea)b_L{KPx>-Ir+K_bH`D|aXvBlo8e7fSgtJmA1%5-V@UkHMZbBpXPxRn`GC*Xp;Jj) zoqj*$Y0Uosm5;jA9^IvyV!O7=thE9-ta5d8Mw>Kz0JeV_?tEx9MEJh`dK(45A@@C!9i3o5_sc*f!Ll;t_5d7dvS5Pb(AV1 zl0j~ubOZ3K169+@3zl|v1cgNO?0@?I08`eM+9PgXBSJXQVQZOUULzX=oP6G&UYYvi z-k%<_BFL1kS8qX_@xc5CPsXxNxfGG@k`_`9a#tf5z{jsV{v_2|H7Mqk$!_dAsbkcS z{{USYEneeM<~Cx~o6V9*G6`-JV;HAeYEk{5tp5N$a!~i(#t+u84y(Adtinc7ftAQU z$JV8W?c|6_xO~mY=snF&vTUUUZ>{7oTy9817h~=@tt8Y!WFgz?i_jmfb2jtbNZ%kK zTZfPz<;`hWIwX6U^StK`#(z47oOWR1Z&^gzoNc&aw`87AA4=1T6^xC$M_>u${c8@w z{gOi-6*n&jJwN*OJeNVEGOfWO2*}-#O-C}mr1HC)3woIS^L^vBBOL*&D{T$Ldm#rs zSnmE+&sto`E*yZso)4vTz8saLgvR8AbReHkdS2sYILaupe`oz3V%u$F58m#{S z0Q6`^KFW=Q$P3T%uTp!^gL@#xdgrZHy1$J_B0tW)+1IM~K1-G;^CtKiqJ=x`=dVz{+fX+g?s2u8zeSmOppV}KD2Yb+XMQL1V;U>tnb*6Is?z?+~y zOyaetTM=vC!j(E*jygYuJ{^d&IgN4JI251H82szjJ{x!^S=0U(UfSupOWWJQcpe+g zQIguhEol5-;@0dt+qK6CX{oGDF9MtC(8s@@9d{gd$KhiT!j2&Rpv`2PURg$kQi zySPcDJd%144_t5uy;Ie#beT0NuOacm$9WV4vI%1zXB{z;e;_#TUoHF{_{pMp^G*KB zZ;Fu5WibZb?JZRSC@@gQcCIb+>NC|kaBjrCPBwscNNcGl^k_1XIFF4 z%W*iIvl6RN`JbBKvmcH;U+`<-Og{+xMW~r?ba8QQ0*(xEv$LdQp#gJ%I}QdZ-VXR@ z;y>FPQ5n>A}ZL zSI$c=ty2+BGk)xU73O~HLy2d>0H2rs+C}^C11y z#D&OV!43!0AoRhm_QUOZp#sM6Bw?~lQsZ|#g1O_^iuy%|g*sfYT&Vm%GtDtsjxp7* zHz($F$3wZ6LoJKM!#};;4+GoYoHV^&+TrDRl1ND94tV^l8Vw&`-@QCZENX*@k0A0m zIQFWxYvYYFSDN=;R+2>_WRX#_a&Slt4l$1X>y{Lx<+^_(k<^Ol&Wt{^-!obP5EuXu z>OP>1{uRx5t4G%_Co`WD*h3KG%Qo%l+^YgV4xW`$`&Y#ZB$Haz-~bY1duC#D{_A8n z&=K{oJNTjE3uvsR)Nj5o>Dpuk9a!IkXtFkNHnVZR+cnb)a*EXEoZ(VV$nf9WcR|*4 zUb6NEZ-7uJT+}?q-mZ3@PgbJeA$ou zBXL|Hc^E3N7dy)L>5eh)Ud{0r_J`BFKYo^a=Zt(eaT2M9%HvkF)*UW)k(M83jiHnQ z{86#!LFrs}tMS+2MzDt0#5Ugz?$qV*-X7HvxWcIn_D;a(xUYlG)f`4K@>b+)zTcVk zxlCtE45KPs$l$fVj~*zE_6-li4;Yck+ZX|UQ64e4o-yyiU<2 zd?XW3iIh8dQq5e+vw6b^!5Q911biSQIEjbDHiw%2O8+rW{a6$C;d#S2Yv-`9ya*1@TX6T zNp+tBn|5|*Ev9QSBVhpx#KKJT(0bQr<8OeM_xIyQ_(}28z+NJ~m*%#N_>$HZl&IQ) zQ6!j_8R&Lrjw`^vFaFAZwNHe@-gw{O{yUSHYu4LC>`Zsl4?SRa|V{{UCg zso7{=B@HH_EzP#O6lItl7VJ&51}x0c5VIgGah5t1tjX@C8wG}x?I%%G1ghnCBa z;~}xM4%w>>t;MDd@w$tJaplUAACqrQx#XPZn&GKdlohJMuT4J&dT_zisiiq5_?-^5 z;(=h&KvCA>MJmchJiM@GC;hB}*SB9o$=^=bF3&8??9G;Pcr~Q{Gr89gCyMRX{c(eK zBr)1^Vs_avqk9F$6f0vjnKrKF8R9@of;Se(K9x#qx!FSAIRhgKI#fl? zyWHuOD79l+JNv8HOgq6++1axPk5f!Z#iX8cADj#U{(Dw?>d9>o@)IoF2G2EK^8Wf# ze5MS{0RI3!)uNP})Q3%)*H^|!9#eCjqibMrDy_7#PQn;uV1Rtjw18@qE}?ea%5vP5 z$9k4cTHZ8UXuwqq^8wuYeQSSpM^U#Yudu>MHn|YsV zl0xKjTzK&mmn1o3@|^kuS`q6)OJ&0bW65H1SxV@&Nh79hM0~2uPd{~AlUJbCB91RD zUki>w>s)odrEm`MFDug(ZrbMJ;$JkrSPn}bD>=SZc2}~E&d%Rje3GOp1Nv3V9~RqN zN)6S>*xxA|xW#ey^Td&6=4Nq&lr7)Qv z2RIy7ZQKDS++`b2=}Rt|uSXNSRwZAOpsznnX0fLK0EC*_J@+wU&NP7BISuvKZhTsS-aG)C6FuKh}k$~VSs*_tc_~w;y?n8 z)08DB=ubLY7aY3NV$0O57}JGJb%1)vKqm(wtmsck$d#yo)$w!l@md*P-u(IrkN! zjHMoFvqvQrqBisd)pa?T>2`&LKQyz+fRTH#I8_}zPfFj@^3XYe7$!SgXoPUWJrxHXd-W%hNhYmnR$A7l zYMO+y#b;*DRc~&jDH9MrV4QB1heEN2E1p26t3~TgU(Jx|b6v^X$n;-{J~O`5^cTGF zq!w1z_tBSVV}JlKFhc>?2cCNq$9nnh;``$dkF|MxH*v3clTm`+a6kyXzy;WSwDsnL94p!YC5Nf<*?FT z*Ta4taACYBa{Z-Ol4Oxr%)5QqCM6zPk;E{AUQ`c*d@-rZs@nKiYRh2}n&(p2V@Vc4 zxCM#}Z<;K~#7DYcv|$OAQZ`%qVd3jf66l^G z)2FwRBeL$&buq{{6d=2^8Au@aDh_xZ2jY$PqvFqu9yvO%jdi<;$OdU&cG&G88&pRM z1GW_eq-d%Fg6h5UewF>3zi7Q5;T@%wuZgYWmK*mU*t%+h$H>o=lI%PrLt{r092p2{ z80?8wey2ZX+@g$VzI3Bev-JnUzAMs>qS|bJBt*28fQnBju^BvVZb-+lHR=8t@rH$E z;q&43)I|2)c`qxHWOf*4B$I^z4a8yC_dNjkXTyIVbW8g)EvBH1Bb^MwGnZy;l>nTB z_f9d?0^PG(Q-0Puji~bMY+qz-A%&zFTyET4laj}9Qh4mW`K@VYFr)EkcRa7_5z*ZG z=i;Y|blWWw=fql?!wgVmGx$|-x`+0h^2HNPEL)V4#K6Ds^-I9MF15YW+g6WINgq@fNBu?XGD z91a+?Y4deoGaQ0z+4OJQFTjSzEhEGUHLFOY$&>8t!AZdw$Eb*N!vYGCk-!Hd1dXMOp&00R&jDI9vqzwK)AoeZ zJUu*~8@SeQC5%P7D`wd&(gFM`#fv%g!xDWf=kJQY@K7&?9vIdyJUMf1tIQ;cqttEg z)?ipNjDZ}oEP#+m1!jwB!NVvu;=dYyYM%rASJkh)6XScUn@AB1c3QQv{`XBg)e6c8 z-z0e_A%nD7By8HvUM1s=Pse`^jG9l3Z|$|65UV^IZmQz$)Tkj$usC*-JRBpmE**F6 zS0T~Q$7Y;CSw*Zn(LGvQdY@I5WHe;cteyHR>Hh!@eFUGikA`(Q@1%#p?XCEBM`@-_ z7UNDbTaDRgp4U**?p%Z!R+ctc@>nbPsIC4be#^c*o>}xcZ1mfbvB_zsYJLrpG{{l{ zr21X@7E$-M`6aM;54Vq#>~)_AU)?UF;7<`tqiT@0_-!w5Fs{ILg6rimUFro=gQxYf$!*1nQ-79=`T+Y9JU9}~O7;Xf7V&vhVRS3>c!1(7q!7MjhyoWbCEvO(Pz`+z^#=2Be-Zai&kB6 zH%BR*0yW_rY;ZbEFn{0R5x536bPtG{@582fw4WVbX>H}5*7lcfp~v@GMnN9TE6J^k z!Bl^|rSp#K?>i{rF)@wlPVciYHU9t&YMNE^SnHZa#ht+!XfCaaF;o7@9S?6_^{1r# zK=Jp(`yp>><7>-kq)_i^G`?zW&)is~%4GWBR-FF;wQiMpU^IUi{65q5TNGfyqlqmY z4^mMh1|=l*aDmTf2b$(}yFUZ?%GzB^;ct%Gd%(H>0EBzQ8(_I0a26Su`EwukS(Uis zcR0tIq&@5vM*N#;{JJyHjutc18&8#Y-5uio)}AZZ^s}S*);KQU$Pp%~W}1uvIly$B zGpHHJXrxfRkk={VDgF>^kQ=>2!2bZ*DCGYD#K*%|&ep8Qvn-puq0cH)K9%IYAl0w@ zKdH^7YnE3Q_g3y!TT7^(H$$J|RsfG|fDdZRUx`a{vv2Lnbt*m-jLYucCH;KeC0uSzE zE0ffr);uYAB-A64+D!i71w8)%x^u|m)aT!~F?Hj#m0M8o#oJlW7$$o-op z{{XSuwnaAI_JY)Ht&NtZ$#rhvWNlK7$8{j9eMrxt7112FHRW|4R%0x5>8m+&CHJFp zZBEl$nEkBBXL7@Fn6462{?Pstk~#eA8(6)b)^?6*Hn=APcTwzpDuvgH{1xJ4GU+;! z>Gv+UXr*K2NPRKOFXqG8RyL#JmeEjH>-KJ8V~B1;0s;R3Wx8YAt!@34Sv^(!(F-)g zwWGJGu=nOGQXb?6Zot#6bwO;ZqE=@kpf$r@_|X<(QhmtWqn>j~YvZdMX$whkPDdni z2l!C3+Sd4W7xt9~{v*7TQF#~%0R(f6)hNA=Hq5(uCzH?lHRsa&M6->Lmdbq%Bm7Hc zQs!vD#{)dnEUyzUvTGdE3bndD0WTv#xl$lo7O*0QO~;@hb+0y?;)Stqn)Pry zGyeef>b(B|6*QgOe%LeH3^@98TSq6v{{ShRv&)BJ|)B2(fp5oZy9c2662`gZr}|q~FP(rZ=Ayd_|HZo=Hm&H)NZ(;Bq>jTEUOSIxd}Lkm|Y@h}ETx>_}PiIOsxy&0}8p zzd?c-Wbp2dsO~2$bqiTti0+}b3=W)N)J9NF{HLEd%}3%kF7>Ti>K8hnhb~M=-@BFE zbN)Rm16lDuh%Ve)*!Wp?1SH87z~Fk3o+~a zD~r!E;_BKcW6I8naHW-q<0O6GTwtGSZ2GKoTnps1y1J2$9^NZv*gcw0m5zM?_02#m zZ*`lYaC|24M2-S9p@PTIe5OGS+_hLKp8 zanq{0aS5%o*`rqp2?2!HDqHRmg0VIODGd8ifb{im8{R8O+V#|$)y&RJ z-ezQl3Gbe-+l~;9Fv6-h%yR}K=GECc=LvZZHi^}J7O4%gkv&ko^B*PcN z_zk-?zPMDcAiP^a{#pE~O0@*ii8G^GQM);B7~ATa+nrWNmD|hO4AJchpncUC0aSKT zg#&;90=(`*5%Doj;? z&r&n}2&_*P_&>thg|s&MuAdg7w$|n)k*%R}PnhkUzyf~%0Ihdn>c%Ze%Tvq5LRV{- zL&ts`crpn--)X5^Ngj}v-)-?F)w_9ewx0sIHmLno?ilR@JnRlZSeb?{g%^qZUE@2? zi`uO4K-TjIx09G)j^MH+5JGXBM5S3`OqGtHHsQ@YhkE*6qaC z;?6@Dl4Zb;jqC`G11RH_4EkZP+SX`{q58`exPH=b|6NNx< z%_ftLE!2^&9}!xLYn7j_Uk-Gi1KsKISz4?S$#ritTiY9aC&-T>kgJTL+8w}Qy!?T% zUklhl;O_{^Z(_F6p}C6Q?Uo3b0JjQek2^9rT&PuUG9*0?0Q~jvZ|$|>-w$ZQ;`hV4 zi`+#H&{E1%=Qozf#rP!eCv$htIq8inU-&4ej5LoGY5LvAhBbL@=Z@fAD6^Gzqq<|A zrwn|E-#xs;>CW-Q!<8h_w>^03u6^_2kApr4)~sD^v}QxR7hfzX#^nT@4KDY}xehj6MzMeh`4cYZs6bK+74IA;?0YsrANqIPK3$`M>t1 z_?M#iLqXK^Tg^yZ*|Md~GDh2;IZ(gMK_G%gG8l8VMavLD|8&g?gWp za_I5A&~4>L2;A6Fzz6-7;AKLn`S;>Cf_^o4vtGZr*Y&G8(*DZwJx1OOmNHM~#CFD@ zCfuV+?j>f!G_qj2;j1e6tUTS5v66-|$;n)!w$Ow1&sd(&3g#)=f-^VwIe-#Eh&92;q(d$?8>rJ9e64 z{2%zsr%vyo&EX3uq>;9ZT^2j3vAZA#nH=XMkDHFw;!YR}a&9w9=cO9B`mnV~Hh113 zyI&3Xo=q=Lnq5y%Xp-kf)a~PO1?`l8v=<6n18XQb+(A}XL-Qj6(AfNa@QgMv=-&~w zST1d?;|Xsir0UmJLnMH%cyt}&U{{vgu(4|M4;Oqg_~YZRhj!WrjQm*F7dmv;GHJ71 z-AeBiOkqAwJBKYRaxuO&04kOT&l~ni@vEx6--@F7`Y94*k6Z=LJ^E*kqP)7fc315$ z8oa6W-1I5qG43mxlau?8r95HrqSsK<;Y~lpDRyOz*TeeFiw0Iezx!Q7}hUGhfKu@}*BSy5H=;`nB`xHn&m;=1|kaaKAFM76g(82yAW# z)Sd-b)G`g+n>YIYe3hY1BD4O{cC+&>?4+Jp2F4GjbDUHU`$EBuk9ZYMNr^}c_=>%y z{>fTDIR`KV>fkLh66vCbyV4{{WsrKc8BOM~CSzcS@GH#J<&RBfr*XK*`JO#cf>g_k31-$#!vI?dMKA`5E zs(8Cw)h$1^bp1L&{$|^_yuOTci2xk|>Fh_i74CQc00;ah3&`)Gz$M-S7@OrI`@{~1 z+;^$8?*RB_+WIw%L9%Vp80JZM5`&!NvG*MMs6DHA=9zp|l{21Mc4a4bRQb`huNT}M z(dkmQP6rto`gE;5FT}qVt)n*9P&tA@hqabJGB+NDPC@>j)$eWa_ri9{MlG-#AjUV4 z2Sq$%KSNF)6!7XNn>LXog+>5q)dqTwjh^`F+NEbyV7=OC9#39w@~Y1ei^F~>@eQzr zbcyuod&v}KpkwQTa;NF}RcU?~d_#@)Pj<)%`B*Lk59?m{X{6eV-M0EQyazmol(K{C z&N!_%(KS04Up7@p*ko-8po7$d&u(ieXLC;KDI3Eii{4GM=X-C29wMDpu6#tFyy${X ze=JtTzwG6#-a?XGYib080D0Lb@CLn-PY)!~m${tCs^7a;aM;P@WAgm!<)4HS{bNY1 zLk!~)!~s8pWBJwTs%k1y-@u^o=(1Nk#p&%u$SaqS=8tJ%Gm*?sCm;kCx}A z_GX5qsQAlHx!-wrG%Ykhv`^;9-FjpW2joU-zx*V+wZvg=q?SV`1VeDltM9@3_NALt z()9@1TYXmFbA99%YR7M-53AehD$3W39^f$ERx-YrrD-_bhddFU()g!&M4sX`7$eSk z2Lq-FI3l#HHGdFV#<5Rp=dON9BWB42bR#$dsY$5GtQSp}pih9!95dvvH|@hUFw zGU15F%aRVA(VaFXYSIMS^eBO3C0Uh0@;O?>p zZga>3*wqbB#ZcQHvk2QEc~a+Z<*LP{SY^!qhc2caX^{?%H;hS|UCXOWn02?Hq0dK&7Xgywce6>%C}qAf2$I-G2RHCa`178f}`Uz^sg zcw@m{AU6`q`h1>K0;Dud0Uh}yiq+CQ6{!Fk?@7D1$sCJi&*nKb+33Cw*Q3+gXzy-B zBceU3=Dj+2S=#nF9aueXa~6IFnWR$I9$w*stgH$AxIc|_c0UYaoFuzufCUQTHOjd4 zIp@;1G))FuxpQ=91~M|w85_S^>a?u_%@nAPT$7c`=lRuB!cEAvI9qa$hWrI#Z4^m& zc(8)QY0eIL>OlI|wV#0Oqt!mt*ZYz%L1^wXo<=y$X~dTC1xU2W*p4#MX0Pc@Jk0jv zMTn_o&e(v-;AXdt1r=nCl_eHg^Rd<6YjF;g^}b(|6_kxZRFF8t$Wmr!{*VII7U1 zv}c@8;b@|Z?C|(@DC6CckI>-%06O#E9DFd6dzmk89#}krPQ+cuZ&87eFb-?!nEVrG zsDTC6pn-|uz~GNgwa)m@;cdM7ds>ZPhTY2u_Yyrml~r=6)Q$O1J{g3SwLW0K@J*`R z#F1RY0>!#kk_BO&Pw*Uf_3mpTUk&N;x?IlcLpjU+pC6A0IQ(ncJWKFcr0pcS#C~4i zmIQ^AAIS4uo~Pj0?Yzq?i-;ybFnrk975Zkp3{FoDX4*Y^8IpP?aMzX^Rn|la6k(Su z{*2(cQSwDf~^Ml!5MEizUZgQWQrr&rs zT$U2qExd3-SCd-+;_kmQ(xr@&jgvVB@Xmo4Xp>A17-w{F6;YpDwPora7}3KPbVZGR z@?#*Md>ZTy-h)3-gVrGE@ZXag7-n{nf?toh@+h_n$~$QRU%;)Ie~F!sF9Lz%eIw15K^k(Ef_ z8#6D>j(hX`{{RZFHJ^r-PdCfjSh^nQ?y+{R|vPrl9K|Z_+vULkaa4%%@LVz(N0m$jj(Z)LZoK{V?l|J2q5WGT&pdbVD(j;07v3+nw~oiEWePG@S}jUCpl$@z&UB7kYlDE zuG{D~FZ<6TUb#6uR82LiWkyLYO*>dDV{a}yrdHd@Gp$>K4JoDah3R4xg3UVP{A!J# zgRbWudPpNh=v}d&!n9YzmiD_(+RJB_81Ly==$@nJP1zQrip9dT7f!?=+(g45`eV|x zY_#|$kIze{LNmKEf^qp(Ya4rV;wTiH{{VQ_?v-h150xlVSak?9QsqyY$rB|jT=~|2 z2>9maU$Y3`Xvi!_@}-Z#em1v3ZnlA=Amp6nS6y-Ozr*2!haP4+l>&%_xuQf=+)5%1*3RtZ4;ib4}sDh!zAzB%pNp;wJ{SG+jT&aYN=@ zYM4{c&6;dq5N*8AA}Mz%7#RYz?EFUr*5C<%{^&lHjYlh8qqdB!OBR)?-a;B(U=G|b zIi$JM^>`3N9;z7!`ybAb_=?yr=3$Z36$^Nh#C+xR_*Xm>8#?Mrvs)6{!+LHNrMe-y z02hPxu8T+$nSh??o<}$tu5#1F781&&Ff2y^8tAm??aW?abGfT;aOIS$HK`|vNKD5S z?eB)SXVn&eT1<>!Z>y?KL?FNZC}bHtyoM?GxcAjpai!wZ>QgTWPMB zL-<+a4IM+vB*31#!Nxsn*QTmlRC$h6xvg`ewT|1#DGYmY#^&i)?UunU_DI3p4odf| z5#g^9sWF>|9^*A`%fs5VCoeX0#%hvkTOBm2rdhtaXvk|pAaCLo4~J*hc-~fc zQq`qt;j8&%z$LN40A{XeegxH;@qk7LcfNh8dnVdw^&wu=wLKrf=E~icS+AV@yrQu@ z;=8R6`yOR6Mv_R3KIo$lLpd>dth&isp+rvI3c>%M^@z@?~*2Bh1J0Crc zoE*10yA3J=s_{b-DCv=!HGL-9=EfnJ(D%F zki;G`bJNngxnC}*@#R&ySxCs%ye$a1GQI1O(mr|*TH*X_;Hzs{jOq)P9I*a%YhToU zC0eT6EG_bV<{;N6tZP3IuGyowMk+JcHRR&pn)pvc3stFWBc0atg|xb6jmY-RR~Hki zEN1wkr$u8~-YC2AR7}AlBZ5v_oYywZ;tAX}!@18qX1-4sO44W3P_24toqXDo+XE!G z`3IbOR%WB(t4RT0IFpQI8p65$)7h}2wnyF>;<9cnHCGsAan5>muQGHYb8l}Y=yX>; zFOLA2nSfFW2C}ZaVPO{GBuqVdu6I(>HI{wRT)KjK1B3Y1Mc$dO$TqakgXx;*ofuoW zy*jeq=UUz))246TH~~g^?^VXF1;Ys!0fFb`9M>N{pRJ$@t0x^#tyPlS#PK<3rIAK+ zjMgg+9_K`ACDiYl>eb?A@|CmGAaPl)7B$|uZ14dU#L0i+_?fmloF0Ly88wd;Z31B+ z8R`kBczHLt;WV^68RAJ;O3^3)k4md~J+PZ;h;{z$b24jQA~5XSU2!@60uWnfxFV1aW;z_W;M5L;~6R|0rjm?Pnlsk9*SwP>bK>?Nf9I* z^T4V04I^mOi)TY^e`E;)(42HTDrohqb$yR)RG#GqD-b4P|~6cJAv;@mMZfUG(U zW`{}^)rT~jxwQI?nvf#54fIS^nIIYbwA%|qM0PF>G+J2vQ)9z$FM_TCYEbjJ$x1jZ?bko+>V|8O4&x9hH zc^_+BG4$sctC9Vn5EfkHk(Q|7vI1NjjB;^T?d@%p?sLv~CcfTv_0ahxMPFB97~;I0 z33qMVxxuCTBN>QdY>lK|lH&ovO#9dq;jYm%g|_>3JXSzk$9=5L`O@p2 zDGSC%2VcsvM~c8~gPe2}=&H?ZV>bk}JFg9Cle~smW*9i?YtnohuUKeuqfAtsk8xfv zqIjxln=p<^9ZyQ?bi;joRJ54g*k-+KKh_}ANZ_oYC96K2@Xy9Q0xMA)BuA$w`Btu_ z@pHqs7W-B(f9sm(72qBk(yz8EnNK60Yid6YY7#c|oRE8N&3m+J*KpMyJxVx!YuxmO z@y3~TF=%msd#!0$>b7@Mf)D}M7_TyqPt+_|CPGemt*ry?zi3d>Kc)t2sxzIoGK6U< zd`G!>GfmU)V^AYgo}~4z^F#1;kdPRx)Dk+L_2yp;4a_d0Nf=}fwd~#(nio_h*pB4a zsX~_Hv|?M~zmmwsto>L*hFFWX2Rh^5EqLMS0G%;@3-pBm1YOSBz1sYZomG z1&C*-IW^|~G1o4wCT0v52R#U`TGgsq=oIL|%E_H(r{ZfiLSzU|4?Anl{9od0C7sM9 z6337QbNb(puh)mT~5*JYv-|enzq#*wG1>=jV)U0-%*@_9!Sqo zS<~5uLRtfX-#`6op>3$$MK;%zmg-0qq?azsj4}1EiOj0jjaJX9!qcZ3O2r8*fNzm8 zkyY+A)3+8j9+(_Z@7b87= zDbe1jU9uoMM;J6(WjVzyj5&N66MVmA9Muc&0>*x07jA)TtU7=Pb5$GdS0jZa(aAPA zR7WFq;K*Sc)iA!BYc}`8XuyPAA9b$2Ti1}0j-#QhoAWkFUnTB?*eT71P(IMJYb5`lp?p$l9EsSN*B6TBXPZ&Ow!Qu4|Ndp6^ zt!W~V9F_cqK`!t}T=ev;`DD!_S5k=5Em!8r$oHwCvA5lmE;H(Al^nMRzaUh)eXwap z;m5s8R(3j@ijB!(wTLp6Vrxdq)+9TbN6^&P_NdFT6cz(H6+NDtG>jP=KIC&sQMr<= zx+3+O!v6q#b@Z$7gAXf>y-6aWw9?%GEI=N(;;kfXgKG2i6+EqM$xa-_n<-@82bKrD zZ)q%K!Ew`}tX)1Sob5Zm5OMsgb5AW6(YSl#y$MR>)O78k@R#;I+4;hdNa!d!MeYkJ zCl%7hA$HoMkHdAF0_PNZ*r{22+W;i~Tv!nQaSqAAhA95=67BL3^#(5Rc z={9jBKxdA)$4cX^URKoZ!ohQDVx0;|*X2@u#aEw3a(WOw>#ulYWzPp4FeJhS z)okRHBe~}uDblW^mmsn2pTe=N0OH9q}rvDSe<1PAkg1DWph<2JCdjcREg$bkTrhIKc1Py(%+F1$tEW zN2F<`2W^lC8Qc^M{*}Xcui<=Jbc`85!8{IfYw1mMRI*gq z;Z$wpab8{HyUjhfx~L=)NjR??7eb1S(dprc|kRYCjX2 zp?4`Gabmr)JJmljO~f!gIOe>MV;i1=oSA-W&pWGNdJ3&Rp@Fe^BdN_*X?{`(JG1N3 ztIk*wbDn)`DN})QleWc%)C)HUAB`|w0rPk4dy1wOLfI-v$E`S_W$Co~Zm8ypMe?>- zX^RZ*IjZvz-!r#4rWKh70FmpOhTzJs#ez!#<)c_EVoR<3!HxsJk=() zW-jg78uqOU;Yk?HXEng)mz$9jTRaNjGQodBgfG+cBvpt)3(1I9mE+tRP&bA})cbj@+%_GCK%HFHkY zC5}LPJ)G+$p-=$8<0smtaZzS7l1RkY zfp*3_@@p>IKbwLHE%{caxQ@~ir-9z6S}H5>diSo1F5^EUQ$`{AfRGMw9D{;p@ zhO?y9q<7ChiQv_ZGX2D(z7II&wn;HU{Kug9PflqWQBN5e$gcZE(<4#^!sLB;t|P!x zh9e-6#d$=sL1yrKPMfugVAER+jm#h*;DTy*=s(6$E9J{#264iol4@ zN#NIfbUd2P%W}+~B%UNH=s@l}3eCUqC}7MO915wc$$a@mj2s_M)z9ipa6!uu4srm- zD>=#7!Z1!ne-wC}Iuk(qcQx`WcW%cC&19;#7XFnNm5bN!+Qcr zJx8@S>?k=%fW>P`Z4!@^{c+x_$r_+LQzz?Lb6m8HxomcWoM8HPrmd`CVE!4c0}IGD z`DdDrOCmaqW7e}$ZR%Ckii-pVPUHg}@l2A{TaPSb*j1o`PfB21oM6?Yh0NJP>c%!4 zFFt~(+}osTGCrcV$HD9v21Z9umTVX>*1P<&66DtC8Eq!sEXtw5@F++mOAho>e4{ zsxeLKWgB9t&T>6z0hS|>P66l*YR_#GAcWMR2D0|T1qtaVft z;Im+`=j&Ykq)E7O=~k5%G#fbPhZN3wj`lBV7b!5tcz;HfqO76qXPB#@u^0s;& zo{+3zAXnyyzpOg$*Y-4{O55RkJBE_lU9 zJ^T%lJv}QlPb)-8Z)&f3ETaTSdRNF}S2aW5p*IK`ZG446RyZ|UYt#oI{cAZSe(X=y zl6EDJJvvvI;*&Z(tePy%c^e(%fH9m3ttO!Yu^{BaF|3%{i~4Bq}H;=~ZJJ z*LxmwLk+($Jm(ZxZCILX`5AsfK>F5ZqjEnsySOY)bIntGr&h{>2dOoud2bn3 zY!QM-6`OF2zh`bmZx)iZhLy+~?m~AL;10%+-pa#iIr`Ld#=DPRwGtpQ5J~A)l3?^@ z$ZuR@cjjwC-tJzCGm6iKBxV`)tM*W#9QLD@B5qARO$%$;8HNzxcdZ*MsTsgg(x_R+ zzcA;Zt*tuB+>!F0YWX5#%_|e>^00Bcrg=E&Te^MRoB%s|RVzC<3j?;dG|NbF_j(FW LDJue#wx|Eu-QcpD literal 0 HcmV?d00001 diff --git a/tests/data/sd/metadata.csv b/tests/data/sd/metadata.csv new file mode 100644 index 000000000..c10a81565 --- /dev/null +++ b/tests/data/sd/metadata.csv @@ -0,0 +1,2 @@ +file_name,text +color.jpg,"a dog" \ No newline at end of file diff --git a/tests/data/sd/metadata2.csv b/tests/data/sd/metadata2.csv new file mode 100644 index 000000000..458dfb8ed --- /dev/null +++ b/tests/data/sd/metadata2.csv @@ -0,0 +1,2 @@ +file_name,text +color.jpg,"a cat" \ No newline at end of file diff --git a/tests/test_datasets/test_hf_dataset.py b/tests/test_datasets/test_hf_dataset.py new file mode 100644 index 000000000..1862c9417 --- /dev/null +++ b/tests/test_datasets/test_hf_dataset.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.testing import RunnerTestCase + +from mmagic.datasets import HuggingFaceDataset + + +class TestHFDataset(RunnerTestCase): + + def test_dataset_from_local(self): + dataset = HuggingFaceDataset( + dataset='tests/data/sd', image_column='file_name') + assert len(dataset) == 1 + + data = dataset[0] + assert data['prompt'] == 'a dog' + assert data['img'] == 'tests/data/sd/color.jpg' + + dataset = HuggingFaceDataset( + dataset='tests/data/sd', + image_column='file_name', + csv='metadata2.csv') + assert len(dataset) == 1 + + data = dataset[0] + assert data['prompt'] == 'a cat' + assert data['img'] == 'tests/data/sd/color.jpg' diff --git a/tests/test_datasets/test_transforms/test_loading.py b/tests/test_datasets/test_transforms/test_loading.py index 6597a3f29..5bd72d069 100644 --- a/tests/test_datasets/test_transforms/test_loading.py +++ b/tests/test_datasets/test_transforms/test_loading.py @@ -5,9 +5,12 @@ import numpy as np import pytest from mmengine.fileio.backends import LocalBackend +from PIL import Image from mmagic.datasets.transforms import (GetSpatialDiscountMask, - LoadImageFromFile, LoadMask) + LoadImageFromFile, + LoadImageFromHuggingFaceDataset, + LoadMask) def test_load_image_from_file(): @@ -298,6 +301,43 @@ def test_load_mask(self): results = loader(results) +def test_load_image_from_huggingface_dataset(): + + path_baboon = Path( + __file__).parent.parent.parent / 'data' / 'image' / 'gt' / 'baboon.png' + img_baboon = mmcv.imread(str(path_baboon), flag='color') + h, w, _ = img_baboon.shape + + # read gt image + # input path is Path object + results = dict(img=Image.open(str(path_baboon))) + config = dict(key='img') + image_loader = LoadImageFromHuggingFaceDataset(**config) + results = image_loader(results) + assert results['img'].shape == (h, w, 3) + assert results['ori_img_shape'] == (h, w, 3) + np.testing.assert_almost_equal(results['img'], img_baboon) + assert results['img_channel_order'] == 'bgr' + assert results['img_color_type'] == 'color' + + # test save_original_img + results = dict(img=Image.open(str(path_baboon))) + config = dict(key='img', save_original_img=True, channel_order='rgb') + image_loader = LoadImageFromHuggingFaceDataset(**config) + results = image_loader(results) + assert results['img'].shape == (h, w, 3) + assert results['ori_img_shape'] == (h, w, 3) + np.testing.assert_almost_equal(results['ori_img'], results['img']) + np.testing.assert_almost_equal(results['img'], img_baboon[..., ::-1]) + assert id(results['ori_img']) != id(results['img']) + assert results['img_channel_order'] == 'rgb' + assert results['img_color_type'] == 'color' + + assert image_loader.__repr__() == ( + image_loader.__class__.__name__ + '(key=img, channel_order=rgb,' + ' to_float32=False, save_original_img=True)') + + def teardown_module(): import gc gc.collect() diff --git a/tests/test_datasets/test_transforms/test_sdxl.py b/tests/test_datasets/test_transforms/test_sdxl.py new file mode 100644 index 000000000..e2ecb3c92 --- /dev/null +++ b/tests/test_datasets/test_transforms/test_sdxl.py @@ -0,0 +1,206 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from unittest import TestCase + +import numpy as np +import torch +from PIL import Image + +from mmagic.registry import TRANSFORMS + + +class TestComputeTimeIds(TestCase): + + def test_register(self): + self.assertIn('ComputeTimeIds', TRANSFORMS) + + def test_transform(self): + img_path = osp.join(osp.dirname(__file__), '../../data/sd/color.jpg') + img = Image.open(img_path) + data = { + 'img': np.array(img), + 'ori_img_shape': [32, 32], + 'img_crop_bbox': [0, 0, 32, 32] + } + + # test transform + trans = TRANSFORMS.build(dict(type='ComputeTimeIds')) + data = trans(data) + self.assertIsInstance(data['time_ids'], torch.Tensor) + self.assertListEqual( + list(data['time_ids'].numpy()), + [32, 32, 0, 0, img.height, img.width]) + + assert trans.__repr__() == (trans.__class__.__name__ + '(key=img)') + + +class TestRandomCropXL(TestCase): + crop_size = 32 + + def test_register(self): + self.assertIn('RandomCropXL', TRANSFORMS) + + def test_transform(self): + img_path = osp.join(osp.dirname(__file__), '../../data/sd/color.jpg') + data = {'img': np.array(Image.open(img_path))} + + # test transform + trans = TRANSFORMS.build( + dict(type='RandomCropXL', size=self.crop_size)) + data = trans(data) + self.assertIn('img_crop_bbox', data) + assert len(data['img_crop_bbox']) == 4 + assert data['img'].shape[0] == data['img'].shape[1] == self.crop_size + upper, left, lower, right = data['img_crop_bbox'] + assert lower == upper + self.crop_size + assert right == left + self.crop_size + np.equal( + np.array(data['img']), + np.array(Image.open(img_path).crop((left, upper, right, lower)))) + + assert trans.__repr__() == ( + trans.__class__.__name__ + + f'(size={(self.crop_size, self.crop_size)},' + f" keys=['img'])") + + def test_transform_multiple_keys(self): + img_path = osp.join(osp.dirname(__file__), '../../data/sd/color.jpg') + data = { + 'img': np.array(Image.open(img_path)), + 'condition_img': np.array(Image.open(img_path)) + } + + # test transform + trans = TRANSFORMS.build( + dict( + type='RandomCropXL', + size=self.crop_size, + keys=['img', 'condition_img'])) + data = trans(data) + self.assertIn('img_crop_bbox', data) + assert len(data['img_crop_bbox']) == 4 + assert data['img'].shape[0] == data['img'].shape[1] == self.crop_size + upper, left, lower, right = data['img_crop_bbox'] + assert lower == upper + self.crop_size + assert right == left + self.crop_size + np.equal( + np.array(data['img']), + np.array(Image.open(img_path).crop((left, upper, right, lower)))) + np.equal(np.array(data['img']), np.array(data['condition_img'])) + + +class TestFlipXL(TestCase): + + def test_register(self): + self.assertIn('FlipXL', TRANSFORMS) + + def test_transform(self): + img_path = osp.join(osp.dirname(__file__), '../../data/sd/color.jpg') + data = { + 'img': np.array(Image.open(img_path)), + 'img_crop_bbox': [0, 0, 10, 10] + } + + # test transform + trans = TRANSFORMS.build( + dict(type='FlipXL', flip_ratio=1., keys=['img'])) + data = trans(data) + self.assertIn('img_crop_bbox', data) + assert len(data['img_crop_bbox']) == 4 + self.assertListEqual( + data['img_crop_bbox'], + [0, data['img'].shape[1] - 10, 10, data['img'].shape[1] - 0]) + + np.equal( + np.array(data['img']), + np.array(Image.open(img_path).transpose(Image.FLIP_LEFT_RIGHT))) + + assert trans.__repr__() == ( + trans.__class__.__name__ + + "(keys=['img'], flip_ratio=1.0, direction=horizontal)") + + # test transform p=0.0 + data = { + 'img': np.array(Image.open(img_path)), + 'img_crop_bbox': [0, 0, 10, 10] + } + trans = TRANSFORMS.build( + dict(type='FlipXL', flip_ratio=0., keys='img')) + data = trans(data) + self.assertIn('img_crop_bbox', data) + self.assertListEqual(data['img_crop_bbox'], [0, 0, 10, 10]) + + np.equal(np.array(data['img']), np.array(Image.open(img_path))) + + def test_transform_multiple_keys(self): + img_path = osp.join(osp.dirname(__file__), '../../data/sd/color.jpg') + data = { + 'img': np.array(Image.open(img_path)), + 'condition_img': np.array(Image.open(img_path)), + 'img_crop_bbox': [0, 0, 10, 10], + 'condition_img_crop_bbox': [0, 0, 10, 10] + } + + # test transform + trans = TRANSFORMS.build( + dict(type='FlipXL', flip_ratio=1., keys=['img', 'condition_img'])) + data = trans(data) + self.assertIn('img_crop_bbox', data) + assert len(data['img_crop_bbox']) == 4 + self.assertListEqual( + data['img_crop_bbox'], + [0, data['img'].shape[1] - 10, 10, data['img'].shape[1] - 0]) + + np.equal( + np.array(data['img']), + np.array(Image.open(img_path).transpose(Image.FLIP_LEFT_RIGHT))) + np.equal(np.array(data['img']), np.array(data['condition_img'])) + + +class TestResizeEdge(TestCase): + + def test_transform(self): + results = dict(img=np.random.randint(0, 256, (128, 256, 3), np.uint8)) + + # test resize short edge by default. + cfg = dict(type='ResizeEdge', scale=224) + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 448, 3)) + + # test resize long edge. + cfg = dict(type='ResizeEdge', scale=224, edge='long') + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (112, 224, 3)) + + # test resize width. + cfg = dict(type='ResizeEdge', scale=224, edge='width') + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (112, 224, 3)) + + # test resize height. + cfg = dict(type='ResizeEdge', scale=224, edge='height') + transform = TRANSFORMS.build(cfg) + results = transform(results) + self.assertTupleEqual(results['img'].shape, (224, 448, 3)) + + # test invalid edge + with self.assertRaisesRegex(AssertionError, 'Invalid edge "hi"'): + cfg = dict(type='ResizeEdge', scale=224, edge='hi') + TRANSFORMS.build(cfg) + + def test_repr(self): + cfg = dict(type='ResizeEdge', scale=224, edge='height') + transform = TRANSFORMS.build(cfg) + self.assertEqual( + repr(transform), 'ResizeEdge(scale=224, edge=height, backend=cv2, ' + 'interpolation=bilinear)') + + +def teardown_module(): + import gc + gc.collect() + globals().clear() + locals().clear() diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index 267af16be..862021ae2 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -617,6 +617,23 @@ def test_after_test_iter(self): hook.after_test_iter(runner, 0, [], outputs) assert mock_visualuzer.add_datasample.call_count == 3 + def test_after_train_epoch(self): + model = MagicMock() + hook = VisualizationHook( + interval=1, + n_samples=2, + vis_kwargs_list=dict(type='GAN'), + by_epoch=True) + mock_visualuzer = MagicMock() + mock_visualuzer.add_datasample = MagicMock() + hook._visualizer = mock_visualuzer + + runner = MagicMock() + runner.model = model + + hook.after_train_epoch(runner) + mock_visualuzer.assert_not_called() + def teardown_module(): import gc diff --git a/tests/test_models/test_data_preprocessors/test_data_preprocessor.py b/tests/test_models/test_data_preprocessors/test_data_preprocessor.py index 60d7dd20d..bf97c4ccb 100644 --- a/tests/test_models/test_data_preprocessors/test_data_preprocessor.py +++ b/tests/test_models/test_data_preprocessors/test_data_preprocessor.py @@ -476,6 +476,7 @@ def test_preprocess_dict_inputs(self): img_A=[torch.randint(0, 255, (3, 5, 5)) for _ in range(3)], img_B=[torch.randint(0, 255, (3, 5, 5)) for _ in range(3)], noise=[torch.randn(16) for _ in range(3)], + time_ids=[torch.randn(6) for _ in range(3)], num_batches=3, tensor=torch.randint(0, 255, (3, 4, 5, 5)), mode=['ema', 'ema', 'ema'], @@ -490,6 +491,7 @@ def test_preprocess_dict_inputs(self): target_B = (torch.stack(inputs['img_B']) - 127.5) / 127.5 target_B = target_B[:, [2, 1, 0]] target_noise = torch.stack(inputs['noise']) + target_time_ids = torch.stack(inputs['time_ids']) # no metainfo, parse as BGR, do conversion target_tensor = ((inputs['tensor'] - 127.5) / 127.5)[:, [2, 1, 0, 3]] @@ -497,6 +499,7 @@ def test_preprocess_dict_inputs(self): assert_allclose(outputs['img_A'], target_A) assert_allclose(outputs['img_B'], target_B) assert_allclose(outputs['noise'], target_noise) + assert_allclose(outputs['time_ids'], target_time_ids) assert_allclose(outputs['tensor'], target_tensor) self.assertEqual(outputs['mode'], 'ema') self.assertEqual(outputs['num_batches'], 3) @@ -769,21 +772,35 @@ def test_forward(self): img2 = torch.randn(3, 4, 4) noise1 = torch.randn(3, 4, 4) noise2 = torch.randn(3, 4, 4) + time_ids1 = torch.randn(6) + time_ids2 = torch.randn(6) target_input1 = (img1[[2, 1, 0], ...].clone() - 127.5) / 127.5 target_input2 = (img2[[2, 1, 0], ...].clone() - 127.5) / 127.5 data = dict(inputs=[ - dict(noise=noise1, img=img1, num_batches=2, mode='ema'), - dict(noise=noise2, img=img2, num_batches=2, mode='ema'), + dict( + noise=noise1, + img=img1, + num_batches=2, + mode='ema', + time_ids=time_ids1), + dict( + noise=noise2, + img=img2, + num_batches=2, + mode='ema', + time_ids=time_ids2), ]) data_preprocessor = DataPreprocessor(output_channel_order='RGB') data = data_preprocessor(data) self.assertEqual( list(data['inputs'].keys()), - ['noise', 'img', 'num_batches', 'mode']) + ['noise', 'img', 'num_batches', 'mode', 'time_ids']) assert_allclose(data['inputs']['noise'][0], noise1) assert_allclose(data['inputs']['noise'][1], noise2) + assert_allclose(data['inputs']['time_ids'][0], time_ids1) + assert_allclose(data['inputs']['time_ids'][1], time_ids2) assert_allclose(data['inputs']['img'][0], target_input1) assert_allclose(data['inputs']['img'][1], target_input2) self.assertEqual(data['inputs']['num_batches'], 2) diff --git a/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py index 64c6345eb..fefe6cc8e 100644 --- a/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/test_models/test_editors/test_stable_diffusion_xl/test_stable_diffusion_xl.py @@ -85,8 +85,10 @@ def test_stable_diffusion_xl_step(): # train step data = dict( - inputs=torch.ones([1, 3, 64, 64]), - time_ids=torch.zeros((1, 6)), + inputs={ + 'img': torch.ones([1, 3, 64, 64]), + 'time_ids': torch.zeros((1, 6)) + }, data_samples=[ DataSample(prompt='an insect robot preparing a delicious meal') ]) From 6d0a29f6fb5f08e9e20bff127c1850b88522f868 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 21 Sep 2023 16:45:41 +0900 Subject: [PATCH 04/14] update readme --- configs/stable_diffusion_xl/README.md | 7 ++++--- configs/stable_diffusion_xl/metafile.yml | 8 ++++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/configs/stable_diffusion_xl/README.md b/configs/stable_diffusion_xl/README.md index bc67a69ee..ac68382fd 100644 --- a/configs/stable_diffusion_xl/README.md +++ b/configs/stable_diffusion_xl/README.md @@ -20,9 +20,10 @@ We present SDXL, a latent diffusion model for text-to-image synthesis. Compared ## Pretrained models -| Model | Task | Dataset | Download | -| :----------------------------------------------------------------: | :--------: | :-----: | :------: | -| [stable_diffusion_xl](./stable-diffusion_xl_ddim_denoisingunet.py) | Text2Image | - | - | +| Model | Task | Dataset | Download | +| :-----------------------------------------------------------------------: | :--------: | :--------------------------------------------------------------------------------------: | :---------: | +| [stable_diffusion_xl](./stable-diffusion_xl_ddim_denoisingunet.py) | Text2Image | - | - | +| [stable_diffusion_xl_pokemon_blip](./stable-diffusion_xl_pokemon_blip.py) | Text2Image | [pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) | [model](<>) | We use stable diffusion xl weights. This model has several weights including vae, unet and clip. diff --git a/configs/stable_diffusion_xl/metafile.yml b/configs/stable_diffusion_xl/metafile.yml index 4844b7daf..8d79a68a5 100644 --- a/configs/stable_diffusion_xl/metafile.yml +++ b/configs/stable_diffusion_xl/metafile.yml @@ -16,3 +16,11 @@ Models: - Dataset: '-' Metrics: {} Task: Text2Image +- Config: configs/stable_diffusion_xl/stable-diffusion_xl_pokemon_blip.py + In Collection: Stable Diffusion XL + Name: stable-diffusion_xl_pokemon_blip + Results: + - Dataset: '[pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions)' + Metrics: {} + Task: Text2Image + Weights: <> From c631719f58741a96be417e44807eda8a88e88616 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 21 Sep 2023 18:31:00 +0900 Subject: [PATCH 05/14] fix test --- tests/test_datasets/test_hf_dataset.py | 6 +++++ .../test_hooks/test_visualization_hook.py | 25 ++++++++++++++++--- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/tests/test_datasets/test_hf_dataset.py b/tests/test_datasets/test_hf_dataset.py index 1862c9417..b039f42a1 100644 --- a/tests/test_datasets/test_hf_dataset.py +++ b/tests/test_datasets/test_hf_dataset.py @@ -1,9 +1,15 @@ # Copyright (c) OpenMMLab. All rights reserved. +import platform + +import pytest from mmengine.testing import RunnerTestCase from mmagic.datasets import HuggingFaceDataset +@pytest.mark.skipif( + 'win' in platform.system().lower(), + reason='skip on windows due to limited RAM.') class TestHFDataset(RunnerTestCase): def test_dataset_from_local(self): diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index 862021ae2..cb68a411f 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -22,6 +22,19 @@ register_all_modules() +class DummyModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 1, 1) + + def forward(self, x): + return self.conv1(x) + + def val_step(self, *args, **kwargs): + return DataSample(fake_img=torch.randn(3, 6, 6), prompt='dummy') + + class TestBasicVisualizationHook(TestCase): def setUp(self) -> None: @@ -621,18 +634,22 @@ def test_after_train_epoch(self): model = MagicMock() hook = VisualizationHook( interval=1, - n_samples=2, - vis_kwargs_list=dict(type='GAN'), - by_epoch=True) + n_samples=1, + vis_kwargs_list=dict(type='Data', name='fake_img'), + by_epoch=True, + fixed_input=True) + hook.inputs_buffer = {'Data': ['dummy']} mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() hook._visualizer = mock_visualuzer runner = MagicMock() + runner.train_dataloader.batch_size = 1 runner.model = model + runner.epoch = 5 hook.after_train_epoch(runner) - mock_visualuzer.assert_not_called() + self.assertEqual(mock_visualuzer.add_datasample.call_count, 1) def teardown_module(): From f11f6a9491941b3295de358d1776def89898b7a3 Mon Sep 17 00:00:00 2001 From: okotaku Date: Thu, 21 Sep 2023 19:56:33 +0900 Subject: [PATCH 06/14] fix test --- configs/_base_/sd_default_runtime.py | 2 +- .../test_hooks/test_visualization_hook.py | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/configs/_base_/sd_default_runtime.py b/configs/_base_/sd_default_runtime.py index 6d05f8caa..ef66eae79 100644 --- a/configs/_base_/sd_default_runtime.py +++ b/configs/_base_/sd_default_runtime.py @@ -12,7 +12,7 @@ interval=1, by_epoch=True, max_keep_ckpts=3, - save_optimizer=False)) + save_optimizer=True)) # config for environment env_cfg = dict( diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index cb68a411f..bddc6e561 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -22,19 +22,6 @@ register_all_modules() -class DummyModel(torch.nn.Module): - - def __init__(self): - super().__init__() - self.conv1 = torch.nn.Conv2d(1, 1, 1) - - def forward(self, x): - return self.conv1(x) - - def val_step(self, *args, **kwargs): - return DataSample(fake_img=torch.randn(3, 6, 6), prompt='dummy') - - class TestBasicVisualizationHook(TestCase): def setUp(self) -> None: From 49a9e12d873c530d0d2eaba1d92955c7c7b5f685 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 22 Sep 2023 11:48:00 +0900 Subject: [PATCH 07/14] add lora training config --- .../stable_diffusion_xl_lora.py | 39 +++++++++++++++++++ configs/_base_/schedules/sd_10e.py | 11 ++++++ configs/stable_diffusion_xl/README.md | 9 +++-- configs/stable_diffusion_xl/metafile.yml | 8 ++++ .../stable-diffusion_xl_lora_pokemon_blip.py | 23 +++++++++++ mmagic/models/archs/lora.py | 2 +- 6 files changed, 87 insertions(+), 5 deletions(-) create mode 100644 configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl_lora.py create mode 100644 configs/_base_/schedules/sd_10e.py create mode 100644 configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py diff --git a/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl_lora.py b/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl_lora.py new file mode 100644 index 000000000..1bf1eff1a --- /dev/null +++ b/configs/_base_/models/stable_diffusion_xl/stable_diffusion_xl_lora.py @@ -0,0 +1,39 @@ +# Use DiffuserWrapper! +stable_diffusion_xl_url = 'stabilityai/stable-diffusion-xl-base-1.0' +vae_url = 'madebyollin/sdxl-vae-fp16-fix' +unet = dict( + type='UNet2DConditionModel', + subfolder='unet', + from_pretrained=stable_diffusion_xl_url) +vae = dict(type='AutoencoderKL', from_pretrained=vae_url) + +diffusion_scheduler = dict( + type='DDPMScheduler', + from_pretrained=stable_diffusion_xl_url, + subfolder='scheduler') + +lora_config = dict(rank=8, target_modules=['to_q', 'to_k', 'to_v']) + +model = dict( + type='StableDiffusionXL', + dtype='fp16', + with_cp=True, + unet=unet, + vae=vae, + enable_xformers=False, + text_encoder_one=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_url, + subfolder='text_encoder'), + tokenizer_one=stable_diffusion_xl_url, + text_encoder_two=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_url, + subfolder='text_encoder_2'), + tokenizer_two=stable_diffusion_xl_url, + scheduler=diffusion_scheduler, + test_scheduler=diffusion_scheduler, + data_preprocessor=dict(type='DataPreprocessor', data_keys=None), + lora_config=lora_config) diff --git a/configs/_base_/schedules/sd_10e.py b/configs/_base_/schedules/sd_10e.py new file mode 100644 index 000000000..cb017cdf2 --- /dev/null +++ b/configs/_base_/schedules/sd_10e.py @@ -0,0 +1,11 @@ +optim_wrapper = dict( + type='AmpOptimWrapper', + dtype='float16', + optimizer=dict(type='AdamW', lr=1e-5, weight_decay=1e-2), + clip_grad=dict(max_norm=1.0), + accumulative_counts=1) + +# train, val, test setting +train_cfg = dict(by_epoch=True, max_epochs=10) +val_cfg = None +test_cfg = None diff --git a/configs/stable_diffusion_xl/README.md b/configs/stable_diffusion_xl/README.md index ac68382fd..e4c175d14 100644 --- a/configs/stable_diffusion_xl/README.md +++ b/configs/stable_diffusion_xl/README.md @@ -20,10 +20,11 @@ We present SDXL, a latent diffusion model for text-to-image synthesis. Compared ## Pretrained models -| Model | Task | Dataset | Download | -| :-----------------------------------------------------------------------: | :--------: | :--------------------------------------------------------------------------------------: | :---------: | -| [stable_diffusion_xl](./stable-diffusion_xl_ddim_denoisingunet.py) | Text2Image | - | - | -| [stable_diffusion_xl_pokemon_blip](./stable-diffusion_xl_pokemon_blip.py) | Text2Image | [pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) | [model](<>) | +| Model | Task | Dataset | Download | +| :---------------------------------------------------------------------------------: | :--------: | :--------------------------------------------------------------------------------------: | :---------: | +| [stable_diffusion_xl](./stable-diffusion_xl_ddim_denoisingunet.py) | Text2Image | - | - | +| [stable_diffusion_xl_pokemon_blip](./stable-diffusion_xl_pokemon_blip.py) | Text2Image | [pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) | [model](<>) | +| [stable-diffusion_xl_lora_pokemon_blip](./stable-diffusion_xl_lora_pokemon_blip.py) | Text2Image | [pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) | [model](<>) | We use stable diffusion xl weights. This model has several weights including vae, unet and clip. diff --git a/configs/stable_diffusion_xl/metafile.yml b/configs/stable_diffusion_xl/metafile.yml index 8d79a68a5..2d3dc99e4 100644 --- a/configs/stable_diffusion_xl/metafile.yml +++ b/configs/stable_diffusion_xl/metafile.yml @@ -24,3 +24,11 @@ Models: Metrics: {} Task: Text2Image Weights: <> +- Config: configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py + In Collection: Stable Diffusion XL + Name: stable-diffusion_xl_lora_pokemon_blip + Results: + - Dataset: '[pokemon-blip-caption](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions)' + Metrics: {} + Task: Text2Image + Weights: <> diff --git a/configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py b/configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py new file mode 100644 index 000000000..e96bd8bef --- /dev/null +++ b/configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py @@ -0,0 +1,23 @@ +_base_ = [ + '../_base_/models/stable_diffusion_xl/stable_diffusion_xl_lora.py', + '../_base_/datasets/pokemon_blip_xl.py', '../_base_/schedules/sd_10e.py', + '../_base_/sd_default_runtime.py' +] + +val_prompts = ['yoda pokemon'] * 4 + +model = dict(val_prompts=val_prompts) + +train_dataloader = dict(batch_size=4, num_workers=4) + +# hooks +custom_hooks = [ + dict( + type='VisualizationHook', + by_epoch=True, + interval=1, + fixed_input=True, + # visualize train dataset + vis_kwargs_list=dict(type='Data', name='fake_img'), + n_samples=1) +] diff --git a/mmagic/models/archs/lora.py b/mmagic/models/archs/lora.py index 066a13cbd..86de9a5e3 100644 --- a/mmagic/models/archs/lora.py +++ b/mmagic/models/archs/lora.py @@ -219,7 +219,7 @@ def forward_lora_mapping(self, x: Tensor) -> Tensor: mapping_out = self.scale * self.lora_mapping(x) return mapping_out - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: Tensor, **kwargs) -> Tensor: """Forward and add LoRA mapping. Args: From 1b92c333fefdcc3ec8ceba77793e318be31d85d2 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 22 Sep 2023 17:51:09 +0900 Subject: [PATCH 08/14] fix test --- tests/test_datasets/test_hf_dataset.py | 16 +++++++++++++--- tests/test_datasets/test_paired_image_dataset.py | 3 +++ tests/test_datasets/test_singan_dataset.py | 3 +++ .../test_datasets/test_unpaired_image_dataset.py | 3 +++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/test_datasets/test_hf_dataset.py b/tests/test_datasets/test_hf_dataset.py index b039f42a1..5f16f6054 100644 --- a/tests/test_datasets/test_hf_dataset.py +++ b/tests/test_datasets/test_hf_dataset.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp import platform import pytest @@ -13,13 +14,15 @@ class TestHFDataset(RunnerTestCase): def test_dataset_from_local(self): + data_root = osp.join(osp.dirname(__file__), '../../') + dataset_path = data_root + 'tests/data/sd' dataset = HuggingFaceDataset( - dataset='tests/data/sd', image_column='file_name') + dataset=dataset_path, image_column='file_name') assert len(dataset) == 1 data = dataset[0] assert data['prompt'] == 'a dog' - assert data['img'] == 'tests/data/sd/color.jpg' + assert 'tests/data/sd/color.jpg' in data['img'] dataset = HuggingFaceDataset( dataset='tests/data/sd', @@ -29,4 +32,11 @@ def test_dataset_from_local(self): data = dataset[0] assert data['prompt'] == 'a cat' - assert data['img'] == 'tests/data/sd/color.jpg' + assert 'tests/data/sd/color.jpg' in data['img'] + + +def teardown_module(): + import gc + gc.collect() + globals().clear() + locals().clear() diff --git a/tests/test_datasets/test_paired_image_dataset.py b/tests/test_datasets/test_paired_image_dataset.py index 837c063cc..333eb80fa 100644 --- a/tests/test_datasets/test_paired_image_dataset.py +++ b/tests/test_datasets/test_paired_image_dataset.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +from mmengine.registry import init_default_scope + from mmagic.datasets import PairedImageDataset from mmagic.utils import register_all_modules @@ -11,6 +13,7 @@ class TestPairedImageDataset(object): @classmethod def setup_class(cls): + init_default_scope('mmagic') cls.imgs_root = osp.join( osp.dirname(osp.dirname(__file__)), 'data/paired') cls.default_pipeline = [ diff --git a/tests/test_datasets/test_singan_dataset.py b/tests/test_datasets/test_singan_dataset.py index 24bfb8f74..907254b3a 100644 --- a/tests/test_datasets/test_singan_dataset.py +++ b/tests/test_datasets/test_singan_dataset.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +from mmengine.registry import init_default_scope + from mmagic.datasets import SinGANDataset from mmagic.utils import register_all_modules @@ -11,6 +13,7 @@ class TestSinGANDataset(object): @classmethod def setup_class(cls): + init_default_scope('mmagic') cls.imgs_root = osp.join( osp.dirname(osp.dirname(__file__)), 'data/image/gt/baboon.png') cls.min_size = 25 diff --git a/tests/test_datasets/test_unpaired_image_dataset.py b/tests/test_datasets/test_unpaired_image_dataset.py index 4b7ff1bf8..694069d97 100644 --- a/tests/test_datasets/test_unpaired_image_dataset.py +++ b/tests/test_datasets/test_unpaired_image_dataset.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp +from mmengine.registry import init_default_scope + from mmagic.datasets import UnpairedImageDataset from mmagic.utils import register_all_modules @@ -11,6 +13,7 @@ class TestUnpairedImageDataset(object): @classmethod def setup_class(cls): + init_default_scope('mmagic') cls.imgs_root = osp.join( osp.dirname(osp.dirname(__file__)), 'data/unpaired') cls.default_pipeline = [ From 2bf5db37d4d978729156985e91c4fa5e26b10305 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 22 Sep 2023 18:11:22 +0900 Subject: [PATCH 09/14] fix test --- tests/test_engine/test_hooks/test_visualization_hook.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index bddc6e561..336ae3c9f 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -7,6 +7,7 @@ import numpy as np import torch from mmengine import MessageHub +from mmengine.registry import init_default_scope from mmengine.testing import assert_allclose from mmengine.visualization import Visualizer from torch.utils.data.dataset import Dataset @@ -25,6 +26,7 @@ class TestBasicVisualizationHook(TestCase): def setUp(self) -> None: + init_default_scope('mmagic') input = torch.rand(2, 3, 32, 32) data_sample = DataSample( path_rgb='rgb.png', @@ -89,6 +91,7 @@ class TestVisualizationHook(TestCase): MessageHub.get_instance('test-gen-visualizer') def test_init(self): + init_default_scope('mmagic') hook = VisualizationHook( interval=10, vis_kwargs_list=dict(type='Noise')) self.assertEqual(hook.interval, 10) From 250f567769f73de1f266ee8225a13077c611dca9 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 22 Sep 2023 18:36:09 +0900 Subject: [PATCH 10/14] fix test --- tests/test_engine/test_hooks/test_visualization_hook.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index 336ae3c9f..c6c512c34 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -126,6 +126,7 @@ def test_vis_sample_with_gan_alias(self): interval=10, vis_kwargs_list=dict(type='GAN'), n_samples=9) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() + mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer # build a empty data sample @@ -221,6 +222,7 @@ def __getitem__(self, index): n_samples=9) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() + mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer # build a empty data sample @@ -348,6 +350,7 @@ def test_after_val_iter(self): interval=10, n_samples=2, vis_kwargs_list=dict(type='GAN')) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() + mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer runner = MagicMock() @@ -373,6 +376,7 @@ def test_after_train_iter(self): interval=2, vis_kwargs_list=dict(type='GAN'), n_samples=9) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() + mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer # build a empty data sample @@ -503,6 +507,7 @@ def train(self): interval=2, vis_kwargs_list=dict(type='GAN'), n_samples=3, n_row=8) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() + mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer # build a empty data sample @@ -526,6 +531,7 @@ def test_after_test_iter(self): vis_kwargs_list=dict(type='GAN')) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() + mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer runner = MagicMock() @@ -605,6 +611,7 @@ def test_after_test_iter(self): mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() + mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer runner = MagicMock() @@ -631,6 +638,7 @@ def test_after_train_epoch(self): hook.inputs_buffer = {'Data': ['dummy']} mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() + mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer runner = MagicMock() From 995359c35adb333c78e6f75ed80ee7a6646e66b1 Mon Sep 17 00:00:00 2001 From: okotaku Date: Fri, 22 Sep 2023 18:55:41 +0900 Subject: [PATCH 11/14] fix test --- .../test_hooks/test_visualization_hook.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py index c6c512c34..6aaae91eb 100644 --- a/tests/test_engine/test_hooks/test_visualization_hook.py +++ b/tests/test_engine/test_hooks/test_visualization_hook.py @@ -122,11 +122,11 @@ def test_vis_sample_with_gan_alias(self): runner.train_dataloader = MagicMock() runner.train_dataloader.batch_size = 10 + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=10, vis_kwargs_list=dict(type='GAN'), n_samples=9) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() - mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer # build a empty data sample @@ -213,6 +213,7 @@ def __getitem__(self, index): runner.val_loop = MagicMock() runner.val_loop.dataloader = val_dataloader + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=10, vis_kwargs_list=[ @@ -222,7 +223,6 @@ def __getitem__(self, index): n_samples=9) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() - mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer # build a empty data sample @@ -346,11 +346,11 @@ def __getitem__(self, index): def test_after_val_iter(self): model = MagicMock() + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=10, n_samples=2, vis_kwargs_list=dict(type='GAN')) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() - mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer runner = MagicMock() @@ -372,11 +372,11 @@ def test_after_train_iter(self): runner.train_dataloader = MagicMock() runner.train_dataloader.batch_size = 10 + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=2, vis_kwargs_list=dict(type='GAN'), n_samples=9) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() - mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer # build a empty data sample @@ -503,11 +503,11 @@ def train(self): runner.train_dataloader = MagicMock() runner.train_dataloader.batch_size = 2 + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=2, vis_kwargs_list=dict(type='GAN'), n_samples=3, n_row=8) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() - mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer # build a empty data sample @@ -523,6 +523,7 @@ def train(self): def test_after_test_iter(self): model = MagicMock() + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=10, n_samples=2, @@ -531,7 +532,6 @@ def test_after_test_iter(self): vis_kwargs_list=dict(type='GAN')) mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() - mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer runner = MagicMock() @@ -602,6 +602,7 @@ def test_after_test_iter(self): hook.after_test_iter(runner, 42, [], outputs) # test max save time + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=10, n_samples=2, @@ -611,7 +612,6 @@ def test_after_test_iter(self): mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() - mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer runner = MagicMock() @@ -629,6 +629,7 @@ def test_after_test_iter(self): def test_after_train_epoch(self): model = MagicMock() + _ = Visualizer.get_instance('name1') hook = VisualizationHook( interval=1, n_samples=1, @@ -638,7 +639,6 @@ def test_after_train_epoch(self): hook.inputs_buffer = {'Data': ['dummy']} mock_visualuzer = MagicMock() mock_visualuzer.add_datasample = MagicMock() - mock_visualuzer.get_current_instance = MagicMock() hook._visualizer = mock_visualuzer runner = MagicMock() From 323ee20872e76a6e15208fc299c7ef311b182ce5 Mon Sep 17 00:00:00 2001 From: okotaku Date: Sat, 23 Sep 2023 14:09:47 +0900 Subject: [PATCH 12/14] add LoRACheckpointToSaveHook --- .../stable-diffusion_xl_lora_pokemon_blip.py | 3 +- mmagic/engine/hooks/__init__.py | 3 +- .../hooks/lora_checkpoint_to_save_hook.py | 35 ++++++ .../test_lora_checkpoint_to_save_hook.py | 103 ++++++++++++++++++ 4 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 mmagic/engine/hooks/lora_checkpoint_to_save_hook.py create mode 100644 tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py diff --git a/configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py b/configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py index e96bd8bef..ea5183818 100644 --- a/configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py +++ b/configs/stable_diffusion_xl/stable-diffusion_xl_lora_pokemon_blip.py @@ -19,5 +19,6 @@ fixed_input=True, # visualize train dataset vis_kwargs_list=dict(type='Data', name='fake_img'), - n_samples=1) + n_samples=1), + dict(type='LoRACheckpointToSaveHook') ] diff --git a/mmagic/engine/hooks/__init__.py b/mmagic/engine/hooks/__init__.py index 8435afa9a..e7de12ea1 100644 --- a/mmagic/engine/hooks/__init__.py +++ b/mmagic/engine/hooks/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. from .ema import ExponentialMovingAverageHook from .iter_time_hook import IterTimerHook +from .lora_checkpoint_to_save_hook import LoRACheckpointToSaveHook from .pggan_fetch_data_hook import PGGANFetchDataHook from .pickle_data_hook import PickleDataHook from .reduce_lr_scheduler_hook import ReduceLRSchedulerHook @@ -9,5 +10,5 @@ __all__ = [ 'ReduceLRSchedulerHook', 'BasicVisualizationHook', 'VisualizationHook', 'ExponentialMovingAverageHook', 'IterTimerHook', 'PGGANFetchDataHook', - 'PickleDataHook' + 'PickleDataHook', 'LoRACheckpointToSaveHook' ] diff --git a/mmagic/engine/hooks/lora_checkpoint_to_save_hook.py b/mmagic/engine/hooks/lora_checkpoint_to_save_hook.py new file mode 100644 index 000000000..8e83a352b --- /dev/null +++ b/mmagic/engine/hooks/lora_checkpoint_to_save_hook.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from collections import OrderedDict +from typing import List + +from mmengine.hooks import Hook +from mmengine.registry import HOOKS + + +@HOOKS.register_module() +class LoRACheckpointToSaveHook(Hook): + """Pick up LoRA weights from checkpoint. + + Args: + lora_keys (List[str]): + """ + priority = 'VERY_LOW' + + def __init__(self, lora_keys: List[str] = ['lora_mapping']): + super().__init__() + self.lora_keys = lora_keys + + def before_save_checkpoint(self, runner, checkpoint: dict) -> None: + """ + Args: + runner (Runner): The runner of the training, validation or testing + process. + checkpoint (dict): Model's checkpoint. + """ + + new_ckpt = OrderedDict() + for k in checkpoint['state_dict'].keys(): + if any(key in k for key in self.lora_keys): + new_ckpt[k] = checkpoint['state_dict'][k] + + checkpoint['state_dict'] = new_ckpt diff --git a/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py b/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py new file mode 100644 index 000000000..52118beff --- /dev/null +++ b/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import platform + +import pytest +import torch +from diffusers.models.unet_2d_condition import UNet2DConditionModel +from mmengine import Config +from mmengine.registry import MODELS +from mmengine.testing import RunnerTestCase + +from mmagic.engine.hooks import LoRACheckpointToSaveHook +from mmagic.models.archs import DiffusersWrapper +from mmagic.models.data_preprocessors import DataPreprocessor +from mmagic.models.editors import ClipWrapper, StableDiffusionXL +from mmagic.models.editors.stable_diffusion import AutoencoderKL + +stable_diffusion_xl_tiny_url = 'hf-internal-testing/tiny-stable-diffusion-xl-pipe' # noqa +lora_config = dict(target_modules=['to_q', 'to_k', 'to_v']) +diffusion_scheduler = dict( + type='EditDDIMScheduler', + variance_type='learned_range', + beta_end=0.012, + beta_schedule='scaled_linear', + beta_start=0.00085, + num_train_timesteps=1000, + set_alpha_to_one=False, + clip_sample=False) +model = dict( + type='StableDiffusionXL', + unet=dict( + type='UNet2DConditionModel', + subfolder='unet', + from_pretrained=stable_diffusion_xl_tiny_url), + vae=dict(type='AutoencoderKL', sample_size=64), + text_encoder_one=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_tiny_url, + subfolder='text_encoder'), + tokenizer_one=stable_diffusion_xl_tiny_url, + text_encoder_two=dict( + type='ClipWrapper', + clip_type='huggingface', + pretrained_model_name_or_path=stable_diffusion_xl_tiny_url, + subfolder='text_encoder_2'), + tokenizer_two=stable_diffusion_xl_tiny_url, + scheduler=diffusion_scheduler, + val_prompts=['a dog', 'a dog'], + lora_config=lora_config) + + +@pytest.mark.skipif( + 'win' in platform.system().lower(), + reason='skip on windows due to limited RAM.') +@pytest.mark.skipif( + torch.__version__ < '1.9.0', + reason='skip on torch<1.9 due to unsupported torch.concat') +class TestLoRACheckpointToSaveHook(RunnerTestCase): + + def setUp(self) -> None: + MODELS.register_module( + name='StableDiffusionXL', module=StableDiffusionXL) + MODELS.register_module(name='ClipWrapper', module=ClipWrapper) + + def gen_wrapped_cls(module, module_name): + return type( + module_name, (DiffusersWrapper, ), + dict( + _module_cls=module, + _module_name=module_name, + __module__=__name__)) + + wrapped_module = gen_wrapped_cls(UNet2DConditionModel, + 'UNet2DConditionModel') + MODELS.register_module( + name='UNet2DConditionModel', module=wrapped_module, force=True) + MODELS.register_module(name='AutoencoderKL', module=AutoencoderKL) + MODELS.register_module( + name='DataPreprocessor', module=DataPreprocessor) + return super().setUp() + + def tearDown(self): + MODELS.module_dict.pop('StableDiffusionXL') + MODELS.module_dict.pop('ClipWrapper') + MODELS.module_dict.pop('UNet2DConditionModel') + MODELS.module_dict.pop('AutoencoderKL') + MODELS.module_dict.pop('DataPreprocessor') + return super().tearDown() + + def test_init(self): + LoRACheckpointToSaveHook() + + def test_before_save_checkpoint(self): + cfg = copy.deepcopy(self.epoch_based_cfg) + runner = self.build_runner(cfg) + runner.model = MODELS.build(Config(model)) + checkpoint = dict(state_dict=MODELS.build(Config(model)).state_dict()) + hook = LoRACheckpointToSaveHook() + hook.before_save_checkpoint(runner, checkpoint) + + for key in checkpoint['state_dict'].keys(): + assert 'lora_mapping' in key From bb2cbec739fafaad52b1c7afc38b625fe69e9d99 Mon Sep 17 00:00:00 2001 From: okotaku Date: Sat, 23 Sep 2023 14:29:16 +0900 Subject: [PATCH 13/14] add LoRACheckpointToSaveHook --- .../test_hooks/test_lora_checkpoint_to_save_hook.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py b/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py index 52118beff..faf480baf 100644 --- a/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py +++ b/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import gc import platform import pytest @@ -86,6 +87,11 @@ def tearDown(self): MODELS.module_dict.pop('UNet2DConditionModel') MODELS.module_dict.pop('AutoencoderKL') MODELS.module_dict.pop('DataPreprocessor') + + gc.collect() + globals().clear() + locals().clear() + return super().tearDown() def test_init(self): @@ -101,3 +107,10 @@ def test_before_save_checkpoint(self): for key in checkpoint['state_dict'].keys(): assert 'lora_mapping' in key + + +def teardown_module(): + import gc + gc.collect() + globals().clear() + locals().clear() From 0e6356a7141745af38f0a814f8aa02413b170476 Mon Sep 17 00:00:00 2001 From: okotaku Date: Sat, 23 Sep 2023 16:05:06 +0900 Subject: [PATCH 14/14] add LoRACheckpointToSaveHook --- .../test_hooks/test_lora_checkpoint_to_save_hook.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py b/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py index faf480baf..3f1808884 100644 --- a/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py +++ b/tests/test_engine/test_hooks/test_lora_checkpoint_to_save_hook.py @@ -1,6 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -import gc import platform import pytest @@ -87,11 +86,6 @@ def tearDown(self): MODELS.module_dict.pop('UNet2DConditionModel') MODELS.module_dict.pop('AutoencoderKL') MODELS.module_dict.pop('DataPreprocessor') - - gc.collect() - globals().clear() - locals().clear() - return super().tearDown() def test_init(self):