Skip to content

Commit

Permalink
[demo] Add ability to deploy any of the supported model types (#85)
Browse files Browse the repository at this point in the history
Resolves #82
Improve smoothness. Dismiss some notifications.
Add warning if deploying to main.
  • Loading branch information
juharris committed Mar 10, 2020
1 parent c3efc5a commit ef86f0b
Show file tree
Hide file tree
Showing 16 changed files with 675 additions and 198 deletions.
2 changes: 1 addition & 1 deletion demo/client/migrations/3_deploy_VPA_classifier.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ module.exports = function (deployer) {
return deployer.deploy(NearestCentroidClassifier,
[classifications[0]], [centroids[0]], [dataCounts[0]],
// Block gasLimit by most miners as of May 2019.
{ gas: 8E6 }
{ gas: 8.8E6 }
).then(classifier => {
// Add classes separately to avoid hitting gasLimit.
const addClassPromises = [];
Expand Down
244 changes: 82 additions & 162 deletions demo/client/src/components/addModel.js

Large diffs are not rendered by default.

7 changes: 3 additions & 4 deletions demo/client/src/containers/modelList.js
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ class ModelList extends React.Component {
}

componentDidMount = async () => {
this.networkType = await getNetworkType()

checkStorages(this.storages).then(permittedStorageTypes => {
permittedStorageTypes = permittedStorageTypes.filter(storageType => storageType !== undefined)
this.setState({ permittedStorageTypes }, this.updateModels)
Expand Down Expand Up @@ -108,17 +106,18 @@ class ModelList extends React.Component {
}, this.updateModels)
}

updateModels() {
async updateModels() {
// TODO Also get valid contracts that the account has already interacted with.
// TODO Filter out models that are not on this network.
const networkType = await getNetworkType()
const limit = 6
Promise.all(this.state.permittedStorageTypes.map(storageType => {
const afterId = this.storageAfterAddress[storageType]
return this.storages[storageType].getModels(afterId, limit).then(response => {
const newModels = response.models
const { remaining } = response
newModels.forEach(model => {
model.restrictContent = !this.validator.isPermitted(this.networkType, model.address)
model.restrictContent = !this.validator.isPermitted(networkType, model.address)
model.metaDataLocation = storageType
})
if (newModels.length > 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,20 @@ contract NearestCentroidClassifier is Classifier64 {
}
}

/**
* Extend the number of dimensions of a centroid.
* Made to be called just after the contract is created and never again.
* @param extension The values to append to a centroid vector.
* @param classification The class to add the extension to.
*/
function extendCentroid(int64[] memory extension, uint64 classification) public onlyOwner {
require(classification < classifications.length, "This classification has not been added yet.");
require(centroids[classification].length + extension.length < 2 ** 63, "Centroid would be too long.");
for (uint i = 0; i < extension.length; ++i) {
centroids[classification].push(extension[i]);
}
}

function addClass(int64[] memory centroid, string memory classification, uint dataCount) public onlyOwner {
require(classifications.length + 1 < 2 ** 64, "There are too many classes already.");
require(centroid.length == centroids[0].length, "Data doesn't have the correct number of dimensions.");
Expand Down Expand Up @@ -114,4 +128,16 @@ contract NearestCentroidClassifier is Classifier64 {
uint offset = uint(toFloat) * 100;
require(oneSquared - offset < _norm && _norm < oneSquared + offset, "The provided data does not have a norm of 1.");
}

// Useful methods to view the underlying data:
// To match the `SparseNearestCentroidClassifier`.
// These methods are not really needed now but they are added in case the implementation of the class
// changes later after some gas cost analysis.
function getNumSamples(uint classIndex) public view returns (uint) {
return dataCounts[classIndex];
}

function getCentroidValue(uint classIndex, uint featureIndex) public view returns (int64) {
return centroids[classIndex][featureIndex];
}
}
2 changes: 1 addition & 1 deletion demo/client/src/float-utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,6 @@ export function convertData(data, web3, toFloat = _toFloat) {
return data.map(num => convertNum(num, web3, toFloat));
}

export function convertToHexData(data, web3, toFloat = _toFloat) {
export function convertDataToHex(data, web3, toFloat = _toFloat) {
return data.map(num => convertToHex(num, web3, toFloat));
}
2 changes: 1 addition & 1 deletion demo/client/src/getWeb3.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import * as _getWeb3 from '@drizzle-utils/get-web3'
import Web3 from "web3" // Only required for custom/fallback provider option.
import Web3 from 'web3' // Only required for custom/fallback provider option.

export async function getWeb3() {
if (window.ethereum) {
Expand Down
178 changes: 178 additions & 0 deletions demo/client/src/ml-models/__tests__/deploy-model.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import assert from 'assert'
import Web3 from 'web3'
import { convertNum } from '../../float-utils'
import { CentroidInfo, ModelDeployer, NearestCentroidModel } from '../deploy-model'

declare const web3: Web3

function assertEqualNumbers(actual: any, expected: any, message?: string | Error): void {
if (web3.utils.isBN(actual)) {
if (web3.utils.isBN(expected)) {
if (message === undefined) {
message = `actual: ${actual} (${typeof actual})\nexpected: ${expected} (${typeof expected})`
}
return assert(actual.eq(expected), message)
} else {
const expectedBN = web3.utils.toBN(expected)
if (message === undefined) {
message = `actual: ${actual} (${typeof actual})\nexpected: ${expected} (${typeof expected}) => BN: ${expectedBN}`
}
return assert(actual.eq(expectedBN), message)
}
} else if (web3.utils.isBN(expected)) {
const actualBN = web3.utils.toBN(actual)
if (message === undefined) {
message = `actual: ${actual} (${typeof actual}) => BN: ${actualBN}\nexpected: ${expected} (${typeof expected})`
}
return assert(actualBN.eq(expected), message)
} else {
if (typeof actual === 'string') {
actual = parseInt(actual)
}
return assert.equal(actual, expected, message)
}
}

describe("ModelDeployer", () => {
let account: string
const deployer = new ModelDeployer(web3)

beforeAll(async () => {
const accounts = await web3.eth.getAccounts()
// Pick a random account between 2 and 9 since 0 and 1 are usually used in the browser.
account = accounts[2 + Math.min(Math.floor(Math.random() * 8), 7)]
})

it("should deploy Naive Bayes", async () => {
const model = {
classifications: [
"A",
"B"
],
classCounts: [
2,
3
],
featureCounts: [
[[0, 2], [1, 1]],
[[1, 3], [2, 2]],
],
totalNumFeatures: 9,
smoothingFactor: 1.0,
type: "naive bayes"
}
const m = await deployer.deployModel(
model,
{
account,
})

for (let i = 0; i < model.classifications.length; ++i) {
assert.equal(await m.methods.classifications(i).call(), model.classifications[i])
assertEqualNumbers(await m.methods.getNumSamples(i).call(), model.classCounts[i])
for (const [featureIndex, count] of model.featureCounts[i]) {
assertEqualNumbers(await m.methods.getFeatureCount(i, featureIndex).call(), count)
}
}
assertEqualNumbers(await m.methods.getClassTotalFeatureCount(0).call(), 3)
assertEqualNumbers(await m.methods.getClassTotalFeatureCount(1).call(), 5)
})

it("should deploy dense Nearest Centroid", async () => {
const model = new NearestCentroidModel(
'dense nearest centroid classifier',
{
"AA": new CentroidInfo([-1, -1], 2),
"BB": new CentroidInfo([+1, +1], 2),
}
)
const m = await deployer.deployModel(
model,
{
account,
})

let i = -1
for (let [classification, centroidInfo] of Object.entries(model.intents)) {
++i
assert.equal(await m.methods.classifications(i).call(), classification)
assertEqualNumbers(await m.methods.getNumSamples(i).call(), centroidInfo.dataCount)
for (let j = 0; j < centroidInfo.centroid.length; ++j) {
assertEqualNumbers(await m.methods.getCentroidValue(i, j).call(), convertNum(centroidInfo.centroid[j], web3))
}
}
})

it("should deploy sparse Nearest Centroid", async () => {
const model = new NearestCentroidModel(
'sparse nearest centroid classifier',
{
"AA": new CentroidInfo([0, +1], 2),
"BB": new CentroidInfo([+1, 0], 2),
}
)
const m = await deployer.deployModel(
model,
{
account,
})

let i = -1
for (let [classification, centroidInfo] of Object.entries(model.intents)) {
++i
assert.equal(await m.methods.classifications(i).call(), classification)
assertEqualNumbers(await m.methods.getNumSamples(i).call(), centroidInfo.dataCount)
for (let j = 0; j < centroidInfo.centroid.length; ++j) {
assertEqualNumbers(await m.methods.getCentroidValue(i, j).call(), convertNum(centroidInfo.centroid[j], web3))
}
}
})

it("should deploy dense Perceptron", async () => {
const classifications = ["A", "B"]
const weights = [1, -1]
const intercept = 0
const m = await deployer.deployModel(
{
type: 'dense perceptron',
classifications,
weights,
intercept,
},
{
account,
})

for (let i = 0; i < classifications.length; ++i) {
assert.equal(await m.methods.classifications(i).call(), classifications[i])
}
for (let i = 0; i < weights.length; ++i) {
assertEqualNumbers(await m.methods.weights(i).call(), convertNum(weights[i], web3))
}
assertEqualNumbers(await m.methods.intercept().call(), convertNum(intercept, web3))
})

it("should deploy sparse Perceptron", async () => {
const classifications = ["AA", "BB"]
const weights = [2, -2]
const intercept = 3
const m = await deployer.deployModel(
{
type: 'sparse perceptron',
classifications,
weights,
intercept,
},
{
account,
})

for (let i = 0; i < classifications.length; ++i) {
assert.equal(await m.methods.classifications(i).call(), classifications[i])
}
for (let i = 0; i < weights.length; ++i) {
assertEqualNumbers(await m.methods.weights(i).call(), convertNum(weights[i], web3))
}
assertEqualNumbers(await m.methods.intercept().call(), convertNum(intercept, web3))
})
})
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
const fs = require('fs')

const DensePerceptron = artifacts.require("classification/DensePerceptron")
const DensePerceptron = artifacts.require("./classification/DensePerceptron")
const NaiveBayesClassifier = artifacts.require("./classification/NaiveBayesClassifier")
const NearestCentroidClassifier = artifacts.require("./classification/NearestCentroidClassifier")
const SparseNearestCentroidClassifier = artifacts.require("./classification/SparseNearestCentroidClassifier")
const SparsePerceptron = artifacts.require("./classification/SparsePerceptron")

const { convertData, convertNum } = require('../../src/float-utils-node')
const { convertData, convertNum } = require('../float-utils-node')

const _toFloat = 1E9

async function loadDensePerceptron(model, web3, toFloat) {
async function deployDensePerceptron(model, web3, toFloat) {
let gasUsed = 0
const weightChunkSize = 450
const { classifications } = model
Expand All @@ -36,7 +36,7 @@ async function loadDensePerceptron(model, web3, toFloat) {
}
}

async function loadSparsePerceptron(model, web3, toFloat) {
async function deploySparsePerceptron(model, web3, toFloat) {
const weightChunkSize = 300
const { classifications } = model
const weights = convertData(model.weights, web3, toFloat)
Expand All @@ -62,7 +62,7 @@ async function loadSparsePerceptron(model, web3, toFloat) {
}
}

async function loadNearestCentroidClassifier(model, web3, toFloat) {
async function deployNearestCentroidClassifier(model, web3, toFloat) {
let gasUsed = 0
const classifications = []
const centroids = []
Expand Down Expand Up @@ -105,7 +105,7 @@ async function loadNearestCentroidClassifier(model, web3, toFloat) {
})
}

exports.loadSparseNearestCentroidClassifier = async function (model, web3, toFloat) {
exports.deploySparseNearestCentroidClassifier = async function (model, web3, toFloat) {
let gasUsed = 0
const initialChunkSize = 500
const chunkSize = 500
Expand Down Expand Up @@ -167,7 +167,7 @@ exports.loadSparseNearestCentroidClassifier = async function (model, web3, toFlo
})
}

async function loadNaiveBayes(model, web3, toFloat) {
async function deployNaiveBayes(model, web3, toFloat) {
let gasUsed = 0
const initialFeatureChunkSize = 150
const featureChunkSize = 350
Expand Down Expand Up @@ -213,20 +213,20 @@ async function loadNaiveBayes(model, web3, toFloat) {
* @returns The contract for the model, an instance of `Classifier64`
* along with the the total amount of gas used to deploy the model.
*/
exports.loadModel = async function (path, web3, toFloat = _toFloat) {
exports.deployModel = async function (path, web3, toFloat = _toFloat) {
const model = JSON.parse(fs.readFileSync(path, 'utf8'))
switch (model.type) {
case 'dense perceptron':
return loadDensePerceptron(model, web3, toFloat)
return deployDensePerceptron(model, web3, toFloat)
case 'naive bayes':
return loadNaiveBayes(model, web3, toFloat)
return deployNaiveBayes(model, web3, toFloat)
case 'dense nearest centroid classifier':
case 'nearest centroid classifier':
return loadNearestCentroidClassifier(model, web3, toFloat)
return deployNearestCentroidClassifier(model, web3, toFloat)
case 'sparse nearest centroid classifier':
return exports.loadSparseNearestCentroidClassifier(model, web3, toFloat)
return exports.deploySparseNearestCentroidClassifier(model, web3, toFloat)
case 'sparse perceptron':
return loadSparsePerceptron(model, web3, toFloat)
return deploySparsePerceptron(model, web3, toFloat)
default:
// Should not happen.
throw new Error(`Unrecognized model type: "${model.type}"`)
Expand Down

0 comments on commit ef86f0b

Please sign in to comment.