# gorgonia/gorgonia

Switch branches/tags
Nothing to show
Fetching contributors…
Cannot retrieve contributors at this time
292 lines (258 sloc) 5.64 KB
 package gorgonia import ( "fmt" "hash/fnv" "log" "math" "github.com/chewxy/math32" "github.com/pkg/errors" "gonum.org/v1/gonum/graph" "gorgonia.org/tensor" ) const ( maxFloat32 = math32.MaxFloat32 maxFloat64 = math.MaxFloat64 ) // NodesToValueGrads is a utility function that converts a Nodes to a slice of ValueGrad for the solvers func NodesToValueGrads(in Nodes) (out []ValueGrad) { out = make([]ValueGrad, len(in)) for i := range in { out[i] = in[i] } return out } func graphNodeToNode(in []graph.Node) (out Nodes) { out = make(Nodes, len(in)) for i, n := range in { out[i] = n.(*Node) // will panic if not. which is a good thng } return } func nodeToGraphNode(in []*Node) (out []graph.Node) { out = make([]graph.Node, len(in)) for i, n := range in { out[i] = n } return } func tensorInfo(t tensor.Tensor) (dt tensor.Dtype, dim int) { dt = t.Dtype() dim = t.Dims() return } func cloneNodes(node Nodes, replacements map[*Node]*Node) Nodes { return nil } // valuesToInts will FORCIBLY cast floats to ints. func valuesToInts(values []Value) (retVal []int, err error) { retVal = tensor.BorrowInts(len(values)) for i, v := range values { var intV int switch sv := v.(type) { case *F64: intV = int(float64(*sv)) case *F32: intV = int(float32(*sv)) case *I: intV = int(*sv) case *I32: intV = int(int32(*sv)) case *I64: intV = int(int64(*sv)) case *U8: intV = int(byte(*sv)) case Scalar: return nil, errors.Errorf(nyiTypeFail, "valueToInts", v) default: return nil, errors.Errorf("Expected values to be all Scalar Value. Got %v of %T instead", v, v) } retVal[i] = intV } return } func valuesToTensors(values []Value) (retVal []tensor.Tensor, err error) { retVal = make([]tensor.Tensor, len(values)) for i, v := range values { if vt, ok := v.(tensor.Tensor); ok { retVal[i] = vt continue } return nil, errors.Errorf("Expected values to all be tensor.Tensor. Got %v of %T in %dth index of the slice", v, v, i) } return } func intRange(start, end int) []int { size := end - start incr := true if start > end { incr = false size = start - end } if size < 0 { panic("Cannot create an int range that is somehow negative in size") } retVal := make([]int, size) for i, v := 0, start; i < size; i++ { retVal[i] = v if incr { v++ } else { v-- } } return retVal } func ones(dt tensor.Dtype, sizes ...int) (retVal Value) { if len(sizes) == 0 { return one(dt) } return tensor.Ones(dt, sizes...) } func hasInf(v Value, dev Device) bool { switch vt := v.(type) { case *F64: return false return math.IsInf(float64(*vt), 0) case *F32: return false return math32.IsInf(float32(*vt), 0) case tensor.Tensor: if e, ok := vt.Engine().(tensor.InfChecker); ok { ok, _ := e.HasInf(vt) // BUG: errors not checked return ok } dt := vt.Dtype() if dt != tensor.Float64 && dt != tensor.Float32 { return false } switch dt { case tensor.Float32: data := vt.Data().([]float32) for _, datum := range data { if math32.IsInf(datum, 0) { return true } } case tensor.Float64: data := vt.Data().([]float64) for _, datum := range data { if math.IsInf(datum, 0) { return true } } } return false case *dualValue: return hasInf(vt.Value, dev) || hasInf(vt.d, dev) default: err := nyi("hasInf", v) panic(err) } } func hasNaN(v Value, dev Device) bool { switch vt := v.(type) { case *F64: return false return math.IsNaN(float64(*vt)) case *F32: return false return math32.IsNaN(float32(*vt)) case tensor.Tensor: if e, ok := vt.Engine().(tensor.NaNChecker); ok { ok, _ := e.HasNaN(vt) // BUG: errors not checked return ok } log.Printf("Value's engine %T", vt.Engine()) dt := vt.Dtype() if dt != tensor.Float64 && dt != tensor.Float32 { return false } switch dt { case tensor.Float32: data := vt.Data().([]float32) for _, datum := range data { if math32.IsNaN(datum) { return true } } case tensor.Float64: data := vt.Data().([]float64) for _, datum := range data { if math.IsNaN(datum) { return true } } } return false case *dualValue: return hasNaN(vt.Value, dev) || hasNaN(vt.d, dev) default: err := nyi("hasNaN", vt) panic(err) } } func setZero(val Value) (retVal Value) { switch v := val.(type) { case Zeroer: v.Zero() return v case Scalar: return zero(v.Dtype()) default: panic(fmt.Sprintf("setZero not implemented yet for %T", v)) } } func checkArity(op arityer, inputs int) error { if inputs != op.Arity() && op.Arity() >= 0 { return errors.Errorf("%v has an arity of %d. Got %d instead", op, op.Arity(), inputs) } return nil } func maxInt(a, b int) int { if a > b { return a } return b } func minInt(a, b int) int { if a < b { return a } return b } func ceilDivInt(a, b int) int { return (a + b - 1) / b } func simpleHash(op hashWriter) uint32 { h := fnv.New32a() op.WriteHash(h) return h.Sum32() } func getDV(x, y *Node) (xdv, ydv *dualValue) { return x.boundTo.(*dualValue), y.boundTo.(*dualValue) } func getDV3(x, y, z *Node) (xdv, ydv, zdv *dualValue) { return x.boundTo.(*dualValue), y.boundTo.(*dualValue), z.boundTo.(*dualValue) } func getConst(x *Node, constant string) (retVal *Node, err error) { var dt tensor.Dtype if dt, err = dtypeOf(x.t); err != nil { return nil, errors.Wrap(err, dtypeOfFail) } if m, ok := constmap[constant]; ok { if n, ok := m[dt]; ok { return n, nil } } return nil, errors.Errorf("constant %v not provided for %v", constant, dt) } func scalarEquiv(s tensor.Shape) bool { if len(s) == 0 { return true } prod := 1 for _, v := range s { prod *= v } return prod == 1 }