In [2]:
:dep ort = { version = "2.0.0-alpha.2", features = [ "load-dynamic" ] }
:dep ndarray
:dep image
:dep ndarray-stats


In [3]:
use std::path::Path;
use ort::{inputs, Session, CUDAExecutionProvider, SessionOutputs};
use ndarray::{s, Array, Axis, Dim};
use image::{imageops::FilterType, GenericImageView};
use ndarray_stats::QuantileExt;


In [4]:
let model = Session::builder()?
    .with_execution_providers([CUDAExecutionProvider::default().build()])?
    .with_intra_threads(4)?
    .with_model_from_file("./yolov8m.onnx")?;

In [5]:
println!("{model:?}")

Session { inner: SharedSessionInner { session_ptr: 0x564b3509e470, allocator: Allocator { ptr: 0x7f10f9fcb3c8, is_default: true } }, inputs: [Input { name: "images", input_type: Tensor { ty: Float32, dimensions: [1, 3, 640, 640] } }], outputs: [Output { name: "output0", output_type: Tensor { ty: Float32, dimensions: [1, 84, 8400] } }] }


()

In [6]:
let orig_img = image::open(Path::new("baseball.jpg"))?;

In [7]:
// orig_img

In [8]:
println!("{}, {}", orig_img.width(), orig_img.height());

640, 480


In [9]:
fn print_type_of<T>(_: &T) {
    println!("{}", std::any::type_name::<T>())
}

In [10]:
print_type_of(&orig_img.dimensions());

(u32, u32)


In [11]:
orig_img[(0, 0)]

Error: cannot index into a value of type `DynamicImage`

In [12]:
print_type_of(&orig_img);

image::dynimage::DynamicImage


In [13]:
// let orig_img = orig_img.to_rgb8();

In [14]:
print_type_of(&orig_img);

image::dynimage::DynamicImage


In [15]:
// let rgb = orig_img[(639, 479)];

In [16]:
for (x, y, rgb) in orig_img.pixels() {
    println!("{}, {}, {:?}", x, y, rgb.0);
    break;
}

0, 0, [73, 1, 0, 255]


()

In [17]:
let img = orig_img.resize_exact(640, 640, FilterType::CatmullRom);

In [18]:
let mut input: Array<f32, _> = Array::zeros((1, 3, 640, 640));

In [19]:
for (x, y, rgb) in img.pixels() {
    let x = x as _;
    let y = y as _;
    let [r, g, b, _] = rgb.0;
    input[(0, 0, y, x)] = (r as f32) / 255.;
    input[(0, 1, y, x)] = (g as f32) / 255.;
    input[(0, 2, y, x)] = (b as f32) / 255.;
}

()

In [20]:
print_type_of(&input.view());

ndarray::ArrayBase<ndarray::ViewRepr<&f32>, ndarray::dimension::dim::Dim<[usize; 4]>>


In [21]:
print_type_of(&input);

ndarray::ArrayBase<ndarray::data_repr::OwnedRepr<f32>, ndarray::dimension::dim::Dim<[usize; 4]>>


In [22]:
const YOLOV8_CLASS_LABELS: [&str; 80] = [
    "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light",
	"fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant",
	"bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard",
	"sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle",
	"wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli",
	"carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet",
	"tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator",
	"book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"
];


In [27]:
//let mut boxes = Vec::new();

{
    let outputs: SessionOutputs = model.run(inputs!["images" => input.view()]?)?;
    print_type_of(&outputs);
    println!("{:?}", outputs.keys());
    let output = &outputs["output0"];
    print_type_of(output);
    let (shape, data) = output.extract_raw_tensor::<f32>()?;
    println!("{:?}, {}", shape, data.len());
    let value = output.extract_tensor::<f32>()?;
    print_type_of(&value);
    let view = value.view();
    print_type_of(&view);
    println!("{:?}", view.shape());
    let view = view.t();
    print_type_of(&view);
    println!("{:?}", view.shape());

    let output = view.slice(s![.., .., 0]);
    println!("{:?}", output.shape());

    for row in output.axis_iter(Axis(0)) {
        // print_type_of(&row);
        // println!("{:?}", row);
        let argmax = row.slice(s![4..]).argmax().unwrap();
        let val = row[[4usize + argmax,]];
        //println!("{}, {}", argmax, val);
        if val > 0.5 {
            println!("{} {} {}", row.slice(s![0..4]), argmax, val);
        }
    }
}


ort::session::output::SessionOutputs
["output0"]
ort::value::Value
[1, 84, 8400], 705600
ort::tensor::Tensor<f32>
ort::tensor::ArrayViewHolder<f32>
[1, 84, 8400]
ndarray::ArrayBase<ndarray::ViewRepr<&f32>, ndarray::dimension::dim::Dim<ndarray::dimension::dynindeximpl::IxDynImpl>>
[8400, 84, 1]
[8400, 84]
[249.19936, 502.17923, 52.591095, 13.788635] 34 0.509558
[226.05466, 572.4138, 28.570679, 56.183655] 35 0.85798615
[226.27249, 572.4507, 28.232681, 56.3844] 35 0.88766855
[226.47989, 572.3797, 27.732666, 56.32361] 35 0.69679815
[225.82463, 572.3745, 29.121002, 56.09662] 35 0.8554197
[225.95483, 572.4892, 28.88478, 56.11682] 35 0.8656875
[226.17255, 572.4514, 28.336472, 56.25354] 35 0.84642035
[225.8009, 572.2546, 29.139725, 55.99597] 35 0.83526874
[225.87451, 572.28564, 28.91397, 55.9505] 35 0.83369225
[225.99356, 572.29364, 28.734787, 56.052612] 35 0.8384652
[226.1612, 572.36865, 28.588638, 56.066284] 35 0.82622015
[42.031322, 426.9093, 44.543785, 117.85742] 0 0.53654045
[41.944942, 4

()