Skip to content

Commit

Permalink
Transformer encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
yushiyangk committed Oct 30, 2022
1 parent 9631929 commit 704704a
Show file tree
Hide file tree
Showing 2 changed files with 244 additions and 0 deletions.
104 changes: 104 additions & 0 deletions shumai/module/transformer.ts
Expand Up @@ -4,6 +4,8 @@ import * as ops from '../tensor/tensor_ops'
import * as util from '../util'
import { Linear } from './linear'
import { Module } from './module'
import { LayerNorm } from './norm'
import { Sequential } from './sequential'

const sm = { ...ops, ...tensor, util }

Expand Down Expand Up @@ -256,3 +258,105 @@ class FeedForward extends Module {
return output
}
}

export class TransformerEncoderLayer extends Module {
private dim: number
private heads: number
private attentionDim: number
private feedForwardDim: number
private mha: TransformerMultiheadAttention
private mhaNorm: LayerNorm
private ff: FeedForward
private ffNorm: LayerNorm

constructor(dim: number, heads: number, attentionDim?: number, feedForwardDim?: number) {
super()

this.dim = dim
this.heads = heads
if (attentionDim === undefined) {
this.attentionDim = dim
} else {
this.attentionDim = attentionDim
}
if (feedForwardDim === undefined) {
this.feedForwardDim = dim
} else {
this.feedForwardDim = feedForwardDim
}

this.mha = new TransformerMultiheadAttention(this.dim, this.heads, this.attentionDim)
this.mhaNorm = new LayerNorm([this.dim])
this.ff = new FeedForward(this.dim, this.feedForwardDim)
this.ffNorm = new LayerNorm([this.dim])
}

forward(input: Tensor): Tensor {
// shape [..., tokens, dim]
let mhaOutput = this.mha(input, input, input) // shape [..., tokens, dim]
mhaOutput = this.mhaNorm(input.add(mhaOutput))

let ffOutput = this.ff(mhaOutput) // shape [..., tokens, dim]
ffOutput = this.ffNorm(mhaOutput.add(ffOutput))

return ffOutput
}
}

export class TransformerEncoder extends Module {
private dim: number
private heads: number
private depth: number
private attentionDim: number
private feedForwardDim: number
private positional: TransformerPositionalEncoding
private layers: Sequential

constructor(
dim: number,
heads: number,
depth: number,
attentionDim?: number,
feedForwardDim?: number,
initSequenceLength?: number
) {
super()

this.dim = dim
this.heads = heads
this.depth = depth
if (attentionDim === undefined) {
this.attentionDim = dim
} else {
this.attentionDim = attentionDim
}
if (feedForwardDim === undefined) {
this.feedForwardDim = dim
} else {
this.feedForwardDim = feedForwardDim
}

if (feedForwardDim === undefined) {
this.positional = new TransformerPositionalEncoding(this.dim)
} else {
this.positional = new TransformerPositionalEncoding(this.dim, initSequenceLength)
}

const layers: TransformerEncoderLayer[] = []
for (let i = 0; i < this.depth; i++) {
layers.push(
new TransformerEncoderLayer(this.dim, this.heads, this.attentionDim, this.feedForwardDim)
)
}
this.layers = new Sequential(...layers)
}

forward(input: Tensor): Tensor {
// shape [..., tokens, dim]
const positionalEncoding = this.positional(input.shape[input.shape.length - 2]) // shape [tokens, dim]

let output = input.add(positionalEncoding) // shape [..., tokens, dim]
output = this.layers(output) // shape [..., tokens, dim]
return output
}
}
140 changes: 140 additions & 0 deletions test/transformer.test.ts
Expand Up @@ -496,3 +496,143 @@ describe('TransformerMultiheadAttention', () => {
expect(!!values.grad).toBe(true)
})
})

describe('TransformerEncoderLayer', () => {
it('single token, two heads', () => {
const module = new sm.module.TransformerEncoderLayer(6, 2)
const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6])

const result = module(input)
areSameShape(result, input)
})
it('single token, two heads (attentionDim)', () => {
const module = new sm.module.TransformerEncoderLayer(6, 2, 7)
const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6])

const result = module(input)
areSameShape(result, input)
})
it('single token, two heads (feedForwardDim)', () => {
const module = new sm.module.TransformerEncoderLayer(6, 2, 6, 12)
const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6])

const result = module(input)
areSameShape(result, input)
})
it('batch samples are independent', () => {
const module = new sm.module.TransformerEncoderLayer(6, 2)
const singleInput = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6])
const batchInput = sm
.tensor(
new Float32Array(
[2, 3, 0.5, 0.25, 1, 2.25].concat([
Math.random(),
Math.random(),
Math.random(),
Math.random(),
Math.random(),
Math.random()
])
)
)
.reshape([2, 1, 6])

const singleResult = module(singleInput)
const batchResult = module(batchInput)
areSameShape(batchResult, batchInput)

expectArraysClose(
batchResult.index([0, ':', ':']).toFloat32Array(),
singleResult.toFloat32Array()
)
areSameShape(batchResult.index([0, ':', ':']), singleResult)
})
it('calculates gradient', () => {
const module = new sm.module.TransformerEncoderLayer(6, 2)
const input = sm
.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25]))
.reshape([1, 6])
.requireGrad()

const result = module(input).sum()
result.backward()
expect(!!input.grad).toBe(true)
})
})

describe('TransformerEncoder', () => {
it('depth=1', () => {
const module = new sm.module.TransformerEncoder(6, 2, 1)
const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6])

const result = module(input)
areSameShape(result, input)
})
it('depth=2', () => {
const module = new sm.module.TransformerEncoder(6, 2, 2)
const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6])

const result = module(input)
areSameShape(result, input)
})
it('depth=2 (attentionDim)', () => {
const module = new sm.module.TransformerEncoder(6, 2, 2, 7)
const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6])

const result = module(input)
areSameShape(result, input)
})
it('depth=2 (feedForwardDim)', () => {
const module = new sm.module.TransformerEncoder(6, 2, 2, 6, 12)
const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6])

const result = module(input)
areSameShape(result, input)
})
it('depth=2 (initSequenceLength)', () => {
const module = new sm.module.TransformerEncoder(6, 2, 2, 6, 6, 1)
const input = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6])

const result = module(input)
areSameShape(result, input)
})
it('batch samples are independent', () => {
const module = new sm.module.TransformerEncoder(6, 2, 2)
const singleInput = sm.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25])).reshape([1, 6])
const batchInput = sm
.tensor(
new Float32Array(
[2, 3, 0.5, 0.25, 1, 2.25].concat([
Math.random(),
Math.random(),
Math.random(),
Math.random(),
Math.random(),
Math.random()
])
)
)
.reshape([2, 1, 6])

const singleResult = module(singleInput)
const batchResult = module(batchInput)
areSameShape(batchResult, batchInput)

expectArraysClose(
batchResult.index([0, ':', ':']).toFloat32Array(),
singleResult.toFloat32Array()
)
areSameShape(batchResult.index([0, ':', ':']), singleResult)
})
it('calculates gradient', () => {
const module = new sm.module.TransformerEncoder(6, 2, 2)
const input = sm
.tensor(new Float32Array([2, 3, 0.5, 0.25, 1, 2.25]))
.reshape([1, 6])
.requireGrad()

const result = module(input).sum()
result.backward()
expect(!!input.grad).toBe(true)
})
})

0 comments on commit 704704a

Please sign in to comment.