Skip to content

Commit

Permalink
[demo] Optimize Sparse NCC (#77)
Browse files Browse the repository at this point in the history
* demo/sparce-ncc: Optimize predict to only look at necessary indices.
* demo/sparse-ncc: Improved gas costs using a struct to store the classInfo.
* demo/load-model: Add more logging for gasUsed.
  • Loading branch information
juharris committed Feb 10, 2020
1 parent 7ac9310 commit 4ff3064
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {Classifier64} from "./Classifier.sol";

/**
* A nearest centroid classifier that uses Euclidean distance to predict the closest centroid based on sparse data sample.
* Data must be sorted indices of features with each feature occurring at most once.
*
* https://en.wikipedia.org/wiki/Nearest_centroid_classifier
*/
Expand All @@ -23,23 +24,47 @@ contract SparseNearestCentroidClassifier is Classifier64 {

uint256 constant public dataCountLimit = 2 ** (256 - 64 - 1);

uint64[][] public centroids;
uint[] public dataCounts;
/**
* Information for a class.
*/
struct ClassInfo {
/**
* The number of samples in the class.
*/
uint64 numSamples;

uint64[] centroid;

/**
* The squared 2-norm of the centroid. Multiplied by (toFloat * toFloat).
*/
uint squaredMagnitude;
}

ClassInfo[] public classInfos;

constructor(
string[] memory _classifications,
uint64[][] memory _centroids,
uint[] memory _dataCounts)
uint64[][] memory centroids,
uint64[] memory dataCounts)
Classifier64(_classifications) public {
require(_centroids.length == _classifications.length, "The number of centroids and classifications must be the same.");
require(centroids.length == _classifications.length, "The number of centroids and classifications must be the same.");
require(_classifications.length > 0, "At least one class is required.");
require(_classifications.length < 2 ** 64, "Too many classes given.");
centroids = _centroids;
dataCounts = _dataCounts;
uint dimensions = centroids[0].length;
require(dimensions < 2 ** 63, "First centroid is too long.");
for (uint i = 1; i < centroids.length; ++i) {
require(centroids[i].length == dimensions, "Inconsistent number of dimensions.");
for (uint i = 0; i < centroids.length; ++i) {
uint64[] memory centroid = centroids[i];
require(centroid.length == dimensions, "Inconsistent number of dimensions.");
classInfos.push(ClassInfo(dataCounts[i], centroid, _getSquaredMagnitude(centroid)));
}
}

function _getSquaredMagnitude(uint64[] memory vector) internal pure returns (uint squaredMagnitude) {
squaredMagnitude = 0;
for (uint i = 0; i < vector.length; ++i) {
// Should be safe multiplication and addition because vector entries should be small.
squaredMagnitude += vector[i] * vector[i];
}
}

Expand All @@ -50,20 +75,25 @@ contract SparseNearestCentroidClassifier is Classifier64 {
* @param classification The class to add the extension to.
*/
function extendCentroid(uint64[] memory extension, uint64 classification) public onlyOwner {
require(classification < centroids.length, "This classification has not been added yet.");
require(centroids[classification].length + extension.length < 2 ** 63, "Centroid would be too long.");
require(classification < classInfos.length, "This classification has not been added yet.");
ClassInfo storage classInfo = classInfos[classification];
uint64[] storage centroid = classInfo.centroid;
require(centroid.length + extension.length < 2 ** 63, "Centroid would be too long.");
uint squaredMagnitude = classInfo.squaredMagnitude;
for (uint i = 0; i < extension.length; ++i) {
centroids[classification].push(extension[i]);
centroid.push(extension[i]);
// Should be safe multiplication and addition because vector entries should be small.
squaredMagnitude += extension[i] * extension[i];
}
classInfo.squaredMagnitude = squaredMagnitude;
}

function addClass(uint64[] memory centroid, string memory classification, uint dataCount) public onlyOwner {
function addClass(uint64[] memory centroid, string memory classification, uint64 dataCount) public onlyOwner {
require(classifications.length + 1 < 2 ** 64, "There are too many classes already.");
require(centroid.length == centroids[0].length, "Data doesn't have the correct number of dimensions.");
require(centroid.length == classInfos[0].centroid.length, "Data doesn't have the correct number of dimensions.");
require(dataCount < dataCountLimit, "Data count is too large.");
classifications.push(classification);
centroids.push(centroid);
dataCounts.push(dataCount);
classInfos.push(ClassInfo(dataCount, centroid, _getSquaredMagnitude(centroid)));
emit AddClass(classification, classifications.length - 1);
}

Expand All @@ -77,35 +107,25 @@ contract SparseNearestCentroidClassifier is Classifier64 {

uint minDistance = UINT256_MAX;
bestClass = 0;
for (uint64 currentClass = 0; currentClass < centroids.length; ++currentClass) {
uint distance = 0;
uint dataIndex = 0;
// This can be optimized by storing magnitudes, updating them on updates,
// and for predicting: only use necessary dimensions to find difference from magnitude of the centroid.
for (uint64 j = 0; j < centroids[currentClass].length; ++j) {
if (dataIndex < data.length && data[dataIndex] == int64(j)) {
// Feature is present.
// Safe calculation because both values are int64.
int256 diff = toFloat;
diff -= centroids[currentClass][j];
diff *= diff;
// Convert back to our float representation.
diff /= toFloat;
distance = distance.add(uint256(diff));
++dataIndex;

if (distance >= minDistance) {
break;
}
} else {
// Feature is not present.
uint256 diff = centroids[currentClass][j];
diff *= diff;
// Convert back to our float representation.
diff /= toFloat;
distance = distance.add(diff);
}
for (uint64 currentClass = 0; currentClass < classInfos.length; ++currentClass) {
uint64[] storage centroid = classInfos[currentClass].centroid;
// Default distance for empty data is `squaredMagnitudes[currentClass]`.
// Well use that as a base and update it.
// distance = squaredMagnitudes[currentClass]
// For each j:
// distance = distance - centroids[currentClass][j]^2 + (centroids[currentClass][j] - toFloat)^2
// = distance - centroids[currentClass][j]^2 + centroids[currentClass][j]^2 - 2 * centroids[currentClass][j] * toFloat + toFloat^2
// = distance - 2 * centroids[currentClass][j] * toFloat + toFloat^2
// = distance + toFloat * (-2 * centroids[currentClass][j] + toFloat)
int distanceUpdate = 0;

for (uint dataIndex = 0; dataIndex < data.length; ++dataIndex) {
// Should be safe since data is not very long.
distanceUpdate += int(toFloat) - 2 * centroid[uint(data[dataIndex])];
}

uint distance = uint(int(classInfos[currentClass].squaredMagnitude) + distanceUpdate * toFloat);

if (distance < minDistance) {
minDistance = distance;
bestClass = currentClass;
Expand All @@ -114,29 +134,53 @@ contract SparseNearestCentroidClassifier is Classifier64 {
}

function update(int64[] memory data, uint64 classification) public onlyOwner {
require(classification < centroids.length, "This classification has not been added yet.");
uint64[] memory centroid = centroids[classification];
uint n = dataCounts[classification];
uint newN;
require(classification < classInfos.length, "This classification has not been added yet.");
ClassInfo storage classInfo = classInfos[classification];
uint64[] memory centroid = classInfo.centroid;
uint n = classInfos[classification].numSamples;
uint64 newN;
// Keep n small enough for multiplication.
if (n >= dataCountLimit) {
newN = dataCounts[classification];
newN = classInfo.numSamples;
} else {
newN = dataCounts[classification] + 1;
dataCounts[classification] = newN;
newN = classInfo.numSamples + 1;
classInfo.numSamples = newN;
}

// Could try to optimize further by not updating zero entries in the centroid that are not in the data.
// This wouldn't help much for our current examples (IMDB + Fake News) since most features occur in all classes.

// Update centroid using moving average calculation.
uint squaredMagnitude = 0;
uint dataIndex = 0;
for (uint64 featureIndex = 0; featureIndex < centroids[classification].length; ++featureIndex) {
for (uint64 featureIndex = 0; featureIndex < centroid.length; ++featureIndex) {
if (dataIndex < data.length && data[dataIndex] == int64(featureIndex)) {
// Feature is present.
centroids[classification][featureIndex] = uint64((int(centroid[featureIndex]) * int(n) + toFloat) / int(newN));
uint64 v = uint64((n * centroid[featureIndex] + toFloat) / newN);
centroid[featureIndex] = v;
squaredMagnitude = squaredMagnitude.add(uint(v) * v);
++dataIndex;
} else {
// Feature is not present.
centroids[classification][featureIndex] = uint64((int(centroid[featureIndex]) * int(n)) / int(newN));
uint64 v = uint64((n * centroid[featureIndex]) / newN);
centroid[featureIndex] = v;
squaredMagnitude = squaredMagnitude.add(uint(v) * v);
}
}
classInfo.centroid = centroid;
classInfo.squaredMagnitude = squaredMagnitude;
}

// Useful methods to view the underlying data:
function getNumSamples(uint classIndex) public view returns (uint64) {
return classInfos[classIndex].numSamples;
}

function getCentroidValue(uint classIndex, uint featureIndex) public view returns (uint64) {
return classInfos[classIndex].centroid[featureIndex];
}

function getSquaredMagnitude(uint classIndex) public view returns (uint) {
return classInfos[classIndex].squaredMagnitude;
}
}
32 changes: 16 additions & 16 deletions demo/client/src/ml-models/load-model-node.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@ const NearestCentroidClassifier = artifacts.require("./classification/NearestCen
const SparseNearestCentroidClassifier = artifacts.require("./classification/SparseNearestCentroidClassifier")
const SparsePerceptron = artifacts.require("./classification/SparsePerceptron")

const { convertData, convertNum } = require('../../src/float-utils-node');
const { convertData, convertNum } = require('../../src/float-utils-node')

const _toFloat = 1E9

async function loadDensePerceptron(model, web3, toFloat) {
let gasUsed = 0
const weightChunkSize = 450
const { classifications } = model
const weights = convertData(model.weights, web3, toFloat);
const intercept = convertNum(model.bias, web3, toFloat);
const learningRate = convertNum(1, web3, toFloat);
console.log(` Deploying Dense Perceptron classifier with first ${Math.min(weights.length, weightChunkSize)} weights.`);
const weights = convertData(model.weights, web3, toFloat)
const intercept = convertNum(model.bias, web3, toFloat)
const learningRate = convertNum(1, web3, toFloat)
console.log(` Deploying Dense Perceptron classifier with first ${Math.min(weights.length, weightChunkSize)} weights.`)
const classifierContract = await DensePerceptron.new(classifications, weights.slice(0, weightChunkSize), intercept, learningRate)
gasUsed += (await web3.eth.getTransactionReceipt(classifierContract.transactionHash)).gasUsed

// Add remaining weights.
for (let i = weightChunkSize; i < weights.length; i += weightChunkSize) {
console.log(` Adding classifier weights [${i}, ${Math.min(i + weightChunkSize, weights.length)}).`);
const r = await classifierContract.initializeWeights(weights.slice(i, i + weightChunkSize))
console.debug(` Added classifier weights [${i}, ${Math.min(i + weightChunkSize, weights.length)}). gasUsed: ${r.receipt.gasUsed}`)
gasUsed += r.receipt.gasUsed
}

Expand All @@ -37,20 +37,20 @@ async function loadDensePerceptron(model, web3, toFloat) {
}

async function loadSparsePerceptron(model, web3, toFloat) {
let gasUsed = 0
const weightChunkSize = 300
const { classifications } = model
const weights = convertData(model.weights, web3, toFloat)
const intercept = convertNum(model.bias, web3, toFloat)
const learningRate = convertNum(1, web3, toFloat)
console.log(` Deploying Sparse Perceptron classifier with first ${Math.min(weights.length, weightChunkSize)} weights.`)
console.log(` Deploying Sparse Perceptron classifier with first ${Math.min(weights.length, weightChunkSize)} weights...`)
const classifierContract = await SparsePerceptron.new(classifications, weights.slice(0, weightChunkSize), intercept, learningRate)
gasUsed += (await web3.eth.getTransactionReceipt(classifierContract.transactionHash)).gasUsed
let gasUsed = (await web3.eth.getTransactionReceipt(classifierContract.transactionHash)).gasUsed
console.log(` Deployed Sparse Perceptron classifier with first ${Math.min(weights.length, weightChunkSize)} weights. gasUsed: ${gasUsed}`)

// Add remaining weights.
for (let i = weightChunkSize; i < weights.length; i += weightChunkSize) {
console.log(` Adding classifier weights [${i}, ${Math.min(i + weightChunkSize, weights.length)}).`)
const r = await classifierContract.initializeWeights(i, weights.slice(i, i + weightChunkSize))
console.debug(` Added classifier weights [${i}, ${Math.min(i + weightChunkSize, weights.length)}) gasUsed: ${r.receipt.gasUsed}`)
gasUsed += r.receipt.gasUsed
}

Expand Down Expand Up @@ -94,7 +94,7 @@ async function loadNearestCentroidClassifier(model, web3, toFloat) {
addClassPromises.push(classifierContract.addClass(centroids[i], classifications[i], dataCounts[i]))
}
return Promise.all(addClassPromises).then(responses => {
console.log(" All classes added.")
console.debug(" All classes added.")
for (const r of responses) {
gasUsed += r.receipt.gasUsed
}
Expand Down Expand Up @@ -139,7 +139,7 @@ exports.loadSparseNearestCentroidClassifier = async function (model, web3, toFlo
addClassPromises.push(classifierContract.addClass(
centroids[i].slice(0, initialChunkSize), classifications[i], dataCounts[i]
).then(r => {
console.log(` Added class ${i}`)
console.debug(` Added class ${i}. gasUsed: ${r.receipt.gasUsed}`)
return r
}))
}
Expand All @@ -155,7 +155,7 @@ exports.loadSparseNearestCentroidClassifier = async function (model, web3, toFlo
for (let j = initialChunkSize; j < centroids[classification].length; j += chunkSize) {
const r = await classifierContract.extendCentroid(
centroids[classification].slice(j, j + chunkSize), classification)
console.log(` Added dimensions [${j}, ${Math.min(j + chunkSize, centroids[classification].length)}) for class ${classification}`)
console.debug(` Added dimensions [${j}, ${Math.min(j + chunkSize, centroids[classification].length)}) for class ${classification}. gasUsed: ${r.receipt.gasUsed}`)
gasUsed += r.receipt.gasUsed
}
}
Expand All @@ -182,7 +182,7 @@ async function loadNaiveBayes(model, web3, toFloat) {
addClassPromises.push(classifierContract.addClass(
classCounts[i], featureCounts[i].slice(0, initialFeatureChunkSize), classifications[i]
).then(r => {
console.log(` Added class ${i}`)
console.debug(` Added class ${i}. gasUsed: ${r.receipt.gasUsed}`)
return r
}))
}
Expand All @@ -197,11 +197,11 @@ async function loadNaiveBayes(model, web3, toFloat) {
const r = await classifierContract.initializeCounts(
featureCounts[classification].slice(j, j + featureChunkSize), classification
)
console.log(` Added features [${j}, ${Math.min(j + featureChunkSize, featureCounts[classification].length)}) for class ${classification}`)
console.debug(` Added features [${j}, ${Math.min(j + featureChunkSize, featureCounts[classification].length)}) for class ${classification}. gasUsed: ${r.receipt.gasUsed}`)
gasUsed += r.receipt.gasUsed
}
}
console.log(` Deployed all Naive Bayes classifier classes. gasUsed: ${gasUsed}.`)
console.debug(` Deployed all Naive Bayes classifier classes. gasUsed: ${gasUsed}.`)
return {
classifierContract,
gasUsed,
Expand Down
17 changes: 10 additions & 7 deletions demo/client/test/contracts/check-gas-costs.js
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,16 @@ contract('CheckGasUsage', function (accounts) {
// },
// {
// path: `${__dirname}/../../../../simulation/saved_runs/1580943847-imdb-nb-model.json`,
// data: [1, 2, 3, 14, 25, 36, 57, 88, 299, 310, 411, 512, 613, 714, 815],
// data: [1, 2, 3, 14, 15, 26, 37, 48, 59, 110, 111, 112, 213, 314, 515, 616, 717, 818, 919, 920],
// },
// {
// path: `${__dirname}/../../../../simulation/saved_runs/1580945025-imdb-ncc-model.json`,
// data: [1, 2, 3, 14, 25, 36, 57, 88, 299, 310, 411, 512, 613, 714, 815],
// data: [1, 2, 3, 14, 15, 26, 37, 48, 59, 110, 111, 112, 213, 314, 515, 616, 717, 818, 919, 920],
// },
// {
// path: `${__dirname}/../../../../simulation/saved_runs/1580945565-imdb-perceptron-model.json`,
// data: [1, 2, 3, 14, 15, 26, 37, 48, 59, 110, 111, 112, 213, 314, 515, 616, 717, 818, 919, 920],
// },
{
path: `${__dirname}/../../../../simulation/saved_runs/1580945565-imdb-perceptron-model.json`,
data: [1, 2, 3, 14, 25, 36, 57, 88, 299, 310, 411, 512, 613, 714, 815],
},
]
const gasUsages = []
for (const model of models) {
Expand Down Expand Up @@ -149,10 +149,13 @@ contract('CheckGasUsage', function (accounts) {
r = await mainInterface.refund(data, predictedClassification, addedTime)
gasUsage['refund'] = r.receipt.gasUsed
console.log(`Refund gas used: ${r.receipt.gasUsed}`)

// Report
// Someone else adds bad data.
console.debug(" Adding currently incorrect data using another account...")
r = await mainInterface.addData(data, 1 - predictedClassification, { from: accounts[1], value: 1E17 })
console.log(`Adding data (was incorrect) gas used: ${r.receipt.gasUsed}`)
gasUsage['addData (was incorrect)'] = r.receipt.gasUsed
e = r.logs.filter(e => e.event == 'AddData')[0]
addedTime = e.args.t;
r = await mainInterface.report(data, 1 - predictedClassification, addedTime, accounts[1])
Expand Down
Loading

0 comments on commit 4ff3064

Please sign in to comment.