Skip to content

Commit

Permalink
brain floating point (#21)
Browse files Browse the repository at this point in the history
* brain floating point

* fix ref

* test and fix Exp

* fix NaN test

* 2^i for bfloat is too much(-126~127), removed
  • Loading branch information
dannypsnl committed Dec 4, 2020
1 parent 16bbe2f commit 505706a
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 0 deletions.
103 changes: 103 additions & 0 deletions bfloat/bfloat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package bfloat

import (
"fmt"
"math/big"
)

const (
// precision specifies the number of bits in the mantissa (including the
// implicit lead bit).
precision = 8
// exponent bias.
bias = 127
)

// Float is a floating-point number in bfloat16 floating-point format.
type Float struct {
// Sign, exponent and fraction.
//
// 1 bit: sign
// 8 bits: exponent
// 7 bits: fraction
bits uint16
}

func NewFromBits(bits uint16) Float {
return Float{bits: bits}
}

func (f Float) Big() (x *big.Float, nan bool) {
signbit := f.Signbit()
exp := f.Exp()
frac := f.Frac()
x = big.NewFloat(0)
x.SetPrec(precision)
x.SetMode(big.ToNearestEven)

// ref: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Contrast_with_bfloat16_and_single_precision
//
// 0b00001 - 0b11110
// Normalized number.
//
// (-1)^signbit * 2^(exp-127) * 1.mant_2
lead := 1
exponent := exp - bias

switch exp {
case 0xFF:
// Inf or NaN
if frac == 0 {
// +-Inf
x.SetInf(signbit)
return x, false
}
// +-NaN
if signbit {
x.Neg(x)
}
return x, true
case 0x00:
if frac == 0 {
// +-Zero
if signbit {
x.Neg(x)
}
return x, false
}
// Denormalized number.
//
// (-1)^signbit * 2^(-126) * 0.mant_2
lead = 0
exponent = -126
}

// number = [ sign ] [ prefix ] mantissa [ exponent ] | infinity .
sign := "+"
if signbit {
sign = "-"
}
s := fmt.Sprintf("%s0b%d.%07bp%d", sign, lead, frac, exponent)
if _, _, err := x.Parse(s, 0); err != nil {
panic(err)
}
return x, false
}

// Signbit reports whether f is negative or negative 0.
func (f Float) Signbit() bool {
// first bit is sign bit: 0b1000000000000000
return f.bits&0x8000 != 0
}

// Exp returns the exponent of f.
func (f Float) Exp() int {
// 8 bit exponent: 0b0111111110000000
return int(f.bits & 0x7F80 >> 7)
}

// Frac returns the fraction of f.
func (f Float) Frac() uint16 {
// 7 bit mantissa: 0b0000000001111111
return f.bits & 0x7F
}
47 changes: 47 additions & 0 deletions bfloat/bfloat_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package bfloat

import (
"math"
"testing"
)

func TestNewFromBits(t *testing.T) {
golden := []struct {
bits uint16
want float64
}{
// Special numbers.
// 0 00000000 0000000 = 0
{bits: 0, want: 0},
// 1 00000000 0000000 = -0
{bits: 0x8000, want: 1. / math.Inf(-1)},
// 0 11111111 0000000 = +Inf
{bits: 0x7f80, want: math.Inf(1)},
// 1 11111111 0000000 = -Inf
{bits: 0xff80, want: math.Inf(-1)},

// 0 11111111 0000001 = +NaN
{bits: 0x7f81, want: math.NaN()},
// 1 11111111 0000001 = -NaN
{bits: 0xff81, want: -math.NaN()},

// from: https://en.wikipedia.org/wiki/Bfloat16_floating-point_format#Examples
{bits: 0x3f80, want: 1},
{bits: 0xc000, want: -2},
{bits: 0x4049, want: 3.140625},
{bits: 0x3eab, want: 0.333984375},
}
for _, g := range golden {
f := NewFromBits(g.bits)
b, isNan := f.Big()
got, _ := b.Float64()
if isNan {
got = g.want
}
wantBits := math.Float64bits(g.want)
gotBits := math.Float64bits(got)
if wantBits != gotBits {
t.Errorf("0x%04X: number mismatch; expected 0x%016X (%v), got 0x%016X (%v)", g.bits, wantBits, g.want, gotBits, got)
}
}
}

0 comments on commit 505706a

Please sign in to comment.