Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draw lines between keypoints #10

Open
github-actions bot opened this issue Mar 17, 2024 · 0 comments
Open

draw lines between keypoints #10

github-actions bot opened this issue Mar 17, 2024 · 0 comments
Assignees
Labels

Comments

@github-actions
Copy link

// TODO: draw lines between keypoints

use bevy::{
    prelude::*,
    window::PrimaryWindow,
};

use bevy_ort::{
    BevyOrtPlugin,
    models::lightglue::{
        GluedPair,
        lightglue_inference,
        Lightglue,
        LightgluePlugin,
    },
    Onnx,
};


fn main() {
    App::new()
        .add_plugins((
            DefaultPlugins,
            BevyOrtPlugin,
            LightgluePlugin,
        ))
        .init_resource::<LightglueInput>()
        .add_systems(Startup, load_lightglue)
        .add_systems(Update, inference)
        .run();
}


#[derive(Resource, Default)]
pub struct LightglueInput {
    pub a: Handle<Image>,
    pub b: Handle<Image>,
}


fn load_lightglue(
    asset_server: Res<AssetServer>,
    mut lightglue: ResMut<Lightglue>,
    mut input: ResMut<LightglueInput>,
) {
    let lightglue_handle: Handle<Onnx> = asset_server.load("models/disk_lightglue_end2end_fused_cpu.onnx");
    lightglue.onnx = lightglue_handle;

    input.a = asset_server.load("images/sacre_coeur1.png");
    input.b = asset_server.load("images/sacre_coeur2.png");
}


fn inference(
    mut commands: Commands,
    lightglue: Res<Lightglue>,
    input: Res<LightglueInput>,
    onnx_assets: Res<Assets<Onnx>>,
    images: Res<Assets<Image>>,
    primary_window: Query<&Window, With<PrimaryWindow>>,
    mut complete: Local<bool>,
) {
    if *complete {
        return;
    }

    let window = primary_window.single();

    let images = vec![
        images.get(&input.a).expect("failed to get image asset"),
        images.get(&input.b).expect("failed to get image asset"),
    ];
    let images = images.iter().map(|image| *image).collect::<Vec<_>>();

    let glued_pairs: Result<Vec<(usize, usize, Vec<GluedPair>)>, String> = (|| {
        let onnx = onnx_assets.get(&lightglue.onnx).ok_or("failed to get ONNX asset")?;
        let session_lock = onnx.session.lock().map_err(|e| e.to_string())?;
        let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?;

        Ok(lightglue_inference(
            session,
            images.as_slice(),
        ))
    })();

    match glued_pairs {
        Ok(glued_pairs) => {
            println!("glued_pairs: {:?}", glued_pairs[0].2.len());

            commands.spawn(NodeBundle {
                style: Style {
                    display: Display::Grid,
                    width: Val::Percent(100.0),
                    height: Val::Percent(100.0),
                    grid_template_columns: RepeatedGridTrack::flex(2, 1.0),
                    grid_template_rows: RepeatedGridTrack::flex(2, 1.0),
                    ..default()
                },
                background_color: BackgroundColor(Color::DARK_GRAY),
                ..default()
            })
            .with_children(|builder| {
                builder.spawn(ImageBundle {
                    style: Style {
                        ..default()
                    },
                    image: UiImage::new(input.a.clone()),
                    ..default()
                })
                .with_children(|builder| {
                    let image_width = images[0].width() as f32;
                    let image_height = images[0].height() as f32;

                    let display_width = window.width() as f32 / 2.0;
                    let display_height = window.height() as f32 / 2.0;

                    let scale_x = display_width / image_width;
                    let scale_y = display_height / image_height;

                    glued_pairs[0].2.iter().for_each(|pair| {
                        let scaled_x = pair.from_x as f32 * scale_x;
                        let scaled_y = pair.from_y as f32 * scale_y;

                        builder.spawn(NodeBundle {
                            style: Style {
                                position_type: PositionType::Absolute,
                                left: Val::Px(scaled_x),
                                top: Val::Px(scaled_y),
                                width: Val::Px(2.0),
                                height: Val::Px(2.0),
                                ..default()
                            },
                            background_color: Color::rgb(1.0, 0.0, 0.0).into(),
                            ..default()
                        });
                    });
                });

                builder.spawn(ImageBundle {
                    style: Style {
                        ..default()
                    },
                    image: UiImage::new(input.b.clone()),
                    ..default()
                })
                .with_children(|builder| {
                    let image_width = images[1].width() as f32;
                    let image_height = images[1].height() as f32;

                    let display_width = window.width() as f32 / 2.0;
                    let display_height = window.height() as f32 / 2.0;

                    let scale_x = display_width / image_width;
                    let scale_y = display_height / image_height;

                    glued_pairs[0].2.iter().for_each(|pair| {
                        let scaled_x = pair.to_x as f32 * scale_x;
                        let scaled_y = pair.to_y as f32 * scale_y;

                        builder.spawn(NodeBundle {
                            style: Style {
                                position_type: PositionType::Absolute,
                                left: Val::Px(scaled_x),
                                top: Val::Px(scaled_y),
                                width: Val::Px(2.0),
                                height: Val::Px(2.0),
                                ..default()
                            },
                            background_color: Color::rgb(0.0, 1.0, 0.0).into(),
                            ..default()
                        });
                    });
                });

                builder.spawn(ImageBundle {
                    style: Style {
                        ..default()
                    },
                    image: UiImage::new(input.a.clone()),
                    ..default()
                });

                builder.spawn(ImageBundle {
                    style: Style {
                        ..default()
                    },
                    image: UiImage::new(input.b.clone()),
                    ..default()
                });

                // TODO: draw lines between keypoints
            });

            commands.spawn(Camera2dBundle::default());

            *complete = true;
        }
        Err(e) => {
            eprintln!("inference failed: {}", e);
        }
    }
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant