Skip to content

Commit

Permalink
Merge pull request #511 from cornellius-gp/least_used_cuda_test
Browse files Browse the repository at this point in the history
Add context manager for selecting min mem cuda device
  • Loading branch information
gpleiss committed Feb 9, 2019
2 parents 45ce91f + 2407ae4 commit f6d6adb
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
18 changes: 18 additions & 0 deletions test/_utils.py
@@ -1,5 +1,8 @@
#!/usr/bin/env python3

from contextlib import contextmanager
from typing import Generator

import torch


Expand All @@ -17,3 +20,18 @@ def approx_equal(self, other, epsilon=1e-4):
"Size mismatch between self ({self}) and other ({other})".format(self=self.size(), other=other.size())
)
return torch.max((self - other).abs()) <= epsilon


def get_cuda_max_memory_allocations() -> int:
"""Get the `max_memory_allocated` for each cuda device"""
return torch.tensor([torch.cuda.max_memory_allocated(i) for i in range(torch.cuda.device_count())])


@contextmanager
def least_used_cuda_device() -> Generator:
"""Contextmanager for automatically selecting the cuda device
with the least allocated memory"""
mem_allocs = get_cuda_max_memory_allocations()
least_used_device = torch.argmin(mem_allocs).item()
with torch.cuda.device(least_used_device):
yield
5 changes: 4 additions & 1 deletion test/examples/test_svgp_gp_regression.py
Expand Up @@ -12,6 +12,7 @@
from gpytorch.models import AbstractVariationalGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy
from .._utils import least_used_cuda_device


def train_data(cuda=False):
Expand Down Expand Up @@ -91,7 +92,9 @@ def test_regression_error_skip_logdet_forward(self):
return self.test_regression_error(skip_logdet_forward=True)

def test_regression_error_cuda(self):
if torch.cuda.is_available():
if not torch.cuda.is_available():
return None
with least_used_cuda_device():
train_x, train_y = train_data(cuda=True)
likelihood = GaussianLikelihood().cuda()
model = SVGPRegressionModel(torch.linspace(0, 1, 25)).cuda()
Expand Down

0 comments on commit f6d6adb

Please sign in to comment.