-
Notifications
You must be signed in to change notification settings - Fork 0
/
mnist.rs
51 lines (41 loc) · 1.47 KB
/
mnist.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
use gradients::prelude::*;
#[network]
pub struct Network {
lin1: Linear<784, 128>,
relu1: ReLU,
lin2: Linear<128, 10>,
relu2: ReLU,
lin3: Linear<10, 10>,
softmax: Softmax,
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
// let device = gradients::CPU::new(); // use cpu (no framework specific features enabled):
// let device = gradients::CudaDevice::new(0)?; // use cuda device (cuda feature enabled):
// use opencl device (opencl feature enabled):
let device = CLDevice::new(0)?;
let mut net = Network::with(&device);
let loader = CSVLoader::new(true);
let loaded_data: CSVReturn<f32> = loader.load("PATH/TO/DATASET/mnist_train.csv")?;
let mut i = Matrix::from((
&device,
(loaded_data.sample_count, loaded_data.features),
&loaded_data.x,
));
i /= 255.;
let y = Matrix::from((&device, (loaded_data.sample_count, 1), &loaded_data.y));
let y = y.onehot();
let mut opt = Adam::new(0.01);
for epoch in range(200) {
let preds = net.forward(&i);
let correct_training = correct_classes(&loaded_data.y.as_usize(), &preds) as f32;
let loss = cce(&device, &preds, &y);
println!(
"epoch: {epoch}, loss: {loss}, training_acc: {acc}",
acc = correct_training / loaded_data.sample_count() as f32
);
let grad = cce_grad(&device, &preds, &y);
net.backward(&grad);
opt.step(&device, net.params());
}
Ok(())
}