-
Notifications
You must be signed in to change notification settings - Fork 228
/
device.h
96 lines (73 loc) · 2.08 KB
/
device.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#pragma once
#include <cmath>
#include <cstdint>
#include "common/definitions.h"
namespace marian {
class Device {
protected:
DeviceId deviceId_;
uint8_t* data_{0};
size_t size_{0};
size_t alignment_;
size_t align(size_t size) {
return (size_t)(ceil(size / (float)alignment_) * alignment_);
}
public:
Device(DeviceId deviceId, size_t alignment = 256)
: deviceId_(deviceId), alignment_(alignment) {}
virtual ~Device(){};
virtual void reserve(size_t size) = 0;
virtual uint8_t* data() { return data_; }
virtual size_t size() { return size_; }
virtual DeviceId getDeviceId() { return deviceId_; }
};
namespace gpu {
class Device : public marian::Device {
public:
Device(DeviceId deviceId, size_t alignment = 256)
: marian::Device(deviceId, alignment) {}
~Device();
void reserve(size_t size) override;
};
} // namespace gpu
namespace cpu {
class Device : public marian::Device {
public:
Device(DeviceId deviceId, size_t alignment = 256)
: marian::Device(deviceId, alignment) {}
~Device();
void reserve(size_t size) override;
};
class WrappedDevice : public marian::Device {
public:
WrappedDevice(DeviceId deviceId, size_t alignment = 256)
: marian::Device(deviceId, alignment) {}
~WrappedDevice() {}
void set(uint8_t* data, size_t size) {
marian::Device::data_ = data;
marian::Device::size_ = size;
}
// doesn't allocate anything, just checks size.
void reserve(size_t size) override {
ABORT_IF(size > size_,
"Requested size {} is larger than pre-allocated size {}",
size,
size_);
}
};
} // namespace cpu
static inline Ptr<Device> DispatchDevice(DeviceId deviceId,
size_t alignment = 256) {
#ifdef CUDA_FOUND
if(deviceId.type == DeviceType::gpu)
return New<gpu::Device>(deviceId, alignment);
else
return New<cpu::Device>(deviceId, alignment);
#else
if(deviceId.type == DeviceType::gpu)
ABORT("CUDA support not compiled into marian");
else
return New<cpu::Device>(deviceId, alignment);
#endif
}
} // namespace marian