Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out of memory #21

Open
tungyen opened this issue Jun 10, 2024 · 1 comment
Open

Out of memory #21

tungyen opened this issue Jun 10, 2024 · 1 comment

Comments

@tungyen
Copy link

tungyen commented Jun 10, 2024

Hi, I can run the demo code of your model on Nvidia RTX-3090. But I only succeed on the video with length 3 seconds, but failed with video of length 9 seconds. Is there any way to optimize for this issue without changing GPU? Thank you very much.

@m43
Copy link

m43 commented Jul 30, 2024

Perhaps try running with half precision? E.g., see PyTorch docs here or use PyTorch Lightning to wrap your code and use their out-of-the-box mixed precision flags.

If you are using the mono depth estimators to obtain the depths, the memory bottleneck can be moved there for longer videos. You can try to batch the call to the depth estimator as in the following snippet which predicts the depths 10 frames a time:

with torch.no_grad():
    batch_size = 10
    if sample.video[0].shape[0] > batch_size:
        vidDepths = []
        for i in range(sample.video[0].shape[0] // batch_size + 1):
            if (i + 1) * batch_size > sample.video[0].shape[0]:
                end_idx = sample.video[0].shape[0]
            else:
                end_idx = (i + 1) * batch_size
            if i * batch_size == end_idx:
                break
            video = sample.video[0][i * batch_size:end_idx]
            vidDepths.append(depth_predictor.infer(video / 255))

        videodepth = torch.cat(vidDepths, dim=0)
    else:
        videodepth = depth_predictor.infer(sample.video[0] / 255)
args.depth_near = 0.01
args.depth_far = 65.0
depths = videodepth.clamp(args.depth_near, args.depth_far)

Also, make sure that gradients are not computed using the torch.no_grad() context manager, for example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants