Skip to content

Commit

Permalink
Merge pull request #53 from gomlx/xlago
Browse files Browse the repository at this point in the history
Fixes to xla package and variable checks when loading.
  • Loading branch information
janpfeifer committed Jun 12, 2024
2 parents 7c3243f + 33440fd commit e56d203
Show file tree
Hide file tree
Showing 9 changed files with 113 additions and 206 deletions.
2 changes: 1 addition & 1 deletion docker/jupyterlab/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ FROM nvidia/cuda:12.3.2-cudnn9-runtime-ubuntu22.04

LABEL maintainer="pfeifer@gmail.com"

ARG GO_VERSION="1.22.2"
ARG GO_VERSION="1.22.4"

# Pre-built GoMLX C library: if not building yourself, you can download it from
# https://github.com/gomlx/gomlx/releases/download/${GOMLX_VERSION}/gomlx_xla-linux-amd64.tar.gz
Expand Down
13 changes: 9 additions & 4 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# GoMLX changelog

## Next
## 0.10.1 - 2024/06/12

* `types.shapes` package:
* **Added support for `Float16` training -- tested with GNNs.**
Expand All @@ -11,22 +11,27 @@
* Added support for `Int8`, `Int16`, `Uint8` and `Uint16`.
* Renamed `UInt{X}` to `Uint{X}` and added a deprecated alias to the old form (so it still compiles).
* Added logging of time to build and compile graph. Last version improved a lot the execution time, but slowed the compilation.
* Fixed `Variable.SetValueGraph` when the shape changes. Improved some documentation.
* Context.Variable:
* Fixed `Variable.SetValueGraph` when the shape changes. Improved some documentation.
* Fixed `Variable.SetValuePreservingOld` when shapes change.
* Fixed checking of loaded variables -- that they are not newly created.
* Package `optimizers`:
* Fixed optimizer constructor `FromContext` to allow further configuration of the optimizer by setting other hyperparameters into context.
* Added hyperparameter `clip_step_by_value`, a clip by value applied to gradient updates.
* `Adam` optimizer: `"clip_step_by_value", "adam_epsilon", "adam_dtype"` hyperparameters support.
* **`MustOptimizerByName` now takes also the context for the optimizer hyperparameters.** -- this breaks the API.
* Package `checkpoints`:
* Allow adding variables to exclude from saving after checkpoint is created -- for newly created variables.
* Allow adding variables to exclude from saving after checkpoint is created -- for newly created variables
* Added `slices.CloseToEpsilon` to easily customize tests.
* Fixed `Variable.SetValuePreservingOld` when shapes change.
* `Scatter` doesn't assume indices are sorted or unique.
* Plotly training plots: added `WithCustomMetricFn` for custom metrics and `ScheduleEveryNSteps`.
* Added OGBN_MAG GNN example:
* Including Layer-Wise Inference.
* Package graph:
* Added `Shift`, `ShiftLeft`, `ShiftRight`, `ShiftWithScalar`, `ShiftWithValue`.
* Dummy package for xla.AOT and xla.StableHLO APIs enabled when using "google3" build tag: this allows the dependency
to the corresponding C++ code to be dropped. (Thanks @tdegris).
* Removed xla.AOTExecute: see issue #52

## 0.9.1 - 2024/04/19

Expand Down
97 changes: 0 additions & 97 deletions ml/aot/aot_test.go

This file was deleted.

63 changes: 27 additions & 36 deletions ml/context/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ import (
//
// Finally, Context also allows one to checkpoint the variable values (save and load). See checkpoint package.
//
// Variable duplicate creation checking:
// the context is by default configure to with Context.Checked(true), which checks at every variable creation whether
// the variable already exists. This is useful to prevent unintended reuse of variables. When checked, variable creation
// (with Context.VariableWithShape and Context.VariableWithValue) will panic if:
//
// - Context.Unique() (the default) and variable already exists (or was loaded);
// - Context.Reuse() and variable didn't exist (or was not loaded);
//
// Remember to set Context.Reuse if you expect to load the variables, or disable Context.Checked(false) if only some
// variables are going to be loaded.
//
// TODO: Handling of devices with multiple instances (e.g.: multiple GPUs/TPUs).
type Context struct {
// scope for currently created variables and registration.
Expand Down Expand Up @@ -554,12 +565,11 @@ func (ctx *Context) ExecSetVariablesInParams(params graph.ParamsMap, g *Graph) {
// The root scope is "/" (RootScope).
func (ctx *Context) InspectVariable(scope, name string) *Variable {
scopeVars, ok := ctx.data.variablesMap[scope]
if !ok {
return nil
}
v, found := scopeVars[name]
if found {
return v
if ok {
v, found := scopeVars[name]
if found {
return v
}
}

// Try to load it, if a loader (checkpoint handler) is configured.
Expand All @@ -571,7 +581,7 @@ func (ctx *Context) InspectVariable(scope, name string) *Variable {
if !found {
return nil
}
v = &Variable{
v := &Variable{
ctx: ctx,
name: name,
scope: scope,
Expand Down Expand Up @@ -677,6 +687,11 @@ func (ctx *Context) DeleteVariablesInScope() {
//
// Notice that variables information is stored in the "data" component of Context objects, and is shared
// among all connected context references.
//
// If Context is set with Context.Checked(true), this may panic if:
//
// - Context.Unique() and variable already exists (or was loaded);
// - Context.Reuse() and variable didn't exist (or was not loaded);
func (ctx *Context) VariableWithShape(name string, shape shapes.Shape) *Variable {
v := ctx.InspectVariable(ctx.scope, name)
if v == nil && ctx.checked && ctx.reuse {
Expand Down Expand Up @@ -705,36 +720,12 @@ func (ctx *Context) VariableWithShape(name string, shape shapes.Shape) *Variable
}
ctx.setVariableInScope(name, v)

// Try to load the variable. Report if something failed.
if ctx.tryToLoad(v) {
return v
}

// Set up variable for initialization.
v.initializer = ctx.initializer
ctx.data.needsInitialization = true
return v
}

// tryToLoad tries to load the variable from the loader. It returns true if it succeeded.
func (ctx *Context) tryToLoad(v *Variable) bool {
loader := ctx.data.loader
if loader == nil {
return false
}
value, found := loader.LoadVariable(ctx, v.Scope(), v.Name())
if found {
if value.Shape().Eq(v.shape) {
v.value = value
} else {
Panicf("loading of variable %q returned shape %s, but variable was created "+
"with shape %s -- did some hyperparameter change since variable was saved that changed "+
"the variable shape?", v.ParameterName(), value.Shape(), v.shape)
}
}
return found
}

func valueToTensor(value any) tensor.Tensor {
if tensorValue, ok := value.(tensor.Tensor); ok {
return tensorValue
Expand All @@ -757,6 +748,11 @@ func valueToTensor(value any) tensor.Tensor {
//
// Notice that variables' information is stored in the "data" component of Context objects, and is shared
// among all connected context references.
//
// If Context is set with Context.Checked(true), this may panic if:
//
// - Context.Unique() and variable already exists (or was loaded);
// - Context.Reuse() and variable didn't exist (or was not loaded);
func (ctx *Context) VariableWithValue(name string, value any) *Variable {
v := ctx.InspectVariable(ctx.scope, name)

Expand Down Expand Up @@ -794,11 +790,6 @@ func (ctx *Context) VariableWithValue(name string, value any) *Variable {
graphToNodes: make(map[graph.GraphId]*variableNodes),
}
ctx.setVariableInScope(name, v)

// Try to load the variable. Report if something failed.
if ctx.tryToLoad(v) {
return v
}
return v
}

Expand Down
1 change: 1 addition & 0 deletions ml/context/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ func TestContext_SetLoader(t *testing.T) {
},
},
})
ctx = ctx.Reuse()
e := NewExec(manager, ctx, func(ctx *Context, g *Graph) (*Node, *Node) {
v0 := ctx.WithInitializer(initializers.Zero).VariableWithShape("x", shapes.Make(shapes.Float32))
v1 := ctx.VariableWithValue("y", 1)
Expand Down
8 changes: 6 additions & 2 deletions types/tensor/tensor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,14 @@ func TestFromValue(t *testing.T) {
shape, err = shapeForValue([]complex128{1.0i, 1.0})
cmpShapes(t, shape, wantShape, err)

// Test for invalid `DType`.
wantShape = shapes.Shape{DType: shapes.Uint16, Dimensions: []int{1, 1}}
shape, err = shapeForValue([][]uint16{{3}})
cmpShapes(t, shape, wantShape, err)

// Test for invalid `DType`.
shape, err = shapeForValue([][]string{{"blah"}})
if shape.DType != shapes.InvalidDType {
t.Fatalf("Wanted InvalidDType for uint16, instead got %q", shape.DType)
t.Fatalf("Wanted InvalidDType for string, instead got %q", shape.DType)
}
if err == nil {
t.Fatalf("Should have returned error for unsupported DType")
Expand Down
66 changes: 0 additions & 66 deletions xla/aotexec.go

This file was deleted.

2 changes: 2 additions & 0 deletions xla/stablehlo.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//go:build !google3

/*
* Copyright 2023 Jan Pfeifer
*
Expand Down
Loading

0 comments on commit e56d203

Please sign in to comment.