From 48f8574f4dcb7998f63856bd7139dd6a905338d0 Mon Sep 17 00:00:00 2001 From: gaomingqi Date: Sat, 29 Apr 2023 07:30:10 +0800 Subject: [PATCH] update inpainter for less gpu memory usage and accessible configs --- README.md | 4 ++- inpainter/base_inpainter.py | 68 +++++++++++++++++++++--------------- inpainter/config/config.yaml | 3 ++ 3 files changed, 46 insertions(+), 29 deletions(-) diff --git a/README.md b/README.md index 1f8d84e..b479dd5 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,9 @@ ## :rocket: Updates -- 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, an inventive project from our lab that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT. +- 2023/04/29: We improved inpainting by decoupling GPU memory usage and video length. Now Track-Anything can inpaint videos with any length! Check [HERE] for our GPU memory requirements. + +- 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, an inventive project from our lab that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT. - 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=trueg) on Hugging Face :hugs:! diff --git a/inpainter/base_inpainter.py b/inpainter/base_inpainter.py index d80d613..c32422c 100644 --- a/inpainter/base_inpainter.py +++ b/inpainter/base_inpainter.py @@ -28,6 +28,9 @@ def __init__(self, E2FGVI_checkpoint, device) -> None: self.neighbor_stride = config['neighbor_stride'] self.num_ref = config['num_ref'] self.step = config['step'] + # config for E2FGVI with splits + self.num_subset_frames = config['num_subset_frames'] + self.num_external_ref = config['num_external_ref'] # sample reference frames from the whole video def get_ref_index(self, f, neighbor_ids, length): @@ -104,11 +107,9 @@ def inpaint_efficient(self, frames, masks, num_tcb, num_tca, dilate_radius=15, r if num_tcb > 0: tcb_imgs = imgs[:, :num_tcb] tcb_masks = masks[:, :num_tcb] - tcb_binary = binary_masks[:num_tcb] if num_tca > 0: tca_imgs = imgs[:, -num_tca:] tca_masks = masks[:, -num_tca:] - tca_binary = binary_masks[-num_tca:] end_idx = -num_tca else: end_idx = T @@ -185,18 +186,25 @@ def inpaint(self, frames, masks, dilate_radius=15, ratio=1): assert frames.shape[:3] == masks.shape, 'different size between frames and masks' assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]' - # set interval - interval = 45 - context_range = 10 # for each split, consider its temporal context [-context_range] frames and [context_range] frames + # set num_subset_frames + num_subset_frames = self.num_subset_frames # split frames into subsets video_length = len(frames) - num_splits = video_length // interval - id_splits = [[i*interval, (i+1)*interval] for i in range(num_splits)] # id splits - # if remaining split > interval/2, add a new split, else, append to the last split - if video_length - id_splits[-1][-1] > interval / 2: - id_splits.append([num_splits*interval, video_length]) + num_splits = video_length // num_subset_frames + id_splits = [[i*num_subset_frames, (i+1)*num_subset_frames] for i in range(num_splits)] # id splits + + if num_splits == 0: + id_splits = [[0, video_length]] + + # if remaining split > num_subset_frames/2, add a new split, else, append to the last split + if video_length - id_splits[-1][-1] > num_subset_frames / 3: + id_splits.append([num_splits*num_subset_frames, video_length]) else: - id_splits[-1][-1] = video_length + diff = video_length - id_splits[-1][-1] + id_splits = [[ids[0]+diff, ids[1]+diff] for ids in id_splits] + id_splits[0][0] = 0 # if OOM, let it happen at the begining :D + + # if appending, convert the appended split to the FIRST one, avoiding OOM at last # perform inpainting for each split inpainted_splits = [] @@ -205,15 +213,15 @@ def inpaint(self, frames, masks, dilate_radius=15, ratio=1): mask_split = masks[id_split[0]:id_split[1]] # | id_before | ----- | id_split[0] | ----- | id_split[1] | ----- | id_after | - # add temporal context - id_before = max(0, id_split[0] - self.step * context_range) + # for each split, consider its temporal context [-context_range] frames and [context_range] frames + id_before = max(0, id_split[0] - self.step * self.num_external_ref) try: - tcb_frames = np.stack([frames[idb] for idb in range(id_before, id_split[0]-self.step, self.step)], 0) - tcb_masks = np.stack([masks[idb] for idb in range(id_before, id_split[0]-self.step, self.step)], 0) + tcb_frames = np.stack([frames[idb] for idb in range(id_before, (id_split[0]-self.step) + 1, self.step)], 0) + tcb_masks = np.stack([masks[idb] for idb in range(id_before, (id_split[0]-self.step) + 1, self.step)], 0) num_tcb = len(tcb_frames) except: num_tcb = 0 - id_after = min(video_length, id_split[1] + self.step * context_range) + id_after = min(video_length, id_split[1] + self.step * self.num_external_ref + 1) try: tca_frames = np.stack([frames[ida] for ida in range(id_split[1]+self.step, id_after, self.step)], 0) tca_masks = np.stack([masks[ida] for ida in range(id_split[1]+self.step, id_after, self.step)], 0) @@ -229,10 +237,12 @@ def inpaint(self, frames, masks, dilate_radius=15, ratio=1): video_split = np.concatenate([video_split, tca_frames], 0) mask_split = np.concatenate([mask_split, tca_masks], 0) + torch.cuda.empty_cache() # inpaint each split inpainted_splits.append(self.inpaint_efficient(video_split, mask_split, num_tcb, num_tca, dilate_radius, ratio)) - + torch.cuda.empty_cache() inpainted_frames = np.concatenate(inpainted_splits, 0) + return inpainted_frames.astype(np.uint8) def inpaint_ori(self, frames, masks, dilate_radius=15, ratio=1): @@ -265,6 +275,8 @@ def inpaint_ori(self, frames, masks, dilate_radius=15, ratio=1): if min(size) < 50: ratio = 50. / min(H, W) size = [int(W*ratio), int(H*ratio)] + + size = [160, 120] binary_masks = resize_masks(masks, tuple(size)) frames = resize_frames(frames, tuple(size)) # T, H, W, 3 # frames and binary_masks are numpy arrays @@ -346,12 +358,12 @@ def inpaint_ori(self, frames, masks, dilate_radius=15, ratio=1): # ---------------------------------------------- # 1/3: set checkpoint and device checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth' - device = 'cuda:8' + device = 'cuda:4' # 2/3: initialise inpainter base_inpainter = BaseInpainter(checkpoint, device) # 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W) # ratio: (0, 1], ratio for down sample, default value is 1 - inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.2) # numpy array, T, H, W, 3 + inpainted_frames = base_inpainter.inpaint(frames[:1000], masks[:1000], ratio=0.1) # numpy array, T, H, W, 3 # save for ti, inpainted_frame in enumerate(inpainted_frames): @@ -361,12 +373,12 @@ def inpaint_ori(self, frames, masks, dilate_radius=15, ratio=1): torch.cuda.empty_cache() print('switch to ori') - inpainted_frames = base_inpainter.inpaint_ori(frames, masks, ratio=0.2) - save_path = '/ssd1/gaomingqi/results/inpainting/avengers' - # ---------------------------------------------- - # end - # ---------------------------------------------- - # save - for ti, inpainted_frame in enumerate(inpainted_frames): - frame = Image.fromarray(inpainted_frame).convert('RGB') - frame.save(os.path.join(save_path, f'{ti:05d}.jpg')) + # inpainted_frames = base_inpainter.inpaint_ori(frames[:50], masks[:50], ratio=0.1) + # save_path = '/ssd1/gaomingqi/results/inpainting/avengers' + # # ---------------------------------------------- + # # end + # # ---------------------------------------------- + # # save + # for ti, inpainted_frame in enumerate(inpainted_frames): + # frame = Image.fromarray(inpainted_frame).convert('RGB') + # frame.save(os.path.join(save_path, f'{ti:05d}.jpg')) diff --git a/inpainter/config/config.yaml b/inpainter/config/config.yaml index ef4c180..ec79c29 100644 --- a/inpainter/config/config.yaml +++ b/inpainter/config/config.yaml @@ -2,3 +2,6 @@ neighbor_stride: 5 num_ref: -1 step: 10 +# config infor for E2FGVI with splits (updated on 23/04/29) +num_subset_frames: 50 +num_external_ref: 2 # (>=0)