forked from cms-sw/cmssw
-
Notifications
You must be signed in to change notification settings - Fork 0
/
libminifloat.h
168 lines (155 loc) · 7.02 KB
/
libminifloat.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#ifndef libminifloat_h
#define libminifloat_h
#include "FWCore/Utilities/interface/thread_safety_macros.h"
#include <cstdint>
#include <cassert>
#include <algorithm>
// ftp://ftp.fox-toolkit.org/pub/fasthalffloatconversion.pdf
class MiniFloatConverter {
public:
MiniFloatConverter() ;
inline static float float16to32(uint16_t h) {
union { float flt; uint32_t i32; } conv;
conv.i32 = mantissatable[offsettable[h>>10]+(h&0x3ff)]+exponenttable[h>>10];
return conv.flt;
}
inline static uint16_t float32to16(float x) {
return float32to16round(x);
}
/// Fast implementation, but it crops the number so it biases low
inline static uint16_t float32to16crop(float x) {
union { float flt; uint32_t i32; } conv;
conv.flt = x;
return basetable[(conv.i32>>23)&0x1ff]+((conv.i32&0x007fffff)>>shifttable[(conv.i32>>23)&0x1ff]);
}
/// Slower implementation, but it rounds to avoid biases
inline static uint16_t float32to16round(float x) {
union { float flt; uint32_t i32; } conv;
conv.flt = x;
uint8_t shift = shifttable[(conv.i32>>23)&0x1ff];
if (shift == 13) {
uint16_t base2 = (conv.i32&0x007fffff)>>12;
uint16_t base = base2 >> 1;
if (((base2 & 1) != 0) && (base < 1023)) base++;
return basetable[(conv.i32>>23)&0x1ff]+base;
} else {
return basetable[(conv.i32>>23)&0x1ff]+((conv.i32&0x007fffff)>>shifttable[(conv.i32>>23)&0x1ff]);
}
}
template<int bits>
inline static float reduceMantissaToNbits(const float &f)
{
static_assert(bits <= 23,"max mantissa size is 23 bits");
constexpr uint32_t mask = (0xFFFFFFFF >> (23-bits)) << (23-bits);
union { float flt; uint32_t i32; } conv;
conv.flt=f;
conv.i32&=mask;
return conv.flt;
}
inline static float reduceMantissaToNbits(const float &f, int bits)
{
uint32_t mask = (0xFFFFFFFF >> (23-bits)) << (23-bits);
union { float flt; uint32_t i32; } conv;
conv.flt=f;
conv.i32&=mask;
return conv.flt;
}
template<int bits>
inline static float reduceMantissaToNbitsRounding(const float &f)
{
static_assert(bits <= 23,"max mantissa size is 23 bits");
constexpr int shift = (23-bits); // bits I throw away
constexpr uint32_t mask = (0xFFFFFFFF >> (shift)) << (shift); // mask for truncation
constexpr uint32_t test = 1 << (shift-1); // most significant bit I throw away
constexpr uint32_t low23 = (0x007FFFFF); // mask to keep lowest 23 bits = mantissa
constexpr uint32_t hi9 = (0xFF800000); // mask to keep highest 9 bits = the rest
constexpr uint32_t maxn = (1<<bits)-2; // max number I can increase before overflowing
union { float flt; uint32_t i32; } conv;
conv.flt=f;
if (conv.i32 & test) { // need to round
uint32_t mantissa = (conv.i32 & low23) >> shift;
if (mantissa < maxn) mantissa++;
conv.i32 = (conv.i32 & hi9) | (mantissa << shift);
} else {
conv.i32 &= mask;
}
return conv.flt;
}
class ReduceMantissaToNbitsRounding {
public:
ReduceMantissaToNbitsRounding(int bits) :
shift(23-bits), mask((0xFFFFFFFF >> (shift)) << (shift)),
test(1 << (shift-1)), maxn((1<<bits)-2) {
assert(bits <= 23); // "max mantissa size is 23 bits"
}
float operator()(float f) const {
constexpr uint32_t low23 = (0x007FFFFF); // mask to keep lowest 23 bits = mantissa
constexpr uint32_t hi9 = (0xFF800000); // mask to keep highest 9 bits = the rest
union { float flt; uint32_t i32; } conv;
conv.flt=f;
if (conv.i32 & test) { // need to round
uint32_t mantissa = (conv.i32 & low23) >> shift;
if (mantissa < maxn) mantissa++;
conv.i32 = (conv.i32 & hi9) | (mantissa << shift);
} else {
conv.i32 &= mask;
}
return conv.flt;
}
private:
const int shift;
const uint32_t mask, test, maxn;
};
inline static float reduceMantissaToNbitsRounding(float f, int bits)
{
return ReduceMantissaToNbitsRounding(bits)(f);
}
template<typename InItr, typename OutItr>
static void reduceMantissaToNbitsRounding(int bits, InItr begin, InItr end, OutItr out)
{
std::transform(begin, end, out, ReduceMantissaToNbitsRounding(bits));
}
inline static float max() {
union { float flt; uint32_t i32; } conv;
conv.i32 = 0x477fe000; // = mantissatable[offsettable[0x1e]+0x3ff]+exponenttable[0x1e]
return conv.flt;
}
// Maximum float32 value that gets rounded to max()
inline static float max32RoundedToMax16() {
union { float flt; uint32_t i32; } conv;
// 2^16 in float32 is the first to result inf in float16, so
// 2^16-1 is the last float32 to result max() in float16
conv.i32 = (0x8f<<23) - 1;
return conv.flt;
}
inline static float min() {
union { float flt; uint32_t i32; } conv;
conv.i32 = 0x38800000; // = mantissatable[offsettable[1]+0]+exponenttable[1]
return conv.flt;
}
// Minimum float32 value that gets rounded to min()
inline static float min32RoundedToMin16() {
union { float flt; uint32_t i32; } conv;
// 2^-14-1 in float32 is the first to result denormalized in float16, so
// 2^-14 is the first float32 to result min() in float16
conv.i32 = (0x71<<23);
return conv.flt;
}
inline static float denorm_min() {
union { float flt; uint32_t i32; } conv;
conv.i32 = 0x33800000; // mantissatable[offsettable[0]+1]+exponenttable[0]
return conv.flt;
}
inline static bool isdenorm(uint16_t h) {
// if exponent is zero (sign-bit excluded of course) and mantissa is not zero
return ((h >> 10) & 0x1f) == 0 && (h & 0x3ff) != 0;
}
private:
CMS_THREAD_SAFE static uint32_t mantissatable[2048];
CMS_THREAD_SAFE static uint32_t exponenttable[64];
CMS_THREAD_SAFE static uint16_t offsettable[64];
CMS_THREAD_SAFE static uint16_t basetable[512];
CMS_THREAD_SAFE static uint8_t shifttable[512];
static void filltables() ;
};
#endif