Skip to content

Commit

Permalink
Clearer naming for partial derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
salieri committed Jun 20, 2019
1 parent bc9deca commit 4ca3420
Show file tree
Hide file tree
Showing 15 changed files with 245 additions and 93 deletions.
39 changes: 39 additions & 0 deletions examples/dense-simple-2x.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/**
* Simple example network which learns to multiply input by two
*
* Input: 5, 3, 8
* Output: 10, 6, 16
*/

import _ from 'lodash';

import {
DeferredInputFeed,
DeferredMemoryInputFeed,
Dense,
Model,
NDArray,
} from '../src';

import { SampleGenerator } from './sample-generator';


export class DenseSimple2x extends SampleGenerator {
public model(): Model {
const model = new Model({ seed: this.params.seed });

model
.input(1)
.push(new Dense({ units: 3, activation: 'identity' }, 'hidden-1'))
.push(new Dense({ units: 1, activation: 'identity' }, 'output'));

return model;
}


public samples(count: number): DeferredInputFeed {
return DeferredMemoryInputFeed.factory(
_.times(count, n => ({ x: new NDArray([n]), y: new NDArray([n * 2]) })),
);
}
}
1 change: 1 addition & 0 deletions examples/index.ts
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
export * from './dense-2x';
export * from './dense-simple-2x';
export * from './sample-generator';
24 changes: 24 additions & 0 deletions src/generic/parameterized.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ export type FinalParameters<T> = {
export abstract class Parameterized<TInput extends Parameters, TCoerced extends TInput = TInput> {
protected readonly params: FinalParameters<TCoerced>;

protected instantiated: boolean = false;

protected defaultInstantiation: boolean = false;

/* eslint-disable-next-line @typescript-eslint/no-explicit-any */
public constructor(params: TInput = {} as any) {
this.params = this.validateParams(params);
Expand All @@ -31,5 +35,25 @@ export abstract class Parameterized<TInput extends Parameters, TCoerced extends

return result.value as unknown as FinalParameters<TCoerced>;
}


public setDefaultInstantiationFlag(flag: boolean): void {
this.defaultInstantiation = flag;
}


public setInstantiatedFlag(flag: boolean): void {
this.instantiated = flag;
}


public getDefaultInstantiationFlag(): boolean {
return this.defaultInstantiation;
}


public getInstantiatedFlag(): boolean {
return this.instantiated;
}
}

7 changes: 4 additions & 3 deletions src/nn/graph/data-feed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ export class GraphDataFeed {

public readonly backpropInput = new DeferredCollectionWrapper();

public readonly backpropFit = new DeferredCollection();
public readonly fitter = new DeferredCollection();

public readonly optimizer = new DeferredCollection();

public readonly train = new DeferredCollectionWrapper();
public readonly trainer = new DeferredCollectionWrapper();

public activationDerivative?: Vector;

Expand All @@ -27,6 +27,7 @@ export class GraphDataFeed {
this.backpropOutput.unsetValues();
this.backpropInput.unsetValues();

this.train.unsetValues();
this.trainer.unsetValues();
this.fitter.unsetValues();
}
}
128 changes: 78 additions & 50 deletions src/nn/layer/dense.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,17 @@ export class Dense extends Layer<DenseParamsInput, DenseParamsCoerced> {


protected async optimizeExec(): Promise<void> {
const backpropFit = this.data.backpropFit;
const fitter = this.data.fitter;
const optimizer = this.data.optimizer;

const weightError = backpropFit.getValue(Dense.WEIGHT_ERROR, Matrix);
const weightError = fitter.getValue(Dense.WEIGHT_ERROR, Matrix);
const weights = optimizer.getValue(Dense.WEIGHT_MATRIX, Matrix);
const optimizedWeight = this.params.weightOptimizer.optimize(weights, weightError);

optimizer.setValue(Dense.WEIGHT_MATRIX, optimizedWeight, Matrix);

if (this.params.bias) {
const biasError = backpropFit.getValue(Dense.BIAS_ERROR, Vector);
const biasError = fitter.getValue(Dense.BIAS_ERROR, Vector);
const bias = optimizer.getValue(Dense.BIAS_VECTOR, Vector);
const optimizedBias = this.params.biasOptimizer.optimize(bias, biasError);

Expand All @@ -63,24 +63,16 @@ export class Dense extends Layer<DenseParamsInput, DenseParamsCoerced> {
}


protected calculateActivationDerivative(): void {
// dActivated/dLinear = dO/dZ = dOut/dNet
protected calculateActivationDerivative(): Vector {
const output = this.data.output;
const train = this.data.train;
const train = this.data.trainer;

const activated = output.getValue(Dense.ACTIVATED_OUTPUT, Vector);
const linear = output.getValue(Dense.LINEAR_OUTPUT, Vector);
const y = train.hasDefaultValue() ? train.getDefaultValue(Vector) : undefined;

this.data.activationDerivative = this.params.activation.derivative(activated, linear, y);
}


protected getActivationDerivative(): Vector {
if (!this.data.activationDerivative) {
throw new Error('Activation derivative has not been calculated');
}

return this.data.activationDerivative;
return this.params.activation.derivative(activated, linear, y);
}


Expand All @@ -89,80 +81,116 @@ export class Dense extends Layer<DenseParamsInput, DenseParamsCoerced> {
*/
protected async backwardExec(): Promise<void> {
const backpropOutput = this.data.backpropOutput;
const backpropFit = this.data.backpropFit;

this.calculateActivationDerivative();
const fitter = this.data.fitter;

// dError/dLinear
const errorTerm = this.calculateErrorTerm();

// dLinear/dWeights
const weightError = this.calculateWeightDerivative(errorTerm);

backpropOutput.setValue(Layer.ERROR_TERM, errorTerm, Vector);
backpropOutput.setValue(Dense.WEIGHT_MATRIX, this.data.optimizer.getValue(Dense.WEIGHT_MATRIX, Matrix), Matrix);
backpropFit.setValue(Dense.WEIGHT_ERROR, weightError, Matrix);
fitter.setValue(Dense.WEIGHT_ERROR, weightError, Matrix);

if (this.params.bias) {
// dLinear/dBias
const biasError = this.calculateBiasDerivative(errorTerm);

backpropFit.setValue(Dense.BIAS_ERROR, biasError, Vector);
fitter.setValue(Dense.BIAS_ERROR, biasError, Vector);
}
}


/**
* dError/dActivated = L'(yHat, y)
*/
protected calculateActivationErrorDerivativeFromLabel(): Vector {
const yHat = this.data.output.getDefaultValue(Vector);
const y = this.data.trainer.getDefaultValue(Vector);

return this.params.loss.gradient(yHat, y);
}


/**
* @link https://brilliant.org/wiki/backpropagation/
* dError/dLinear = (dError/dActivated) * (dActivated/dLinear)
*/
protected calculateErrorTermFromLabel(): Vector {
const yHat = this.data.output.getDefaultValue(Vector);
const y = this.data.train.getDefaultValue(Vector);
const derivative = this.getActivationDerivative();
const loss = this.params.loss.gradient(yHat, y);
// dError/dActivated
const loss = this.calculateActivationErrorDerivativeFromLabel();

// dActivated/dLinear
const derivative = this.calculateActivationDerivative();

// (a[final] - y) = (yHat - y) = -(y - yHat) = dErrorTotal / dOutput
// layerError = g'(a[final])(yHat - y)
// return derivative.mul(yHat.sub(y));

return derivative.mul(loss);
// dError/dLinear = (dError/dActivated) * (dActivated/dLinear)
return loss.mul(derivative);
}


protected calculateErrorTermFromChain(): Vector {
/**
* dError/dActivated = (dError/dLinearNext) . (dLinearNext/dActivated) = errorTermNext . weightNext
*/
protected calculateActivationErrorDerivativeFromChain(): Vector {
const backpropInput = this.data.backpropInput;
const layerErrorNext = backpropInput.getValue(Layer.ERROR_TERM, Vector);

// (dLinearNext/dActivated) = weights[l+1]
const weightNext = backpropInput.getValue(Dense.WEIGHT_MATRIX, Matrix);
const derivative = this.getActivationDerivative();

// layerError = (wNext)T dNext .* g'(z)
return weightNext.transpose().vecmul(layerErrorNext).mul(derivative);
// dError[l+1]/dLinear[l+1]
const dErrorOverDLinearNext = backpropInput.getValue(Layer.ERROR_TERM, Vector);

// dError/dActivated
return weightNext.transpose().vecmul(dErrorOverDLinearNext);
}


/**
* dError/dLinear[l] = (weights[l+1]T . (dError[l+1]/dLinear[l+1])) .* (dActivated/dLinear)
* dError/dLinear = (dError/dActivated) * (dActivated/dLinear)
*/
protected calculateErrorTermFromChain(): Vector {
// dError/dActivated
const dErrorOverDActivated = this.calculateActivationErrorDerivativeFromChain();

// dActivated/dLinear
const dActivatedOverDLinear = this.calculateActivationDerivative();

// dError/dLinear
return dErrorOverDActivated.mul(dActivatedOverDLinear);
}


/**
* dError/dLinear
*/
protected calculateErrorTerm(): Vector {
return this.data.train.hasDefaultValue() ? this.calculateErrorTermFromLabel() : this.calculateErrorTermFromChain();
return this.data.trainer.hasDefaultValue() ? this.calculateErrorTermFromLabel() : this.calculateErrorTermFromChain();
}


/**
* dLinear/dWeights = (dError/dLinear) (o) X
*/
protected calculateWeightDerivative(errorTerm: Vector): Matrix {
const inputVector = this.data.input.getDefaultValue(Vector);

return inputVector.outer(errorTerm).transpose();
return errorTerm.outer(inputVector); // inputVector.outer(errorTerm).transpose();
}


/**
* dLinear/dBias = sum(dError/dLinear)
*/
protected calculateBiasDerivative(errorTerm: Vector): Vector {
return new Vector([errorTerm.sum()]);

/* if (this.data.train.hasDefaultValue()) {
return new Vector([errorTerm.sum()]);
}
const backpropInput = this.data.backpropInput;
const layerErrorNext = backpropInput.getValue(Layer.ERROR_TERM, Vector);
const weightNext = backpropInput.getValue(Dense.WEIGHT_MATRIX, Matrix);
const diagonalWeights = weightNext.pickDiagonal();
const derivative = this.getActivationDerivative();
return new Vector([layerErrorNext.mul(diagonalWeights).mul(derivative).sum()]); */
return this.data.trainer.hasDefaultValue()
? new Vector([errorTerm.sum()])
: new Vector([this.data.optimizer.getValue(Dense.BIAS_VECTOR, Vector).dot(errorTerm)]);
}


Expand Down Expand Up @@ -271,7 +299,7 @@ export class Dense extends Layer<DenseParamsInput, DenseParamsCoerced> {


protected async compileInitialization(): Promise<void> {
this.raw.trainingLabels.setDefault(this.data.train);
this.raw.trainingLabels.setDefault(this.data.trainer);
this.raw.backpropOutputs.setDefault(this.data.backpropOutput);
this.raw.outputs.setDefault(this.data.output);
}
Expand Down Expand Up @@ -302,10 +330,10 @@ export class Dense extends Layer<DenseParamsInput, DenseParamsCoerced> {


protected async compileBackPropagation(): Promise<void> {
const train = this.data.train;
const train = this.data.trainer;
const backpropInput = this.data.backpropInput;
const backpropOutput = this.data.backpropOutput;
const backpropFit = this.data.backpropFit;
const fitter = this.data.fitter;
const optimizer = this.data.optimizer;

const rawTrainingLabels = this.raw.trainingLabels;
Expand All @@ -329,10 +357,10 @@ export class Dense extends Layer<DenseParamsInput, DenseParamsCoerced> {

backpropOutput.declare(Layer.ERROR_TERM, this.countOutputUnits());
backpropOutput.declare(Dense.WEIGHT_MATRIX, weightDims);
backpropFit.declare(Dense.WEIGHT_ERROR, weightDims);
fitter.declare(Dense.WEIGHT_ERROR, weightDims);

if (this.params.bias) {
backpropFit.declare(Dense.BIAS_ERROR, 1);
fitter.declare(Dense.BIAS_ERROR, 1);
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/nn/layer/layer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ export type LayerParams = Parameters;
export abstract class Layer <TInput extends LayerParams = LayerParams, TCoerced extends TInput = TInput>
extends Parameterized<TInput, TCoerced>
implements GraphEntity {
public static readonly ERROR_TERM: string = 'error';
public static readonly ERROR_TERM: string = 'error-term';

public static readonly TRAINING_LABEL: string = 'train';

Expand Down
5 changes: 3 additions & 2 deletions src/nn/loss/mean-squared-error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ import { Vector } from '../../math';
/**
* Mean Squared Error
* A.k.a. Quadratic Loss
* @link https://ml-cheatsheet.readthedocs.io/en/latest/calculus.html
*/
export class MeanSquaredError extends Loss {
public calculate(yHat: Vector, y: Vector): number {
// sum( ( yHat - y ) ^ 2 ) / y.size
return yHat.sub(y).pow(2).mean();
// 0.5 * sum( ( yHat - y ) ^ 2 ) / y.size
return 0.5 * yHat.sub(y).pow(2).mean();
}


Expand Down
19 changes: 19 additions & 0 deletions src/nn/loss/squared-error.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { Loss } from './loss';
import { Vector } from '../../math';

/**
* Squared Error
*/
export class SquaredError extends Loss {
public calculate(yHat: Vector, y: Vector): number {
// 0.5 * sum( ( yHat - y ) ^ 2 )
return 0.5 * yHat.sub(y).pow(2).sum();
}


public gradient(yHat: Vector, y: Vector): Vector {
return yHat.sub(y).pow(2).mul(0.5);
// return yHat.sub(y);
}
}

0 comments on commit 4ca3420

Please sign in to comment.