Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[demo] Optimize Sparse NCC #77

Merged
merged 5 commits into from
Feb 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {Classifier64} from "./Classifier.sol";

/**
* A nearest centroid classifier that uses Euclidean distance to predict the closest centroid based on sparse data sample.
* Data must be sorted indices of features with each feature occurring at most once.
*
* https://en.wikipedia.org/wiki/Nearest_centroid_classifier
*/
Expand All @@ -23,23 +24,47 @@ contract SparseNearestCentroidClassifier is Classifier64 {

uint256 constant public dataCountLimit = 2 ** (256 - 64 - 1);

uint64[][] public centroids;
uint[] public dataCounts;
/**
* Information for a class.
*/
struct ClassInfo {
/**
* The number of samples in the class.
*/
uint64 numSamples;

uint64[] centroid;

/**
* The squared 2-norm of the centroid. Multiplied by (toFloat * toFloat).
*/
uint squaredMagnitude;
}

ClassInfo[] public classInfos;

constructor(
string[] memory _classifications,
uint64[][] memory _centroids,
uint[] memory _dataCounts)
uint64[][] memory centroids,
uint64[] memory dataCounts)
Classifier64(_classifications) public {
require(_centroids.length == _classifications.length, "The number of centroids and classifications must be the same.");
require(centroids.length == _classifications.length, "The number of centroids and classifications must be the same.");
require(_classifications.length > 0, "At least one class is required.");
require(_classifications.length < 2 ** 64, "Too many classes given.");
centroids = _centroids;
dataCounts = _dataCounts;
uint dimensions = centroids[0].length;
require(dimensions < 2 ** 63, "First centroid is too long.");
for (uint i = 1; i < centroids.length; ++i) {
require(centroids[i].length == dimensions, "Inconsistent number of dimensions.");
for (uint i = 0; i < centroids.length; ++i) {
uint64[] memory centroid = centroids[i];
require(centroid.length == dimensions, "Inconsistent number of dimensions.");
classInfos.push(ClassInfo(dataCounts[i], centroid, _getSquaredMagnitude(centroid)));
}
}

function _getSquaredMagnitude(uint64[] memory vector) internal pure returns (uint squaredMagnitude) {
squaredMagnitude = 0;
for (uint i = 0; i < vector.length; ++i) {
// Should be safe multiplication and addition because vector entries should be small.
squaredMagnitude += vector[i] * vector[i];
}
}

Expand All @@ -50,20 +75,25 @@ contract SparseNearestCentroidClassifier is Classifier64 {
* @param classification The class to add the extension to.
*/
function extendCentroid(uint64[] memory extension, uint64 classification) public onlyOwner {
require(classification < centroids.length, "This classification has not been added yet.");
require(centroids[classification].length + extension.length < 2 ** 63, "Centroid would be too long.");
require(classification < classInfos.length, "This classification has not been added yet.");
ClassInfo storage classInfo = classInfos[classification];
uint64[] storage centroid = classInfo.centroid;
require(centroid.length + extension.length < 2 ** 63, "Centroid would be too long.");
uint squaredMagnitude = classInfo.squaredMagnitude;
for (uint i = 0; i < extension.length; ++i) {
centroids[classification].push(extension[i]);
centroid.push(extension[i]);
// Should be safe multiplication and addition because vector entries should be small.
squaredMagnitude += extension[i] * extension[i];
}
classInfo.squaredMagnitude = squaredMagnitude;
}

function addClass(uint64[] memory centroid, string memory classification, uint dataCount) public onlyOwner {
function addClass(uint64[] memory centroid, string memory classification, uint64 dataCount) public onlyOwner {
require(classifications.length + 1 < 2 ** 64, "There are too many classes already.");
require(centroid.length == centroids[0].length, "Data doesn't have the correct number of dimensions.");
require(centroid.length == classInfos[0].centroid.length, "Data doesn't have the correct number of dimensions.");
require(dataCount < dataCountLimit, "Data count is too large.");
classifications.push(classification);
centroids.push(centroid);
dataCounts.push(dataCount);
classInfos.push(ClassInfo(dataCount, centroid, _getSquaredMagnitude(centroid)));
emit AddClass(classification, classifications.length - 1);
}

Expand All @@ -77,35 +107,25 @@ contract SparseNearestCentroidClassifier is Classifier64 {

uint minDistance = UINT256_MAX;
bestClass = 0;
for (uint64 currentClass = 0; currentClass < centroids.length; ++currentClass) {
uint distance = 0;
uint dataIndex = 0;
// This can be optimized by storing magnitudes, updating them on updates,
// and for predicting: only use necessary dimensions to find difference from magnitude of the centroid.
for (uint64 j = 0; j < centroids[currentClass].length; ++j) {
if (dataIndex < data.length && data[dataIndex] == int64(j)) {
// Feature is present.
// Safe calculation because both values are int64.
int256 diff = toFloat;
diff -= centroids[currentClass][j];
diff *= diff;
// Convert back to our float representation.
diff /= toFloat;
distance = distance.add(uint256(diff));
++dataIndex;

if (distance >= minDistance) {
break;
}
} else {
// Feature is not present.
uint256 diff = centroids[currentClass][j];
diff *= diff;
// Convert back to our float representation.
diff /= toFloat;
distance = distance.add(diff);
}
for (uint64 currentClass = 0; currentClass < classInfos.length; ++currentClass) {
uint64[] storage centroid = classInfos[currentClass].centroid;
// Default distance for empty data is `squaredMagnitudes[currentClass]`.
// Well use that as a base and update it.
// distance = squaredMagnitudes[currentClass]
// For each j:
// distance = distance - centroids[currentClass][j]^2 + (centroids[currentClass][j] - toFloat)^2
// = distance - centroids[currentClass][j]^2 + centroids[currentClass][j]^2 - 2 * centroids[currentClass][j] * toFloat + toFloat^2
// = distance - 2 * centroids[currentClass][j] * toFloat + toFloat^2
// = distance + toFloat * (-2 * centroids[currentClass][j] + toFloat)
int distanceUpdate = 0;

for (uint dataIndex = 0; dataIndex < data.length; ++dataIndex) {
// Should be safe since data is not very long.
distanceUpdate += int(toFloat) - 2 * centroid[uint(data[dataIndex])];
}

uint distance = uint(int(classInfos[currentClass].squaredMagnitude) + distanceUpdate * toFloat);

if (distance < minDistance) {
minDistance = distance;
bestClass = currentClass;
Expand All @@ -114,29 +134,53 @@ contract SparseNearestCentroidClassifier is Classifier64 {
}

function update(int64[] memory data, uint64 classification) public onlyOwner {
require(classification < centroids.length, "This classification has not been added yet.");
uint64[] memory centroid = centroids[classification];
uint n = dataCounts[classification];
uint newN;
require(classification < classInfos.length, "This classification has not been added yet.");
ClassInfo storage classInfo = classInfos[classification];
uint64[] memory centroid = classInfo.centroid;
uint n = classInfos[classification].numSamples;
uint64 newN;
// Keep n small enough for multiplication.
if (n >= dataCountLimit) {
newN = dataCounts[classification];
newN = classInfo.numSamples;
} else {
newN = dataCounts[classification] + 1;
dataCounts[classification] = newN;
newN = classInfo.numSamples + 1;
classInfo.numSamples = newN;
}

// Could try to optimize further by not updating zero entries in the centroid that are not in the data.
// This wouldn't help much for our current examples (IMDB + Fake News) since most features occur in all classes.

// Update centroid using moving average calculation.
uint squaredMagnitude = 0;
uint dataIndex = 0;
for (uint64 featureIndex = 0; featureIndex < centroids[classification].length; ++featureIndex) {
for (uint64 featureIndex = 0; featureIndex < centroid.length; ++featureIndex) {
if (dataIndex < data.length && data[dataIndex] == int64(featureIndex)) {
// Feature is present.
centroids[classification][featureIndex] = uint64((int(centroid[featureIndex]) * int(n) + toFloat) / int(newN));
uint64 v = uint64((n * centroid[featureIndex] + toFloat) / newN);
centroid[featureIndex] = v;
squaredMagnitude = squaredMagnitude.add(uint(v) * v);
++dataIndex;
} else {
// Feature is not present.
centroids[classification][featureIndex] = uint64((int(centroid[featureIndex]) * int(n)) / int(newN));
uint64 v = uint64((n * centroid[featureIndex]) / newN);
centroid[featureIndex] = v;
squaredMagnitude = squaredMagnitude.add(uint(v) * v);
}
}
classInfo.centroid = centroid;
classInfo.squaredMagnitude = squaredMagnitude;
}

// Useful methods to view the underlying data:
function getNumSamples(uint classIndex) public view returns (uint64) {
return classInfos[classIndex].numSamples;
}

function getCentroidValue(uint classIndex, uint featureIndex) public view returns (uint64) {
return classInfos[classIndex].centroid[featureIndex];
}

function getSquaredMagnitude(uint classIndex) public view returns (uint) {
return classInfos[classIndex].squaredMagnitude;
}
}
32 changes: 16 additions & 16 deletions demo/client/src/ml-models/load-model-node.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,25 @@ const NearestCentroidClassifier = artifacts.require("./classification/NearestCen
const SparseNearestCentroidClassifier = artifacts.require("./classification/SparseNearestCentroidClassifier")
const SparsePerceptron = artifacts.require("./classification/SparsePerceptron")

const { convertData, convertNum } = require('../../src/float-utils-node');
const { convertData, convertNum } = require('../../src/float-utils-node')

const _toFloat = 1E9

async function loadDensePerceptron(model, web3, toFloat) {
let gasUsed = 0
const weightChunkSize = 450
const { classifications } = model
const weights = convertData(model.weights, web3, toFloat);
const intercept = convertNum(model.bias, web3, toFloat);
const learningRate = convertNum(1, web3, toFloat);
console.log(` Deploying Dense Perceptron classifier with first ${Math.min(weights.length, weightChunkSize)} weights.`);
const weights = convertData(model.weights, web3, toFloat)
const intercept = convertNum(model.bias, web3, toFloat)
const learningRate = convertNum(1, web3, toFloat)
console.log(` Deploying Dense Perceptron classifier with first ${Math.min(weights.length, weightChunkSize)} weights.`)
const classifierContract = await DensePerceptron.new(classifications, weights.slice(0, weightChunkSize), intercept, learningRate)
gasUsed += (await web3.eth.getTransactionReceipt(classifierContract.transactionHash)).gasUsed

// Add remaining weights.
for (let i = weightChunkSize; i < weights.length; i += weightChunkSize) {
console.log(` Adding classifier weights [${i}, ${Math.min(i + weightChunkSize, weights.length)}).`);
const r = await classifierContract.initializeWeights(weights.slice(i, i + weightChunkSize))
console.debug(` Added classifier weights [${i}, ${Math.min(i + weightChunkSize, weights.length)}). gasUsed: ${r.receipt.gasUsed}`)
gasUsed += r.receipt.gasUsed
}

Expand All @@ -37,20 +37,20 @@ async function loadDensePerceptron(model, web3, toFloat) {
}

async function loadSparsePerceptron(model, web3, toFloat) {
let gasUsed = 0
const weightChunkSize = 300
const { classifications } = model
const weights = convertData(model.weights, web3, toFloat)
const intercept = convertNum(model.bias, web3, toFloat)
const learningRate = convertNum(1, web3, toFloat)
console.log(` Deploying Sparse Perceptron classifier with first ${Math.min(weights.length, weightChunkSize)} weights.`)
console.log(` Deploying Sparse Perceptron classifier with first ${Math.min(weights.length, weightChunkSize)} weights...`)
const classifierContract = await SparsePerceptron.new(classifications, weights.slice(0, weightChunkSize), intercept, learningRate)
gasUsed += (await web3.eth.getTransactionReceipt(classifierContract.transactionHash)).gasUsed
let gasUsed = (await web3.eth.getTransactionReceipt(classifierContract.transactionHash)).gasUsed
console.log(` Deployed Sparse Perceptron classifier with first ${Math.min(weights.length, weightChunkSize)} weights. gasUsed: ${gasUsed}`)

// Add remaining weights.
for (let i = weightChunkSize; i < weights.length; i += weightChunkSize) {
console.log(` Adding classifier weights [${i}, ${Math.min(i + weightChunkSize, weights.length)}).`)
const r = await classifierContract.initializeWeights(i, weights.slice(i, i + weightChunkSize))
console.debug(` Added classifier weights [${i}, ${Math.min(i + weightChunkSize, weights.length)}) gasUsed: ${r.receipt.gasUsed}`)
gasUsed += r.receipt.gasUsed
}

Expand Down Expand Up @@ -94,7 +94,7 @@ async function loadNearestCentroidClassifier(model, web3, toFloat) {
addClassPromises.push(classifierContract.addClass(centroids[i], classifications[i], dataCounts[i]))
}
return Promise.all(addClassPromises).then(responses => {
console.log(" All classes added.")
console.debug(" All classes added.")
for (const r of responses) {
gasUsed += r.receipt.gasUsed
}
Expand Down Expand Up @@ -139,7 +139,7 @@ exports.loadSparseNearestCentroidClassifier = async function (model, web3, toFlo
addClassPromises.push(classifierContract.addClass(
centroids[i].slice(0, initialChunkSize), classifications[i], dataCounts[i]
).then(r => {
console.log(` Added class ${i}`)
console.debug(` Added class ${i}. gasUsed: ${r.receipt.gasUsed}`)
return r
}))
}
Expand All @@ -155,7 +155,7 @@ exports.loadSparseNearestCentroidClassifier = async function (model, web3, toFlo
for (let j = initialChunkSize; j < centroids[classification].length; j += chunkSize) {
const r = await classifierContract.extendCentroid(
centroids[classification].slice(j, j + chunkSize), classification)
console.log(` Added dimensions [${j}, ${Math.min(j + chunkSize, centroids[classification].length)}) for class ${classification}`)
console.debug(` Added dimensions [${j}, ${Math.min(j + chunkSize, centroids[classification].length)}) for class ${classification}. gasUsed: ${r.receipt.gasUsed}`)
gasUsed += r.receipt.gasUsed
}
}
Expand All @@ -182,7 +182,7 @@ async function loadNaiveBayes(model, web3, toFloat) {
addClassPromises.push(classifierContract.addClass(
classCounts[i], featureCounts[i].slice(0, initialFeatureChunkSize), classifications[i]
).then(r => {
console.log(` Added class ${i}`)
console.debug(` Added class ${i}. gasUsed: ${r.receipt.gasUsed}`)
return r
}))
}
Expand All @@ -197,11 +197,11 @@ async function loadNaiveBayes(model, web3, toFloat) {
const r = await classifierContract.initializeCounts(
featureCounts[classification].slice(j, j + featureChunkSize), classification
)
console.log(` Added features [${j}, ${Math.min(j + featureChunkSize, featureCounts[classification].length)}) for class ${classification}`)
console.debug(` Added features [${j}, ${Math.min(j + featureChunkSize, featureCounts[classification].length)}) for class ${classification}. gasUsed: ${r.receipt.gasUsed}`)
gasUsed += r.receipt.gasUsed
}
}
console.log(` Deployed all Naive Bayes classifier classes. gasUsed: ${gasUsed}.`)
console.debug(` Deployed all Naive Bayes classifier classes. gasUsed: ${gasUsed}.`)
return {
classifierContract,
gasUsed,
Expand Down
17 changes: 10 additions & 7 deletions demo/client/test/contracts/check-gas-costs.js
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,16 @@ contract('CheckGasUsage', function (accounts) {
// },
// {
// path: `${__dirname}/../../../../simulation/saved_runs/1580943847-imdb-nb-model.json`,
// data: [1, 2, 3, 14, 25, 36, 57, 88, 299, 310, 411, 512, 613, 714, 815],
// data: [1, 2, 3, 14, 15, 26, 37, 48, 59, 110, 111, 112, 213, 314, 515, 616, 717, 818, 919, 920],
// },
// {
// path: `${__dirname}/../../../../simulation/saved_runs/1580945025-imdb-ncc-model.json`,
// data: [1, 2, 3, 14, 25, 36, 57, 88, 299, 310, 411, 512, 613, 714, 815],
// data: [1, 2, 3, 14, 15, 26, 37, 48, 59, 110, 111, 112, 213, 314, 515, 616, 717, 818, 919, 920],
// },
// {
// path: `${__dirname}/../../../../simulation/saved_runs/1580945565-imdb-perceptron-model.json`,
// data: [1, 2, 3, 14, 15, 26, 37, 48, 59, 110, 111, 112, 213, 314, 515, 616, 717, 818, 919, 920],
// },
{
path: `${__dirname}/../../../../simulation/saved_runs/1580945565-imdb-perceptron-model.json`,
data: [1, 2, 3, 14, 25, 36, 57, 88, 299, 310, 411, 512, 613, 714, 815],
},
]
const gasUsages = []
for (const model of models) {
Expand Down Expand Up @@ -149,10 +149,13 @@ contract('CheckGasUsage', function (accounts) {
r = await mainInterface.refund(data, predictedClassification, addedTime)
gasUsage['refund'] = r.receipt.gasUsed
console.log(`Refund gas used: ${r.receipt.gasUsed}`)

// Report
// Someone else adds bad data.
console.debug(" Adding currently incorrect data using another account...")
r = await mainInterface.addData(data, 1 - predictedClassification, { from: accounts[1], value: 1E17 })
console.log(`Adding data (was incorrect) gas used: ${r.receipt.gasUsed}`)
gasUsage['addData (was incorrect)'] = r.receipt.gasUsed
e = r.logs.filter(e => e.event == 'AddData')[0]
addedTime = e.args.t;
r = await mainInterface.report(data, 1 - predictedClassification, addedTime, accounts[1])
Expand Down
Loading