Skip to content

Commit

Permalink
feat: offline grid view
Browse files Browse the repository at this point in the history
  • Loading branch information
mosure committed Mar 18, 2024
1 parent a28416a commit 41368e0
Show file tree
Hide file tree
Showing 4 changed files with 448 additions and 149 deletions.
145 changes: 145 additions & 0 deletions src/grid_view.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
use bevy::{
prelude::*,
window::PrimaryWindow,
};

use crate::materials::foreground::ForegroundMaterial;


pub struct GridViewPlugin;
impl Plugin for GridViewPlugin {
fn build(&self, app: &mut App) {
app.init_resource::<GridView>();
app.add_systems(Update, draw_grid_view);
}
}

#[derive(Debug, Clone)]
pub enum Element {
Image(Handle<Image>),
Alphablend(Handle<ForegroundMaterial>),
}

#[derive(Resource, Default)]
pub struct GridView {
pub source: Vec<Element>,
}


#[derive(Component, Default)]
pub struct GridViewParent;


fn draw_grid_view(
mut commands: Commands,
primary_window: Query<
&Window,
With<PrimaryWindow>
>,
grid_view: Res<GridView>,
grid_view_parent: Query<
Entity,
With<GridViewParent>
>,
) {
if !grid_view.is_changed() {
return;
}

for entity in grid_view_parent.iter() {
commands.entity(entity).despawn_recursive();
}

let window = primary_window.single();

let (
columns,
rows,
_width,
_height,
) = calculate_grid_dimensions(
window.width(),
window.height(),
grid_view.source.len(),
);

commands.spawn(NodeBundle {
style: Style {
display: Display::Grid,
width: Val::Percent(100.0),
height: Val::Percent(100.0),
grid_template_columns: RepeatedGridTrack::flex(columns as u16, 1.0),
grid_template_rows: RepeatedGridTrack::flex(rows as u16, 1.0),
..default()
},
background_color: BackgroundColor(Color::BLACK),
..default()
})
.insert(GridViewParent)
.with_children(|builder| {
grid_view.source.iter()
.for_each(|element| {
match element {
Element::Image(image) => {
builder.spawn(ImageBundle {
style: Style {
width: Val::Percent(100.0),
height: Val::Percent(100.0),
..default()
},
image: UiImage::new(image.clone()),
..default()
});
}
Element::Alphablend(material) => {
builder.spawn(MaterialNodeBundle {
style: Style {
width: Val::Percent(100.0),
height: Val::Percent(100.0),
..default()
},
material: material.clone(),
..default()
});
}
}
});
});
}


fn calculate_grid_dimensions(
window_width: f32,
window_height: f32,
num_streams: usize,
) -> (usize, usize, f32, f32) {
let window_aspect_ratio = window_width / window_height;
let stream_aspect_ratio: f32 = 16.0 / 9.0;
let mut best_layout = (1, num_streams);
let mut best_diff = f32::INFINITY;
let mut best_sprite_size = (0.0, 0.0);

for columns in 1..=num_streams {
let rows = (num_streams as f32 / columns as f32).ceil() as usize;
let sprite_width = window_width / columns as f32;
let sprite_height = sprite_width / stream_aspect_ratio;
let total_height_needed = sprite_height * rows as f32;
let (final_sprite_width, final_sprite_height) = if total_height_needed > window_height {
let adjusted_sprite_height = window_height / rows as f32;
let adjusted_sprite_width = adjusted_sprite_height * stream_aspect_ratio;
(adjusted_sprite_width, adjusted_sprite_height)
} else {
(sprite_width, sprite_height)
};
let grid_aspect_ratio = final_sprite_width * columns as f32 / (final_sprite_height * rows as f32);
let diff = (window_aspect_ratio - grid_aspect_ratio).abs();

if diff < best_diff {
best_diff = diff;
best_layout = (columns, rows);
best_sprite_size = (final_sprite_width, final_sprite_height);
}
}

(best_layout.0, best_layout.1, best_sprite_size.0, best_sprite_size.1)
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use bevy::prelude::*;
use bevy_ort::BevyOrtPlugin;

pub mod ffmpeg;
pub mod grid_view;
pub mod materials;
pub mod matting;
pub mod mp4;
Expand All @@ -19,6 +20,7 @@ impl Plugin for LightFieldPlugin {
fn build(&self, app: &mut App) {
app.add_plugins(BevyOrtPlugin);

app.add_plugins(grid_view::GridViewPlugin);
app.add_plugins(materials::StreamMaterialsPlugin);
app.add_plugins(person_detect::PersonDetectPlugin);
app.add_plugins(pipeline::PipelinePlugin);
Expand Down
120 changes: 115 additions & 5 deletions src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ use image::{
ImageBuffer,
Luma,
Rgb,
Rgba,
};
use imageproc::geometric_transformations::{
rotate_about_center,
Expand Down Expand Up @@ -59,6 +60,7 @@ impl Plugin for PipelinePlugin {
generate_raw_frames,
generate_rotated_frames,
generate_mask_frames,
generate_alphablend_frames,
generate_yolo_frames,
)
);
Expand All @@ -70,6 +72,7 @@ impl Plugin for PipelinePlugin {
pub struct PipelineConfig {
pub raw_frames: bool,
pub rotate_raw_frames: bool,
pub alphablend_frames: bool,
pub yolo: bool, // https://github.com/ultralytics/ultralytics
pub repair_frames: bool, // https://huggingface.co/docs/diffusers/en/optimization/onnx & https://github.com/bnm6900030/swintormer
pub upsample_frames: bool, // https://huggingface.co/ssube/stable-diffusion-x4-upscaler-onnx
Expand All @@ -85,9 +88,10 @@ impl Default for PipelineConfig {
raw_frames: true,
rotate_raw_frames: true,
yolo: true,
repair_frames: false,
upsample_frames: false,
alphablend_frames: true,
mask_frames: true,
upsample_frames: false,
repair_frames: false,
light_field_cameras: false,
depth_maps: false,
gaussian_cloud: false,
Expand Down Expand Up @@ -227,8 +231,6 @@ fn generate_rotated_frames(
raw_frames,
session,
) in raw_frames.iter() {
// TODO: get stream descriptor rotation

if config.rotate_raw_frames {
let run_node = !RotatedFrames::exists(session);
let mut rotated_frames = RotatedFrames::load_from_session(session);
Expand Down Expand Up @@ -391,6 +393,66 @@ fn generate_mask_frames(
}


fn generate_alphablend_frames(
mut commands: Commands,
session: Query<
(
Entity,
&PipelineConfig,
&RotatedFrames,
&MaskFrames,
&Session,
),
Without<AlphablendFrames>,
>,
) {
for (
entity,
config,
rotated_frames,
mask_frames,
session,
) in session.iter() {
if config.alphablend_frames {
let run_node = !AlphablendFrames::exists(session);
let mut alphablend_frames = AlphablendFrames::load_from_session(session);

if run_node {
info!("generating alphablend frames for session {}", session.id);

rotated_frames.frames.iter()
.for_each(|(stream_id, frames)| {
let output_directory = format!("{}/{}", alphablend_frames.directory, stream_id.0);
std::fs::create_dir_all(&output_directory).unwrap();

let frames = frames.par_iter()
.zip(mask_frames.frames.get(stream_id).unwrap())
.map(|(frame, mask)| {
let frame_idx = std::path::Path::new(frame).file_stem().unwrap().to_str().unwrap();
let output_path = format!("{}/{}.png", output_directory, frame_idx);

alphablend_image(
std::path::Path::new(frame),
std::path::Path::new(mask),
std::path::Path::new(&output_path),
).unwrap();

output_path
})
.collect::<Vec<_>>();

alphablend_frames.frames.insert(*stream_id, frames);
});
} else {
info!("alphablend frames already exist for session {}", session.id);
}

commands.entity(entity).insert(alphablend_frames);
}
}
}


fn generate_yolo_frames(
mut commands: Commands,
raw_frames: Query<
Expand Down Expand Up @@ -504,7 +566,6 @@ fn generate_yolo_frames(
}


// TODO: alphablend frames
#[derive(Component, Default)]
pub struct AlphablendFrames {
pub frames: HashMap<StreamId, Vec<String>>,
Expand Down Expand Up @@ -837,6 +898,30 @@ fn get_next_session_id(output_directory: &str) -> usize {
}


pub fn load_png(
image_path: &std::path::Path,
) -> Image {
let image = image::open(image_path).unwrap();
let image = image.into_rgba8();
let width = image.width();
let height = image.height();

let image_bytes = image.into_raw();

Image::new(
Extent3d {
width,
height,
depth_or_array_layers: 1,
},
bevy::render::render_resource::TextureDimension::D2,
image_bytes,
bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
RenderAssetUsages::all(),
)
}


fn rotate_image(
image_path: &std::path::Path,
output_path: &std::path::Path,
Expand Down Expand Up @@ -869,3 +954,28 @@ fn rotate_image(

Ok(())
}


fn alphablend_image(
image_path: &std::path::Path,
mask_path: &std::path::Path,
output_path: &std::path::Path,
) -> image::ImageResult<()> {
let img = image::open(image_path).unwrap();

let mask = image::open(mask_path).unwrap();
let mask = mask.resize_exact(img.width(), img.height(), image::imageops::FilterType::Triangle);

let mut output_img: ImageBuffer<Rgba<u8>, Vec<u8>> = ImageBuffer::new(img.dimensions().0, img.dimensions().1);

for (x, y, pixel) in img.pixels() {
let mask_pixel = mask.get_pixel(x, y).0[0];
let mut img_pixel = pixel.0;
img_pixel[3] = mask_pixel;
output_img.put_pixel(x, y, Rgba(img_pixel));
}

output_img.save(output_path)?;

Ok(())
}
Loading

0 comments on commit 41368e0

Please sign in to comment.