From 528ca3e68b0b60de4f5e9434c076efca74e0b947 Mon Sep 17 00:00:00 2001 From: Yu Shiyang Date: Wed, 2 Nov 2022 01:34:07 +0000 Subject: [PATCH] Transformer encoder (#97) * Positional encoding * Feed forward network * Transformer encoder * Added documentation comments for Transformer --- shumai/module/transformer.ts | 290 +++++++++++++++++++++++++++++++++++ test/transformer.test.ts | 253 +++++++++++++++++++++++++++++- 2 files changed, 542 insertions(+), 1 deletion(-) diff --git a/shumai/module/transformer.ts b/shumai/module/transformer.ts index 0dd03898..b4d1304e 100644 --- a/shumai/module/transformer.ts +++ b/shumai/module/transformer.ts @@ -4,9 +4,115 @@ import * as ops from '../tensor/tensor_ops' import * as util from '../util' import { Linear } from './linear' import { Module } from './module' +import { LayerNorm } from './norm' +import { Sequential } from './sequential' const sm = { ...ops, ...tensor, util } +/** + * A module to generate the positional encoding for a Transformer of a given input dimension, + * + * $$ \mathrm{PE}_{i, 2z} = \sin \left( \frac{i}{10000^{2z/d}} \right) $$ + * + * $$ \mathrm{PE}_{i, 2z + 1} = \cos \left( \frac{i}{10000^{2z/d}} \right) $$ + * + * where $i$ is the sequence position, $2z$ and $2z+1$ are the dimensions of the input embedding, and $d$ is the dimensionality of the input embedding. + * + * The multiplicative factors $\frac{1}{10000^{2z/d}}$ are precomputed during object creation as they are constant for all $i$. + * + * The full PE is initially precomputed for all $i$ up to 256 (configurable). If the module is called with a sequence length larger than the initial value, the additional values are also calculated and stored. + */ +export class TransformerPositionalEncoding extends Module { + /** + * The default `initSequenceLength` if none is supplied in the constructor. + */ + static readonly DEFAULT_SEQUENCE_LENGTH = 256 + /** + * The base of the exponent in the positional encoding. + */ + static readonly ENCODING_BASE = 10000 + + private dim: number + private sequenceLength: number + private encodingFactors: Tensor + private encoding: Tensor + + /** + * @param dim - Number of dimensions of each input embedding + * @param initSequenceLength - Initial sequence length that the positional embedding should be computed for, or {@link DEFAULT_SEQUENCE_LENGTH} if not specified + */ + constructor(dim: number, initSequenceLength?: number) { + super() + + if (dim <= 0) { + throw new Error(`Module dimension must be > 0: got ${dim}`) + } + + this.dim = dim + if (initSequenceLength === undefined) { + this.sequenceLength = TransformerPositionalEncoding.DEFAULT_SEQUENCE_LENGTH + } else if (initSequenceLength <= 0) { + throw new Error(`Initial sequenceLength must be > 0: got ${initSequenceLength}`) + } else { + this.sequenceLength = initSequenceLength + } + + // base and numerator must be full([1], x) instead of scalar(x) + // Otherwise, if the other operand has shape [1], the result will be reduced to scalar + const base = sm.full([1], TransformerPositionalEncoding.ENCODING_BASE) + const numerator = sm.full([1], 1) + const denominator = sm.scalar(this.dim) + const evenDims = sm.arange(0, this.dim + 1, 2) + this.encodingFactors = numerator.div(base.power(evenDims.div(denominator))) // shape [floor((dim + 1) / 2)] + + this.encoding = this.calculate(0, this.sequenceLength) // shape [sequenceLength, dim] + } + + /** + * Calculate positional encodings at a given range of sequence positions. + * + * @param start - Start of the range to calculate + * @param end - End of the range to calculate + * + * @returns a Tensor of calculated positional embeddings with shape `[end - start, dim]` + */ + calculate(start: number, end: number): Tensor { + const length = end - start + const pairedDim = this.encodingFactors.shape[0] + const pos = sm.arange(start, end).reshape([length, 1]) + + const evenEncoding = sm.sin(pos.mul(this.encodingFactors)).reshape([length, pairedDim, 1]) + const oddEncoding = sm.cos(pos.mul(this.encodingFactors)).reshape([length, pairedDim, 1]) + let encoding = sm.concatenate([evenEncoding, oddEncoding], -1) // shape [length, pairedDim, 2] + encoding = encoding.reshape([length, pairedDim * 2]) + + if (this.dim % 2 !== 0) { + encoding = encoding.index([':', `:${this.dim}`]).reshape([length, this.dim]) + // reshape is necessary to preserve the last axis if this.dim is 1 + } + + return encoding + } + + /** + * @param sequenceLength - Length of the sequence for which the positional embedding should be calculated + * @returns a Tensor of positional embeddings with shape `[length, dim]`, using precomputed values if available + */ + forward(sequenceLength: number): Tensor { + if (sequenceLength > this.sequenceLength) { + const extension = this.calculate(this.sequenceLength, sequenceLength) + this.encoding = sm.concatenate([this.encoding, extension], 0) + this.sequenceLength = sequenceLength + } + + if (sequenceLength === this.sequenceLength) { + return this.encoding + } else { + return this.encoding.index([`:${sequenceLength}`, ':']) + } + } +} + function checkAttentionInputs( attentionDim: number, queries: Tensor, @@ -59,10 +165,16 @@ function checkAttentionInputs( } } +/** + * Scaled dot-product mechanism as described by Vaswani et al. The {@link scaleFactor} is computed during object creation as $\frac{1}{\sqrt{d}}$, where $d$ is the dimensionality of the inputs. + */ export class TransformerDotProductAttention extends Module { private dim: number private scaleFactor: Tensor + /** + * @param dim - Number of dimensions of the inputs + */ constructor(dim: number) { super() this.dim = dim @@ -73,6 +185,12 @@ export class TransformerDotProductAttention extends Module { return tensor.mul(this.scaleFactor) } + /** + * @param queries - Tensor of query embeddings, shape `[..., queryTokens, dim]` + * @param keys - Tensor of key embeddings, shape `[..., keyTokens, dim]` + * @param values - Tensor of value embeddings each corresponding to a key, shape `[..., keyTokens, dim]` + * @returns A Tensor of shape `[..., queryTokens, dim]` + */ forward(queries: Tensor, keys: Tensor, values: Tensor, mask?: Tensor): Tensor { // queries shape [..., queryTokens, dim] // keys and values shape [..., keyTokens, dim] @@ -92,6 +210,9 @@ export class TransformerDotProductAttention extends Module { } } +/** + * Multi-head attention mechanism as described by Vaswani et al. The input Tensors are linearly embedded before being passed to {@link TransformerDotProductAttention | scaled dot-product attentions}. + */ export class TransformerMultiheadAttention extends Module { private dim: number private heads: number @@ -102,6 +223,11 @@ export class TransformerMultiheadAttention extends Module { private attention: TransformerDotProductAttention private concatEmbed: Linear + /** + * @param dim - Number of dimensions of the input embeddings + * @param heads - Number of heads for the multi-head attention + * @param attentionDim - Number of dimensions of the further embeddings used by the {@link TransformerDotProductAttention | scaled dot-product attentions}, or `dim` if not specified + */ constructor(dim: number, heads: number, attentionDim?: number) { super() @@ -125,6 +251,12 @@ export class TransformerMultiheadAttention extends Module { this.concatEmbed = new Linear(this.attentionDim * heads, dim) } + /** + * @param queries - Tensor of query vectors, shape `[..., queryTokens, dim]` + * @param keys - Tensor of key vectors, shape `[..., keyTokens, dim]` + * @param values - Tensor of value vectors each corresponding to a key, shape `[..., keyTokens, dim]` + * @returns A Tensor of shape `[..., queryTokens, dim]` + */ forward(queries: Tensor, keys: Tensor, values: Tensor): Tensor { // queries shape [..., queryTokens, dim] // keys and values shape [..., keyTokens, dim] @@ -161,3 +293,161 @@ export class TransformerMultiheadAttention extends Module { return output } } + +class FeedForward extends Module { + private dim: number + private hiddenDim: number + private affineIn: Linear + private affineOut: Linear + + constructor(dim: number, hiddenDim?: number) { + super() + this.dim = dim + if (hiddenDim === undefined) { + this.hiddenDim = dim + } else { + this.hiddenDim = hiddenDim + } + this.affineIn = new Linear(this.dim, this.hiddenDim) + this.affineOut = new Linear(this.hiddenDim, this.dim) + } + + forward(input: Tensor): Tensor { + // shape [..., dim] + let output = this.affineIn(input).relu() // shape [..., hiddenDim] + output = this.affineOut(output) // shape [..., dim] + return output + } +} + +/** + * A layer of the Transformer encoder, as described by Vaswani et al, consisting of a {@link TransformerMultiheadAttention | multi-head attention} layer and a fully-connected feed forward network. Both of these use residual connections and are normalised with {@link LayerNorm}. + */ +export class TransformerEncoderLayer extends Module { + private dim: number + private heads: number + private attentionDim: number + private feedForwardDim: number + private mha: TransformerMultiheadAttention + private mhaNorm: LayerNorm + private ff: FeedForward + private ffNorm: LayerNorm + + /** + * @param dim - Number of dimensions of the input embeddings + * @param heads - Number of heads for the multi-head attention + * @param attentionDim - Number of dimensions of the embeddings used in the scaled dot-product attention, or `dim` if not specified + * @param feedForwardDim - Number of dimensions in the hidden layer of the feed forward network, or `dim` if not specified + */ + constructor(dim: number, heads: number, attentionDim?: number, feedForwardDim?: number) { + super() + + this.dim = dim + this.heads = heads + if (attentionDim === undefined) { + this.attentionDim = dim + } else { + this.attentionDim = attentionDim + } + if (feedForwardDim === undefined) { + this.feedForwardDim = dim + } else { + this.feedForwardDim = feedForwardDim + } + + this.mha = new TransformerMultiheadAttention(this.dim, this.heads, this.attentionDim) + this.mhaNorm = new LayerNorm([this.dim]) + this.ff = new FeedForward(this.dim, this.feedForwardDim) + this.ffNorm = new LayerNorm([this.dim]) + } + + /** + * @param input - Input Tensor of shape `[..., tokens, dim]` + * @returns A Tensor of shape `[..., tokens, dim]` + */ + forward(input: Tensor): Tensor { + // shape [..., tokens, dim] + let mhaOutput = this.mha(input, input, input) // shape [..., tokens, dim] + mhaOutput = this.mhaNorm(input.add(mhaOutput)) + + let ffOutput = this.ff(mhaOutput) // shape [..., tokens, dim] + ffOutput = this.ffNorm(mhaOutput.add(ffOutput)) + + return ffOutput + } +} + +/** + * Transformer encoder as described by Vaswani et al containing an arbitrary number of {@link TransformerEncoderLayer | TransformerEncoderLayers}. + * + * This module includes the {@link TransformerPositionalEncoding | positional encoding}, but does not include any initial embedding of an input sequence into vectors (which should have been separately done by e.g. word2vec). + */ +export class TransformerEncoder extends Module { + private dim: number + private heads: number + private depth: number + private attentionDim: number + private feedForwardDim: number + private positional: TransformerPositionalEncoding + private layers: Sequential + + /** + * @param dim - Number of dimensions of the input embeddings + * @param heads - Number of heads for the multi-head attention + * @param depth - Number of encoder layers + * @param attentionDim - Number of dimensions of the embeddings used in the scaled dot-product attention, or `dim` if not specified + * @param feedForwardDim - Number of dimensions in the hidden layer of the feed forward network, or `dim` if not specified + * @param initSequenceLength - Initial sequence length that the positional embedding should be computed for, or {@link TransformerPositionalEncoding.DEFAULT_SEQUENCE_LENGTH} if not specified + */ + constructor( + dim: number, + heads: number, + depth: number, + attentionDim?: number, + feedForwardDim?: number, + initSequenceLength?: number + ) { + super() + + this.dim = dim + this.heads = heads + this.depth = depth + if (attentionDim === undefined) { + this.attentionDim = dim + } else { + this.attentionDim = attentionDim + } + if (feedForwardDim === undefined) { + this.feedForwardDim = dim + } else { + this.feedForwardDim = feedForwardDim + } + + if (feedForwardDim === undefined) { + this.positional = new TransformerPositionalEncoding(this.dim) + } else { + this.positional = new TransformerPositionalEncoding(this.dim, initSequenceLength) + } + + const layers: TransformerEncoderLayer[] = [] + for (let i = 0; i < this.depth; i++) { + layers.push( + new TransformerEncoderLayer(this.dim, this.heads, this.attentionDim, this.feedForwardDim) + ) + } + this.layers = new Sequential(...layers) + } + + /** + * @param input - Input Tensor of shape `[..., tokens, dim]` + * @returns A Tensor of shape `[..., tokens, dim]` + */ + forward(input: Tensor): Tensor { + // shape [..., tokens, dim] + const positionalEncoding = this.positional(input.shape[input.shape.length - 2]) // shape [tokens, dim] + + let output = input.add(positionalEncoding) // shape [..., tokens, dim] + output = this.layers(output) // shape [..., tokens, dim] + return output + } +} diff --git a/test/transformer.test.ts b/test/transformer.test.ts index 72808d94..186f7593 100644 --- a/test/transformer.test.ts +++ b/test/transformer.test.ts @@ -1,6 +1,117 @@ import * as sm from '@shumai/shumai' import { describe, expect, it } from 'bun:test' -import { areSameShape, expectArraysClose, expectThrows } from './utils' +import { areSameShape, expectArraysClose, expectThrows, isShape } from './utils' + +describe('TransformerPositionalEncoding', () => { + it('dim=1', () => { + const module = new sm.module.TransformerPositionalEncoding(1) + + const result = module(3) + const expected = [0, 1, 2].map((x) => Math.sin(x)) + expect(isShape(result, [3, 1])).toBe(true) + expectArraysClose(result.toFloat32Array(), expected) + }) + it('dim=2', () => { + const module = new sm.module.TransformerPositionalEncoding(2) + + const result = module(3) + const expected = [0, 0, 1, 1, 2, 2].map((x, i) => (i % 2 ? Math.cos(x) : Math.sin(x))) + expect(isShape(result, [3, 2])).toBe(true) + expectArraysClose(result.toFloat32Array(), expected) + }) + it('dim=3', () => { + const dim = 3 + const module = new sm.module.TransformerPositionalEncoding(dim) + + const result = module(3) + const expected = [0, 0, 0, 1, 1, 1, 2, 2, 2].map((x, i) => + ((i % dim) % 2 ? Math.cos : Math.sin)( + x / + Math.pow( + sm.module.TransformerPositionalEncoding.ENCODING_BASE, + ((i % dim) - ((i % dim) % 2 ? 1 : 0)) / dim + ) + ) + ) + expect(isShape(result, [3, dim])).toBe(true) + expectArraysClose(result.toFloat32Array(), expected) + }) + it('dim=4', () => { + const dim = 4 + const module = new sm.module.TransformerPositionalEncoding(dim) + + const result = module(3) + const expected = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2].map((x, i) => + ((i % dim) % 2 ? Math.cos : Math.sin)( + x / + Math.pow( + sm.module.TransformerPositionalEncoding.ENCODING_BASE, + ((i % dim) - ((i % dim) % 2 ? 1 : 0)) / dim + ) + ) + ) + expect(isShape(result, [3, dim])).toBe(true) + expectArraysClose(result.toFloat32Array(), expected) + }) + it('init sequence length, dim=4', () => { + const dim = 4 + const module = new sm.module.TransformerPositionalEncoding(dim, dim) + + const result = module(3) + const expected = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2].map((x, i) => + ((i % dim) % 2 ? Math.cos : Math.sin)( + x / + Math.pow( + sm.module.TransformerPositionalEncoding.ENCODING_BASE, + ((i % dim) - ((i % dim) % 2 ? 1 : 0)) / dim + ) + ) + ) + expect(isShape(result, [3, dim])).toBe(true) + expectArraysClose(result.toFloat32Array(), expected) + }) + it('extends sequence length, dim=4', () => { + const dim = 4 + const module = new sm.module.TransformerPositionalEncoding(dim, 2) + + const result = module(3) + const expected = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2].map((x, i) => + ((i % dim) % 2 ? Math.cos : Math.sin)( + x / + Math.pow( + sm.module.TransformerPositionalEncoding.ENCODING_BASE, + ((i % dim) - ((i % dim) % 2 ? 1 : 0)) / dim + ) + ) + ) + expect(isShape(result, [3, dim])).toBe(true) + expectArraysClose(result.toFloat32Array(), expected) + }) + it('dim=0 is invalid', () => { + expectThrows( + () => new sm.module.TransformerPositionalEncoding(0), + new RegExp('dimension must be > 0') + ) + }) + it('dim=-1 is invalid', () => { + expectThrows( + () => new sm.module.TransformerPositionalEncoding(-1), + new RegExp('dimension must be > 0') + ) + }) + it('0 init sequenceLength is invalid', () => { + expectThrows( + () => new sm.module.TransformerPositionalEncoding(3, 0), + new RegExp('sequenceLength must be > 0') + ) + }) + it('-1 init sequenceLength is invalid', () => { + expectThrows( + () => new sm.module.TransformerPositionalEncoding(3, -1), + new RegExp('sequenceLength must be > 0') + ) + }) +}) describe('TransformerDotProductAttention', () => { it('single matching token', () => { @@ -385,3 +496,143 @@ describe('TransformerMultiheadAttention', () => { expect(!!values.grad).toBe(true) }) }) + +describe('TransformerEncoderLayer', () => { + it('single token, two heads', () => { + const module = new sm.module.TransformerEncoderLayer(6, 2) + const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6]) + + const result = module(input) + areSameShape(result, input) + }) + it('single token, two heads (attentionDim)', () => { + const module = new sm.module.TransformerEncoderLayer(6, 2, 7) + const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6]) + + const result = module(input) + areSameShape(result, input) + }) + it('single token, two heads (feedForwardDim)', () => { + const module = new sm.module.TransformerEncoderLayer(6, 2, 6, 12) + const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6]) + + const result = module(input) + areSameShape(result, input) + }) + it('batch samples are independent', () => { + const module = new sm.module.TransformerEncoderLayer(6, 2) + const singleInput = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6]) + const batchInput = sm + .tensor( + new Float32Array( + [2, 3, 0.5, 0.25, 1, 2.25].concat([ + Math.random(), + Math.random(), + Math.random(), + Math.random(), + Math.random(), + Math.random() + ]) + ) + ) + .reshape([2, 1, 6]) + + const singleResult = module(singleInput) + const batchResult = module(batchInput) + areSameShape(batchResult, batchInput) + + expectArraysClose( + batchResult.index([0, ':', ':']).toFloat32Array(), + singleResult.toFloat32Array() + ) + areSameShape(batchResult.index([0, ':', ':']), singleResult) + }) + it('calculates gradient', () => { + const module = new sm.module.TransformerEncoderLayer(6, 2) + const input = sm + .tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])) + .reshape([1, 6]) + .requireGrad() + + const result = module(input).sum() + result.backward() + expect(!!input.grad).toBe(true) + }) +}) + +describe('TransformerEncoder', () => { + it('depth=1', () => { + const module = new sm.module.TransformerEncoder(6, 2, 1) + const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6]) + + const result = module(input) + areSameShape(result, input) + }) + it('depth=2', () => { + const module = new sm.module.TransformerEncoder(6, 2, 2) + const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6]) + + const result = module(input) + areSameShape(result, input) + }) + it('depth=2 (attentionDim)', () => { + const module = new sm.module.TransformerEncoder(6, 2, 2, 7) + const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6]) + + const result = module(input) + areSameShape(result, input) + }) + it('depth=2 (feedForwardDim)', () => { + const module = new sm.module.TransformerEncoder(6, 2, 2, 6, 12) + const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6]) + + const result = module(input) + areSameShape(result, input) + }) + it('depth=2 (initSequenceLength)', () => { + const module = new sm.module.TransformerEncoder(6, 2, 2, 6, 6, 1) + const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6]) + + const result = module(input) + areSameShape(result, input) + }) + it('batch samples are independent', () => { + const module = new sm.module.TransformerEncoder(6, 2, 2) + const singleInput = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6]) + const batchInput = sm + .tensor( + new Float32Array( + [2, 3, 0.5, 0.25, 1, 2.25].concat([ + Math.random(), + Math.random(), + Math.random(), + Math.random(), + Math.random(), + Math.random() + ]) + ) + ) + .reshape([2, 1, 6]) + + const singleResult = module(singleInput) + const batchResult = module(batchInput) + areSameShape(batchResult, batchInput) + + expectArraysClose( + batchResult.index([0, ':', ':']).toFloat32Array(), + singleResult.toFloat32Array() + ) + areSameShape(batchResult.index([0, ':', ':']), singleResult) + }) + it('calculates gradient', () => { + const module = new sm.module.TransformerEncoder(6, 2, 2) + const input = sm + .tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])) + .reshape([1, 6]) + .requireGrad() + + const result = module(input).sum() + result.backward() + expect(!!input.grad).toBe(true) + }) +})