Skip to content

Commit

Permalink
update inpainter for less gpu memory usage and accessible configs
Browse files Browse the repository at this point in the history
  • Loading branch information
gaomingqi committed Apr 28, 2023
1 parent 91d7172 commit 48f8574
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 29 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
<!-- ![avengers]() -->

## :rocket: Updates
- 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, an inventive project from our lab that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT.
- 2023/04/29: We improved inpainting by decoupling GPU memory usage and video length. Now Track-Anything can inpaint videos with any length! Check [HERE] for our GPU memory requirements.

- 2023/04/25: We are delighted to introduce [Caption-Anything](https://github.com/ttengwang/Caption-Anything) :writing_hand:, an inventive project from our lab that combines the capabilities of Segment Anything, Visual Captioning, and ChatGPT.

- 2023/04/20: We deployed [[DEMO]](https://huggingface.co/spaces/watchtowerss/Track-Anything?duplicate=trueg) on Hugging Face :hugs:!

Expand Down
68 changes: 40 additions & 28 deletions inpainter/base_inpainter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def __init__(self, E2FGVI_checkpoint, device) -> None:
self.neighbor_stride = config['neighbor_stride']
self.num_ref = config['num_ref']
self.step = config['step']
# config for E2FGVI with splits
self.num_subset_frames = config['num_subset_frames']
self.num_external_ref = config['num_external_ref']

# sample reference frames from the whole video
def get_ref_index(self, f, neighbor_ids, length):
Expand Down Expand Up @@ -104,11 +107,9 @@ def inpaint_efficient(self, frames, masks, num_tcb, num_tca, dilate_radius=15, r
if num_tcb > 0:
tcb_imgs = imgs[:, :num_tcb]
tcb_masks = masks[:, :num_tcb]
tcb_binary = binary_masks[:num_tcb]
if num_tca > 0:
tca_imgs = imgs[:, -num_tca:]
tca_masks = masks[:, -num_tca:]
tca_binary = binary_masks[-num_tca:]
end_idx = -num_tca
else:
end_idx = T
Expand Down Expand Up @@ -185,18 +186,25 @@ def inpaint(self, frames, masks, dilate_radius=15, ratio=1):
assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'

# set interval
interval = 45
context_range = 10 # for each split, consider its temporal context [-context_range] frames and [context_range] frames
# set num_subset_frames
num_subset_frames = self.num_subset_frames
# split frames into subsets
video_length = len(frames)
num_splits = video_length // interval
id_splits = [[i*interval, (i+1)*interval] for i in range(num_splits)] # id splits
# if remaining split > interval/2, add a new split, else, append to the last split
if video_length - id_splits[-1][-1] > interval / 2:
id_splits.append([num_splits*interval, video_length])
num_splits = video_length // num_subset_frames
id_splits = [[i*num_subset_frames, (i+1)*num_subset_frames] for i in range(num_splits)] # id splits

if num_splits == 0:
id_splits = [[0, video_length]]

# if remaining split > num_subset_frames/2, add a new split, else, append to the last split
if video_length - id_splits[-1][-1] > num_subset_frames / 3:
id_splits.append([num_splits*num_subset_frames, video_length])
else:
id_splits[-1][-1] = video_length
diff = video_length - id_splits[-1][-1]
id_splits = [[ids[0]+diff, ids[1]+diff] for ids in id_splits]
id_splits[0][0] = 0 # if OOM, let it happen at the begining :D

# if appending, convert the appended split to the FIRST one, avoiding OOM at last

# perform inpainting for each split
inpainted_splits = []
Expand All @@ -205,15 +213,15 @@ def inpaint(self, frames, masks, dilate_radius=15, ratio=1):
mask_split = masks[id_split[0]:id_split[1]]

# | id_before | ----- | id_split[0] | ----- | id_split[1] | ----- | id_after |
# add temporal context
id_before = max(0, id_split[0] - self.step * context_range)
# for each split, consider its temporal context [-context_range] frames and [context_range] frames
id_before = max(0, id_split[0] - self.step * self.num_external_ref)
try:
tcb_frames = np.stack([frames[idb] for idb in range(id_before, id_split[0]-self.step, self.step)], 0)
tcb_masks = np.stack([masks[idb] for idb in range(id_before, id_split[0]-self.step, self.step)], 0)
tcb_frames = np.stack([frames[idb] for idb in range(id_before, (id_split[0]-self.step) + 1, self.step)], 0)
tcb_masks = np.stack([masks[idb] for idb in range(id_before, (id_split[0]-self.step) + 1, self.step)], 0)
num_tcb = len(tcb_frames)
except:
num_tcb = 0
id_after = min(video_length, id_split[1] + self.step * context_range)
id_after = min(video_length, id_split[1] + self.step * self.num_external_ref + 1)
try:
tca_frames = np.stack([frames[ida] for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
tca_masks = np.stack([masks[ida] for ida in range(id_split[1]+self.step, id_after, self.step)], 0)
Expand All @@ -229,10 +237,12 @@ def inpaint(self, frames, masks, dilate_radius=15, ratio=1):
video_split = np.concatenate([video_split, tca_frames], 0)
mask_split = np.concatenate([mask_split, tca_masks], 0)

torch.cuda.empty_cache()
# inpaint each split
inpainted_splits.append(self.inpaint_efficient(video_split, mask_split, num_tcb, num_tca, dilate_radius, ratio))

torch.cuda.empty_cache()
inpainted_frames = np.concatenate(inpainted_splits, 0)

return inpainted_frames.astype(np.uint8)

def inpaint_ori(self, frames, masks, dilate_radius=15, ratio=1):
Expand Down Expand Up @@ -265,6 +275,8 @@ def inpaint_ori(self, frames, masks, dilate_radius=15, ratio=1):
if min(size) < 50:
ratio = 50. / min(H, W)
size = [int(W*ratio), int(H*ratio)]

size = [160, 120]
binary_masks = resize_masks(masks, tuple(size))
frames = resize_frames(frames, tuple(size)) # T, H, W, 3
# frames and binary_masks are numpy arrays
Expand Down Expand Up @@ -346,12 +358,12 @@ def inpaint_ori(self, frames, masks, dilate_radius=15, ratio=1):
# ----------------------------------------------
# 1/3: set checkpoint and device
checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth'
device = 'cuda:8'
device = 'cuda:4'
# 2/3: initialise inpainter
base_inpainter = BaseInpainter(checkpoint, device)
# 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
# ratio: (0, 1], ratio for down sample, default value is 1
inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.2) # numpy array, T, H, W, 3
inpainted_frames = base_inpainter.inpaint(frames[:1000], masks[:1000], ratio=0.1) # numpy array, T, H, W, 3

# save
for ti, inpainted_frame in enumerate(inpainted_frames):
Expand All @@ -361,12 +373,12 @@ def inpaint_ori(self, frames, masks, dilate_radius=15, ratio=1):
torch.cuda.empty_cache()
print('switch to ori')

inpainted_frames = base_inpainter.inpaint_ori(frames, masks, ratio=0.2)
save_path = '/ssd1/gaomingqi/results/inpainting/avengers'
# ----------------------------------------------
# end
# ----------------------------------------------
# save
for ti, inpainted_frame in enumerate(inpainted_frames):
frame = Image.fromarray(inpainted_frame).convert('RGB')
frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
# inpainted_frames = base_inpainter.inpaint_ori(frames[:50], masks[:50], ratio=0.1)
# save_path = '/ssd1/gaomingqi/results/inpainting/avengers'
# # ----------------------------------------------
# # end
# # ----------------------------------------------
# # save
# for ti, inpainted_frame in enumerate(inpainted_frames):
# frame = Image.fromarray(inpainted_frame).convert('RGB')
# frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
3 changes: 3 additions & 0 deletions inpainter/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@
neighbor_stride: 5
num_ref: -1
step: 10
# config infor for E2FGVI with splits (updated on 23/04/29)
num_subset_frames: 50
num_external_ref: 2 # (>=0)

0 comments on commit 48f8574

Please sign in to comment.