diff --git a/cml-core/src/core/inference.rs b/cml-core/src/core/inference.rs index d695429..eaef99c 100644 --- a/cml-core/src/core/inference.rs +++ b/cml-core/src/core/inference.rs @@ -2,7 +2,7 @@ use crate::metadata::MetaData; use anyhow::Result; use deadpool::managed::{Manager, Pool}; use derive_getters::Getters; -use std::{future::Future, path::PathBuf}; +use std::path::PathBuf; #[derive(Builder, Getters)] pub struct NewSample { @@ -15,7 +15,7 @@ pub struct NewSample { optional_tags: Option>, } -pub trait Inference { +pub trait Inference { async fn init_inference( &self, target_type: M, @@ -23,15 +23,14 @@ pub trait Inference { optional_tags: Option>, ) -> Result<()>; - async fn inference( + async fn inference( &self, metadata: MetaData, - target_type: M, + available_status: &[&str], data: &mut Vec>, pool: &Pool, inference_fn: FN, ) -> Result<()> where - FN: FnOnce(&mut Vec>, &Pool) -> R, - R: Future>>>; + FN: FnOnce(&mut Vec>, &str, T) -> Vec>; } diff --git a/cml-core/src/core/task.rs b/cml-core/src/core/task.rs index 3fbefc4..870a814 100644 --- a/cml-core/src/core/task.rs +++ b/cml-core/src/core/task.rs @@ -1,15 +1,12 @@ use anyhow::Result; use chrono::Duration; use derive_getters::Getters; -use std::path::PathBuf; #[derive(Builder, Getters, Clone)] -pub struct TaskConfig { +pub struct TaskConfig<'a> { min_start_count: usize, min_update_count: usize, - work_dir: PathBuf, - local_dir: Option, - working_status: Vec, + working_status: Vec<&'a str>, limit_time: Duration, } @@ -27,5 +24,5 @@ pub trait Task { fining_build_fn: FN, ) -> Result<()> where - FN: Fn(&TaskConfig, &str) -> Result<()> + Send + Sync; + FN: Fn(&str) -> Result<()> + Send + Sync; } diff --git a/cml-tdengine/src/core/inference.rs b/cml-tdengine/src/core/inference.rs index 78e94e7..aa58d6c 100644 --- a/cml-tdengine/src/core/inference.rs +++ b/cml-tdengine/src/core/inference.rs @@ -1,16 +1,14 @@ +use crate::{models::stables::STable, TDengine}; use anyhow::Result; use cml_core::{ core::inference::{Inference, NewSample}, handler::Handler, metadata::MetaData, }; -use std::future::Future; use std::time::{Duration, SystemTime}; use taos::{taos_query::Manager, *}; -use crate::{models::stables::STable, TDengine}; - -impl Inference> for TDengine { +impl Inference> for TDengine { async fn init_inference( &self, target_type: Field, @@ -39,24 +37,34 @@ impl Inference> for TDeng Ok(()) } - async fn inference( + async fn inference( &self, metadata: MetaData, - target_type: Field, + available_status: &[&str], data: &mut Vec>, pool: &Pool, inference_fn: FN, ) -> Result<()> where - FN: FnOnce(&mut Vec>, &Pool) -> R, - R: Future>>>, + FN: FnOnce(&mut Vec>, &str, i64) -> Vec>, { - let samples_with_res = inference_fn(data, pool).await?; - let taos = pool.get().await?; - taos.use_database("training_data").await?; let mut stmt = Stmt::init(&taos)?; - + taos.use_database("task").await?; + let last_task_time = taos + .query_one(format!( + "SELECT LAST(ts) FROM task.`{}` WHERE status IN ({})", + metadata.batch(), + available_status + .iter() + .map(|s| format!("'{}'", s)) + .collect::>() + .join(", ") + )) + .await? + .unwrap_or_else(|| panic!("There is no task in batch: {}", metadata.batch())); + let samples_with_res = inference_fn(data, metadata.batch(), last_task_time); + taos.use_database("inference").await?; let (tag_placeholder, field_placeholder) = metadata.get_placeholders(); stmt.prepare(format!( @@ -78,7 +86,15 @@ impl Inference> for TDeng for sample in &samples_with_res { let output = match sample.output() { Some(value) => ColumnView::from(value.clone()), - None => ColumnView::null(1, target_type.ty()), + None => ColumnView::null( + 1, + taos.query("SELECT * FROM inference LIMIT 0") + .await? + .fields() + .get(2) + .unwrap() + .ty(), + ), }; let mut values = vec![ @@ -98,8 +114,8 @@ impl Inference> for TDeng stmt.bind(&values)?; current_ts += Duration::from_nanos(1).as_nanos() as i64; } - stmt.add_batch()?; + stmt.execute()?; Ok(()) } @@ -112,6 +128,11 @@ mod tests { options::{CacheModel, ReplicaNum, SingleSTable}, DatabaseBuilder, }; + use cml_core::{ + core::inference::NewSampleBuilder, handler::Handler, metadata::MetaDataBuilder, + }; + use std::fs; + use std::time::{Duration, SystemTime}; #[tokio::test] async fn test_inference_init() -> Result<()> { @@ -149,6 +170,200 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_concurrent_inference() -> Result<()> { + let cml = TDengine::from_dsn("taos://"); + let pool = cml.build_pool(); + let taos = pool.get().await?; + + taos.exec("DROP DATABASE IF EXISTS inference").await?; + taos.exec("DROP DATABASE IF EXISTS task").await?; + taos.exec( + "CREATE DATABASE IF NOT EXISTS inference + PRECISION 'ns'", + ) + .await?; + taos.exec("CREATE DATABASE IF NOT EXISTS task PRECISION 'ns'") + .await?; + taos.exec( + "CREATE STABLE IF NOT EXISTS inference.inference + (ts TIMESTAMP, data_path NCHAR(255), output FLOAT) + TAGS (model_update_time TIMESTAMP)", + ) + .await?; + taos.exec( + "CREATE STABLE IF NOT EXISTS task.task + (ts TIMESTAMP, status BINARY(8)) + TAGS (model_update_time TIMESTAMP)", + ) + .await?; + taos.exec( + "INSERT INTO task.`FUCK` + USING task.task + TAGS ('2022-08-08 18:18:18.518') + VALUES (NOW, 'TRAIN')", + ) + .await?; + taos.exec( + "INSERT INTO task.`FUCK` + USING task.task + TAGS ('2022-08-08 18:18:18.518') + VALUES (NOW-2s, 'SUCCESS')", + ) + .await?; + taos.exec( + "INSERT INTO task.`FUCK8` + USING task.task + TAGS ('2022-08-08 18:18:18.518') + VALUES (NOW, 'SUCCESS')", + ) + .await?; + + let model_update_time = (SystemTime::now() - Duration::from_secs(86400)) + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_nanos() as i64; + let batch_meta_1: MetaData = MetaDataBuilder::default() + .model_update_time(model_update_time) + .batch("FUCK".to_owned()) + .inherent_field_num(3) + .inherent_tag_num(1) + .optional_field_num(0) + .build()?; + let batch_meta_2: MetaData = MetaDataBuilder::default() + .model_update_time(model_update_time) + .batch("FUCK8".to_owned()) + .inherent_field_num(3) + .inherent_tag_num(1) + .optional_field_num(0) + .build()?; + + fs::create_dir_all("/tmp/inference_dir/")?; + fs::write("/tmp/inference_dir/inference_data1.txt", b"8.8")?; + fs::write("/tmp/inference_dir/inference_data2.txt", b"98.8")?; + let mut batch_data_1 = vec![ + NewSampleBuilder::default() + .data_path("/tmp/inference_dir/inference_data1.txt".into()) + .build()?, + NewSampleBuilder::default() + .data_path("/tmp/inference_dir/inference_data2.txt".into()) + .build()?, + ]; + let mut batch_data_2 = vec![NewSampleBuilder::default() + .data_path("/tmp/inference_dir/inference_data1.txt".into()) + .build()?]; + + let available_status = vec!["SUCCESS"]; + let last_batch_time_1: i64 = taos + .query_one(format!( + "SELECT LAST(ts) FROM task.`{}` WHERE status IN ({}) ", + batch_meta_1.batch(), + available_status + .iter() + .map(|s| format!("'{}'", s)) + .collect::>() + .join(", ") + )) + .await? + .unwrap(); + let last_batch_time_2: i64 = taos + .query_one(format!( + "SELECT LAST(ts) FROM task.`{}` WHERE status IN ({}) ", + batch_meta_2.batch(), + available_status + .iter() + .map(|s| format!("'{}'", s)) + .collect::>() + .join(", ") + )) + .await? + .unwrap(); + fs::write( + "/tmp/inference_dir/".to_string() + + batch_meta_1.batch() + + &last_batch_time_1.to_string() + + ".txt", + b"10", + )?; + fs::write( + "/tmp/inference_dir/".to_string() + + batch_meta_2.batch() + + &last_batch_time_2.to_string() + + ".txt", + b"20", + )?; + let inference_fn = |vec_data: &mut Vec>, + batch: &str, + task_time: i64| + -> Vec> { + let mut result: Vec> = Vec::new(); + let working_dir = "/tmp/inference_dir/".to_string(); + let model_inference = + fs::read_to_string(working_dir + batch + &task_time.to_string() + ".txt") + .unwrap() + .parse::() + .unwrap(); + for inference_data in vec_data.iter() { + // inference + let inference_result = fs::read_to_string(inference_data.data_path()) + .unwrap() + .parse::() + .unwrap() + + model_inference; + let output = if inference_result > 25.0 { + Some(Value::Float(inference_result)) + } else { + None + }; + result.push( + NewSampleBuilder::default() + .data_path(inference_data.data_path().to_path_buf()) + .output(output) + .build() + .unwrap(), + ); + } + result + }; + + tokio::spawn({ + async move { + cml.inference( + batch_meta_1, + &available_status, + &mut batch_data_1, + &pool, + inference_fn, + ) + .await + .unwrap(); + cml.inference( + batch_meta_2, + &available_status, + &mut batch_data_2, + &pool, + inference_fn, + ) + .await + .unwrap(); + } + }) + .await?; + + let mut result = taos + .query("SELECT output FROM inference.inference ORDER BY output ASC") + .await?; + let records = result.to_records().await?; + assert_eq!( + vec![ + vec![Value::Null(Ty::Float)], + vec![Value::Float(28.8)], + vec![Value::Float(108.8)] + ], + records + ); + + fs::remove_dir_all("/tmp/inference_dir/")?; + taos.exec("DROP DATABASE IF EXISTS inference").await?; + taos.exec("DROP DATABASE IF EXISTS task").await?; Ok(()) } } diff --git a/cml-tdengine/src/core/register.rs b/cml-tdengine/src/core/register.rs index 857173f..1256a31 100644 --- a/cml-tdengine/src/core/register.rs +++ b/cml-tdengine/src/core/register.rs @@ -156,7 +156,10 @@ mod tests { #[tokio::test(flavor = "multi_thread")] async fn test_concurrent_register() -> Result<()> { - let taos = TaosBuilder::from_dsn("taos://")?.build().await?; + let cml = TDengine::from_dsn("taos://"); + let pool = cml.build_pool(); + let taos = pool.get().await?; + taos.exec("DROP DATABASE IF EXISTS training_data").await?; taos.exec( "CREATE DATABASE IF NOT EXISTS training_data @@ -169,8 +172,6 @@ mod tests { TAGS (model_update_time TIMESTAMP, fucking_tag_1 BINARY(255), fucking_tag_2 TINYINT)" ).await?; - let cml = TDengine::from_dsn("taos://"); - let pool = cml.build_pool(); let batch_state = BatchState::create(2); let model_update_time = (SystemTime::now() - Duration::from_secs(86400)) @@ -250,11 +251,11 @@ mod tests { }) .await?; - let mut result = taos - .query("SELECT COUNT(*) FROM training_data.training_data") - .await?; - let records = result.to_records().await?; - assert_eq!(vec![vec![Value::BigInt(4)]], records); + let records = taos + .query_one("SELECT COUNT(*) FROM training_data.training_data") + .await? + .unwrap_or(0); + assert_eq!(4, records); taos.exec("DROP DATABASE IF EXISTS training_data").await?; Ok(()) diff --git a/cml-tdengine/src/core/task.rs b/cml-tdengine/src/core/task.rs index 8d8fc99..f9f07fb 100644 --- a/cml-tdengine/src/core/task.rs +++ b/cml-tdengine/src/core/task.rs @@ -58,7 +58,7 @@ impl Task for TDengine { fining_build_fn: FN, ) -> Result<()> where - FN: Fn(&TaskConfig, &str) -> Result<()> + Send + Sync, + FN: Fn(&str) -> Result<()> + Send + Sync, { let taos = self.build_sync().unwrap(); @@ -96,7 +96,7 @@ impl Task for TDengine { } if !timeout_clause.is_empty() { - taos.exec("INSERT INTO".to_owned() + &timeout_clause.join(" "))?; + taos.exec("INSERT INTO ".to_owned() + &timeout_clause.join(" "))?; } batch_info.retain(|b| !batch_with_task.contains(&b.batch)); @@ -106,30 +106,26 @@ impl Task for TDengine { for batch in batch_info { match batch.model_update_time { Some(model_update_time) => { - if let Value::BigInt(count) = taos - .query(format!( + let count = taos + .query_one(format!( "SELECT COUNT(*) FROM training_data.`{}` WHERE ts > {}", batch.batch, model_update_time.timestamp_nanos() ))? - .to_rows_vec()?[0][0] - { - if count as usize > *task_config.min_update_count() { - scratch_in_queue.push(batch.batch); - } + .unwrap_or(0); + if count as usize > *task_config.min_update_count() { + scratch_in_queue.push(batch.batch); } } None => { - if let Value::BigInt(count) = taos - .query(format!( + let count = taos + .query_one(format!( "SELECT COUNT(*) FROM training_data.`{}`", batch.batch ))? - .to_rows_vec()?[0][0] - { - if count as usize > *task_config.min_start_count() { - fining_in_queue.push(batch.batch); - } + .unwrap_or(0); + if count as usize > *task_config.min_start_count() { + fining_in_queue.push(batch.batch); } } } @@ -139,13 +135,13 @@ impl Task for TDengine { || { scratch_in_queue .par_iter() - .map(|b| build_from_scratch_fn(&task_config, b).unwrap()) + .map(|b| build_from_scratch_fn(b).unwrap()) .collect::>() }, || { fining_in_queue .par_iter() - .map(|b| fining_build_fn(&task_config, b).unwrap()) + .map(|b| fining_build_fn(b).unwrap()) .collect::>() }, ); @@ -224,9 +220,7 @@ mod tests { let config: TaskConfig = TaskConfigBuilder::default() .min_start_count(1) .min_update_count(1) - .work_dir("/tmp/work_dir".into()) - .local_dir(Some("/tmp/local_dir".into())) - .working_status(vec!["TRAIN".to_string(), "EVAL".to_string()]) + .working_status(vec!["TRAIN", "EVAL"]) .limit_time(Duration::days(2)) .build()?; @@ -261,8 +255,28 @@ mod tests { VALUES (NOW, 'true', '/tmp/file_1.txt', 1.0), (NOW + 1s, 'false', '/tmp/file_2.txt', 2.0)", )?; + taos.exec( + "INSERT INTO task.`FUCK` + USING task.task + TAGS ('2022-08-08 18:18:18.518') + VALUES (NOW -3d, 'TRAIN')", + )?; + + taos.exec( + "INSERT INTO training_data.`FUCK8` + USING training_data.training_data + TAGS (null) + VALUES (NOW, 'true', '/tmp/file_1.txt', 1.0), + (NOW + 1s, 'false', '/tmp/file_2.txt', 2.0)", + )?; + taos.exec( + "INSERT INTO task.`FUCK8` + USING task.task + TAGS (null) + VALUES (NOW -3d, 'TRAIN')", + )?; - let build_fn = |c: &TaskConfig, b: &str| -> Result<()> { + let build_fn = |b: &str| -> Result<()> { type B = ADBackendDecorator>; B::seed(220225); @@ -440,8 +454,7 @@ mod tests { let config_optimizer = AdamConfig::new().with_weight_decay(Some(WeightDecayConfig::new(5e-5))); - let working_dir = c.work_dir().to_str().unwrap(); - + let working_dir = "/tmp/work_dir"; let learner = LearnerBuilder::new(working_dir) .with_file_checkpointer(1, CompactRecorder::new()) .devices(vec![device])