diff --git a/example/hello_world.cpp b/example/hello_world.cpp index 492d64316..d6525f57f 100644 --- a/example/hello_world.cpp +++ b/example/hello_world.cpp @@ -15,12 +15,12 @@ int main() { - // get the default GPU device - boost::compute::device gpu = - boost::compute::system::default_gpu_device(); + // get the default device + boost::compute::device device = + boost::compute::system::default_device(); // print the GPU's name - std::cout << "hello from " << gpu.name() << std::endl; + std::cout << "hello from " << device.name() << std::endl; return 0; } diff --git a/include/boost/compute/system.hpp b/include/boost/compute/system.hpp index d9f894911..cbc4e2994 100644 --- a/include/boost/compute/system.hpp +++ b/include/boost/compute/system.hpp @@ -31,10 +31,21 @@ class system static device default_device() { // check for device from environment variable - const char *name = std::getenv("BOOST_COMPUTE_DEFAULT_DEVICE"); - if(name){ - device device = find_device(name); - if(device.id()){ + const char *name = std::getenv("BOOST_COMPUTE_DEFAULT_DEVICE"); + const char *platform = std::getenv("BOOST_COMPUTE_DEFAULT_PLATFORM"); + const char *vendor = std::getenv("BOOST_COMPUTE_DEFAULT_VENDOR"); + + if(name || platform || vendor){ + BOOST_FOREACH(const device &device, devices()){ + if (name && !matches(device.name(), name)) + continue; + + if (platform && !matches(device_platform(device).name(), platform)) + continue; + + if (vendor && !matches(device.vendor(), vendor)) + continue; + return device; } } @@ -80,7 +91,7 @@ class system static device find_device(const std::string &name) { BOOST_FOREACH(const device &device, devices()){ - if(device.name() == name){ + if(device.name().find(name) != std::string::npos){ return device; } } @@ -141,6 +152,15 @@ class system clGetPlatformIDs(0, 0, &count); return static_cast(count); } + +private: + static platform device_platform(const device &device) { + return platform( device.get_info(CL_DEVICE_PLATFORM) ); + } + + static bool matches(const std::string &str, const std::string &pattern) { + return str.find(pattern) != std::string::npos; + } }; } // end compute namespace