-
Notifications
You must be signed in to change notification settings - Fork 4
/
wasm.rs
97 lines (82 loc) · 2.97 KB
/
wasm.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
use js_sys::{Array, Float32Array, Uint8Array};
use ndarray::ArrayD;
use wasm_bindgen::{prelude::wasm_bindgen, JsValue};
use crate::{CPUBackend, Dataset, Logger, PredictOptions, TrainOptions, RESOURCES};
#[wasm_bindgen]
extern "C" {
#[wasm_bindgen(js_namespace = console)]
fn log(s: &str);
}
fn console_log(string: String) {
log(string.as_str())
}
#[wasm_bindgen]
pub fn wasm_backend_create(config: String, shape: Array) -> usize {
let config = serde_json::from_str(&config).unwrap();
let mut len = 0;
let logger = Logger { log: console_log };
let net_backend = CPUBackend::new(config, logger, None);
shape.set_length(net_backend.size.len() as u32);
for (i, s) in net_backend.size.iter().enumerate() {
shape.set(i as u32, JsValue::from(*s))
}
RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
len = backend.len();
backend.push(net_backend);
});
len
}
#[wasm_bindgen]
pub fn wasm_backend_train(id: usize, buffers: Vec<Float32Array>, options: String) {
let options: TrainOptions = serde_json::from_str(&options).unwrap();
let mut datasets = Vec::new();
for i in 0..options.datasets {
let input = buffers[i * 2].to_vec();
let output = buffers[i * 2 + 1].to_vec();
datasets.push(Dataset {
inputs: ArrayD::from_shape_vec(options.input_shape.clone(), input).unwrap(),
outputs: ArrayD::from_shape_vec(options.output_shape.clone(), output).unwrap(),
});
}
RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
backend[id].train(datasets, options.epochs, options.batches, options.rate)
});
}
#[wasm_bindgen]
pub fn wasm_backend_predict(id: usize, buffer: Float32Array, options: String) -> Float32Array {
let options: PredictOptions = serde_json::from_str(&options).unwrap();
let inputs = ArrayD::from_shape_vec(options.input_shape, buffer.to_vec()).unwrap();
let mut res = ArrayD::zeros(options.output_shape);
RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
res = backend[id].predict(inputs);
});
Float32Array::from(res.as_slice().unwrap())
}
#[wasm_bindgen]
pub fn wasm_backend_save(id: usize) -> Uint8Array {
let mut buffer = Vec::new();
RESOURCES.with(|cell| {
let backend = cell.backend.borrow_mut();
buffer = backend[id].save();
});
Uint8Array::from(buffer.as_slice())
}
#[wasm_bindgen]
pub fn wasm_backend_load(buffer: Uint8Array, shape: Array) -> usize {
let mut len = 0;
let logger = Logger { log: console_log };
let net_backend = CPUBackend::load(buffer.to_vec().as_slice(), logger);
shape.set_length(net_backend.size.len() as u32);
for (i, s) in net_backend.size.iter().enumerate() {
shape.set(i as u32, JsValue::from(*s))
}
RESOURCES.with(|cell| {
let mut backend = cell.backend.borrow_mut();
len = backend.len();
backend.push(net_backend);
});
len
}