# IMDB Movie Review Dataset

This is a library to download and parse the [IMDB's Large Movie Review Dataset](http://ai.stanford.edu/~amaas/data/sentiment/) dataset and a demo of a transformer based model. The dataset has 25K training, and 25K test dataset, plus 50K unlabeled examples.

It's inspired on [Keras' Text classification with Transformer](https://keras.io/examples/nlp/text_classification_with_transformer/) demo.




## Environment Set Up

Let's set up `go.mod` to use the local copy of GoMLX, so it can be developed jointly the dataset code with the model. That's often how data pre-processing and model code is developed together with experimentation.

If you are not changing code, feel free to simply skip this cell. Or if you used a different directory for you projects, change it below.

Notice the directory `${HOME}/Projects/gomlx` is where the GoMLX code is copied by default in [its Docker](https://hub.docker.com/repository/docker/janpfeifer/gomlx_jupyterlab/general).

In [1]:
!*go mod edit -replace github.com/gomlx/gomlx="${HOME}/Projects/gomlx"

## Data Preparation

### Downloading data files

To download, uncompress and untar to the local directory, simply do the following. Notice if it's already downloaded in the given `--data` directory, it returns immediately.

In [2]:
import (
    "github.com/gomlx/gomlx/examples/imdb"
    "github.com/gomlx/gomlx/ml/data"
)

var flagDataDir = flag.String("data", "~/work/imdb", "Directory to cache downloaded and generated dataset files.")

func AssertNoError(err error) {
    if err != nil {
        log.Fatalf("Failed: %+v", err)
    }
}

func AssertDownloaded() {
    *flagDataDir = data.ReplaceTildeInDir(*flagDataDir)
    if !data.FileExists(*flagDataDir) {
        AssertNoError(os.MkdirAll(*flagDataDir, 0777))
    }

    AssertNoError(imdb.Download(*flagDataDir))
}

%%
AssertDownloaded()

> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.


### Sampling some examples

It creates a small dataset and print out some random examples.

It also defines the `DType`, used for all internal representations of the model, and the flag `--max_len` that defines the maximum number of tokens used per observation. This will beused in the modeling later.

In [3]:
import (
    "github.com/gomlx/gomlx/examples/imdb"
)

// DType used for the models.
const DType = shapes.Float32

var (
    flagMaxLen              = flag.Int("max_len", 200, "Maximum number of tokens to take from observation.")
)

func Sample() {
    ds := imdb.NewDataset("Test", imdb.Test, *flagMaxLen, 3, DType, true, nil)
    _, inputs, labels, err := ds.Yield()
    AssertNoError(err)
    labelsData := shapes.CastAsDType(labels[0].Local().FlatCopy(), shapes.Int64).([]int64)
    for ii := 0; ii < 3; ii++ {
        fmt.Printf("\nlabel=%v, input=%q\n", labelsData[ii], imdb.InputToString(inputs[0], ii))
    }
    fmt.Println()
}

%% --max_len=200
AssertDownloaded()
Sample()


> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.

label=0, input="<START> the most obvious flaw horrible horrible script this movie had a potentially good story but it was ruined with bad dialogue continuity problems things that were never explained gaping plotholes sub plots that went nowhere and just plain stupidity not to mention the awful cliched directing of sandra locke not even two great performances could ve saved this movie so it didn t matter that devon and rosanna arquette give horrific performances the thing is they re better actors than this movie would have you believe the best of the arquettes rosanna arquette silverado after hours desperately seeking susan has some fine moments like a great scene in the beginning when she painfully pulls her handcuffs off but gives an overall weak performance by her standards and devon gummersall dick when trumpets fade and the brill

## Training

We will create 3 different types of models for this demo: **Bag of Words** (or simply **bow**), **CNNs** and **Transformers**.

### Model Support

We first define here a few componens shared among all models:

* `Normalize` according to flag, works for sequence nodes (shaped `[batch_size, sequence_len, embedding_dim]`) and feature nodes (shaped `[batch_size, embedding_dim]`).
* `EmbedTokensGraph` that embeds the tokens, before being consumed by the models.
* `ReadoutGraph` that takes the embeddings after they are pooled on the sequence axis (so shaped `[batch_size, embedding_dim]` it adds a FNN (feed-forward neural network) with a few layers and convert to the final logits.

We also define the corresponding hyperparameters.

In [4]:
import (
    . "github.com/gomlx/gomlx/graph"
    
    "github.com/gomlx/gomlx/ml/context"
    "github.com/gomlx/gomlx/types/shapes"
)

var (
    flagNormalization    = flag.String("norm", "layer", "Type of normalization to use. Valid values are \"none\", \"batch\", \"layer\".")
    flagDropoutRate      = flag.Float64("dropout", 0.15, "Dropout rate")
    flagMaxVocab            = flag.Int("max_vocab", 20000, "Top most frequent words to consider, the rest is considered unknown.")
    flagTokenEmbeddingSize  = flag.Int("token_embed", 32, "Size of token embedding table. There are ~140K unique tokens")
    flagNumHiddenLayers     = flag.Int("hidden_layers", 2, "Number of output hidden layers, stacked with residual connection.")
    flagNumNodes            = flag.Int("num_nodes", 32, "Number of nodes in output hidden layers.")
    flagWordDropoutRate  = flag.Float64("word_dropout", 0, "Dropout rate for whole words of the input")
)

// Normalize `x` according to `--norm` flag. Works for sequence nodes (rank-3) or plain feature nodes (rank-2).
func Normalize(ctx *context.Context, x *Node) *Node {
    switch *flagNormalization {
    case "layer":
        if x.Rank() == 3 {
            // Normalize sequence.
            return layers.LayerNormalization(ctx, x, -2, -1).
                LearnedOffset(true).LearnedScale(true).ScaleNormalization(true).Done()
        } else {
            // Normalize features only.
            return layers.LayerNormalization(ctx, x, -1).Done()
        }
    case "batch":
        return layers.BatchNormalization(ctx, x, -1).Done()
    case "none":
        return x
    }
    exceptions.Panicf("invalid normalization selected %q -- valid values are batch, layer, none", *flagNormalization)
    return nil
}

// EmbedTokensGraph creates embeddings for tokens and returns them along with the mask of used tokens --
// set to false where padding was used.
func EmbedTokensGraph(ctx *context.Context, tokens *Node) (embed, mask *Node) {
    g := tokens.Graph()
    mask = NotEqual(tokens, ZerosLike(tokens)) // Mask of tokens actually used.

    // The token ids are indexed by frequency. Truncate to the vocabulary size considered, replacing
    // ids higher than that by 0.
    maxVocab := len(imdb.LoadedVocab.ListEntries)
    if maxVocab > *flagMaxVocab {
        maxVocab = *flagMaxVocab
    }

    // Limits tokens to the maxVocab.
    tokens = Where(GreaterOrEqual(tokens, Const(g, maxVocab)),
        MulScalar(OnesLike(tokens), float64(maxVocab-1)),
        tokens)

    // Embed tokens: shape=[batchSize, maxLen, embedDim]
    embed = layers.Embedding(ctx.In("tokens"), tokens, DType, maxVocab, *flagTokenEmbeddingSize)
    embed = Where(mask, embed, ZerosLike(embed))
    
    if *flagWordDropoutRate > 0 {
        dims := embed.Shape().Dimensions[:len(embed.Shape().Dimensions)-1]
        dropoutMask := Ones(g, shapes.Make(DType, dims...))
        dropoutMask = layers.Dropout(ctx, dropoutMask, ConstAsDType(g, DType, *flagWordDropoutRate))
        dropoutMask = ExpandDims(dropoutMask, -1)
        embed = Mul(embed, dropoutMask)
    }    
    return
}

// Activation function used for models.
func Activation(x *Node) *Node {
    return layers.Relu(x)
}

// ReadoutGraph takes the embeddings after they have been pooled on the sequence axis, so shaped `[batch_size, embed_dim]`
// adds a FNN on top and readout the final logit.
func ReadoutGraph(ctx *context.Context, embed *Node) *Node {
    g := embed.Graph()
    var dropoutRate *Node
    if *flagDropoutRate > 0 {
        dropoutRate = ConstAsDType(g, DType, *flagDropoutRate)
    }

    // Output layers.
    for ii := 0; ii < *flagNumHiddenLayers; ii++ {
        ctx := ctx.In(fmt.Sprintf("output_dense_%d", ii))
        residual := embed
        if *flagDropoutRate > 0 {
            embed = layers.Dropout(ctx, embed, dropoutRate)
        }
        //embed = Tanh(embed)
        embed = Activation(embed)
        embed = layers.DenseWithBias(ctx, embed, *flagNumNodes)
        embed = Normalize(ctx, embed)
        if ii > 0 {
            // Add residual connection.
            embed = Add(embed, residual)
        }
    }

    // Final embed layer with dimension 1.
    {
        ctx := ctx.In("readout")
        if *flagDropoutRate > 0 {
            embed = layers.Dropout(ctx, embed, dropoutRate)
        }
        embed = Activation(embed)
        embed = layers.DenseWithBias(ctx, embed, 1)
    }
    return embed
}

%%

### Bag Of Words Model

This is the simples model we are going to train. It's basically a no-op that connects the embedding table we defined above with the `ReadoutGraph` FNN afterwards.

We define here placeholders for our future *CNN* and *Transformer* models.

Finally we test that the shape is correct. To actually train and evaluate we still need to define the training loop, what we do in the following section.

In [5]:
import (
    . "github.com/gomlx/gomlx/graph"
    
    "github.com/gomlx/gomlx/ml/context"
)

var (
    flagModel = flag.String("model", "transformer", "Model type: bow or transformer.")
)

// ModelGraph builds the model for our demo. It returns the logits, not the predictions, which works with most losses.
func ModelGraph(ctx *context.Context, spec any, inputs []*Node) []*Node {
    _ = spec // Not used.
    tokens := inputs[0]
    embed, mask := EmbedTokensGraph(ctx, tokens)

    // Normalization function.
    if *flagModel == "bow" {
        // Bag-Of-Words model doesn't do anything, it's just the embedding table for each token.
    } else if *flagModel == "cnn" {
        embed = Conv1DGraph(ctx, embed, mask)
    } else if *flagModel == "transformer" {
        embed = TransformerGraph(ctx, tokens, embed, mask)
    } else {
        exceptions.Panicf("unknown model type %q, only types \"bow\", \"cnn\" and \"transformer\" are implemented", *flagModel)
    }

    // Sum-up per-token embeddings and do a FNN on the output. From now on, the dimensions are `[batch_dim, embed_dim]`
    // Notice we are not using mask.
    embed = ReduceMax(embed, 1)
    logits := ReadoutGraph(ctx, embed)
    return []*Node{logits}
}

func Conv1DGraph(ctx *context.Context, embed, mask *Node) *Node {
    panic("Not implemented.")
    return nil
}

func TransformerGraph(ctx *context.Context, tokens, embed, mask *Node) *Node {
    panic("Not implemented.")
    return nil
}

%% --model=bow

// Let's test that the logits are coming out with the right shape: we want [batch_size, 10], since there are 10 classes.
AssertDownloaded()
manager := NewManager()
ds := imdb.NewDataset("Test", imdb.Test, *flagMaxLen, 3, DType, true, nil)
_, inputs, _, err := ds.Yield()
AssertNoError(err)

g := manager.NewGraph("test")
ctx := context.NewContext(manager)
logits := ModelGraph(ctx, nil, []*Node{Const(g, inputs[0])})
fmt.Printf("Logits shape for batch_size=%d: %s\n", 3, logits[0].Shape())

> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.
Logits shape for batch_size=3: (Float32)[3 1]


### Training Loop

Training loop, with the usual bells and whistles. We also do a quick test with 100 steps to check things are working.

In [6]:
import (
    . "github.com/gomlx/gomlx/graph"
    
    "github.com/gomlx/gomlx/ml/context"
    "github.com/gomlx/gomlx/ml/train"
)

var (
    flagOptimizer        = flag.String("optimizer", "adam", "Optimizer, options: adam or sgd.")
    flagNumSteps         = flag.Int("steps", 5000, "Number of gradient descent steps to perform")
    flagBatchSize        = flag.Int("batch", 32, "Batch size for training")
    flagLearningRate     = flag.Float64("learning_rate", 0.0001, "Initial learning rate.")
    flagL2Regularization = flag.Float64("l2_reg", 0, "L2 regularization on kernels. It doesn't interact well with --batch_norm or with --optimizer=adam.")
    
    flagCheckpoint       = flag.String("checkpoint", "", "Directory save and load checkpoints from. If left empty, no checkpoints are created.")
    flagCheckpointKeep   = flag.Int("checkpoint_keep", 10, "Number of checkpoints to keep, if --checkpoint is set.")
    flagPlots            = flag.Bool("plots", true, "Plots during training: perform periodic evaluations, "+
                                     "save results if --checkpoint is set and draw plots, if in a Jupyter notebook.")
)

func trainModel() {
    // Make sure data is downloaded.
    AssertDownloaded()

    // Manager handles creation of ML computation graphs, accelerator resources, etc.
    manager := NewManager()
    fmt.Printf("Platform: %s\n", manager.Platform())

    // Datasets.
    var trainDS, trainEvalDS, testEvalDS train.Dataset
    trainDS = imdb.NewDataset("train", imdb.Train, *flagMaxLen, *flagBatchSize, DType, true, nil)
    trainEvalDS = imdb.NewDataset("train-eval", imdb.Train, *flagMaxLen, *flagBatchSize, DType, false, nil)
    testEvalDS = imdb.NewDataset("test-eval", imdb.Test, *flagMaxLen, *flagBatchSize, DType, false, nil)

    // Parallelize generation of batches, to prevent the dataset from being a bottleneck to the acceleartor (GPU/TPU).
    trainDS = data.Parallel(trainDS)
    trainEvalDS = data.Parallel(trainEvalDS)
    testEvalDS = data.Parallel(testEvalDS)

    // Metrics we are interested.
    meanAccuracyMetric := metrics.NewMeanBinaryLogitsAccuracy("Mean Accuracy", "#acc")
    movingAccuracyMetric := metrics.NewMovingAverageBinaryLogitsAccuracy("Moving Average Accuracy", "~acc", 0.01)

    // Context holds the variables and hyperparameters for the model.
    ctx := context.NewContext(manager)
    ctx.SetParam(optimizers.LearningRateKey, *flagLearningRate)
    ctx.SetParam(layers.L2RegularizationKey, *flagL2Regularization)

    // Checkpoints saving.
    var checkpoint *checkpoints.Handler
    if *flagCheckpoint != "" {
        var err error
        checkpoint, err = checkpoints.Build(ctx).DirFromBase(*flagCheckpoint, *flagDataDir).Keep(*flagCheckpointKeep).Done()
        AssertNoError(err)
    }

    // Create a train.Trainer: this object will orchestrate running the model, feeding
    // results to the optimizer, evaluating the metrics, etc. (all happens in trainer.TrainStep)
    loss := losses.BinaryCrossentropyLogits
    trainer := train.NewTrainer(
        manager, ctx, ModelGraph, loss,
        optimizers.MustOptimizerByName(*flagOptimizer),
        []metrics.Interface{movingAccuracyMetric}, // trainMetrics
        []metrics.Interface{meanAccuracyMetric})   // evalMetrics

    // Use standard training loop.
    loop := train.NewLoop(trainer)
    commandline.AttachProgressBar(loop) // Attaches a progress bar to the loop.

    // Attach a checkpoint: checkpoint every 1 minute of training.
    if checkpoint != nil {
        period := time.Minute * 1
        train.PeriodicCallback(loop, period, true, "saving checkpoint", 100,
            func(loop *train.Loop, metrics []tensor.Tensor) error {
                fmt.Printf("\n[saving checkpoint@%d] [median train step (ms): %d]\n", loop.LoopStep, loop.MedianTrainStepDuration().Milliseconds())
                return checkpoint.Save()
            })
    }

    // Attach a margaid plots: plot points at exponential steps.
    // The points generated are saved along the checkpoint directory (if one is given).
    if *flagPlots {
        _ = margaid.NewDefault(loop, checkpoint.Dir(), 100, 1.1, trainEvalDS, testEvalDS)
    }

    // Loop for given number of steps.
    _, err := loop.RunSteps(trainDS, *flagNumSteps)
    AssertNoError(err)
    fmt.Printf("\t[Step %d] median train step: %d microseconds\n",
        loop.LoopStep, loop.MedianTrainStepDuration().Microseconds())


    // Report final evaluation.
    fmt.Println()
    err = commandline.ReportEval(trainer, trainEvalDS, testEvalDS)
    AssertNoError(err)
}

%% --model=bow --steps=100 --plots=false
trainModel()


> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.
Platform: CUDA
Training (100 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (100 steps/s)[0m [loss=0.676] [~loss=0.688] [~acc=55.22%]        
	[Step 100] median train step: 1093 microseconds

Results on train-eval [Parallelized]:
	Mean Loss (#loss): 0.701
	Mean Accuracy (#acc): 49.99%
Results on test-eval [Parallelized]:
	Mean Loss (#loss): 0.701
	Mean Accuracy (#acc): 50.00%


### Training Bag-Of-Words ("bow") Model

A proper run of our *bow* model.

In [7]:
%% --model=bow --steps=20000 --optimizer=adamw
trainModel()

> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.
Platform: CUDA


Training (20000 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (748 steps/s)[0m [loss=0.216] [~loss=0.252] [~acc=89.69%]        


	[Step 20000] median train step: 944 microseconds

Results on train-eval [Parallelized]:
	Mean Loss (#loss): 0.150
	Mean Accuracy (#acc): 94.98%
Results on test-eval [Parallelized]:
	Mean Loss (#loss): 0.360
	Mean Accuracy (#acc): 84.68%


### CNN Model

Below we redefine our CNN model function, with 2 convolution layers, and train on it.

Notice how well it can overfit to the training data ... but it doesn't help the test results. To improve this one needs some careful regularization.

In [8]:
// Conv1DGraph stacks two 1D convolution layers on top of the sequence embeddings (shaped `[batch_size, sequence_len, embed_dim]`).
func Conv1DGraph(ctx *context.Context, embed, mask *Node) *Node {
    g := embed.Graph()
    // 1D Convolution:
    {
        ctx := ctx.In("conv1")
        embed = Activation(embed)
        embed = layers.Dropout(ctx, embed, ConstAsDType(g, DType, *flagDropoutRate))
        embed = layers.Convolution(ctx, embed).KernelSize(7).Filters(*flagTokenEmbeddingSize).Strides(3).Done()
        embed = Normalize(ctx, embed)
    }
    {
        ctx := ctx.In("conv2")
        embed = Activation(embed)
        embed = layers.Convolution(ctx, embed).KernelSize(7).Filters(*flagTokenEmbeddingSize).Strides(3).Done()
        embed = Normalize(ctx, embed)
    }
    return embed
}

%% --model=cnn --steps=20000 --optimizer=adamw
trainModel()

> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.
Platform: CUDA


Training (20000 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (544 steps/s)[0m [loss=0.080] [~loss=0.097] [~acc=96.62%]        


	[Step 20000] median train step: 1222 microseconds

Results on train-eval [Parallelized]:
	Mean Loss (#loss): 0.027
	Mean Accuracy (#acc): 99.38%
Results on test-eval [Parallelized]:
	Mean Loss (#loss): 0.602
	Mean Accuracy (#acc): 82.58%


### Transformer Model

Finally a Transformer version of the model. 

Notice it's not significantly better than our previous simple Bag-Of-Words model. Likely because there is not enough data for the transformer to make any difference. The success of transformers in large-language-models is in large part due to the training with huge amounts of unsupervised (or self-supervised) data.


In [9]:
var (
    flagMaxAttLen           = flag.Int("max_att_len", 200, "Maximum attention length: input will be split in ranges of this size.")
    flagNumAttHeads         = flag.Int("att_heads", 2, "Number of attention heads, if --model=transformer.")
    flagNumAttLayers        = flag.Int("att_layers", 1, "Number of stacked attention layers, if --model=transformer.")
    flagAttKeyQueryEmbedDim = flag.Int("att_key_dim", 8, "Dimension of the Key/Query attention embedding.")    
)

// TransformerGraph is the part of the model that takes the word/token embeddings to a tranformed
// embedding through attention ready to be pooled and read out.
func TransformerGraph(ctx *context.Context, input, embed, mask *Node) *Node {
    var newEmbed *Node
    if *flagMaxAttLen >= *flagMaxLen {
        // One transformer window covers the whole length, which makes it trivial.
        newEmbed = TransformerLayers(ctx.In("transformer"), embed, mask)
        embed = Add(embed, newEmbed)
        return embed
    }

    // Split embedding in multiple split embeddings and apply transformer in each of them.
    attLen := *flagMaxAttLen
    sequenceFrom := 0
    for {
        // x.shape = [batchSize, sequence, embedding]
        sequenceTo := sequenceFrom + attLen
        if sequenceTo > *flagMaxLen {
            sequenceTo = *flagMaxLen
            sequenceFrom = sequenceTo - attLen
        }
        // part = x[:, sequenceFrom:sequenceTo, :]
        residual := Slice(embed, AxisRange(), AxisRange(sequenceFrom, sequenceTo), AxisRange())
        partMask := Slice(mask, AxisRange(), AxisRange(sequenceFrom, sequenceTo))
        // Reuse "transformer" scope.
        part := TransformerLayers(ctx.In("transformer").Checked(false), residual, partMask)
        part = Add(residual, part)
        if newEmbed == nil {
            newEmbed = part
        } else {
            newEmbed = Add(newEmbed, part)
        }

        if sequenceTo == *flagMaxLen {
            // Reached end of parts.
            break
        }
        sequenceFrom += attLen - 20 // Attention window overlap 10 tokens among themselves.
    }
    embed = newEmbed // Notice shape changed to `[batchSize, maxAttLen, embedDim]`
    return embed
}

// TransformerLayers builds the stacked transformer layers for the model.
func TransformerLayers(ctx *context.Context, embed, mask *Node) *Node {
    g := embed.Graph()
    shape := embed.Shape()
    embedDim := shape.Dimensions[2]

    var dropoutRate *Node
    if *flagDropoutRate > 0 {
        dropoutRate = ConstAsDType(g, DType, *flagDropoutRate)
    }

    // Create positional embedding variable: it is 1 in every axis, but for the
    // sequence dimension -- there will be one embedding per position.
    // Shape: [1, maxLen, embedDim]
    posEmbedShape := shape.Copy()
    posEmbedShape.Dimensions[0] = 1
    posEmbedVar := ctx.VariableWithShape("positional", posEmbedShape)
    posEmbed := posEmbedVar.ValueGraph(g)
    embed = Add(embed, posEmbed) // Just add the embeddings, seems to work well.

    // Add the requested number of attention layers.
    for ii := 0; ii < *flagNumAttLayers; ii++ {
        // Each layer in its own scope.
        ctx := ctx.In(fmt.Sprintf("AttLayer_%d", ii))
        residual := embed
        embed = layers.MultiHeadAttention(ctx, embed, embed, embed, *flagNumAttHeads, *flagAttKeyQueryEmbedDim).
            SetKeyMask(mask).SetQueryMask(mask).
            SetOutputDim(embedDim).
            SetValueHeadDim(embedDim).Done()
        if *flagDropoutRate > 0 {
            embed = layers.Dropout(ctx.In("dropout_1"), embed, dropoutRate)
        }
        embed = Normalize(ctx.In("normalization_1"), embed)
        attentionOutput := embed

        // Transformers recipe: 2 dense layers after attention.
        embed = layers.Dense(ctx.In("ffn_1"), embed, true, embedDim)
        embed = Tanh(embed)
        embed = layers.Dense(ctx.In("ffn_2"), embed, true, embedDim)
        if *flagDropoutRate > 0 {
            embed = layers.Dropout(ctx.In("dropout_1"), embed, dropoutRate)
        }
        embed = Add(embed, attentionOutput)
        embed = Normalize(ctx.In("normalization_2"), embed)

        // Residual connection: not part of the usual transfomer layer ...
        if ii > 0 {
            embed = Add(residual, embed)
        }
    }
    return embed
}

%% --model=transformer --steps=3000 --batch=100 --optimizer=adamw --max_len=200 --max_att_len=200 --att_heads=4 --att_layers=3 --att_key_dim=32 --max_vocab=10000
trainModel()

> Loading previously generated preprocessed binary file.
Loaded data from "aclImdb.bin": 100000 examples, 141088 unique tokens, 23727054 tokens in total.
Platform: CUDA


Training (3000 steps):  100% [[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m[32m=[0m] (8 steps/s)[0m [loss=0.204] [~loss=0.160] [~acc=95.14%]          


	[Step 3000] median train step: 114828 microseconds

Results on train-eval [Parallelized]:
	Mean Loss (#loss): 0.135
	Mean Accuracy (#acc): 96.54%
Results on test-eval [Parallelized]:
	Mean Loss (#loss): 0.382
	Mean Accuracy (#acc): 85.74%
