Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/torchcodec/_core/Cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <torch/types.h>
#include <memory>
#include <mutex>
Expand All @@ -27,7 +29,7 @@ class Cache {
public:
using element_type = std::unique_ptr<T, D>;

Cache(int capacity) : capacity_(capacity) {}
explicit Cache(int capacity) : capacity_(capacity) {}

// Adds an object to the cache if the cache has capacity. Returns true
// if object was added and false otherwise.
Expand Down Expand Up @@ -56,8 +58,9 @@ bool Cache<T, D>::addIfCacheHasCapacity(element_type&& obj) {
template <typename T, typename D>
typename Cache<T, D>::element_type Cache<T, D>::get() {
std::scoped_lock lock(mutex_);
if (cache_.empty())
if (cache_.empty()) {
return nullptr;
}

element_type obj = std::move(cache_.back());
cache_.pop_back();
Expand Down Expand Up @@ -92,7 +95,15 @@ class PerGpuCache {
std::vector<std::unique_ptr<Cache<T, D>>> cache_;
};

torch::DeviceIndex getNonNegativeDeviceIndex(const torch::Device& device) {
// Note: this function is inline for convenience, not performance. Because the
// rest of this file is template functions, they must all be defined in this
// header. This function is not a template function, and should, in principle,
// be defined in a .cpp file to preserve the One Definition Rule. That's
// annoying for such a small amount of code, so we just inline it. If this file
// grows, and there are more such functions, we should break them out into a
// .cpp file.
inline torch::DeviceIndex getNonNegativeDeviceIndex(
const torch::Device& device) {
torch::DeviceIndex deviceIndex = device.index();
// For single GPU machines libtorch returns -1 for the device index. So for
// that case we set the device index to 0. That's used in per-gpu cache
Expand Down
Loading