In [2]:
:dep candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }

In [16]:
use candle_core::Device;

let device = Device::new_cuda(0).unwrap_or(Device::Cpu);

device

Cpu

# Basic Tensor Ops

In [4]:
use candle_core::Tensor;

let img = Tensor::randn(0f32, 1.0, (1, 784), &device)?;

let w1 = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let b1 = Tensor::randn(0f32, 1.0, (100, ), &device)?;

let w2 = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let b2 = Tensor::randn(0f32, 1.0, (10, ), &device)?;

In [5]:
let x = img.matmul(&w1)?.broadcast_add(&b1)?;

x

Tensor[dims 1, 100; f32]

In [6]:
let x = x.relu()?;

x

Tensor[dims 1, 100; f32]

In [7]:
let digit = x.matmul(&w2)?.broadcast_add(&b2)?;

digit

Tensor[dims 1, 10; f32]

In [8]:
println!("{}", digit);

[[-257.8495,    5.0175,  379.4254,  -19.7939, -149.9870,  235.7281, -256.8443,
    -3.6007,  -85.4078,   81.2617]]
Tensor[[1, 10], f32]


# Basic Abstractions

In [9]:
use candle_core::{Tensor, Result};

struct Model {
    first: Linear,
    second: Linear,
}

impl Model {
    fn forward(&self, img: &Tensor) -> Result<Tensor> {
        let x = self.first.forward(&img)?;
        let x = x.relu()?;
        self.second.forward(&x)
    }
}

struct Linear{
    weight: Tensor,
    bias: Tensor,
}

impl Linear {
    fn forward(&self, x: &Tensor) -> Result<Tensor> {
        let x = x.matmul(&self.weight)?;
        x.broadcast_add(&self.bias)
    }
}

In [10]:
use candle_core::Tensor;

let w1 = Tensor::randn(0f32, 1.0, (784, 100), &device)?;
let b1 = Tensor::randn(0f32, 1.0, (100, ), &device)?;
let l1 = Linear { weight: w1, bias: b1 };

let w2 = Tensor::randn(0f32, 1.0, (100, 10), &device)?;
let b2 = Tensor::randn(0f32, 1.0, (10, ), &device)?;
let l2 = Linear { weight: w2, bias: b2 };

let model = Model { first: l1, second: l2 };

In [11]:
let img = Tensor::randn(0f32, 1.0, (1, 784), &device)?;

let digit = model.forward(&img)?;

digit

Tensor[dims 1, 10; f32]

In [12]:
println!("{}", digit);

[[ 310.6161, -111.7927,  -45.6357,    0.7160,  -43.6480,  -60.4750,  284.5006,
  -410.4315,  -33.8279,  100.1221]]
Tensor[[1, 10], f32]


# Built-in Abstractions

In [13]:
:dep candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }