# UCI Adult Dataset or Census Income

This is a very popular ML task, with tabular data. The objective is to predict whether income exceeds $50K/yr based on census data. 
Also known as "Census Income" dataset.

The data is old and biased on different ways ... but it can be used opaquely for ML experimentation.



## Environment Set Up

Let's set up `go.mod` to use the local copy of GoMLX, so it can be developed jointly the dataset code with the model. That's often how data pre-processing and model code is developed together with experimentation.

If you are not changing code, feel free to simply skip this cell. Or if you used a different directory for you projects, change it below.

Notice the directory `${HOME}/Projects/gomlx` is where the GoMLX code is copied by default in [its Docker](https://hub.docker.com/repository/docker/janpfeifer/gomlx_jupyterlab/general).

In [1]:
!*go mod edit -replace github.com/gomlx/gomlx="${HOME}/Projects/gomlx"

## Data Preparation

GoMLX provides [a simple `adult` library](https://pkg.go.dev/github.com/gomlx/gomlx/examples/adult) to facilitate downdoaling and preprocessing the data. Data is available in [UCI Machine Learning Repository](https://archive.ics.uci.edu/ml/datasets/Adult).

After downloading the data and validating the checksum (both training and testing), it generates the quantiles for the continuous features, and the vocabularies for the categorical features. It saves all this info for faster restart later in a binary file. So this won't be necessary a second time.

The quantiles are used to calibrate the values, using a piece-wise-lienar calibration, very good for these things. See [`layers.PieceWiseLinearCalibration` documentation](https://pkg.go.dev/github.com/gomlx/gomlx@v0.1.0/ml/layers#PieceWiseLinearCalibration).

We create a flag `--data` to define the directory where to save the intermediary files: downloaded and preprocessed datasets.
In this examle we set it to `~/work/uci-adult`. Verbosity can be contolled with the `--verbosity` flag. 

We set default in Go for these flags, but they can easily be reset for a new run by providing them after the `%%` Jupyter kernel meta-command -- in indicates that the subsequent lines should be put in to a `func main`.


In [2]:
import (
    "flag"
    
    "github.com/gomlx/gomlx/examples/adult"
)

var (
    flagDataDir       = flag.String("data", "~/work/uci-adult", "Directory to save and load downloaded and generated dataset files.")
    flagVerbosity     = flag.Int("verbosity", 0, "Level of verbosity, the higher the more verbose.")
    flagForceDownload = flag.Bool("force_download", false, "Force re-download of Adult dataset files.")
    flagNumQuantiles  = flag.Int("quantiles", 100, "Max number of quantiles to use for numeric features, used during piece-wise linear calibration. It will only use unique values, so if there are fewer variability, fewer quantiles are used.")
)

%% --verbosity=2
adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)


Sample Categorical: (24.08% positive ratio, 23.86% weighted positive ratio)
	Row 0:	[7 10 5 1 2 5 2 39]
	Row 1:	[6 10 3 4 1 5 2 39]
	Row 2:	[4 12 1 6 2 5 2 39]
	...
	Row 32558:	[4 12 7 1 5 5 1 39]
	Row 32559:	[4 12 5 1 4 5 2 39]
	Row 32560:	[5 12 3 4 6 5 1 39]

Sample Continuous:
	Row 0:	[39 13 2174 0 40]
	Row 1:	[50 13 0 0 13]
	Row 2:	[38 9 0 0 40]
	...
	Row 32558:	[58 9 0 0 40]
	Row 32559:	[22 9 0 0 20]
	Row 32560:	[52 9 15024 0 40]


In [3]:
!ls -lh ~/work/uci-adult

total 7.0M
-rw-r--r-- 1 janpf janpf 3.8M Mar 21  2023 adult.data
-rw-r--r-- 1 janpf janpf 1.3M Mar 21  2023 adult_data-100_quantiles.bin
-rw-r--r-- 1 janpf janpf 2.0M Mar 21  2023 adult.test
drwxr-x--- 2 janpf janpf 4.0K Mar 27 09:26 base_model


### Creating Datasets

First we create the GoMLX's `Manager`: it's the object that manages the underlying XLA
setup, connection and execution. It's needed to create tensors.

With that we create the samplers of data that we will use to train and evaluate. They implement 
GoMLX's `train.Dataset` interface, which is what is used by our training loop to draw batches to
train, or our eval loop to draw batches to evaluate.

The inputs are 3 tensors: *categorical values*, *continuous values* and *weights*.

In the cell below we define the `Manager` flags, `BuildSamplers` and printout some samples.

In [4]:
import (
    "flag"
    "fmt"
    "io"

    . "github.com/gomlx/gomlx/graph"
    "github.com/gomlx/gomlx/examples/adult"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/types/tensor"
)

var (
    flagBatchSize      = flag.Int("batch", 128, "BatchSampler size for training")
)

// Global manager created an initialization, used everywhere.
var manager = NewManager()

// BuildDatasets returns 3 `train.Dataset`:
// * trainingSampler is an endless random sampler used for training.
// * trainingEvalSampler samples through exactly one epoch of the train dataset.
// * testEvalSampler samples through exactly one epoch of the test dataset.
func BuildDatasets(manager *Manager) (trainDS, trainEvalDS, testEvalDS train.Dataset) {
    baseDS := adult.NewDataset(manager, adult.Data.Train, "batched train")
    trainEvalDS = baseDS.Copy().BatchSize(*flagBatchSize, false)
    testEvalDS = adult.NewDataset(manager, adult.Data.Test, "test").
        BatchSize(*flagBatchSize, false)
    // For training, we shuffle and loop indefinitely.
    trainDS = baseDS.BatchSize(*flagBatchSize, true).Shuffle().Infinite(true)
    return
}

// PositiveRatio finds out the the ratio of positive labels in the
// training and testing data.
//
// We could do this easily with GoMLX computation model (just `ReduceAllSum`), but
// this examples shows it's also ok to mix Go computations.
func PositiveRatio(ds train.Dataset) float32 {
    ds.Reset()  // Start from beginning.
    var sum float32
    var count float32
    for {
        _, _, labels, err := ds.Yield()
        if err == io.EOF {
            break;
        }
        if err != nil { panic(err) }
        data := labels[0].Local().FlatCopy().([]float32)
        for _, value := range data {
            sum += value
        }
        count += float32(len(data))
    }
    return sum/count
}

%%
adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    
trainingDS, trainingEvalDS, testEvalDS := BuildDatasets(manager)

// Take one batch.
_, inputs, labels, err := trainingDS.Yield()
if err != nil { panic(err) }
fmt.Printf("Inputs of batch (size %d):\n", *flagBatchSize)
fmt.Printf("\tcategorical:\n\t\tFeatures=%v\n", adult.Data.VocabulariesFeatures)
fmt.Printf("\t\tValues: %s\n", inputs[0].Local().StringN(16))
fmt.Printf("\tcontinuous:\n\t\tFeatures=%v\n", adult.Data.QuantilesFeatures)
fmt.Printf("\t\tValues: %s\n", inputs[1].Local().StringN(10))
fmt.Printf("\tweights: %s\n", inputs[2].Local().StringN(5))
fmt.Printf("\nLabels of batch:\n\t%s\n", labels[0].Local().StringN(10))
fmt.Printf("\nLabels distributions:\n\tTrain:\t%.2f%% positive\n\tTest:\t%.2f%% positive\n",
           PositiveRatio(trainingEvalDS)*100.0, PositiveRatio(testEvalDS)*100.0)


Inputs of batch (size 128):
	categorical:
		Features=[workclass education marital-status occupation relationship race sex native-country]
		Values: (Int64)[128 8]: (... too large, 1024 values ..., first 16 values: [4 16 4 10 2 5 2 39 6 9 3 5 1 5 2 39])
	continuous:
		Features=[age education-num capital-gain capital-loss hours-per-week]
		Values: (Float32)[128 5]: (... too large, 640 values ..., first 10 values: [42 10 0 0 40 64 11 0 0 30])
	weights: (Float32)[128 1]: (... too large, 128 values ..., first 5 values: [27444 30664 200783 226668 91666])

Labels of batch:
	(Float32)[128 1]: (... too large, 128 values ..., first 10 values: [1 0 0 0 0 1 0 0 0 0])

Labels distributions:
	Train:	24.08% positive
	Test:	23.62% positive


## Model Definition

Lots of hyper-parameter flags, but otherwise a straight forward FNN, using piece-wise linear calibration of the continuous features, and embeddings for the categorical features.

> **Note**: building models is a constant checking that shapes are compatible. It's a bit annoying, in particular because shapes are known in runtime only -- no compile time check. GoMLX tries to help providing a stack trace of where errors happen so one can pin-point issues quickly. But often it involves lots of experimentation (more than ordinary Go code).
>
> Developing with a Noteboook (see [GoNB](https://github.com/janpfeifer/gonb)) or simply a unit test on your `ModelGraph` function are quick/convenient ways to develop models -- before actually training them. You can also use shape asserts in the middle of the 
>`ModelGraph`, as we do below.

In [5]:
import (
    "fmt"
    "io"

    . "github.com/gomlx/gomlx/graph"

    "github.com/gomlx/gomlx/ml/context"
    "github.com/gomlx/gomlx/examples/adult"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/ml/train/optimizers"
    "github.com/gomlx/gomlx/types/shapes"
)

var (
    // ModelDType used for the model. Must match RawData Go types.
    ModelDType = shapes.Float32
    

    // Model hyperparameters.
    flagUseCategorical       = flag.Bool("use_categorical", true, "Use categorical features.")
    flagUseContinuous        = flag.Bool("use_continuous", true, "Use continuous features.")
    flagTrainableCalibration = flag.Bool("trainable_calibration", true, "Allow piece-wise linear calibration to adjust outputs.")
    flagEmbeddingDim    = flag.Int("embedding_dim", 8, "Default embedding dimension for categorical values.")
    flagNumHiddenLayers = flag.Int("hidden_layers", 8, "Number of hidden layers, stacked with residual connection.")
    flagNumNodes        = flag.Int("num_nodes", 32, "Number of nodes in hidden layers.")
    flagDropoutRate     = flag.Float64("dropout", 0, "Dropout rate")
    
    // Training parameter, referenced here.
    flagLearningRate  = flag.Float64("learning_rate", 0.001, "Initial learning rate.")
    flagNumSteps      = flag.Int("steps", 5000, "Number of gradient descent steps to perform")
)


// ModelGraph outputs the logits (not the probabilities). The parameter inputs should contain 3 tensors:
//
// - categorical inputs, shaped  `(int64)[batch_size, len(VocabulariesFeatures)]`
// - continuous inputs, shaped `(float32)[batch_size, len(Quantiles)]`
// - weights: not currently used, but shaped `(float32)[batch_size, 1]`.
func ModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node {
    _ = spec // Not used, since the dataset is always the same.
    g := inputs[0].Graph()
    
    // Use Cosine schedule of the learning rate.
    optimizers.CosineAnnealingSchedule(ctx, g, ModelDType).
        PeriodInSteps(*flagNumSteps/3).Done()
    
    categorical, continuous := inputs[0], inputs[1]
    batchSize := categorical.Shape().Dimensions[0]
    
    var allEmbeddings []*Node

    if *flagUseCategorical {
        // Embedding of categorical values, each with its own vocabulary.
        numCategorical := categorical.Shape().Dimensions[1]
        for catIdx := 0; catIdx < numCategorical; catIdx++ {
            // Take one column at a time of the categorical values.
            split := Slice(categorical, AxisRange(), AxisRange(catIdx, catIdx+1))
            // Embed it accordingly.
            embedCtx := ctx.In(fmt.Sprintf("categorical_%d_%s", catIdx, adult.Data.VocabulariesFeatures[catIdx]))
            vocab := adult.Data.Vocabularies[catIdx]
            vocabSize := len(vocab)
            embedding := layers.Embedding(embedCtx, split, ModelDType, vocabSize, *flagEmbeddingDim)
            embedding.AssertDims(batchSize, *flagEmbeddingDim) // 2-dim tensor, with batch size as the leading dimension.
            allEmbeddings = append(allEmbeddings, embedding)
        }
    }

    if *flagUseContinuous {
        // Piecewise-linear calibration of the continuous values. Each feature has its own number of quantiles.
        numContinuous := continuous.Shape().Dimensions[1]
        for contIdx := 0; contIdx < numContinuous; contIdx++ {
            // Take one column at a time of the continuous values.
            split := Slice(continuous, AxisRange(), AxisRange(contIdx, contIdx+1))
            featureName := adult.Data.QuantilesFeatures[contIdx]
            calibrationCtx := ctx.In(fmt.Sprintf("continuous_%d_%s", contIdx, featureName))
            quantiles := adult.Data.Quantiles[contIdx]
            layers.AssertQuantilesForPWLCalibrationValid(quantiles)
            calibrated := layers.PieceWiseLinearCalibration(calibrationCtx, split, Const(g, quantiles), *flagTrainableCalibration)
            calibrated.AssertDims(batchSize, 1) // 2-dim tensor, with batch size as the leading dimension.
            allEmbeddings = append(allEmbeddings, calibrated)
        }
    }
    layer := Concatenate(allEmbeddings, -1)
    layer.AssertDims(batchSize, -1) // 2-dim tensor, with batch size as the leading dimension (-1 means it is not checked).
    
    layer = layers.DenseWithBias(ctx.In(fmt.Sprintf("DenseLayer_%d", 0)), layer, *flagNumNodes)
    for ii := 1; ii < *flagNumHiddenLayers; ii++ {
        ctx := ctx.In(fmt.Sprintf("DenseLayer_%d", ii))
        // Add layer with residual connection.
        tmp := Sigmoid(layer)
        if *flagDropoutRate > 0 {
            tmp = layers.Dropout(ctx, tmp, Scalar(g, ModelDType, *flagDropoutRate))
        }
        tmp = layers.DenseWithBias(ctx, tmp, *flagNumNodes)
        layer = Add(layer, tmp)  // Residual connections
    }
    layer = Sigmoid(layer)
    logits := layers.DenseWithBias(ctx.In("DenseFinal"), layer, 1)
    logits.AssertDims(batchSize, 1) // 2-dim tensor, with batch size as the leading dimension.
    return []*Node{logits}
}

%%
adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    

// Let's just check that we get the right shape from the model function, wihtout any real data.
graph := NewGraph(manager, "test")
ctx := context.NewContext(manager)
ctx.SetParam(optimizers.LearningRateKey, *flagLearningRate)

// Create placeholder (parameters) graph nodes, just to test the graph building is working.
inputs := []*Node{
    // Categorical: shaped [batch_size, num_categorical]
    graph.Parameter("categorical", shapes.Make(shapes.Int64, *flagBatchSize, len(adult.Data.VocabulariesFeatures))),
    // Continuous: shaped [batch_size, num_continuos]
    graph.Parameter("continuous", shapes.Make(shapes.Float32, *flagBatchSize, len(adult.Data.QuantilesFeatures))),
    // Weights: shaped [batch_size, 1]
    graph.Parameter("weights", shapes.Make(shapes.Float32, *flagBatchSize, 1)),    
}
logits := ModelGraph(ctx, nil, inputs)
fmt.Printf("Logits shape for batch_size=%d: %s\n", *flagBatchSize, logits[0].Shape())

Logits shape for batch_size=128: (Float32)[128 1]


## Training Loop

We can create a training loop with only a `Manager`, a `Context` (for the model varibles) and the `ModelGraph` function.

To make it more interesting we also add the following:

* Accuracy metrics for training and testing.
* Checkpoints -- so trained model can be saved, and reloaded.
* A progress-bar that also shows training metrics.
* We dynamically plot how the loss and accuracy evolve.

First we define the corresponding flags and the `trainModel` function, and run it for very few steps to make sure
it is working.

In [6]:
import (
    "fmt"
    "io"
    "time"

    . "github.com/gomlx/gomlx/graph"

    "github.com/gomlx/exceptions"
    "github.com/gomlx/gomlx/examples/adult"
    "github.com/gomlx/gomlx/examples/notebook/gonb/margaid"
    "github.com/gomlx/gomlx/examples/notebook/gonb/plotly"
    "github.com/gomlx/gomlx/ml/context"
    "github.com/gomlx/gomlx/ml/train"
    "github.com/gomlx/gomlx/types/shapes"
    "github.com/gomlx/gomlx/types/slices"
    "github.com/gomlx/gomlx/types/tensor"
    "github.com/janpfeifer/gonb/gonbui"
)

var (
    flagOptimizer      = flag.String("optimizer", "adam", "Type of optimizer to use: 'sgd' or 'adam'")
    flagLearningRate   = flag.Float64("learning_rate", 0.001, "Initial learning rate.")
    flagCheckpoint     = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")
    flagCheckpointKeep = flag.Int("checkpoint_keep", 10, "Number of checkpoints to keep, if --checkpoint is set.")
    flagPlots          = flag.Bool("plots", true, "Plots during training: perform periodic evaluations, "+
                                   "save results if --checkpoint is set and draw plots, if in a Jupyter notebook.")
    flagPlotType       = flag.String("plot_type", "plotly", "Type of plot to use, values are \"plotly\" or \"margaid\"")
)

func trainModel() {
    // Fixes directories.
    *flagDataDir = data.ReplaceTildeInDir(*flagDataDir)
    *flagCheckpoint = data.ReplaceTildeInDir(*flagCheckpoint)
    if *flagCheckpoint != "" && !path.IsAbs(*flagCheckpoint) {
        *flagCheckpoint = path.Join(*flagDataDir, *flagCheckpoint)
    }

    // Load data and create datasets.
    adult.LoadAndPreprocessData(*flagDataDir, *flagNumQuantiles, *flagForceDownload, *flagVerbosity)    
    trainDS, trainEvalDS, testEvalDS := BuildDatasets(manager)

    // Context holds the variables and optionally hyperparameters for the model.
    ctx := context.NewContext(manager)
    ctx.SetParam(optimizers.LearningRateKey, *flagLearningRate)

    // Metrics we are interested.
    meanAccuracyMetric := metrics.NewMeanBinaryLogitsAccuracy("Mean Accuracy", "#acc")
    movingAccuracyMetric := metrics.NewMovingAverageBinaryLogitsAccuracy("Moving Average Accuracy", "~acc", 0.01)

    // Checkpoints saving.
    var checkpoint *checkpoints.Handler
    if *flagCheckpoint != "" {
        var err error
        checkpoint, err = checkpoints.Build(ctx).Dir(*flagCheckpoint).Keep(*flagCheckpointKeep).Done()
        if err != nil { panic(err) }
        fmt.Printf("Checkpointing model to %q\n", checkpoint.Dir())
        globalStep := optimizers.GetGlobalStep(ctx)
        if globalStep != 0 {
            fmt.Printf("Restarting training from global_step=%d\n", globalStep)
        }
    }

    // Pick a known optimizer.
    optimizerFn, found := optimizers.KnownOptimizers[*flagOptimizer]
    if !found {
        log.Fatalf("Unknown optimizer %q, please use one of %v",
            *flagOptimizer, slices.Keys(optimizers.KnownOptimizers))
    }

    // Create a train.Trainer: this object will orchestrate running the model, feeding
    // results to the optimizer, evaluating the metrics, etc. (all happens in trainer.TrainStep)
    trainer := train.NewTrainer(manager, ctx, ModelGraph, losses.BinaryCrossentropyLogits,
        optimizerFn(),
        []metrics.Interface{movingAccuracyMetric}, // trainMetrics
        []metrics.Interface{meanAccuracyMetric})   // evalMetrics

    // Use standard training loop.
    loop := train.NewLoop(trainer)
    commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.

    // Attach a checkpoint: checkpoint every 1 minute of training.
    if checkpoint != nil {
        period := time.Minute * 1
        train.PeriodicCallback(loop, period, true, "saving checkpoint", 100,
            func(loop *train.Loop, metrics []tensor.Tensor) error {
                fmt.Printf("\n[saving checkpoint@%d] [median train step (ms): %d]\n", loop.LoopStep, loop.MedianTrainStepDuration().Milliseconds())
                return checkpoint.Save()
            })
    }

	// Attach a margaid plots: plot points at exponential steps.
    // Points (metrics) are saved along the checkpoint directory (if one is given).
	if *flagPlots {
        switch *flagPlotType {
        case "margaid":
    		_ = margaid.NewDefault(loop, checkpoint.Dir(), 100, 1.1, trainEvalDS, testEvalDS)
        case "plotly":
            _ = plotly.New().
                WithCheckpoint(checkpoint.Dir()).
                Dynamic(trainEvalDS, testEvalDS).
                ScheduleExponential(loop, 100, 1.1)
        default:
            exceptions.Panicf("Invalid --plot_type=%q, valid values are %q or %q", *flagPlotType, "margaid", "plotly")
        }
	}

    // Run the given number of steps.
    _, err := loop.RunSteps(trainDS, *flagNumSteps)
    if err != nil { panic(err) }
    fmt.Printf("\t[Step %d] median train step: %d microseconds\n", loop.LoopStep, loop.MedianTrainStepDuration().Microseconds())

    // Print a final evaluation on train and test datasets.
    fmt.Println()
    err = commandline.ReportEval(trainer, trainEvalDS, testEvalDS)
    if err != nil { panic(err) }
    fmt.Println()
}

// Notice command line flags are passed in the %% notebook command. We set --plot=false here to disable plotting
// since this is only a quick test that our train() loop is working. See below the final run for a full training.
%% --steps=500 --plots=false
trainModel()

Training (500 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (110 steps/s)[0m [step=499] [loss+=0.354] [~loss+=0.398] [~loss=0.398] [~acc=82.11%]        
	[Step 500] median train step: 967 microseconds

Results on batched train:
	Mean Loss+Regularization (#loss+): 0.361
	Mean Loss (#loss): 0.361
	Mean Accuracy (#acc): 84.17%
Results on test:
	Mean Loss+Regularization (#loss+): 0.358
	Mean Loss (#loss): 0.358
	Mean Accuracy (#acc): 84.15%



## Final run with 5K steps

With everything working, we can do our final run.

> **Note** here is where someone might want to hyperparameter tune, trying out different hyperparameters.

In [7]:
// Remove previously trained model -- skip this cell, if you want to continue training.
!rm -rf ~/work/uci-adult/base_model

In [8]:
%% --steps=5000 --checkpoint base_model --plot_type=margaid
trainModel()

Checkpointing model to "/home/janpf/work/uci-adult/base_model"


Training (5000 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (374 steps/s)[0m [step=4999] [loss+=0.316] [~loss+=0.271] [~loss=0.271] [~acc=87.50%]        

[saving checkpoint@5000] [median train step (ms): 0]


	[Step 5000] median train step: 916 microseconds

Results on batched train:
	Mean Loss+Regularization (#loss+): 0.274
	Mean Loss (#loss): 0.274
	Mean Accuracy (#acc): 87.43%
Results on test:
	Mean Loss+Regularization (#loss+): 0.280
	Mean Loss (#loss): 0.280
	Mean Accuracy (#acc): 87.06%



## Extend training another 5K steps

Since the model training went well, and it doesn't seem to be yet terribly overfiting, 
let's train further, another 5k steps, for 10K steps in total.

Notice the plots continue from where it stopped. And this time we use [Plotly](https://plotly.com/javascript/) to plot the training results -- they don't display in Github since they depend on javascript.

Unfortunately, it doesn't help (the accuracy on the test set doesn't improve), 5k steps was already enough.

In [9]:
import "github.com/gomlx/exceptions"

%% --steps=5000 --checkpoint base_model --plot_type=plotly
err := exceptions.TryCatch[error](func() { trainModel() })
if err != nil {
    fmt.Fprintf(os.Stderr, "Error: %+v", err)
}

loading: "checkpoint-n0000000-20240418-155646-step-00005000"
Checkpointing model to "/home/janpf/work/uci-adult/base_model"
Restarting training from global_step=5000


Training (5000 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (439 steps/s)[0m [step=9999] [loss+=0.293] [~loss+=0.268] [~loss=0.268] [~acc=87.51%]        

[saving checkpoint@10000] [median train step (ms): 0]


	[Step 10000] median train step: 905 microseconds

Results on batched train:
	Mean Loss+Regularization (#loss+): 0.270
	Mean Loss (#loss): 0.270
	Mean Accuracy (#acc): 87.44%
Results on test:
	Mean Loss+Regularization (#loss+): 0.278
	Mean Loss (#loss): 0.278
	Mean Accuracy (#acc): 87.07%

