diff --git a/.gitignore b/.gitignore index 8d09ed7c..5d516766 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,5 @@ Thumbs.db /target # Environment -.env \ No newline at end of file +.env +oidc.toml \ No newline at end of file diff --git a/Makefile b/Makefile index be630b35..a1a1408a 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,7 @@ help: @echo "NetVisor Development Commands" @echo "" + @echo " make fresh-db - Clean and set up a new database" @echo " make setup-db - Set up database" @echo " make clean-db - Clean up database" @echo " make clean-daemon - Remove daemon config file" @@ -22,6 +23,10 @@ help: @echo " make install-dev-mac - Install development dependencies on macOS" @echo " make install-dev-linux - Install development dependencies on Linux" +fresh-db: + make clean-db + make setup-db + setup-db: @echo "Setting up PostgreSQL..." @docker run -d \ @@ -97,6 +102,9 @@ lint: @echo "Linting UI..." cd ui && npm run lint && npm run format -- --check && npm run check +stripe-webhook: + stripe listen --forward-to http://localhost:60072/api/billing/webhooks + clean: make clean-db docker compose down -v diff --git a/backend/Cargo.lock b/backend/Cargo.lock index 0c4e8677..0da53ca3 100644 --- a/backend/Cargo.lock +++ b/backend/Cargo.lock @@ -452,6 +452,17 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-client-ip" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f08a543641554404b42acd0d2494df12ca2be034d7b8ee4dbbf7446f940a2ef" +dependencies = [ + "axum", + "client-ip", + "serde", +] + [[package]] name = "axum-core" version = "0.5.5" @@ -506,6 +517,18 @@ dependencies = [ "syn 2.0.110", ] +[[package]] +name = "bad_email" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c4cec0ab6cbc30d146e8f4c3a2a09b0ee82ac02dfd9dc87d707324eb45db0e9" +dependencies = [ + "serde", + "serde_derive", + "serde_json", + "toml", +] + [[package]] name = "base16ct" version = "0.2.0" @@ -825,6 +848,15 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" +[[package]] +name = "client-ip" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31211fc26899744f5b22521fdc971e5f3875991d8880537537470685a0e9552d" +dependencies = [ + "http", +] + [[package]] name = "colorchoice" version = "1.0.4" @@ -1567,6 +1599,7 @@ dependencies = [ "pear", "serde", "serde_json", + "toml", "uncased", "version_check", ] @@ -2944,8 +2977,10 @@ dependencies = [ "async-stripe-webhook", "async-trait", "axum", + "axum-client-ip", "axum-extra", "axum-macros", + "bad_email", "base64ct", "bollard", "chrono", diff --git a/backend/Cargo.toml b/backend/Cargo.toml index c75fb547..43e4e72e 100644 --- a/backend/Cargo.toml +++ b/backend/Cargo.toml @@ -75,7 +75,7 @@ config = "0.14" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } dotenv = "0.15" -figment = { version = "0.10", features = ["json", "env"] } +figment = { version = "0.10", features = ["json", "env", "toml"] } # === CLI === clap = { version = "4.0", features = ["derive"] } @@ -158,6 +158,8 @@ serde_with = "3.15.1" lettre = { version = "0.11.19", default-features = false, features = ["smtp-transport", "builder", "tokio1", "tokio1-rustls", "ring", "webpki-roots"] } html2text = "0.16.4" json_value_merge = "2.0.1" +bad_email = "0.1.1" +axum-client-ip = "1.1.3" # === Platform-specific Dependencies === [target.'cfg(target_os = "linux")'.dependencies] diff --git a/backend/migrations/20251128035448_org-onboarding-status.sql b/backend/migrations/20251128035448_org-onboarding-status.sql new file mode 100644 index 00000000..e1a32d0d --- /dev/null +++ b/backend/migrations/20251128035448_org-onboarding-status.sql @@ -0,0 +1,10 @@ +-- Add migration script here +ALTER TABLE organizations ADD COLUMN onboarding JSONB DEFAULT '[]'; + +-- Set onboarding for existing organizations where is_onboarded is true +UPDATE organizations +SET onboarding = '["OrgCreated", "OnboardingModalCompleted"]'::JSONB +WHERE is_onboarded = true; + +-- Drop the old is_onboarded column +ALTER TABLE organizations DROP COLUMN is_onboarded; \ No newline at end of file diff --git a/backend/src/bin/server.rs b/backend/src/bin/server.rs index 2a1c7c3e..8bc91be8 100644 --- a/backend/src/bin/server.rs +++ b/backend/src/bin/server.rs @@ -1,14 +1,15 @@ -use std::{net::SocketAddr, sync::Arc, time::Duration}; +use std::{net::SocketAddr, str::FromStr, sync::Arc, time::Duration}; use axum::{ Extension, Router, http::{HeaderValue, Method}, }; +use axum_client_ip::ClientIpSource; use clap::Parser; use netvisor::server::{ auth::middleware::AuthenticatedEntity, billing::types::base::{BillingPlan, BillingRate, PlanConfig}, - config::{AppState, CliArgs, ServerConfig}, + config::{AppState, ServerCli, ServerConfig}, organizations::r#impl::base::{Organization, OrganizationBase}, shared::{ handlers::{cache::AppCache, factory::create_router}, @@ -27,124 +28,17 @@ use tower_http::{ }; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; -#[derive(Parser)] -#[command(name = "netvisor-server")] -#[command(about = "NetVisor server")] -struct Cli { - /// Override server port - #[arg(long)] - server_port: Option, - - /// Override log level - #[arg(long)] - log_level: Option, - - /// Override rust system log level - #[arg(long)] - rust_log: Option, - - /// Override database path - #[arg(long)] - database_url: Option, - - /// Override integrated daemon url - #[arg(long)] - integrated_daemon_url: Option, - - /// Use secure session cookies (if serving UI behind HTTPS) - #[arg(long)] - use_secure_session_cookies: Option, - - /// Enable or disable registration flow - #[arg(long)] - disable_registration: bool, - - /// OIDC client ID - #[arg(long)] - oidc_client_id: Option, - - /// OIDC client secret - #[arg(long)] - oidc_client_secret: Option, - - /// OIDC issuer url - #[arg(long)] - oidc_issuer_url: Option, - - /// OIDC issuer url - #[arg(long)] - oidc_provider_name: Option, - - /// OIDC redirect url - #[arg(long)] - oidc_redirect_url: Option, - - /// OIDC redirect url - #[arg(long)] - stripe_secret: Option, - - /// OIDC redirect url - #[arg(long)] - stripe_webhook_secret: Option, - - #[arg(long)] - smtp_username: Option, - - #[arg(long)] - smtp_password: Option, - - /// Email used as to/from in emails send by NetVisor - #[arg(long)] - smtp_email: Option, - - #[arg(long)] - smtp_relay: Option, - - #[arg(long)] - smtp_port: Option, - - /// Server URL used in features like password reset and invite links - #[arg(long)] - public_url: Option, -} - -impl From for CliArgs { - fn from(cli: Cli) -> Self { - Self { - server_port: cli.server_port, - log_level: cli.log_level, - rust_log: cli.rust_log, - database_url: cli.database_url, - integrated_daemon_url: cli.integrated_daemon_url, - use_secure_session_cookies: cli.use_secure_session_cookies, - disable_registration: cli.disable_registration, - oidc_client_id: cli.oidc_client_id, - oidc_client_secret: cli.oidc_client_secret, - oidc_issuer_url: cli.oidc_issuer_url, - oidc_provider_name: cli.oidc_provider_name, - oidc_redirect_url: cli.oidc_redirect_url, - stripe_secret: cli.stripe_secret, - stripe_webhook_secret: cli.stripe_webhook_secret, - smtp_email: cli.smtp_email, - smtp_password: cli.smtp_password, - smtp_relay: cli.smtp_relay, - smtp_username: cli.smtp_username, - public_url: cli.public_url, - } - } -} - #[tokio::main] async fn main() -> anyhow::Result<()> { let _ = dotenv::dotenv(); - let cli = Cli::parse(); - let cli_args = CliArgs::from(cli); + let cli = ServerCli::parse(); // Load configuration using figment - let config = ServerConfig::load(cli_args)?; + let config = ServerConfig::load(cli)?; let listen_addr = format!("0.0.0.0:{}", &config.server_port); let web_external_path = config.web_external_path.clone(); + let client_ip_source = config.client_ip_source.clone(); // Initialize tracing tracing_subscriber::registry() @@ -247,6 +141,10 @@ async fn main() -> anyhow::Result<()> { CorsLayer::permissive() }; + let client_ip_source = client_ip_source + .map(|s| ClientIpSource::from_str(&s)) + .unwrap_or(Ok(ClientIpSource::ConnectInfo))?; + let cache_headers = SetResponseHeaderLayer::if_not_present( header::CACHE_CONTROL, HeaderValue::from_static("no-store, no-cache, must-revalidate, private"), @@ -260,7 +158,8 @@ async fn main() -> anyhow::Result<()> { .layer(TraceLayer::new_for_http()) .layer(cors) .layer(Extension(app_cache)) - .layer(cache_headers), + .layer(cache_headers) + .layer(client_ip_source.into_extension()), ); let listener = tokio::net::TcpListener::bind(&listen_addr).await?; @@ -339,7 +238,7 @@ async fn main() -> anyhow::Result<()> { plan: None, plan_status: None, name: "My Organization".to_string(), - is_onboarded: false, + onboarding: vec![], }), AuthenticatedEntity::System, ) @@ -351,8 +250,6 @@ async fn main() -> anyhow::Result<()> { AuthenticatedEntity::System, ) .await?; - } else { - tracing::debug!("Server already has data, skipping seed data"); } tokio::signal::ctrl_c().await?; diff --git a/backend/src/daemon/runtime/service.rs b/backend/src/daemon/runtime/service.rs index dcf0105b..794970d2 100644 --- a/backend/src/daemon/runtime/service.rs +++ b/backend/src/daemon/runtime/service.rs @@ -75,11 +75,20 @@ impl DaemonRuntimeService { let error_msg = api_response .error .unwrap_or_else(|| "Unknown error".to_string()); - tracing::warn!( - daemon_id = %daemon_id, - err = %error_msg, - "Failed to check for work" - ); + + if error_msg.contains("not found") { + tracing::error!( + daemon_id = %daemon_id, + error = %error_msg, + "Failed to check for work - the Daemon ID present in the config on this host could not be found on the server. Please remove the config and install a new daemon." + ); + } else { + tracing::error!( + daemon_id = %daemon_id, + err = %error_msg, + "Failed to check for work" + ); + } } else if let Some((payload, cancel_current_session)) = api_response.data { if !cancel_current_session && payload.is_none() { tracing::info!( @@ -162,10 +171,20 @@ impl DaemonRuntimeService { let error_msg = api_response .error .unwrap_or_else(|| "Unknown error".to_string()); - tracing::error!( - error = %error_msg, - "Heartbeat failed - check network connectivity" - ); + + if error_msg.contains("not found") { + tracing::error!( + error = %error_msg, + daemon_id = %daemon_id, + "Heartbeat failed - the Daemon ID present in the config on this host could not be found on the server. Please remove the config and install a new daemon." + ); + } else { + tracing::error!( + error = %error_msg, + daemon_id = %daemon_id, + "Heartbeat failed - check network connectivity" + ); + } } if let Err(e) = self.config_store.update_heartbeat().await { diff --git a/backend/src/server/api_keys/handlers.rs b/backend/src/server/api_keys/handlers.rs index b80f9a79..2bcc32ba 100644 --- a/backend/src/server/api_keys/handlers.rs +++ b/backend/src/server/api_keys/handlers.rs @@ -3,10 +3,11 @@ use crate::server::{ auth::middleware::RequireMember, config::AppState, shared::{ + events::types::{TelemetryEvent, TelemetryOperation}, handlers::traits::{ CrudHandlers, bulk_delete_handler, delete_handler, get_all_handler, get_by_id_handler, }, - services::traits::CrudService, + services::traits::{CrudService, EventBusService}, types::api::{ApiError, ApiResponse, ApiResult}, }, }; @@ -15,6 +16,9 @@ use axum::{ extract::{Path, State}, routing::{delete, get, post, put}, }; +use axum_client_ip::ClientIp; +use axum_extra::{TypedHeader, headers::UserAgent}; +use chrono::Utc; use std::sync::Arc; use uuid::Uuid; @@ -32,7 +36,7 @@ pub fn create_router() -> Router> { pub async fn create_handler( State(state): State>, RequireMember(user): RequireMember, - Json(api_key): Json, + Json(mut api_key): Json, ) -> ApiResult>> { tracing::debug!( api_key_name = %api_key.base.name, @@ -42,6 +46,7 @@ pub async fn create_handler( ); let service = ApiKey::get_service(&state); + api_key.base.key = service.generate_api_key(); let api_key = service .create(api_key, user.clone().into()) .await @@ -54,6 +59,30 @@ pub async fn create_handler( ApiError::internal_error(&e.to_string()) })?; + let organization = state + .services + .organization_service + .get_by_id(&user.organization_id) + .await?; + + if let Some(organization) = organization + && organization.not_onboarded(&TelemetryOperation::FirstApiKeyCreated) + { + service + .event_bus() + .publish_telemetry(TelemetryEvent { + id: Uuid::new_v4(), + authentication: user.clone().into(), + organization_id: user.organization_id, + operation: TelemetryOperation::FirstApiKeyCreated, + timestamp: Utc::now(), + metadata: serde_json::json!({ + "is_onboarding_step": true + }), + }) + .await?; + } + Ok(Json(ApiResponse::success(ApiKeyResponse { key: api_key.base.key.clone(), api_key, @@ -63,17 +92,15 @@ pub async fn create_handler( pub async fn rotate_key_handler( State(state): State>, RequireMember(user): RequireMember, + ClientIp(ip): ClientIp, + user_agent: Option>, Path(api_key_id): Path, ) -> ApiResult>> { - tracing::debug!( - api_key_id = %api_key_id, - user_id = %user.user_id, - "API key rotation request received" - ); + let user_agent = user_agent.map(|u| u.to_string()); let service = ApiKey::get_service(&state); let key = service - .rotate_key(api_key_id, user.clone().into()) + .rotate_key(api_key_id, ip, user_agent, user.clone()) .await .map_err(|e| { tracing::error!( diff --git a/backend/src/server/api_keys/impl/base.rs b/backend/src/server/api_keys/impl/base.rs index 0fa01868..ec63ea31 100644 --- a/backend/src/server/api_keys/impl/base.rs +++ b/backend/src/server/api_keys/impl/base.rs @@ -33,6 +33,16 @@ pub struct ApiKey { pub base: ApiKeyBase, } +impl ApiKey { + pub fn suppress_logs(&self, other: &Self) -> bool { + self.base.key == other.base.key + && self.base.name == other.base.name + && self.base.expires_at == other.base.expires_at + && self.base.network_id == other.base.network_id + && self.base.is_enabled == other.base.is_enabled + } +} + impl Display for ApiKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}: {}", self.base.name, self.id) diff --git a/backend/src/server/api_keys/service.rs b/backend/src/server/api_keys/service.rs index d8fc7518..31226971 100644 --- a/backend/src/server/api_keys/service.rs +++ b/backend/src/server/api_keys/service.rs @@ -1,17 +1,18 @@ use anyhow::{Result, anyhow}; use async_trait::async_trait; use chrono::Utc; +use std::net::IpAddr; use std::sync::Arc; use uuid::Uuid; use crate::server::{ - api_keys::r#impl::base::{ApiKey, ApiKeyBase}, - auth::middleware::AuthenticatedEntity, + api_keys::r#impl::base::ApiKey, + auth::middleware::{AuthenticatedEntity, AuthenticatedUser}, shared::{ entities::ChangeTriggersTopologyStaleness, events::{ bus::EventBus, - types::{EntityEvent, EntityOperation}, + types::{AuthEvent, AuthOperation, EntityEvent, EntityOperation}, }, services::traits::{CrudService, EventBusService}, storage::{ @@ -45,44 +46,39 @@ impl CrudService for ApiKeyService { &self.storage } - async fn create(&self, api_key: ApiKey, authentication: AuthenticatedEntity) -> Result { - let key = self.generate_api_key(); - - tracing::debug!( - api_key_name = %api_key.base.name, - network_id = %api_key.base.network_id, - "Creating API key" - ); - - let api_key = ApiKey::new(ApiKeyBase { - key: key.clone(), - name: api_key.base.name, - last_used: None, - expires_at: api_key.base.expires_at, - network_id: api_key.base.network_id, - is_enabled: true, - }); + /// Update entity + async fn update( + &self, + entity: &mut ApiKey, + authentication: AuthenticatedEntity, + ) -> Result { + let current = self + .get_by_id(&entity.id()) + .await? + .ok_or_else(|| anyhow!("Could not find {}", entity))?; + let updated = self.storage().update(entity).await?; - let created = self.storage.create(&api_key).await?; - let trigger_stale = created.triggers_staleness(None); + let suppress_logs = updated.suppress_logs(¤t); + let trigger_stale = updated.triggers_staleness(Some(current)); self.event_bus() .publish_entity(EntityEvent { id: Uuid::new_v4(), - entity_type: created.clone().into(), - entity_id: created.id(), - network_id: self.get_network_id(&created), - organization_id: self.get_organization_id(&created), - operation: EntityOperation::Created, + entity_id: updated.id(), + network_id: self.get_network_id(&updated), + organization_id: self.get_organization_id(&updated), + entity_type: updated.clone().into(), + operation: EntityOperation::Updated, timestamp: Utc::now(), metadata: serde_json::json!({ - "trigger_stale": trigger_stale + "trigger_stale": trigger_stale, + "suppress_logs": suppress_logs }), authentication, }) .await?; - Ok(created) + Ok(updated) } } @@ -98,19 +94,30 @@ impl ApiKeyService { pub async fn rotate_key( &self, api_key_id: Uuid, - authentication: AuthenticatedEntity, + ip_address: IpAddr, + user_agent: Option, + user: AuthenticatedUser, ) -> Result { - tracing::info!( - api_key_id = %api_key_id, - "Rotating API key" - ); - if let Some(mut api_key) = self.get_by_id(&api_key_id).await? { let new_key = self.generate_api_key(); api_key.base.key = new_key.clone(); - let _updated = self.update(&mut api_key, authentication).await?; + self.event_bus + .publish_auth(AuthEvent { + id: Uuid::new_v4(), + user_id: Some(user.user_id), + organization_id: Some(user.organization_id), + operation: AuthOperation::RotateKey, + timestamp: Utc::now(), + ip_address, + user_agent, + metadata: serde_json::json!({}), + authentication: user.clone().into(), + }) + .await?; + + let _updated = self.update(&mut api_key, user.into()).await?; Ok(new_key) } else { diff --git a/backend/src/server/auth/handlers.rs b/backend/src/server/auth/handlers.rs index 250d2fea..199f61c6 100644 --- a/backend/src/server/auth/handlers.rs +++ b/backend/src/server/auth/handlers.rs @@ -7,9 +7,10 @@ use crate::server::{ RegisterRequest, ResetPasswordRequest, UpdateEmailPasswordRequest, }, base::LoginRegisterParams, - oidc::OidcPendingAuth, + oidc::{OidcFlow, OidcPendingAuth, OidcProviderMetadata}, }, middleware::AuthenticatedUser, + oidc::OidcService, }, config::AppState, organizations::handlers::process_pending_invite, @@ -21,12 +22,14 @@ use crate::server::{ }; use axum::{ Router, - extract::{ConnectInfo, Query, State}, + extract::{Path, Query, State}, response::{Json, Redirect}, routing::{get, post}, }; +use axum_client_ip::ClientIp; use axum_extra::{TypedHeader, headers::UserAgent}; -use std::{net::SocketAddr, sync::Arc}; +use bad_email::is_email_unwanted; +use std::{net::IpAddr, sync::Arc}; use tower_sessions::Session; use url::Url; use uuid::Uuid; @@ -39,16 +42,17 @@ pub fn create_router() -> Router> { .route("/me", post(get_current_user)) .nest("/keys", api_keys::handlers::create_router()) .route("/update", post(update_password_auth)) - .route("/oidc/authorize", get(oidc_authorize)) - .route("/oidc/callback", get(oidc_callback)) - .route("/oidc/unlink", post(unlink_oidc_account)) + .route("/oidc/providers", get(list_oidc_providers)) + .route("/oidc/{slug}/authorize", get(oidc_authorize)) + .route("/oidc/{slug}/callback", get(oidc_callback)) + .route("/oidc/{slug}/unlink", post(unlink_oidc_account)) .route("/forgot-password", post(forgot_password)) .route("/reset-password", post(reset_password)) } async fn register( State(state): State>, - ConnectInfo(addr): ConnectInfo, + ClientIp(ip): ClientIp, user_agent: Option>, session: Session, Json(request): Json, @@ -57,7 +61,14 @@ async fn register( return Err(ApiError::forbidden("User registration is disabled")); } - let ip = addr.ip(); + if is_email_unwanted(request.email.as_str()) { + return Err(ApiError::conflict( + "Email address uses a disposable domain. Please register with a non-disposable email address.", + )); + } + + let subscribed = request.subscribed; + let user_agent = user_agent.map(|u| u.to_string()); let (org_id, permissions, network_ids) = match process_pending_invite(&state, &session).await { @@ -84,6 +95,7 @@ async fn register( ip, user_agent, network_ids, + subscribed, }, ) .await?; @@ -98,12 +110,11 @@ async fn register( async fn login( State(state): State>, - ConnectInfo(addr): ConnectInfo, + ClientIp(ip): ClientIp, user_agent: Option>, session: Session, Json(request): Json, ) -> ApiResult>> { - let ip = addr.ip(); let user_agent = user_agent.map(|u| u.to_string()); let user = state @@ -122,12 +133,11 @@ async fn login( async fn logout( State(state): State>, - ConnectInfo(addr): ConnectInfo, + ClientIp(ip): ClientIp, user_agent: Option>, session: Session, ) -> ApiResult>> { if let Ok(Some(user_id)) = session.get::("user_id").await { - let ip = addr.ip(); let user_agent = user_agent.map(|u| u.to_string()); state @@ -168,7 +178,7 @@ async fn get_current_user( async fn update_password_auth( State(state): State>, session: Session, - ConnectInfo(addr): ConnectInfo, + ClientIp(ip): ClientIp, user_agent: Option>, auth_user: AuthenticatedUser, Json(request): Json, @@ -179,7 +189,6 @@ async fn update_password_auth( .map_err(|e| ApiError::internal_error(&format!("Failed to read session: {}", e)))? .ok_or_else(|| ApiError::unauthorized("Not authenticated".to_string()))?; - let ip = addr.ip(); let user_agent = user_agent.map(|u| u.to_string()); let user = state @@ -200,11 +209,10 @@ async fn update_password_auth( async fn forgot_password( State(state): State>, - ConnectInfo(addr): ConnectInfo, + ClientIp(ip): ClientIp, user_agent: Option>, Json(request): Json, ) -> ApiResult>> { - let ip = addr.ip(); let user_agent = user_agent.map(|u| u.to_string()); state @@ -223,12 +231,11 @@ async fn forgot_password( async fn reset_password( State(state): State>, - ConnectInfo(addr): ConnectInfo, + ClientIp(ip): ClientIp, user_agent: Option>, session: Session, Json(request): Json, ) -> ApiResult>> { - let ip = addr.ip(); let user_agent = user_agent.map(|u| u.to_string()); let user = state @@ -245,8 +252,21 @@ async fn reset_password( Ok(Json(ApiResponse::success(user))) } +async fn list_oidc_providers( + State(state): State>, +) -> ApiResult>>> { + let oidc_service = state + .services + .oidc_service + .as_ref() + .ok_or_else(|| ApiError::internal_error("OIDC not configured"))?; + + Ok(Json(ApiResponse::success(oidc_service.list_providers()))) +} + async fn oidc_authorize( State(state): State>, + Path(slug): Path, session: Session, Query(params): Query, ) -> ApiResult { @@ -256,8 +276,37 @@ async fn oidc_authorize( .as_ref() .ok_or_else(|| ApiError::internal_error("OIDC not configured"))?; - let (auth_url, pending_auth) = oidc_service - .authorize_url() + // Verify provider exists + let provider = oidc_service + .get_provider(&slug) + .ok_or_else(|| ApiError::not_found(format!("OIDC provider '{}' not found", slug)))?; + + // Parse and validate flow parameter + let flow = match params.flow.as_deref() { + Some("login") => OidcFlow::Login, + Some("register") => OidcFlow::Register, + Some("link") => OidcFlow::Link, + Some(other) => { + return Err(ApiError::bad_request(&format!( + "Invalid flow '{}'. Must be 'login', 'register', or 'link'", + other + ))); + } + None => { + return Err(ApiError::bad_request( + "flow parameter is required (login, register, or link)", + )); + } + }; + + // Validate return_url is present + let return_url = params + .return_url + .ok_or_else(|| ApiError::bad_request("return_url parameter is required"))?; + + // Generate authorization URL using provider + let (auth_url, pending_auth) = provider + .authorize_url(flow) .await .map_err(|e| ApiError::internal_error(&format!("Failed to generate auth URL: {}", e)))?; @@ -265,34 +314,40 @@ async fn oidc_authorize( session .insert("oidc_pending_auth", pending_auth) .await - .map_err(|e| ApiError::internal_error(&format!("Failed to save session: {}", e)))?; + .map_err(|e| ApiError::internal_error(&format!("Failed to save pending auth: {}", e)))?; + session - .insert("oidc_is_linking", params.link.unwrap_or(false)) + .insert("oidc_provider_slug", slug) .await - .map_err(|e| ApiError::internal_error(&format!("Failed to save session: {}", e)))?; + .map_err(|e| ApiError::internal_error(&format!("Failed to save provider slug: {}", e)))?; + session - .insert( - "oidc_return_url", - params - .return_url - .ok_or_else(|| ApiError::bad_request("return_url parameter is required"))?, - ) + .insert("oidc_return_url", return_url) .await - .map_err(|e| ApiError::internal_error(&format!("Failed to save session: {}", e)))?; + .map_err(|e| ApiError::internal_error(&format!("Failed to save return URL: {}", e)))?; + + // Store subscribed flag if present + if let Some(subscribed) = params.subscribed { + session + .insert("oidc_subscribed", subscribed) + .await + .map_err(|e| ApiError::internal_error(&format!("Failed to save subscribed: {}", e)))?; + } Ok(Redirect::to(&auth_url)) } async fn oidc_callback( State(state): State>, + Path(slug): Path, session: Session, - ConnectInfo(addr): ConnectInfo, + ClientIp(ip): ClientIp, user_agent: Option>, Query(params): Query, ) -> Result { - let ip = addr.ip(); let user_agent = user_agent.map(|u| u.to_string()); + // Verify OIDC is configured let oidc_service = match state.services.oidc_service.as_ref() { Some(service) => service, None => { @@ -303,7 +358,15 @@ async fn oidc_callback( } }; - // Extract session data + // Verify provider exists + if oidc_service.get_provider(&slug).is_none() { + return Err(Redirect::to(&format!( + "/error?message={}", + urlencoding::encode(&format!("OIDC provider '{}' not found", slug)) + ))); + } + + // Extract and validate session data let return_url: String = session .get("oidc_return_url") .await @@ -312,7 +375,7 @@ async fn oidc_callback( .ok_or_else(|| { Redirect::to(&format!( "/error?message={}", - urlencoding::encode("Session error: Unable to determine return URL") + urlencoding::encode("Session error: No return URL found") )) })?; @@ -329,6 +392,28 @@ async fn oidc_callback( )) })?; + let session_slug: String = session + .get("oidc_provider_slug") + .await + .ok() + .flatten() + .ok_or_else(|| { + Redirect::to(&format!( + "{}?error={}", + return_url, + urlencoding::encode("Session error: No provider slug found") + )) + })?; + + // Verify provider slug matches + if session_slug != slug { + return Err(Redirect::to(&format!( + "{}?error={}", + return_url, + urlencoding::encode("Provider mismatch in callback") + ))); + } + // Verify CSRF token if pending_auth.csrf_token != params.state { return Err(Redirect::to(&format!( @@ -338,122 +423,269 @@ async fn oidc_callback( ))); } - let is_linking: bool = session - .get("oidc_is_linking") + // Get subscribed flag from session + let subscribed: bool = session + .get("oidc_subscribed") .await .ok() .flatten() .unwrap_or(false); - let mut return_url_parsed = Url::parse(&return_url).map_err(|_| { + + // Parse return URL for error handling + let return_url_parsed = Url::parse(&return_url).map_err(|_| { Redirect::to(&format!( "/error?message={}", urlencoding::encode("Invalid return URL") )) })?; - if is_linking { - // LINK FLOW - return_url_parsed - .query_pairs_mut() - .append_pair("auth_modal", "true"); - - let user_id: Uuid = session.get("user_id").await.ok().flatten().ok_or_else(|| { - let mut url = return_url_parsed.clone(); - url.query_pairs_mut() - .append_pair("error", "You must be logged in to link an OIDC account."); - Redirect::to(url.as_str()) - })?; - - match oidc_service - .link_to_user(&user_id, ¶ms.code, pending_auth, ip, user_agent) + // Handle different flows + match pending_auth.flow { + OidcFlow::Link => { + handle_link_flow(HandleLinkFlowParams { + oidc_service, + slug: &slug, + code: ¶ms.code, + pending_auth, + ip, + user_agent, + session, + return_url: return_url_parsed, + }) .await - { - Ok(_) => { - // Clear session data - let _ = session.remove::("oidc_pending_auth").await; - let _ = session.remove::("oidc_is_linking").await; - let _ = session.remove::("oidc_return_url").await; - - Ok(Redirect::to(return_url_parsed.as_str())) - } - Err(e) => { - tracing::error!("Failed to link OIDC: {}", e); - let _ = session.remove::("oidc_pending_auth").await; - let _ = session.remove::("oidc_is_linking").await; - let _ = session.remove::("oidc_return_url").await; - - return_url_parsed - .query_pairs_mut() - .append_pair("error", &format!("Failed to link OIDC account: {}", e)); - Err(Redirect::to(return_url_parsed.as_str())) - } } - } else { - let (org_id, permissions, network_ids) = - match process_pending_invite(&state, &session).await { - Ok(Some((org_id, permissions, network_ids))) => { - (Some(org_id), Some(permissions), network_ids) - } - Ok(_) => (None, None, vec![]), - Err(e) => { - return Err(Redirect::to(&format!( - "{}?error={}", - return_url, - urlencoding::encode(&format!("Failed to process invite: {}", e)) - ))); - } - }; - - match oidc_service - .login_or_register( - ¶ms.code, + OidcFlow::Login => { + handle_login_flow(HandleLinkFlowParams { + oidc_service, + slug: &slug, + code: ¶ms.code, pending_auth, - LoginRegisterParams { - org_id, - permissions, + ip, + user_agent, + session, + return_url: return_url_parsed, + }) + .await + } + OidcFlow::Register => { + handle_register_flow( + state.clone(), + subscribed, + HandleLinkFlowParams { + oidc_service, + slug: &slug, + code: ¶ms.code, + pending_auth, ip, user_agent, - network_ids, + session, + return_url: return_url_parsed, }, ) .await - { - Ok(user) => { - if let Err(e) = session.insert("user_id", user.id).await { - tracing::error!("Failed to save session: {}", e); - return Err(Redirect::to(&format!( - "{}?error={}", - return_url, - urlencoding::encode(&format!("Failed to create session: {}", e)) - ))); - } - - // Clear session data - let _ = session.remove::("oidc_pending_auth").await; - let _ = session.remove::("oidc_is_linking").await; - let _ = session.remove::("oidc_return_url").await; - - Ok(Redirect::to(&return_url)) + } + } +} + +struct HandleLinkFlowParams<'a> { + oidc_service: &'a OidcService, + slug: &'a str, + code: &'a str, + pending_auth: OidcPendingAuth, + ip: IpAddr, + user_agent: Option, + session: Session, + return_url: Url, +} + +async fn handle_link_flow(params: HandleLinkFlowParams<'_>) -> Result { + let HandleLinkFlowParams { + oidc_service, + slug, + code, + pending_auth, + ip, + user_agent, + session, + mut return_url, + } = params; + + // Add auth_modal query param to return URL + return_url + .query_pairs_mut() + .append_pair("auth_modal", "true"); + + // Verify user is logged in + let user_id: Uuid = session.get("user_id").await.ok().flatten().ok_or_else(|| { + let mut url = return_url.clone(); + url.query_pairs_mut() + .append_pair("error", "You must be logged in to link an OIDC account."); + Redirect::to(url.as_str()) + })?; + + // Link OIDC account to user + match oidc_service + .link_to_user(slug, &user_id, code, pending_auth, ip, user_agent) + .await + { + Ok(_) => { + // Clear session data + let _ = session.remove::("oidc_pending_auth").await; + let _ = session.remove::("oidc_provider_slug").await; + let _ = session.remove::("oidc_return_url").await; + let _ = session.remove::("oidc_subscribed").await; + + Ok(Redirect::to(return_url.as_str())) + } + Err(e) => { + tracing::error!("Failed to link OIDC: {}", e); + + // Clear session data + let _ = session.remove::("oidc_pending_auth").await; + let _ = session.remove::("oidc_provider_slug").await; + let _ = session.remove::("oidc_return_url").await; + let _ = session.remove::("oidc_subscribed").await; + + return_url + .query_pairs_mut() + .append_pair("error", &format!("Failed to link OIDC account: {}", e)); + Err(Redirect::to(return_url.as_str())) + } + } +} + +async fn handle_login_flow(params: HandleLinkFlowParams<'_>) -> Result { + let HandleLinkFlowParams { + oidc_service, + slug, + code, + pending_auth, + ip, + user_agent, + session, + return_url, + } = params; + + // Login user + match oidc_service + .login(slug, code, pending_auth, ip, user_agent) + .await + { + Ok(user) => { + // Save user_id to session + if let Err(e) = session.insert("user_id", user.id).await { + tracing::error!("Failed to save session: {}", e); + return Err(Redirect::to(&format!( + "{}?error={}", + return_url, + urlencoding::encode(&format!("Failed to create session: {}", e)) + ))); } - Err(e) => { - tracing::error!("Failed to login/register via OIDC: {}", e); - Err(Redirect::to(&format!( + + // Clear OIDC session data + let _ = session.remove::("oidc_pending_auth").await; + let _ = session.remove::("oidc_provider_slug").await; + let _ = session.remove::("oidc_return_url").await; + let _ = session.remove::("oidc_subscribed").await; + + Ok(Redirect::to(return_url.as_str())) + } + Err(e) => { + tracing::error!("Failed to login via OIDC: {}", e); + Err(Redirect::to(&format!( + "{}?error={}", + return_url, + urlencoding::encode(&format!("Failed to login: {}", e)) + ))) + } + } +} + +async fn handle_register_flow( + state: Arc, + subscribed: bool, + params: HandleLinkFlowParams<'_>, +) -> Result { + let HandleLinkFlowParams { + oidc_service, + slug, + code, + pending_auth, + ip, + user_agent, + session, + return_url, + } = params; + + // Process pending invite if present + let (org_id, permissions, network_ids) = match process_pending_invite(&state, &session).await { + Ok(Some((org_id, permissions, network_ids))) => { + (Some(org_id), Some(permissions), network_ids) + } + Ok(_) => (None, None, vec![]), + Err(e) => { + return Err(Redirect::to(&format!( + "{}?error={}", + return_url, + urlencoding::encode(&format!("Failed to process invite: {}", e)) + ))); + } + }; + + // Register user + match oidc_service + .register( + slug, + code, + pending_auth, + LoginRegisterParams { + org_id, + permissions, + ip, + user_agent, + network_ids, + subscribed, + }, + ) + .await + { + Ok(user) => { + // Save user_id to session + if let Err(e) = session.insert("user_id", user.id).await { + tracing::error!("Failed to save session: {}", e); + return Err(Redirect::to(&format!( "{}?error={}", return_url, - urlencoding::encode(&format!("Failed to authenticate: {}", e)) - ))) + urlencoding::encode(&format!("Failed to create session: {}", e)) + ))); } + + // Clear OIDC session data + let _ = session.remove::("oidc_pending_auth").await; + let _ = session.remove::("oidc_provider_slug").await; + let _ = session.remove::("oidc_return_url").await; + let _ = session.remove::("oidc_subscribed").await; + + Ok(Redirect::to(return_url.as_str())) + } + Err(e) => { + tracing::error!("Failed to register via OIDC: {}", e); + Err(Redirect::to(&format!( + "{}?error={}", + return_url, + urlencoding::encode(&format!("Failed to register: {}", e)) + ))) } } } async fn unlink_oidc_account( State(state): State>, + Path(slug): Path, session: Session, - ConnectInfo(addr): ConnectInfo, + ClientIp(ip): ClientIp, user_agent: Option>, ) -> ApiResult>> { - let ip = addr.ip(); let user_agent = user_agent.map(|u| u.to_string()); let oidc_service = state @@ -462,14 +694,24 @@ async fn unlink_oidc_account( .as_ref() .ok_or_else(|| ApiError::internal_error("OIDC not configured"))?; + // Verify provider exists + if oidc_service.get_provider(&slug).is_none() { + return Err(ApiError::not_found(format!( + "OIDC provider '{}' not found", + slug + ))); + } + + // Get user_id from session let user_id: Uuid = session .get("user_id") .await .map_err(|e| ApiError::internal_error(&format!("Failed to read session: {}", e)))? .ok_or_else(|| ApiError::unauthorized("Not authenticated".to_string()))?; + // Unlink OIDC account let updated_user = oidc_service - .unlink_from_user(&user_id, ip, user_agent) + .unlink_from_user(&slug, &user_id, ip, user_agent) .await .map_err(|e| ApiError::internal_error(&format!("Failed to unlink OIDC: {}", e)))?; diff --git a/backend/src/server/auth/impl/api.rs b/backend/src/server/auth/impl/api.rs index f45c1b3a..80e1a379 100644 --- a/backend/src/server/auth/impl/api.rs +++ b/backend/src/server/auth/impl/api.rs @@ -22,6 +22,7 @@ pub struct RegisterRequest { #[validate(length(min = 12, message = "Password must be at least 12 characters"))] #[validate(custom(function = "validate_password_complexity"))] pub password: String, + pub subscribed: bool, } /// Validate password complexity requirements @@ -61,11 +62,11 @@ pub struct UpdateEmailPasswordRequest { pub email: Option, } -// Query params for authorize #[derive(Debug, Deserialize)] pub struct OidcAuthorizeParams { - pub link: Option, + pub flow: Option, // "login", "register", or "link" pub return_url: Option, + pub subscribed: Option, } #[derive(Debug, Deserialize)] diff --git a/backend/src/server/auth/impl/base.rs b/backend/src/server/auth/impl/base.rs index 4c04f4d3..d4f750c8 100644 --- a/backend/src/server/auth/impl/base.rs +++ b/backend/src/server/auth/impl/base.rs @@ -9,6 +9,7 @@ pub struct LoginRegisterParams { pub ip: IpAddr, pub user_agent: Option, pub network_ids: Vec, + pub subscribed: bool, } pub struct ProvisionUserParams { @@ -19,4 +20,5 @@ pub struct ProvisionUserParams { pub org_id: Option, pub permissions: Option, pub network_ids: Vec, + pub subscribed: bool, } diff --git a/backend/src/server/auth/impl/oidc.rs b/backend/src/server/auth/impl/oidc.rs index 1cbe9826..03efe787 100644 --- a/backend/src/server/auth/impl/oidc.rs +++ b/backend/src/server/auth/impl/oidc.rs @@ -1,15 +1,181 @@ +use anyhow::Result; +use openidconnect::{ + AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, + PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, + core::{CoreClient, CoreProviderMetadata, CoreResponseType}, + reqwest::Client as ReqwestClient, +}; use serde::{Deserialize, Serialize}; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OidcPendingAuth { + pub pkce_verifier: String, + pub nonce: String, + pub csrf_token: String, + pub flow: OidcFlow, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum OidcFlow { + Login, + Register, + Link, +} + +#[derive(Debug, Clone)] pub struct OidcUserInfo { pub subject: String, pub email: Option, pub name: Option, } -#[derive(Debug, Serialize, Deserialize)] -pub struct OidcPendingAuth { - pub pkce_verifier: String, - pub nonce: String, - pub csrf_token: String, +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct OidcProviderConfig { + pub name: String, + pub slug: String, + pub logo: Option, + pub issuer_url: String, + pub client_id: String, + pub client_secret: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OidcProviderMetadata { + pub name: String, + pub slug: String, + pub logo: Option, +} + +impl OidcProviderConfig { + pub fn to_metadata(&self) -> OidcProviderMetadata { + OidcProviderMetadata { + name: self.name.clone(), + slug: self.slug.clone(), + logo: self.logo.clone(), + } + } +} + +/// Individual OIDC provider - just handles protocol operations +pub struct OidcProvider { + pub slug: String, + pub name: String, + pub logo: Option, + issuer_url: String, + client_id: String, + client_secret: String, + redirect_url: String, +} + +impl OidcProvider { + pub fn new( + slug: String, + name: String, + logo: Option, + issuer_url: String, + client_id: String, + client_secret: String, + redirect_url: String, + ) -> Self { + Self { + slug, + name, + logo, + issuer_url, + client_id, + client_secret, + redirect_url, + } + } + + /// Generate authorization URL for user to visit + pub async fn authorize_url(&self, flow: OidcFlow) -> Result<(String, OidcPendingAuth)> { + let http_client = ReqwestClient::builder() + .redirect(reqwest::redirect::Policy::none()) + .build()?; + + let provider_metadata = CoreProviderMetadata::discover_async( + IssuerUrl::new(self.issuer_url.clone())?, + &http_client, + ) + .await?; + + let client = CoreClient::from_provider_metadata( + provider_metadata, + ClientId::new(self.client_id.clone()), + Some(ClientSecret::new(self.client_secret.clone())), + ) + .set_redirect_uri(RedirectUrl::new(self.redirect_url.clone())?); + + let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); + + let (auth_url, csrf_token, nonce) = client + .authorize_url( + AuthenticationFlow::::AuthorizationCode, + CsrfToken::new_random, + Nonce::new_random, + ) + .add_scope(Scope::new("openid".to_string())) + .add_scope(Scope::new("email".to_string())) + .add_scope(Scope::new("profile".to_string())) + .set_pkce_challenge(pkce_challenge) + .url(); + + let pending_auth = OidcPendingAuth { + pkce_verifier: pkce_verifier.secret().clone(), + nonce: nonce.secret().clone(), + csrf_token: csrf_token.secret().clone(), + flow, + }; + + Ok((auth_url.to_string(), pending_auth)) + } + + /// Exchange authorization code for user info + pub async fn exchange_code( + &self, + code: &str, + pending_auth: &OidcPendingAuth, + ) -> Result { + let http_client = ReqwestClient::builder() + .redirect(reqwest::redirect::Policy::none()) + .build()?; + + let provider_metadata = CoreProviderMetadata::discover_async( + IssuerUrl::new(self.issuer_url.clone())?, + &http_client, + ) + .await?; + + let client = CoreClient::from_provider_metadata( + provider_metadata, + ClientId::new(self.client_id.clone()), + Some(ClientSecret::new(self.client_secret.clone())), + ) + .set_redirect_uri(RedirectUrl::new(self.redirect_url.clone())?); + + let pkce_verifier = PkceCodeVerifier::new(pending_auth.pkce_verifier.clone()); + let nonce = Nonce::new(pending_auth.nonce.clone()); + + let token_response = client + .exchange_code(AuthorizationCode::new(code.to_string()))? + .set_pkce_verifier(pkce_verifier) + .request_async(&http_client) + .await?; + + let id_token = token_response + .id_token() + .ok_or_else(|| anyhow::anyhow!("No ID token in response"))?; + + let claims = id_token.claims(&client.id_token_verifier(), &nonce)?; + + Ok(OidcUserInfo { + subject: claims.subject().to_string(), + email: claims.email().map(|e| e.to_string()), + name: claims + .name() + .and_then(|n| n.get(None).map(|s| s.to_string())), + }) + } } diff --git a/backend/src/server/auth/middleware.rs b/backend/src/server/auth/middleware.rs index 3d7fb68b..43c9deda 100644 --- a/backend/src/server/auth/middleware.rs +++ b/backend/src/server/auth/middleware.rs @@ -14,6 +14,7 @@ use axum::{ response::{IntoResponse, Response}, }; use chrono::Utc; +use email_address::EmailAddress; use serde::Deserialize; use serde::Serialize; use tower_sessions::Session; @@ -35,6 +36,7 @@ pub enum AuthenticatedEntity { organization_id: Uuid, permissions: UserOrgPermissions, network_ids: Vec, + email: EmailAddress, }, Daemon { network_id: Uuid, @@ -50,7 +52,15 @@ impl Display for AuthenticatedEntity { AuthenticatedEntity::Anonymous => write!(f, "Anonymous"), AuthenticatedEntity::System => write!(f, "System"), AuthenticatedEntity::Daemon { .. } => write!(f, "Daemon"), - AuthenticatedEntity::User { .. } => write!(f, "User"), + AuthenticatedEntity::User { + user_id, + permissions, + .. + } => write!( + f, + "User {{ user_id: {}, permissions: {} }}", + user_id, permissions + ), } } } @@ -107,6 +117,7 @@ impl From for AuthenticatedEntity { organization_id: value.base.organization_id, permissions: value.base.permissions, network_ids: vec![], + email: value.base.email, } } } @@ -221,6 +232,7 @@ where organization_id: user.base.organization_id, permissions: user.base.permissions, network_ids, + email: user.base.email, }) } } @@ -232,6 +244,7 @@ pub struct AuthenticatedUser { pub organization_id: Uuid, pub permissions: UserOrgPermissions, pub network_ids: Vec, + pub email: EmailAddress, } impl From for AuthenticatedEntity { @@ -241,6 +254,7 @@ impl From for AuthenticatedEntity { organization_id: value.organization_id, permissions: value.permissions, network_ids: value.network_ids, + email: value.email, } } } @@ -260,11 +274,13 @@ where organization_id, permissions, network_ids, + email, } => Ok(AuthenticatedUser { user_id, organization_id, permissions, network_ids, + email, }), _ => Err(AuthError(ApiError::unauthorized( "User authentication required".to_string(), @@ -274,7 +290,7 @@ where } /// Extractor that only accepts authenticated daemons (rejects users) -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Copy)] pub struct AuthenticatedDaemon { pub network_id: Uuid, pub api_key_id: Uuid, diff --git a/backend/src/server/auth/oidc.rs b/backend/src/server/auth/oidc.rs index 29b42110..9a3a6bff 100644 --- a/backend/src/server/auth/oidc.rs +++ b/backend/src/server/auth/oidc.rs @@ -1,20 +1,15 @@ use anyhow::{Error, Result, anyhow}; +use bad_email::is_email_unwanted; use chrono::Utc; use email_address::EmailAddress; -use openidconnect::{ - AuthenticationFlow, AuthorizationCode, ClientId, ClientSecret, CsrfToken, IssuerUrl, Nonce, - PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, TokenResponse, - core::{CoreClient, CoreProviderMetadata, CoreResponseType}, - reqwest::Client as ReqwestClient, -}; -use std::{net::IpAddr, str::FromStr, sync::Arc}; +use std::{collections::HashMap, net::IpAddr, str::FromStr, sync::Arc}; use uuid::Uuid; use crate::server::{ auth::{ r#impl::{ base::{LoginRegisterParams, ProvisionUserParams}, - oidc::{OidcPendingAuth, OidcUserInfo}, + oidc::{OidcPendingAuth, OidcProvider, OidcProviderConfig, OidcProviderMetadata}, }, middleware::AuthenticatedEntity, service::AuthService, @@ -29,214 +24,101 @@ use crate::server::{ users::{r#impl::base::User, service::UserService}, }; -#[derive(Clone)] pub struct OidcService { - pub issuer_url: String, - pub client_id: String, - pub client_secret: String, - pub redirect_url: String, - pub provider_name: String, - pub auth_service: Arc, - pub user_service: Arc, - pub event_bus: Arc, + providers: HashMap>, + auth_service: Arc, + user_service: Arc, + event_bus: Arc, } impl OidcService { - pub fn new(params: OidcService) -> Self { - params - } + pub fn new( + configs: Vec, + public_url: &str, + auth_service: Arc, + user_service: Arc, + event_bus: Arc, + ) -> Self { + let mut providers = HashMap::new(); + + for config in configs { + // Build provider-specific callback URL + let redirect_url = format!("{}/api/auth/oidc/{}/callback", public_url, config.slug); + + let provider = OidcProvider::new( + config.slug.clone(), + config.name.clone(), + config.logo.clone(), + config.issuer_url.clone(), + config.client_id.clone(), + config.client_secret.clone(), + redirect_url, + ); + + providers.insert(config.slug.clone(), Arc::new(provider)); + } - /// Generate authorization URL for user to visit - /// Returns: (auth_url, pending_auth to store in session) - pub async fn authorize_url(&self) -> Result<(String, OidcPendingAuth)> { - let http_client = ReqwestClient::builder() - .redirect(reqwest::redirect::Policy::none()) - .build()?; - - let provider_metadata = CoreProviderMetadata::discover_async( - IssuerUrl::new(self.issuer_url.clone())?, - &http_client, - ) - .await?; - - let client = CoreClient::from_provider_metadata( - provider_metadata, - ClientId::new(self.client_id.clone()), - Some(ClientSecret::new(self.client_secret.clone())), - ) - .set_redirect_uri(RedirectUrl::new(self.redirect_url.clone())?); - - let (pkce_challenge, pkce_verifier) = PkceCodeChallenge::new_random_sha256(); - - let (auth_url, csrf_token, nonce) = client - .authorize_url( - AuthenticationFlow::::AuthorizationCode, - CsrfToken::new_random, - Nonce::new_random, - ) - .add_scope(Scope::new("openid".to_string())) - .add_scope(Scope::new("email".to_string())) - .add_scope(Scope::new("profile".to_string())) - .set_pkce_challenge(pkce_challenge) - .url(); - - let pending_auth = OidcPendingAuth { - pkce_verifier: pkce_verifier.secret().clone(), - nonce: nonce.secret().clone(), - csrf_token: csrf_token.secret().clone(), - }; - - Ok((auth_url.to_string(), pending_auth)) + Self { + providers, + auth_service, + user_service, + event_bus, + } } - /// Exchange authorization code for user info - async fn exchange_code( - &self, - code: &str, - pending_auth: OidcPendingAuth, - ) -> Result { - let http_client = ReqwestClient::builder() - .redirect(reqwest::redirect::Policy::none()) - .build()?; - - let provider_metadata = CoreProviderMetadata::discover_async( - IssuerUrl::new(self.issuer_url.clone())?, - &http_client, - ) - .await?; - - let client = CoreClient::from_provider_metadata( - provider_metadata, - ClientId::new(self.client_id.clone()), - Some(ClientSecret::new(self.client_secret.clone())), - ) - .set_redirect_uri(RedirectUrl::new(self.redirect_url.clone())?); - - let pkce_verifier = PkceCodeVerifier::new(pending_auth.pkce_verifier); - let nonce = Nonce::new(pending_auth.nonce); - - let token_response = client - .exchange_code(AuthorizationCode::new(code.to_string()))? - .set_pkce_verifier(pkce_verifier) - .request_async(&http_client) - .await?; - - let id_token = token_response - .id_token() - .ok_or_else(|| anyhow::anyhow!("No ID token in response"))?; - - let claims = id_token.claims(&client.id_token_verifier(), &nonce)?; - - Ok(OidcUserInfo { - subject: claims.subject().to_string(), - email: claims.email().map(|e| e.to_string()), - name: claims - .name() - .and_then(|n| n.get(None).map(|s| s.to_string())), - }) + pub fn get_provider(&self, slug: &str) -> Option<&Arc> { + self.providers.get(slug) } - /// Link OIDC account to existing user - pub async fn link_to_user( - &self, - user_id: &Uuid, - code: &str, - pending_auth: OidcPendingAuth, - ip: IpAddr, - user_agent: Option, - ) -> Result { - let user_info = self.exchange_code(code, pending_auth).await?; - - // Check if this OIDC account is already linked to another user - if let Some(existing_user) = self - .auth_service - .user_service - .get_user_by_oidc(&user_info.subject) - .await? - { - if existing_user.id != *user_id { - return Err(anyhow!( - "This OIDC account is already linked to another user" - )); - } - // Already linked to this user - return Ok(existing_user); - } - - let mut user = self - .user_service - .get_by_id(user_id) - .await? - .ok_or_else(|| anyhow::anyhow!("User not found"))?; - - user.base.oidc_provider = Some(self.provider_name.clone()); - user.base.oidc_subject = Some(user_info.subject); - user.base.oidc_linked_at = Some(chrono::Utc::now()); - - let authentication: AuthenticatedEntity = user.clone().into(); - - self.event_bus - .publish_auth(AuthEvent { - id: Uuid::new_v4(), - user_id: Some(user.id), - organization_id: Some(user.base.organization_id), - timestamp: Utc::now(), - operation: AuthOperation::OidcLinked, - ip_address: ip, - user_agent, - metadata: serde_json::json!({ - "method": "oidc", - "provider": self.provider_name - }), - authentication: authentication.clone(), + pub fn list_providers(&self) -> Vec { + self.providers + .values() + .map(|provider| OidcProviderMetadata { + name: provider.name.clone(), + slug: provider.slug.clone(), + logo: provider.logo.clone(), }) - .await?; + .collect() + } - self.user_service.update(&mut user, authentication).await + pub fn is_empty(&self) -> bool { + self.providers.is_empty() } - /// Login or register user via OIDC - pub async fn login_or_register( + /// Register new user via OIDC (fails if account already exists) + pub async fn register( &self, + provider_slug: &str, code: &str, pending_auth: OidcPendingAuth, params: LoginRegisterParams, ) -> Result { + let provider = self + .get_provider(provider_slug) + .ok_or_else(|| anyhow!("Provider '{}' not found", provider_slug))?; + let LoginRegisterParams { org_id, permissions, ip, user_agent, network_ids, + subscribed, } = params; - let user_info = self.exchange_code(code, pending_auth).await?; + // Exchange code for user info using provider + let user_info = provider.exchange_code(code, &pending_auth).await?; - // Check if user exists with this OIDC account, login if so - if let Some(user) = self - .auth_service + // Check if user already exists with this OIDC account + if let Some(_existing_user) = self .user_service .get_user_by_oidc(&user_info.subject) .await? { - self.event_bus - .publish_auth(AuthEvent { - id: Uuid::new_v4(), - user_id: Some(user.id), - organization_id: Some(user.base.organization_id), - timestamp: Utc::now(), - operation: AuthOperation::LoginSuccess, - ip_address: ip, - user_agent, - metadata: serde_json::json!({ - "method": "oidc", - "provider": self.provider_name - }), - authentication: user.clone().into(), - }) - .await?; - - return Ok(user); + return Err(anyhow!( + "An account with this {} login already exists. Please use the login flow instead.", + provider.name + )); } // Parse or create fallback email @@ -250,6 +132,12 @@ impl OidcService { Ok::(EmailAddress::new_unchecked(fallback_email_str)) })?; + if is_email_unwanted(email.as_str()) { + return Err(anyhow!( + "Email address uses a disposable domain. Please register with a non-disposable email address." + )); + } + // Register new user let user = self .auth_service @@ -257,13 +145,15 @@ impl OidcService { email, password_hash: None, oidc_subject: Some(user_info.subject), - oidc_provider: Some(self.provider_name.clone()), + oidc_provider: Some(provider.slug.clone()), org_id, permissions, network_ids, + subscribed, }) .await?; + // Publish event self.event_bus .publish_auth(AuthEvent { id: Uuid::new_v4(), @@ -275,7 +165,8 @@ impl OidcService { user_agent, metadata: serde_json::json!({ "method": "oidc", - "provider": self.provider_name + "provider": provider.slug, + "provider_name": provider.name }), authentication: user.clone().into(), }) @@ -284,13 +175,153 @@ impl OidcService { Ok(user) } + /// Login existing user via OIDC (fails if account doesn't exist) + pub async fn login( + &self, + provider_slug: &str, + code: &str, + pending_auth: OidcPendingAuth, + ip: IpAddr, + user_agent: Option, + ) -> Result { + let provider = self + .get_provider(provider_slug) + .ok_or_else(|| anyhow!("Provider '{}' not found", provider_slug))?; + + // Exchange code for user info using provider + let user_info = provider.exchange_code(code, &pending_auth).await?; + + // Check if user exists with this OIDC account + let user = self + .user_service + .get_user_by_oidc(&user_info.subject) + .await? + .ok_or_else(|| { + anyhow!( + "No account found with this {} login. Please register first.", + provider.name + ) + })?; + + // Publish event + self.event_bus + .publish_auth(AuthEvent { + id: Uuid::new_v4(), + user_id: Some(user.id), + organization_id: Some(user.base.organization_id), + timestamp: Utc::now(), + operation: AuthOperation::LoginSuccess, + ip_address: ip, + user_agent, + metadata: serde_json::json!({ + "method": "oidc", + "provider": provider.slug, + "provider_name": provider.name + }), + authentication: user.clone().into(), + }) + .await?; + + Ok(user) + } + + /// Link OIDC account to existing user + pub async fn link_to_user( + &self, + provider_slug: &str, + user_id: &Uuid, + code: &str, + pending_auth: OidcPendingAuth, + ip: IpAddr, + user_agent: Option, + ) -> Result { + let provider = self + .get_provider(provider_slug) + .ok_or_else(|| anyhow!("Provider '{}' not found", provider_slug))?; + + // Exchange code for user info using provider + let user_info = provider.exchange_code(code, &pending_auth).await?; + + // Check if this OIDC account is already linked to another user + if let Some(existing_user) = self + .user_service + .get_user_by_oidc(&user_info.subject) + .await? + { + if existing_user.id != *user_id { + return Err(anyhow!( + "This {} account is already linked to another user", + provider.name + )); + } + // Already linked to this user + return Ok(existing_user); + } + + // Get and update user + let mut user = self + .user_service + .get_by_id(user_id) + .await? + .ok_or_else(|| anyhow::anyhow!("User not found"))?; + + // ERROR if user already has a different OIDC provider linked + if let Some(existing_provider) = &user.base.oidc_provider + && existing_provider != provider_slug + { + let existing_provider_name = self + .get_provider(existing_provider) + .map(|p| p.name.as_str()) + .unwrap_or(existing_provider); + + return Err(anyhow!( + "You already have {} linked to your account. Please unlink it first before linking {}.", + existing_provider_name, + provider.name + )); + } + + user.base.oidc_provider = Some(provider.slug.clone()); + user.base.oidc_subject = Some(user_info.subject); + user.base.oidc_linked_at = Some(chrono::Utc::now()); + + let authentication: AuthenticatedEntity = user.clone().into(); + + // Publish event + self.event_bus + .publish_auth(AuthEvent { + id: Uuid::new_v4(), + user_id: Some(user.id), + organization_id: Some(user.base.organization_id), + timestamp: Utc::now(), + operation: AuthOperation::OidcLinked, + ip_address: ip, + user_agent, + metadata: serde_json::json!({ + "method": "oidc", + "provider": provider.slug, + "provider_name": provider.name + }), + authentication: authentication.clone(), + }) + .await?; + + self.user_service.update(&mut user, authentication).await + } + /// Unlink OIDC from user pub async fn unlink_from_user( &self, + provider_slug: &str, user_id: &Uuid, ip: IpAddr, user_agent: Option, ) -> Result { + let provider = self + .get_provider(provider_slug) + .ok_or_else(|| anyhow!("Provider '{}' not found", provider_slug))?; + + // Get user let mut user = self .user_service .get_by_id(user_id) @@ -311,6 +342,7 @@ impl OidcService { let authentication: AuthenticatedEntity = user.clone().into(); + // Publish event self.event_bus .publish_auth(AuthEvent { id: Uuid::new_v4(), @@ -322,7 +354,8 @@ impl OidcService { user_agent, metadata: serde_json::json!({ "method": "oidc", - "provider": self.provider_name + "provider": provider.slug, + "provider_name": provider.name }), authentication: authentication.clone(), }) diff --git a/backend/src/server/auth/service.rs b/backend/src/server/auth/service.rs index 941488af..b7041a71 100644 --- a/backend/src/server/auth/service.rs +++ b/backend/src/server/auth/service.rs @@ -1,3 +1,4 @@ +use crate::server::shared::events::types::TelemetryOperation; use crate::server::{ auth::{ r#impl::{ @@ -6,7 +7,7 @@ use crate::server::{ }, middleware::{AuthenticatedEntity, AuthenticatedUser}, }, - email::service::EmailService, + email::traits::EmailService, organizations::{ r#impl::base::{Organization, OrganizationBase}, service::OrganizationService, @@ -14,7 +15,7 @@ use crate::server::{ shared::{ events::{ bus::EventBus, - types::{AuthEvent, AuthOperation}, + types::{AuthEvent, AuthOperation, TelemetryEvent}, }, services::traits::CrudService, storage::{filter::EntityFilter, traits::StorableEntity}, @@ -80,6 +81,7 @@ impl AuthService { ip, user_agent, network_ids, + subscribed, } = params; request @@ -106,6 +108,7 @@ impl AuthService { org_id, permissions, network_ids, + subscribed, }) .await?; @@ -138,6 +141,7 @@ impl AuthService { org_id, permissions, network_ids, + subscribed, } = params; let all_users = self @@ -151,7 +155,9 @@ impl AuthService { .find(|u| u.base.password_hash.is_none() && u.base.oidc_subject.is_none()) .cloned(); - if let Some(mut seed_user) = seed_user { + let mut is_new_org = false; + + let user = if let Some(mut seed_user) = seed_user { // First user ever - claim seed user tracing::info!("First user registration - claiming seed user"); seed_user.base.email = email; @@ -166,6 +172,8 @@ impl AuthService { seed_user.base.oidc_linked_at = Some(chrono::Utc::now()); } + is_new_org = true; + self.user_service .update(&mut seed_user, AuthenticatedEntity::System) .await @@ -174,6 +182,8 @@ impl AuthService { let organization_id = if let Some(org_id) = org_id { org_id } else { + is_new_org = true; + // Create new organization for this user let organization = self .organization_service @@ -183,7 +193,7 @@ impl AuthService { name: "My Organization".to_string(), plan: None, plan_status: None, - is_onboarded: false, + onboarding: vec![], }), AuthenticatedEntity::System, ) @@ -225,7 +235,24 @@ impl AuthService { } else { Err(anyhow!("Must provide either password or OIDC credentials")) } + }; + + if is_new_org && let Ok(user) = &user { + self.event_bus + .publish_telemetry(TelemetryEvent { + id: Uuid::new_v4(), + authentication: user.clone().into(), + organization_id: user.base.organization_id, + operation: TelemetryOperation::OrgCreated, + timestamp: Utc::now(), + metadata: serde_json::json!({ + "subscribed": subscribed + }), + }) + .await?; } + + user } /// Login with username and password @@ -434,14 +461,7 @@ impl AuthService { tokens.insert(token.clone(), (user.id, Instant::now())); email_service - .send_email( - user.base.email.clone(), - "NetVisor Password Reset", - &format!( - "Click here to reset your password", - url, token - ), - ) + .send_password_reset(user.base.email.clone(), url, token) .await?; Ok(()) diff --git a/backend/src/server/billing/service.rs b/backend/src/server/billing/service.rs index ea78408a..c711df34 100644 --- a/backend/src/server/billing/service.rs +++ b/backend/src/server/billing/service.rs @@ -4,6 +4,9 @@ use crate::server::billing::types::features::Feature; use crate::server::networks::service::NetworkService; use crate::server::organizations::r#impl::base::Organization; use crate::server::organizations::service::OrganizationService; +use crate::server::shared::events::bus::EventBus; +use crate::server::shared::events::types::TelemetryEvent; +use crate::server::shared::events::types::TelemetryOperation; use crate::server::shared::services::traits::CrudService; use crate::server::shared::storage::filter::EntityFilter; use crate::server::shared::types::metadata::TypeMetadataProvider; @@ -11,6 +14,7 @@ use crate::server::users::r#impl::permissions::UserOrgPermissions; use crate::server::users::service::UserService; use anyhow::Error; use anyhow::anyhow; +use chrono::Utc; use std::sync::Arc; use std::sync::OnceLock; use stripe::Client; @@ -48,6 +52,7 @@ pub struct BillingService { pub user_service: Arc, pub network_service: Arc, pub plans: OnceLock>, + pub event_bus: Arc, } const SEAT_PRODUCT_ID: &str = "extra_seats"; @@ -62,6 +67,7 @@ impl BillingService { organization_service: Arc, user_service: Arc, network_service: Arc, + event_bus: Arc, ) -> Self { Self { stripe: Client::new(stripe_secret), @@ -70,6 +76,7 @@ impl BillingService { network_service, user_service, plans: OnceLock::new(), + event_bus, } } @@ -299,7 +306,7 @@ impl BillingService { authentication: AuthenticatedEntity, ) -> Result { // Get or create Stripe customer - let customer_id = self + let (_, customer_id) = self .get_or_create_customer(organization_id, authentication) .await?; @@ -509,7 +516,7 @@ impl BillingService { &self, organization_id: Uuid, authentication: AuthenticatedEntity, - ) -> Result { + ) -> Result<(Organization, CustomerId), Error> { // Check if org already has stripe_customer_id let mut organization = self .organization_service @@ -517,8 +524,8 @@ impl BillingService { .await? .ok_or_else(|| anyhow!("Organization {} doesn't exist.", organization_id))?; - if let Some(customer_id) = organization.base.stripe_customer_id { - return Ok(CustomerId::from(customer_id)); + if let Some(customer_id) = organization.base.stripe_customer_id.clone() { + return Ok((organization, CustomerId::from(customer_id.to_owned()))); } let organization_owners = self @@ -550,7 +557,7 @@ impl BillingService { .update(&mut organization, authentication) .await?; - Ok(customer.id) + Ok((organization, customer.id)) } /// Handle webhook events @@ -633,31 +640,64 @@ impl BillingService { .await? .ok_or_else(|| anyhow!("Could not find organization to update subscriptions status"))?; - // Update enabled features to match new plan - // if let Some(included_networks) = plan.config().included_networks { - // let networks = self - // .network_service - // .get_all(EntityFilter::unfiltered().organization_id(&org_id)) - // .await?; - // let keep_ids = networks - // .iter() - // .take(included_networks) - // .map(|n| n.id) - // .collect::>(); - - // for network in networks { - // if !keep_ids.contains(&network.id) { - // self.network_service - // .delete(&network.id, AuthenticatedEntity::System) - // .await?; - // tracing::info!( - // organization_id = %org_id, - // network_id = %network.id, - // "Deleted network due to plan downgrade" - // ); - // } - // } - // } + let owners = self + .user_service + .get_organization_owners(&organization.id) + .await?; + + // First time signing up for a plan + if let Some(owner) = owners.first() + && organization.base.plan.is_none() + && organization.not_onboarded(&TelemetryOperation::CommercialPlanSelected) + && organization.not_onboarded(&TelemetryOperation::PersonalPlanSelected) + { + let operation = if plan.is_commercial() { + TelemetryOperation::CommercialPlanSelected + } else { + TelemetryOperation::PersonalPlanSelected + }; + + self.event_bus + .publish_telemetry(TelemetryEvent { + id: Uuid::new_v4(), + authentication: owner.clone().into(), + organization_id: organization.id, + operation, + timestamp: Utc::now(), + metadata: serde_json::json!({ + "is_onboarding_step": true + }), + }) + .await?; + } + + // If they can't pay for networks, remove them + if let Some(included_networks) = plan.config().included_networks + && plan.config().network_cents.is_none() + { + let networks = self + .network_service + .get_all(EntityFilter::unfiltered().organization_id(&org_id)) + .await?; + let keep_ids = networks + .iter() + .take(included_networks.try_into().unwrap_or(3)) + .map(|n| n.id) + .collect::>(); + + for network in networks { + if !keep_ids.contains(&network.id) { + self.network_service + .delete(&network.id, AuthenticatedEntity::System) + .await?; + tracing::info!( + organization_id = %org_id, + network_id = %network.id, + "Deleted network due to plan downgrade" + ); + } + } + } match plan { BillingPlan::Community { .. } => {} diff --git a/backend/src/server/billing/subscriber.rs b/backend/src/server/billing/subscriber.rs index 176255db..c38148a8 100644 --- a/backend/src/server/billing/subscriber.rs +++ b/backend/src/server/billing/subscriber.rs @@ -81,10 +81,6 @@ impl EventSubscriber for BillingService { Ok(()) } - fn debounce_window_ms(&self) -> u64 { - 50 // Small window to batch multiple subnet deletions - } - fn name(&self) -> &str { "billing_quota_update" } diff --git a/backend/src/server/billing/types/base.rs b/backend/src/server/billing/types/base.rs index 6ff69135..3394215f 100644 --- a/backend/src/server/billing/types/base.rs +++ b/backend/src/server/billing/types/base.rs @@ -19,6 +19,7 @@ use strum::{Display, EnumDiscriminants, EnumIter, IntoStaticStr}; EnumDiscriminants, Eq, )] +#[strum_discriminants(derive(IntoStaticStr, Serialize))] #[serde(tag = "type")] pub enum BillingPlan { Community(PlanConfig), diff --git a/backend/src/server/config.rs b/backend/src/server/config.rs index f804f8fa..fe6e4288 100644 --- a/backend/src/server/config.rs +++ b/backend/src/server/config.rs @@ -1,110 +1,125 @@ -use crate::server::shared::services::factory::ServiceFactory; +use crate::server::auth::r#impl::oidc::OidcProviderMetadata; +use crate::server::{ + auth::r#impl::oidc::OidcProviderConfig, shared::services::factory::ServiceFactory, +}; use anyhow::{Error, Result}; +use clap::Parser; use figment::{ Figment, - providers::{Env, Serialized}, + providers::{Env, Format, Serialized, Toml}, }; use serde::{Deserialize, Serialize}; use std::{path::PathBuf, sync::Arc}; use crate::server::shared::storage::factory::StorageFactory; -/// CLI arguments structure (for figment integration) -#[derive(Debug)] -pub struct CliArgs { - pub server_port: Option, - pub log_level: Option, - pub rust_log: Option, - pub database_url: Option, - pub integrated_daemon_url: Option, - pub use_secure_session_cookies: Option, - pub disable_registration: bool, - pub oidc_issuer_url: Option, - pub oidc_client_id: Option, - pub oidc_client_secret: Option, - pub oidc_redirect_url: Option, - pub oidc_provider_name: Option, - pub stripe_secret: Option, - pub stripe_webhook_secret: Option, - pub smtp_username: Option, - pub smtp_password: Option, - pub smtp_relay: Option, - pub smtp_email: Option, - pub public_url: Option, -} +#[derive(Parser)] +#[command(name = "netvisor-server")] +#[command(about = "NetVisor server")] +pub struct ServerCli { + /// Override server port + #[arg(long)] + server_port: Option, -/// Flattened server configuration struct -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ServerConfig { - // Server settings - /// What port the server should listen on - pub server_port: u16, + /// Override log level + #[arg(long)] + log_level: Option, - /// Level of logs to show - pub log_level: String, + /// Override rust system log level + #[arg(long)] + rust_log: Option, - /// Rust log level - pub rust_log: String, + /// Override database path + #[arg(long)] + database_url: Option, - /// Where database should be located - pub database_url: String, + /// Override integrated daemon url + #[arg(long)] + integrated_daemon_url: Option, - /// Where static web assets are located for serving - pub web_external_path: Option, + /// Use secure session cookies (if serving UI behind HTTPS) + #[arg(long)] + use_secure_session_cookies: Option, - /// Public URL for server for email links, webhooks, etc - pub public_url: String, + /// Enable or disable registration flow + #[arg(long)] + disable_registration: bool, - /// URL for daemon running in same docker stack or in other local context - pub integrated_daemon_url: Option, + /// OIDC redirect url + #[arg(long)] + stripe_secret: Option, - /// Use secure with issued session cookies - pub use_secure_session_cookies: bool, + /// OIDC redirect url + #[arg(long)] + stripe_webhook_secret: Option, - /// Disable user registration endpoint - pub disable_registration: bool, + #[arg(long)] + smtp_username: Option, - /// OIDC issuer URL - pub oidc_issuer_url: Option, + #[arg(long)] + smtp_password: Option, - /// OIDC client ID - pub oidc_client_id: Option, + /// Email used as to/from in emails send by NetVisor using SMTP + #[arg(long)] + smtp_email: Option, - /// OIDC client secret - pub oidc_client_secret: Option, + #[arg(long)] + smtp_relay: Option, - /// OIDC redirect url - pub oidc_redirect_url: Option, + #[arg(long)] + smtp_port: Option, - /// OIDC redirect url - pub oidc_provider_name: Option, + /// Server URL used in features like password reset and invite links + #[arg(long)] + public_url: Option, - /// Stripe key - pub stripe_key: Option, + #[arg(long)] + pub plunk_api_key: Option, - /// Stripe Secret - pub stripe_secret: Option, + /// Configure what proxy (if any) is providing IP address for requests, ie in a reverse proxy setup, for accurate IP in auth event logging + #[arg(long)] + pub client_ip_source: Option, - pub stripe_webhook_secret: Option, + /// List of OIDC providers + #[arg(long)] + pub oidc_providers: Option, +} +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ServerConfig { + pub server_port: u16, + pub log_level: String, + pub rust_log: String, + pub database_url: String, + pub web_external_path: Option, + pub public_url: String, + pub integrated_daemon_url: Option, + pub use_secure_session_cookies: bool, + pub disable_registration: bool, + pub client_ip_source: Option, pub smtp_username: Option, - pub smtp_password: Option, - pub smtp_relay: Option, - pub smtp_email: Option, + #[serde(default)] + pub oidc_providers: Option>, + + // Used in SaaS deployment + pub plunk_api_key: Option, + pub stripe_key: Option, + pub stripe_secret: Option, + pub stripe_webhook_secret: Option, } #[derive(Debug, Serialize, Deserialize)] pub struct PublicConfigResponse { pub server_port: u16, pub disable_registration: bool, - pub oidc_enabled: bool, - pub oidc_provider_name: String, + pub oidc_providers: Vec, pub billing_enabled: bool, pub has_integrated_daemon: bool, pub has_email_service: bool, + pub has_email_opt_in: bool, pub public_url: String, } @@ -120,11 +135,6 @@ impl Default for ServerConfig { use_secure_session_cookies: false, integrated_daemon_url: None, disable_registration: false, - oidc_client_id: None, - oidc_client_secret: None, - oidc_issuer_url: None, - oidc_redirect_url: None, - oidc_provider_name: None, stripe_key: None, stripe_secret: None, stripe_webhook_secret: None, @@ -132,17 +142,19 @@ impl Default for ServerConfig { smtp_password: None, smtp_email: None, smtp_relay: None, + plunk_api_key: None, + client_ip_source: None, + oidc_providers: None, } } } impl ServerConfig { - pub fn load(cli_args: CliArgs) -> anyhow::Result { + pub fn load(cli_args: ServerCli) -> anyhow::Result { // Standard configuration layering: Defaults → Env → CLI (highest priority) - let mut figment = Figment::from(Serialized::defaults(ServerConfig::default())); - - // Add environment variables with NETVISOR_ prefix - figment = figment.merge(Env::prefixed("NETVISOR_")); + let mut figment = Figment::from(Serialized::defaults(ServerConfig::default())) + .merge(Toml::file("../oidc.toml")) + .merge(Env::prefixed("NETVISOR_")); // Add CLI overrides (highest priority) - only if explicitly provided if let Some(server_port) = cli_args.server_port { @@ -163,21 +175,6 @@ impl ServerConfig { if let Some(use_secure_session_cookies) = cli_args.use_secure_session_cookies { figment = figment.merge(("use_secure_session_cookies", use_secure_session_cookies)); } - if let Some(oidc_issuer_url) = cli_args.oidc_issuer_url { - figment = figment.merge(("oidc_issuer_url", oidc_issuer_url)); - } - if let Some(oidc_client_id) = cli_args.oidc_client_id { - figment = figment.merge(("oidc_client_id", oidc_client_id)); - } - if let Some(oidc_client_secret) = cli_args.oidc_client_secret { - figment = figment.merge(("oidc_client_secret", oidc_client_secret)); - } - if let Some(oidc_redirect_url) = cli_args.oidc_redirect_url { - figment = figment.merge(("oidc_redirect_url", oidc_redirect_url)); - } - if let Some(oidc_provider_name) = cli_args.oidc_provider_name { - figment = figment.merge(("oidc_provider_name", oidc_provider_name)); - } if let Some(stripe_secret) = cli_args.stripe_secret { figment = figment.merge(("stripe_secret", stripe_secret)); } @@ -199,6 +196,15 @@ impl ServerConfig { if let Some(public_url) = cli_args.public_url { figment = figment.merge(("public_url", public_url)); } + if let Some(plunk_api_key) = cli_args.plunk_api_key { + figment = figment.merge(("plunk_api_key", plunk_api_key)); + } + if let Some(client_ip_source) = cli_args.client_ip_source { + figment = figment.merge(("client_ip_source", client_ip_source)); + } + if let Some(oidc_providers) = cli_args.oidc_providers { + figment = figment.merge(("oidc_providers", oidc_providers)); + } figment = figment.merge(("disable_registration", cli_args.disable_registration)); diff --git a/backend/src/server/daemons/handlers.rs b/backend/src/server/daemons/handlers.rs index f25f5dd3..8b0bd60f 100644 --- a/backend/src/server/daemons/handlers.rs +++ b/backend/src/server/daemons/handlers.rs @@ -1,3 +1,4 @@ +use crate::server::shared::events::types::TelemetryOperation; use crate::server::{ auth::middleware::{AuthenticatedDaemon, AuthenticatedEntity}, config::AppState, @@ -14,11 +15,12 @@ use crate::server::{ }, hosts::r#impl::base::{Host, HostBase}, shared::{ + events::types::TelemetryEvent, handlers::traits::{ bulk_delete_handler, create_handler, delete_handler, get_all_handler, get_by_id_handler, update_handler, }, - services::traits::CrudService, + services::traits::{CrudService, EventBusService}, storage::traits::StorableEntity, types::api::{ApiError, ApiResponse, ApiResult}, }, @@ -65,7 +67,7 @@ async fn register_daemon( let (host, _) = state .services .host_service - .create_host_with_services(dummy_host, Vec::new(), auth_daemon.clone().into()) + .create_host_with_services(dummy_host, Vec::new(), auth_daemon.into()) .await?; let mut daemon = Daemon::new(DaemonBase { @@ -85,6 +87,39 @@ async fn register_daemon( .await .map_err(|e| ApiError::internal_error(&format!("Failed to register daemon: {}", e)))?; + let org_id = state + .services + .network_service + .get_by_id(&request.network_id) + .await? + .map(|n| n.base.organization_id) + .unwrap_or_default(); + let organization = state + .services + .organization_service + .get_by_id(&org_id) + .await?; + + if let Some(organization) = organization + && organization.not_onboarded(&TelemetryOperation::FirstDaemonRegistered) + { + state + .services + .daemon_service + .event_bus() + .publish_telemetry(TelemetryEvent { + id: Uuid::new_v4(), + authentication: auth_daemon.into(), + organization_id: organization.id, + operation: TelemetryOperation::FirstDaemonRegistered, + timestamp: Utc::now(), + metadata: serde_json::json!({ + "is_onboarding_step": true + }), + }) + .await?; + } + let discovery_service = state.services.discovery_service.clone(); let self_report_discovery = discovery_service @@ -231,7 +266,7 @@ async fn receive_work_request( daemon.base.last_seen = Utc::now(); service - .update(&mut daemon, auth_daemon.clone().into()) + .update(&mut daemon, auth_daemon.into()) .await .map_err(|e| ApiError::internal_error(&format!("Failed to update heartbeat: {}", e)))?; diff --git a/backend/src/server/daemons/impl/base.rs b/backend/src/server/daemons/impl/base.rs index 507032c3..3e11211e 100644 --- a/backend/src/server/daemons/impl/base.rs +++ b/backend/src/server/daemons/impl/base.rs @@ -31,6 +31,17 @@ pub struct Daemon { pub base: DaemonBase, } +impl Daemon { + pub fn suppress_logs(&self, other: &Self) -> bool { + self.base.capabilities == other.base.capabilities + && self.base.mode == other.base.mode + && self.base.ip == other.base.ip + && self.base.port == other.base.port + && self.base.network_id == other.base.network_id + && self.base.host_id == other.base.host_id + } +} + impl Display for Daemon { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}: {}", self.base.ip, self.id) diff --git a/backend/src/server/daemons/service.rs b/backend/src/server/daemons/service.rs index e51ed59c..e068cb05 100644 --- a/backend/src/server/daemons/service.rs +++ b/backend/src/server/daemons/service.rs @@ -9,16 +9,21 @@ use crate::{ hosts::r#impl::ports::PortBase, services::r#impl::endpoints::{ApplicationProtocol, Endpoint}, shared::{ + entities::ChangeTriggersTopologyStaleness, events::{ bus::EventBus, types::{EntityEvent, EntityOperation}, }, services::traits::{CrudService, EventBusService}, - storage::generic::GenericPostgresStorage, + storage::{ + generic::GenericPostgresStorage, + traits::{StorableEntity, Storage}, + }, types::api::ApiResponse, }, }, }; +use anyhow::anyhow; use anyhow::{Error, Result}; use async_trait::async_trait; use chrono::Utc; @@ -49,6 +54,41 @@ impl CrudService for DaemonService { fn storage(&self) -> &Arc> { &self.daemon_storage } + + /// Update entity + async fn update( + &self, + entity: &mut Daemon, + authentication: AuthenticatedEntity, + ) -> Result { + let current = self + .get_by_id(&entity.id()) + .await? + .ok_or_else(|| anyhow!("Could not find {}", entity))?; + let updated = self.storage().update(entity).await?; + + let suppress_logs = updated.suppress_logs(¤t); + let trigger_stale = updated.triggers_staleness(Some(current)); + + self.event_bus() + .publish_entity(EntityEvent { + id: Uuid::new_v4(), + entity_id: updated.id(), + network_id: self.get_network_id(&updated), + organization_id: self.get_organization_id(&updated), + entity_type: updated.clone().into(), + operation: EntityOperation::Updated, + timestamp: Utc::now(), + metadata: serde_json::json!({ + "trigger_stale": trigger_stale, + "suppress_logs": suppress_logs + }), + authentication, + }) + .await?; + + Ok(updated) + } } impl DaemonService { diff --git a/backend/src/server/discovery/service.rs b/backend/src/server/discovery/service.rs index 3703d5ab..8a64ad7e 100644 --- a/backend/src/server/discovery/service.rs +++ b/backend/src/server/discovery/service.rs @@ -188,8 +188,17 @@ impl DiscoveryService { .await? .ok_or_else(|| anyhow::anyhow!("Could not find discovery {}", discovery))?; - // If it's a scheduled discovery, need to reschedule - let updated = if matches!(discovery.base.run_type, RunType::Scheduled { .. }) { + // If it's a scheduled discovery and schedule has changed, need to reschedule + let updated = if let RunType::Scheduled { + cron_schedule: new_cron, + .. + } = &discovery.base.run_type + && let RunType::Scheduled { + cron_schedule: current_cron, + .. + } = ¤t.base.run_type + && current_cron != new_cron + { // Remove old schedule first if let Some(scheduler) = &self.scheduler { let _ = scheduler.write().await.remove(&discovery.id).await; @@ -214,9 +223,7 @@ impl DiscoveryService { updated } else { // For non-scheduled, just update - let updated = self.discovery_storage.update(&mut discovery).await?; - tracing::info!("Updated discovery {}: {}", updated.base.name, updated.id); - updated + self.discovery_storage.update(&mut discovery).await? }; let trigger_stale = updated.triggers_staleness(Some(current)); @@ -394,7 +401,7 @@ impl DiscoveryService { let job_id = scheduler.write().await.add(job).await?; - tracing::info!( + tracing::debug!( "Scheduled discovery {} with cron: {}", discovery_id, cron_schedule diff --git a/backend/src/server/email/mod.rs b/backend/src/server/email/mod.rs index 8cf2a986..a293e8e4 100644 --- a/backend/src/server/email/mod.rs +++ b/backend/src/server/email/mod.rs @@ -1,2 +1,6 @@ -pub mod service; +pub mod plunk; +pub mod smtp; +pub mod subscriber; +pub mod templates; +pub mod traits; pub mod types; diff --git a/backend/src/server/email/plunk.rs b/backend/src/server/email/plunk.rs new file mode 100644 index 00000000..8e059150 --- /dev/null +++ b/backend/src/server/email/plunk.rs @@ -0,0 +1,144 @@ +use crate::server::email::{templates::PASSWORD_RESET_TITLE, traits::EmailProvider}; +use anyhow::Error; +use anyhow::anyhow; +use async_trait::async_trait; +use email_address::EmailAddress; +// use plunk::{PlunkClient, PlunkClientTrait, PlunkPayloads}; +use reqwest::Client; +use serde_json::Value; +use serde_json::json; + +/// Plunk-based email provider +pub struct PlunkEmailProvider { + api_key: String, + client: Client, +} + +impl PlunkEmailProvider { + pub fn new(api_key: String) -> Self { + Self { + api_key, + client: Client::new(), + } + } + + pub async fn send_transactional_email( + &self, + to: EmailAddress, + subject: String, + body: String, + ) -> Result<(), Error> { + let url = "https://api.useplunk.com/v1/send"; + let payload = json!({ + "to": to.to_string(), + "subject": subject, + "body": body, + "name": "NetVisor", + "from": "no-reply@email.netvisor.io", + "reply": "no-reply@email.netvisor.io" + }); + + let response = self + .client + .post(url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&payload) + .send() + .await?; + + if response.status().is_success() { + Ok(()) + } else { + Err(anyhow!( + "Failed to send email via Plunk: {}", + response.text().await? + )) + } + } +} + +#[async_trait] +impl EmailProvider for PlunkEmailProvider { + async fn send_password_reset( + &self, + to: EmailAddress, + url: String, + token: String, + ) -> Result<(), Error> { + self.send_transactional_email( + to, + PASSWORD_RESET_TITLE.to_string(), + self.build_password_reset_email(url, token), + ) + .await + .map_err(|e| anyhow!("{}", e)) + .map(|_| ()) + } + + /// Send an invite via email + async fn send_invite( + &self, + to: EmailAddress, + from: EmailAddress, + url: String, + ) -> Result<(), Error> { + self.send_transactional_email( + to, + self.build_invite_title(from.clone()), + self.build_invite_email(url, from), + ) + .await + .map_err(|e| anyhow!("{}", e)) + .map(|_| ()) + } + + async fn track_event( + &self, + event: String, + email: EmailAddress, + subscribed: bool, + data: Value, + ) -> Result<(), Error> { + // Convert all values in the object to strings + let normalized_data = if let Value::Object(map) = data { + let stringified: serde_json::Map = map + .into_iter() + .map(|(k, v)| { + let string_value = match v { + Value::String(s) => Value::String(s), + other => Value::String(serde_json::to_string(&other).unwrap_or_default()), + }; + (k, string_value) + }) + .collect(); + Value::Object(stringified) + } else { + serde_json::json!({}) + }; + + let body = serde_json::json!({ + "event": event, + "email": email.to_string(), + "subscribed": subscribed, + "data": normalized_data + }); + + let response = self + .client + .post("https://api.useplunk.com/v1/track") + .header("Content-Type", "application/json") + .header("Authorization", format!("Bearer {}", self.api_key)) + .json(&body) + .send() + .await?; + + if response.status().is_success() { + Ok(()) + } else { + Err(anyhow!( + "Failed to track Plunk event: {}", + response.text().await? + )) + } + } +} diff --git a/backend/src/server/email/service.rs b/backend/src/server/email/smtp.rs similarity index 55% rename from backend/src/server/email/service.rs rename to backend/src/server/email/smtp.rs index 448b3e2a..7ff52712 100644 --- a/backend/src/server/email/service.rs +++ b/backend/src/server/email/smtp.rs @@ -1,24 +1,30 @@ -use anyhow::{Result, anyhow}; -use email_address::EmailAddress; use lettre::{ AsyncSmtpTransport, AsyncTransport, Tokio1Executor, message::{Mailbox, MultiPart, SinglePart}, transport::smtp::authentication::Credentials, }; -#[derive(Clone)] -pub struct EmailService { +use anyhow::{Error, anyhow}; +use async_trait::async_trait; +use email_address::EmailAddress; + +use crate::server::email::{ + templates::PASSWORD_RESET_TITLE, + traits::{EmailProvider, strip_html_tags}, +}; + +pub struct SmtpEmailProvider { mailer: AsyncSmtpTransport, from: Mailbox, } -impl EmailService { +impl SmtpEmailProvider { pub fn new( smtp_username: String, smtp_password: String, smtp_email: String, smtp_relay: String, - ) -> Result { + ) -> Result { let creds = Credentials::new(smtp_username, smtp_password); let mailer = AsyncSmtpTransport::::relay(&smtp_relay) @@ -33,11 +39,10 @@ impl EmailService { .map_err(|e| anyhow!("Invalid from email address: {}", e))?, ); - Ok(EmailService { mailer, from }) + Ok(Self { mailer, from }) } - /// Send an HTML email - pub async fn send_email(&self, to: EmailAddress, subject: &str, html_body: &str) -> Result<()> { + async fn send_email(&self, to: EmailAddress, title: String, body: String) -> Result<(), Error> { let to_mbox = Mailbox::new( None, to.email() @@ -48,11 +53,11 @@ impl EmailService { let email = lettre::Message::builder() .from(self.from.clone()) .to(to_mbox) - .subject(subject) + .subject(title) .multipart( MultiPart::alternative() - .singlepart(SinglePart::plain(strip_html_tags(html_body))) - .singlepart(SinglePart::html(html_body.to_string())), + .singlepart(SinglePart::plain(strip_html_tags(body.clone()))) + .singlepart(SinglePart::html(body)), )?; self.mailer @@ -64,7 +69,33 @@ impl EmailService { } } -/// Strip HTML tags for plain text fallback -fn strip_html_tags(html: &str) -> String { - html2text::from_read(html.as_bytes(), 80).unwrap_or_else(|_| html.to_string()) +#[async_trait] +impl EmailProvider for SmtpEmailProvider { + async fn send_invite( + &self, + to: EmailAddress, + from: EmailAddress, + url: String, + ) -> Result<(), Error> { + self.send_email( + to, + self.build_invite_title(from.clone()), + self.build_invite_email(url, from), + ) + .await + } + + async fn send_password_reset( + &self, + to: EmailAddress, + url: String, + token: String, + ) -> Result<(), Error> { + self.send_email( + to, + PASSWORD_RESET_TITLE.to_string(), + self.build_password_reset_email(url, token), + ) + .await + } } diff --git a/backend/src/server/email/subscriber.rs b/backend/src/server/email/subscriber.rs new file mode 100644 index 00000000..4b5b0fea --- /dev/null +++ b/backend/src/server/email/subscriber.rs @@ -0,0 +1,55 @@ +use crate::server::{ + auth::middleware::AuthenticatedEntity, + email::traits::EmailService, + shared::events::{ + bus::{EventFilter, EventSubscriber}, + types::Event, + }, +}; +use anyhow::Error; +use async_trait::async_trait; +use serde_json::Value; +use std::collections::HashMap; + +#[async_trait] +impl EventSubscriber for EmailService { + fn event_filter(&self) -> EventFilter { + // All telemetry events + EventFilter::telemetry_only(None) + } + + async fn handle_events(&self, events: Vec) -> Result<(), Error> { + if events.is_empty() { + return Ok(()); + } + + for event in events { + if let Event::Telemetry(e) = event + && let AuthenticatedEntity::User { email, .. } = e.authentication + { + let mut metadata_map: HashMap = serde_json::from_value(e.metadata)?; + + let subscribed = metadata_map + .remove("subscribed") + .map(|v| serde_json::from_value::(v).unwrap_or(false)) + .unwrap_or(false); + + let metadata = serde_json::to_value(metadata_map)?; + + self.track_event( + e.operation.to_string().to_lowercase(), + email, + subscribed, + metadata, + ) + .await?; + } + } + + Ok(()) + } + + fn name(&self) -> &str { + "email_triggers" + } +} diff --git a/backend/src/server/email/templates.rs b/backend/src/server/email/templates.rs new file mode 100644 index 00000000..321d716c --- /dev/null +++ b/backend/src/server/email/templates.rs @@ -0,0 +1,118 @@ +// Email template constants + +pub const EMAIL_HEADER: &str = r#" + + + + + NetVisor + + + + + + +
+ + + + + +"#; + +pub const EMAIL_FOOTER: &str = r#" + + + +
+ NetVisor +
+ + + + + + +
+ + Discord + + + + GitHub + +
+ +

© 2025 NetVisor. All rights reserved.

+
+
+ + +"#; + +pub const PASSWORD_RESET_TITLE: &str = "NetVisor Password Reset"; + +pub const PASSWORD_RESET_BODY: &str = r#" + + +

Reset Your Password

+

Hi there,

+

We received a request to reset your password for your NetVisor account. Click the button below to create a new password:

+ + + + + + + Reset Password + + + + + + +

If the button doesn't work, copy and paste this link into your browser:

+

{reset_url}

+ + + + + + +

This password reset link will expire in 24 hours. If you didn't request a password reset, you can safely ignore this email.

+ + +"#; + +pub const INVITE_LINK_BODY: &str = r#" + + +

You've Been Invited to NetVisor

+

Hi there,

+

{inviter_name} has invited you to join their NetVisor instance to visualize and explore their network infrastructure.

+

Click the button below to accept the invitation and create your account:

+ + + + + + + Accept Invitation + + + + + + +

If the button doesn't work, copy and paste this link into your browser:

+

{invite_url}

+ + + + + + +

This invitation link will expire in 7 days. If you didn't expect this invitation, you can safely ignore this email.

+ + +"#; diff --git a/backend/src/server/email/traits.rs b/backend/src/server/email/traits.rs new file mode 100644 index 00000000..7f2c6411 --- /dev/null +++ b/backend/src/server/email/traits.rs @@ -0,0 +1,120 @@ +use std::sync::Arc; + +use anyhow::{Error, Result}; +use async_trait::async_trait; +use email_address::EmailAddress; +use serde_json::Value; + +use crate::server::{ + email::templates::{EMAIL_FOOTER, EMAIL_HEADER, INVITE_LINK_BODY, PASSWORD_RESET_BODY}, + users::service::UserService, +}; + +/// Trait for email provider implementations +#[async_trait] +pub trait EmailProvider: Send + Sync { + // Example usage function + fn build_email(&self, body: String) -> String { + format!("{}{}{}", EMAIL_HEADER, body, EMAIL_FOOTER) + } + + fn build_invite_title(&self, from_user: EmailAddress) -> String { + format!("You've been invited to join {} on NetVisor", from_user) + } + + fn build_password_reset_email(&self, url: String, token: String) -> String { + self.build_email(PASSWORD_RESET_BODY.replace( + "{reset_url}", + &format!("{}/reset-password?token={}", url, token), + )) + } + + fn build_invite_email(&self, url: String, from: EmailAddress) -> String { + self.build_email( + INVITE_LINK_BODY + .replace("{invite_url}", &url) + .replace("{inviter_name}", from.as_str()), + ) + } + + /// Send an HTML email + async fn send_password_reset( + &self, + to: EmailAddress, + url: String, + token: String, + ) -> Result<(), Error>; + + /// Send an invite via email + async fn send_invite( + &self, + to: EmailAddress, + from: EmailAddress, + url: String, + ) -> Result<(), Error>; + + /// Track an event (optional, only for providers that support it) + async fn track_event( + &self, + event: String, + email: EmailAddress, + subscribed: bool, + data: Value, + ) -> Result<()> { + // Default implementation does nothing + let _ = (event, email, subscribed, data); + Ok(()) + } +} + +/// Email service that wraps the provider +pub struct EmailService { + provider: Box, + pub user_service: Arc, +} + +impl EmailService { + pub fn new(provider: Box, user_service: Arc) -> Self { + Self { + provider, + user_service, + } + } + + /// Send an HTML email + pub async fn send_password_reset( + &self, + to: EmailAddress, + url: String, + token: String, + ) -> Result<()> { + self.provider.send_password_reset(to, url, token).await + } + + pub async fn send_invite( + &self, + to: EmailAddress, + from: EmailAddress, + url: String, + ) -> Result<()> { + self.provider.send_invite(to, from, url).await + } + + /// Track an event (delegates to provider) + pub async fn track_event( + &self, + event: String, + email: EmailAddress, + subscribed: bool, + data: Value, + ) -> Result<()> { + self.provider + .track_event(event, email, subscribed, data) + .await + } +} + +/// Strip HTML tags for plain text fallback +pub fn strip_html_tags(html: String) -> String { + html2text::from_read(html.as_bytes(), 80).unwrap_or_else(|_| html.to_string()) +} diff --git a/backend/src/server/hosts/subscriber.rs b/backend/src/server/hosts/subscriber.rs index 34284b6f..89a55728 100644 --- a/backend/src/server/hosts/subscriber.rs +++ b/backend/src/server/hosts/subscriber.rs @@ -95,7 +95,7 @@ impl EventSubscriber for HostService { } fn debounce_window_ms(&self) -> u64 { - 50 // Small window to batch multiple subnet deletions + 50 // Small window to batch bulk subnet deletions } fn name(&self) -> &str { diff --git a/backend/src/server/logging/subscriber.rs b/backend/src/server/logging/subscriber.rs index d47c8730..8b44dc72 100644 --- a/backend/src/server/logging/subscriber.rs +++ b/backend/src/server/logging/subscriber.rs @@ -18,17 +18,20 @@ impl EventSubscriber for LoggingService { async fn handle_events(&self, events: Vec) -> Result<(), Error> { // Log each event individually for event in events { - event.log(); - tracing::debug!("{}", event); + let suppress_logs = event + .metadata() + .get("suppress_logs") + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .unwrap_or(false); + + if !suppress_logs { + tracing::info!("{}", event); + } } Ok(()) } - fn debounce_window_ms(&self) -> u64 { - 0 // No batching for logging - we want immediate logs - } - fn name(&self) -> &str { "logging" } diff --git a/backend/src/server/networks/service.rs b/backend/src/server/networks/service.rs index cbe33fa0..867b0f6a 100644 --- a/backend/src/server/networks/service.rs +++ b/backend/src/server/networks/service.rs @@ -91,8 +91,6 @@ impl NetworkService { .create_host_with_services(remote_host, vec![client_service], authenticated.clone()) .await?; - tracing::info!("Default data seeded successfully"); - Ok(()) } } diff --git a/backend/src/server/organizations/handlers.rs b/backend/src/server/organizations/handlers.rs index b2b58e04..e46cd2c6 100644 --- a/backend/src/server/organizations/handlers.rs +++ b/backend/src/server/organizations/handlers.rs @@ -73,6 +73,9 @@ async fn create_invite( )); } + let send_to = request.send_to.clone(); + let from_user = user.email.clone(); + let invite = state .services .organization_service @@ -82,10 +85,22 @@ async fn create_invite( user.user_id, state.config.public_url.clone(), user.into(), + send_to.clone(), ) .await .map_err(|e| ApiError::internal_error(&e.to_string()))?; + if let Some(send_to) = send_to + && let Some(email_service) = &state.services.email_service + { + let url = format!( + "{}/api/organizations/invites/{}/accept", + invite.url.clone(), + invite.id + ); + email_service.send_invite(send_to, from_user, url).await?; + } + Ok(Json(ApiResponse::success(invite))) } diff --git a/backend/src/server/organizations/impl/api.rs b/backend/src/server/organizations/impl/api.rs index 59e70013..d7a36540 100644 --- a/backend/src/server/organizations/impl/api.rs +++ b/backend/src/server/organizations/impl/api.rs @@ -1,4 +1,5 @@ use crate::server::users::r#impl::permissions::UserOrgPermissions; +use email_address::EmailAddress; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -7,4 +8,5 @@ pub struct CreateInviteRequest { pub expiration_hours: Option, pub permissions: UserOrgPermissions, pub network_ids: Vec, + pub send_to: Option, } diff --git a/backend/src/server/organizations/impl/base.rs b/backend/src/server/organizations/impl/base.rs index c21da5dd..80159a4a 100644 --- a/backend/src/server/organizations/impl/base.rs +++ b/backend/src/server/organizations/impl/base.rs @@ -5,7 +5,8 @@ use uuid::Uuid; use validator::Validate; use crate::server::{ - billing::types::base::BillingPlan, shared::entities::ChangeTriggersTopologyStaleness, + billing::types::base::BillingPlan, + shared::{entities::ChangeTriggersTopologyStaleness, events::types::TelemetryOperation}, }; #[derive(Debug, Clone, Serialize, Validate, Deserialize, Default, PartialEq, Eq, Hash)] @@ -15,7 +16,7 @@ pub struct OrganizationBase { pub name: String, pub plan: Option, pub plan_status: Option, - pub is_onboarded: bool, + pub onboarding: Vec, } #[derive(Debug, Clone, Validate, Serialize, Deserialize, PartialEq, Eq, Hash)] @@ -28,6 +29,16 @@ pub struct Organization { pub base: OrganizationBase, } +impl Organization { + pub fn not_onboarded(&self, step: &TelemetryOperation) -> bool { + !self.base.onboarding.contains(step) + } + + pub fn has_onboarded(&self, step: &TelemetryOperation) -> bool { + self.base.onboarding.contains(step) + } +} + impl Display for Organization { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{:?}: {:?}", self.base.name, self.id) diff --git a/backend/src/server/organizations/impl/invites.rs b/backend/src/server/organizations/impl/invites.rs index 8fee160d..8419a6a7 100644 --- a/backend/src/server/organizations/impl/invites.rs +++ b/backend/src/server/organizations/impl/invites.rs @@ -1,4 +1,5 @@ use chrono::{DateTime, Utc}; +use email_address::EmailAddress; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -17,6 +18,7 @@ pub struct Invite { pub created_by: Uuid, pub created_at: DateTime, pub expires_at: DateTime, + pub send_to: Option, } impl Invite { @@ -27,6 +29,7 @@ impl Invite { expiration_hours: i64, permissions: UserOrgPermissions, network_ids: Vec, + send_to: Option, ) -> Self { let now = Utc::now(); Self { @@ -38,6 +41,7 @@ impl Invite { url, created_at: now, expires_at: now + chrono::Duration::hours(expiration_hours), + send_to, } } diff --git a/backend/src/server/organizations/impl/storage.rs b/backend/src/server/organizations/impl/storage.rs index cf9ccde8..04096617 100644 --- a/backend/src/server/organizations/impl/storage.rs +++ b/backend/src/server/organizations/impl/storage.rs @@ -6,7 +6,10 @@ use uuid::Uuid; use crate::server::{ billing::types::base::BillingPlan, organizations::r#impl::base::{Organization, OrganizationBase}, - shared::storage::traits::{SqlValue, StorableEntity}, + shared::{ + events::types::TelemetryOperation, + storage::traits::{SqlValue, StorableEntity}, + }, }; impl StorableEntity for Organization { @@ -58,7 +61,7 @@ impl StorableEntity for Organization { stripe_customer_id, plan, plan_status, - is_onboarded, + onboarding, }, } = self.clone(); @@ -71,7 +74,7 @@ impl StorableEntity for Organization { "stripe_customer_id", "plan", "plan_status", - "is_onboarded", + "onboarding", ], vec![ SqlValue::Uuid(id), @@ -81,7 +84,7 @@ impl StorableEntity for Organization { SqlValue::OptionalString(stripe_customer_id), SqlValue::OptionBillingPlan(plan), SqlValue::OptionalString(plan_status), - SqlValue::Bool(is_onboarded), + SqlValue::TelemetryOperation(onboarding), ], )) } @@ -92,6 +95,10 @@ impl StorableEntity for Organization { .unwrap_or(None) .and_then(|v| serde_json::from_value(v).ok()); + let onboarding: Vec = + serde_json::from_value(row.get::("onboarding")) + .map_err(|e| anyhow::anyhow!("Failed to deserialize onboarding: {}", e))?; + Ok(Organization { id: row.get("id"), created_at: row.get("created_at"), @@ -101,7 +108,7 @@ impl StorableEntity for Organization { stripe_customer_id: row.get("stripe_customer_id"), plan, plan_status: row.get("plan_status"), - is_onboarded: row.get("is_onboarded"), + onboarding, }, }) } diff --git a/backend/src/server/organizations/mod.rs b/backend/src/server/organizations/mod.rs index cfb50050..0a734d6f 100644 --- a/backend/src/server/organizations/mod.rs +++ b/backend/src/server/organizations/mod.rs @@ -1,3 +1,4 @@ pub mod handlers; pub mod r#impl; pub mod service; +pub mod subscriber; diff --git a/backend/src/server/organizations/service.rs b/backend/src/server/organizations/service.rs index 42e49685..9e1c87b6 100644 --- a/backend/src/server/organizations/service.rs +++ b/backend/src/server/organizations/service.rs @@ -11,6 +11,7 @@ use crate::server::{ use anyhow::{Error, anyhow}; use async_trait::async_trait; use chrono::Utc; +use email_address::EmailAddress; use std::collections::HashMap; use std::sync::Arc; use tokio::sync::RwLock; @@ -112,10 +113,7 @@ impl OrganizationService { invites.retain(|_, invite| invite.expires_at > now); - tracing::debug!( - "Cleaned up expired invites. Current count: {}", - invites.len() - ); + tracing::debug!("Cleaned up expired invites."); } pub async fn create_invite( @@ -125,6 +123,7 @@ impl OrganizationService { user_id: Uuid, url: String, authentication: AuthenticatedEntity, + send_to: Option, ) -> Result { let expiration_hours = request.expiration_hours.unwrap_or(168); // Default 7 days @@ -135,6 +134,7 @@ impl OrganizationService { expiration_hours, request.permissions, request.network_ids, + send_to, ); // Store invite diff --git a/backend/src/server/organizations/subscriber.rs b/backend/src/server/organizations/subscriber.rs new file mode 100644 index 00000000..44f4bcf3 --- /dev/null +++ b/backend/src/server/organizations/subscriber.rs @@ -0,0 +1,52 @@ +use anyhow::Error; +use async_trait::async_trait; + +use crate::server::{ + auth::middleware::AuthenticatedEntity, + organizations::service::OrganizationService, + shared::{ + events::{ + bus::{EventFilter, EventSubscriber}, + types::Event, + }, + services::traits::CrudService, + }, +}; + +#[async_trait] +impl EventSubscriber for OrganizationService { + fn event_filter(&self) -> EventFilter { + EventFilter::telemetry_only(None) + } + + async fn handle_events(&self, events: Vec) -> Result<(), Error> { + if events.is_empty() { + return Ok(()); + } + + for event in events { + if let Event::Telemetry(event) = event { + let is_onboarding_step = event + .metadata + .get("is_onboarding_step") + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .unwrap_or(false); + + if let Some(mut organization) = self.get_by_id(&event.organization_id).await? + && is_onboarding_step + && organization.not_onboarded(&event.operation) + { + organization.base.onboarding.push(event.operation); + self.update(&mut organization, AuthenticatedEntity::System) + .await?; + } + } + } + + Ok(()) + } + + fn name(&self) -> &str { + "organization_onboarding" + } +} diff --git a/backend/src/server/shared/events/bus.rs b/backend/src/server/shared/events/bus.rs index ab54c0e8..0946881f 100644 --- a/backend/src/server/shared/events/bus.rs +++ b/backend/src/server/shared/events/bus.rs @@ -10,7 +10,10 @@ use uuid::Uuid; use crate::server::shared::{ entities::EntityDiscriminants, - events::types::{AuthEvent, AuthOperation, EntityEvent, EntityOperation, Event}, + events::types::{ + AuthEvent, AuthOperation, EntityEvent, EntityOperation, Event, TelemetryEvent, + TelemetryOperation, + }, }; // Trait for event subscribers @@ -38,6 +41,7 @@ pub struct EventFilter { // None = match all values (ignore as a filter) pub entity_operations: Option>>>, pub auth_operations: Option>, + pub telemetry_operations: Option>, pub network_ids: Option>, } @@ -46,6 +50,7 @@ impl EventFilter { Self { entity_operations: None, auth_operations: None, + telemetry_operations: None, network_ids: None, } } @@ -55,16 +60,27 @@ impl EventFilter { ) -> Self { Self { entity_operations: Some(entity_operations), - auth_operations: None, + auth_operations: Some(vec![]), + telemetry_operations: Some(vec![]), network_ids: None, } } - pub fn auth_only(auth_operations: Vec) -> Self { + pub fn auth_only(auth_operations: Option>) -> Self { Self { - entity_operations: None, - auth_operations: Some(auth_operations), - network_ids: None, + entity_operations: Some(HashMap::new()), + telemetry_operations: Some(vec![]), + auth_operations, + network_ids: Some(vec![]), + } + } + + pub fn telemetry_only(telemetry_operations: Option>) -> Self { + Self { + entity_operations: Some(HashMap::new()), + telemetry_operations, + auth_operations: Some(vec![]), + network_ids: Some(vec![]), } } @@ -72,6 +88,7 @@ impl EventFilter { match event { Event::Entity(entity_event) => self.matches_entity(entity_event), Event::Auth(auth_event) => self.matches_auth(auth_event), + Event::Telemetry(telemetry_event) => self.matches_telemetry(telemetry_event), } } @@ -109,6 +126,15 @@ impl EventFilter { true } + + fn matches_telemetry(&self, event: &TelemetryEvent) -> bool { + // Check auth operation filter + if let Some(telemetry_operations) = &self.telemetry_operations { + return telemetry_operations.contains(&event.operation); + } + + true + } } /// Internal: Manages batching state for a subscriber @@ -167,10 +193,7 @@ impl SubscriberState { // Count events per org before processing let mut events_per_org: HashMap, usize> = HashMap::new(); for event in &events { - let org_id = match event { - Event::Entity(e) => e.network_id, - Event::Auth(e) => e.organization_id, - }; + let org_id = event.org_id(); *events_per_org.entry(org_id).or_default() += 1; } @@ -278,26 +301,13 @@ impl EventBus { self.publish(Event::Auth(event)).await } + /// Publish an auth event + pub async fn publish_telemetry(&self, event: TelemetryEvent) -> Result<()> { + self.publish(Event::Telemetry(event)).await + } + /// Publish an event to all subscribers async fn publish(&self, event: Event) -> Result<()> { - match &event { - Event::Entity(e) => { - tracing::debug!( - operation = %e.operation, - entity_type = %e.entity_type, - entity_id = %e.entity_id, - "Publishing entity event", - ); - } - Event::Auth(e) => { - tracing::debug!( - operation = ?e.operation, - user_id = ?e.user_id, - "Publishing auth event", - ); - } - } - // Send to broadcast channel (non-blocking) let _ = self.sender.send(event.clone()); diff --git a/backend/src/server/shared/events/types.rs b/backend/src/server/shared/events/types.rs index f0bda320..d53893ca 100644 --- a/backend/src/server/shared/events/types.rs +++ b/backend/src/server/shared/events/types.rs @@ -1,6 +1,6 @@ use crate::server::{auth::middleware::AuthenticatedEntity, shared::entities::Entity}; use chrono::{DateTime, Utc}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use std::{fmt::Display, net::IpAddr}; use strum::IntoDiscriminant; use uuid::Uuid; @@ -9,6 +9,7 @@ use uuid::Uuid; pub enum Event { Entity(Box), Auth(AuthEvent), + Telemetry(TelemetryEvent), } impl Event { @@ -16,6 +17,7 @@ impl Event { match self { Event::Auth(a) => a.id, Event::Entity(e) => e.id, + Event::Telemetry(t) => t.id, } } @@ -23,6 +25,7 @@ impl Event { match self { Event::Auth(a) => a.organization_id, Event::Entity(e) => e.organization_id, + Event::Telemetry(t) => Some(t.organization_id), } } @@ -30,6 +33,15 @@ impl Event { match self { Event::Auth(_) => None, Event::Entity(e) => e.network_id, + Event::Telemetry(_) => None, + } + } + + pub fn metadata(&self) -> serde_json::Value { + match self { + Event::Auth(e) => e.metadata.clone(), + Event::Entity(e) => e.metadata.clone(), + Event::Telemetry(e) => e.metadata.clone(), } } @@ -51,7 +63,6 @@ impl Event { network_id = %network_id_str, organization_id = %org_id_str, operation = %event.operation, - "Entity Event Logged" ); } Event::Auth(event) => { @@ -75,7 +86,12 @@ impl Event { user_id = %user_id_str, user_agent = %user_agent_str, operation = %event.operation, - "Auth Event Logged" + ); + } + Event::Telemetry(event) => { + tracing::info!( + organization_id = %event.organization_id, + operation = %event.operation, ); } } @@ -119,6 +135,11 @@ impl Display for Event { e.metadata, e.authentication ), + Event::Telemetry(t) => write!( + f, + "{{ id: {}, authentication: {}, organization_id: {}, operation: {}, timestamp: {}, metadata: {} }}", + t.id, t.authentication, t.organization_id, t.operation, t.timestamp, t.metadata, + ), } } } @@ -134,7 +155,9 @@ impl PartialEq for Event { } #[derive(Debug, Clone, Serialize, PartialEq, Eq, strum::Display)] +#[strum(serialize_all = "snake_case")] pub enum AuthOperation { + // User Auth Register, LoginSuccess, LoginFailed, @@ -146,6 +169,9 @@ pub enum AuthOperation { OidcLinked, OidcUnlinked, LoggedOut, + + // Api Key Auth + RotateKey, } #[derive(Debug, Clone, Serialize)] @@ -173,7 +199,22 @@ impl PartialEq for AuthEvent { } } +impl Display for AuthEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{ id: {}, operation: {}, ip: {}, user_agent: {}, authentication: {} }}", + self.id, + self.operation, + self.ip_address, + self.user_agent.clone().unwrap_or("unknown".to_string()), + self.authentication + ) + } +} + #[derive(Debug, Clone, Serialize, PartialEq, Eq, strum::Display)] +#[strum(serialize_all = "snake_case")] pub enum EntityOperation { Get, GetAll, @@ -182,7 +223,6 @@ pub enum EntityOperation { Deleted, DiscoveryStarted, DiscoveryCancelled, - Custom(&'static str), } #[derive(Debug, Clone, Serialize, Eq)] @@ -213,8 +253,43 @@ impl Display for EntityEvent { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!( f, - "Event: {{ id: {}, entity_type: {}, entity_id: {} }}", - self.id, self.entity_type, self.entity_id + "{{ id: {}, entity_type: {}, entity_id: {}, operation: {} }}", + self.id, self.entity_type, self.entity_id, self.operation + ) + } +} + +#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, Deserialize, strum::Display)] +#[strum(serialize_all = "snake_case")] +pub enum TelemetryOperation { + // Onboarding funnel + OrgCreated, + OnboardingModalCompleted, + PersonalPlanSelected, + CommercialPlanSelected, + FirstApiKeyCreated, + FirstDaemonRegistered, + FirstTopologyRebuild, // FirstDiscoveryStarted, + // FirstDiscoveryCompleted, + // FirstHostDiscovered, +} + +#[derive(Debug, Clone, Serialize, PartialEq)] +pub struct TelemetryEvent { + pub id: Uuid, + pub organization_id: Uuid, + pub operation: TelemetryOperation, + pub timestamp: DateTime, + pub authentication: AuthenticatedEntity, + pub metadata: serde_json::Value, +} + +impl Display for TelemetryEvent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{{ id: {}, organization_id: {}, operation: {}, authentication: {} }}", + self.id, self.organization_id, self.operation, self.authentication ) } } diff --git a/backend/src/server/shared/handlers/factory.rs b/backend/src/server/shared/handlers/factory.rs index a1a9a66a..159d371b 100644 --- a/backend/src/server/shared/handlers/factory.rs +++ b/backend/src/server/shared/handlers/factory.rs @@ -12,6 +12,7 @@ use crate::server::organizations::r#impl::base::Organization; use crate::server::services::definitions::ServiceDefinitionRegistry; use crate::server::shared::concepts::Concept; use crate::server::shared::entities::EntityDiscriminants; +use crate::server::shared::events::types::{TelemetryEvent, TelemetryOperation}; use crate::server::shared::services::traits::CrudService; use crate::server::shared::storage::traits::StorableEntity; use crate::server::shared::types::api::{ApiError, ApiResult}; @@ -34,11 +35,13 @@ use axum::extract::State; use axum::http::HeaderValue; use axum::routing::post; use axum::{Json, Router, routing::get}; +use chrono::Utc; use reqwest::header; use serde::{Deserialize, Serialize}; use std::sync::Arc; use strum::{IntoDiscriminant, IntoEnumIterator}; use tower_http::set_header::SetResponseHeaderLayer; +use uuid::Uuid; pub fn create_router() -> Router> { Router::new() @@ -103,26 +106,26 @@ async fn get_health() -> Json> { pub async fn get_public_config( State(state): State>, ) -> Json> { + let oidc_providers = state + .services + .oidc_service + .as_ref() + .map(|o| o.as_ref().list_providers()) + .unwrap_or_default(); + Json(ApiResponse::success(PublicConfigResponse { server_port: state.config.server_port, disable_registration: state.config.disable_registration, - oidc_enabled: state.config.oidc_client_id.is_some() - && state.config.oidc_client_secret.is_some() - && state.config.oidc_issuer_url.is_some() - && state.config.oidc_provider_name.is_some() - && state.config.oidc_redirect_url.is_some(), - oidc_provider_name: state - .config - .oidc_provider_name - .clone() - .unwrap_or("OIDC Provider".to_string()), + oidc_providers, billing_enabled: state.config.stripe_secret.is_some(), has_integrated_daemon: state.config.integrated_daemon_url.is_some(), - has_email_service: state.config.smtp_password.is_some() + has_email_service: (state.config.smtp_password.is_some() && state.config.smtp_username.is_some() && state.config.smtp_email.is_some() - && state.config.smtp_relay.is_some(), + && state.config.smtp_relay.is_some()) + || state.config.plunk_api_key.is_some(), public_url: state.config.public_url.clone(), + has_email_opt_in: state.config.plunk_api_key.is_some(), })) } @@ -145,8 +148,10 @@ pub async fn onboarding( .await? .ok_or_else(|| anyhow!("Could not find organization."))?; - if org.base.is_onboarded { - return Err(ApiError::bad_request("Org is already onboarded")); + if org.has_onboarded(&TelemetryOperation::OnboardingModalCompleted) { + return Err(ApiError::bad_request( + "Org has already completed onboarding modal", + )); } // Billing not enabled = self hosted @@ -155,7 +160,9 @@ pub async fn onboarding( } org.base.name = request.organization_name; - org.base.is_onboarded = true; + org.base + .onboarding + .push(TelemetryOperation::OnboardingModalCompleted); let updated_org = state .services .organization_service @@ -171,6 +178,14 @@ pub async fn onboarding( .create(network, user.clone().into()) .await?; + if request.populate_seed_data { + state + .services + .network_service + .seed_default_data(network.id, user.clone().into()) + .await?; + } + let topology = Topology::new(TopologyBase::new("My Topology".to_string(), network.id)); state @@ -179,14 +194,6 @@ pub async fn onboarding( .create(topology, user.clone().into()) .await?; - if request.populate_seed_data { - state - .services - .network_service - .seed_default_data(network.id, user.into()) - .await?; - } - if let Some(integrated_daemon_url) = &state.config.integrated_daemon_url { let api_key = state .services @@ -211,5 +218,20 @@ pub async fn onboarding( .await?; } + state + .services + .event_bus + .publish_telemetry(TelemetryEvent { + id: Uuid::new_v4(), + organization_id: org.id, + operation: TelemetryOperation::OnboardingModalCompleted, + timestamp: Utc::now(), + authentication: user.into(), + metadata: serde_json::json!({ + "is_onboarding_step": true + }), + }) + .await?; + Ok(Json(ApiResponse::success(updated_org))) } diff --git a/backend/src/server/shared/services/factory.rs b/backend/src/server/shared/services/factory.rs index 30a4e73d..c89b0186 100644 --- a/backend/src/server/shared/services/factory.rs +++ b/backend/src/server/shared/services/factory.rs @@ -5,7 +5,7 @@ use crate::server::{ config::ServerConfig, daemons::service::DaemonService, discovery::service::DiscoveryService, - email::service::EmailService, + email::{plunk::PlunkEmailProvider, smtp::SmtpEmailProvider, traits::EmailService}, groups::service::GroupService, hosts::service::HostService, logging::service::LoggingService, @@ -103,8 +103,32 @@ impl ServiceFactory { subnet_service.clone(), event_bus.clone(), )); + let user_service = Arc::new(UserService::new(storage.users.clone(), event_bus.clone())); + let email_service = config.clone().and_then(|c| { + // Prefer Plunk if API key is provided + if let Some(plunk_api_key) = c.plunk_api_key { + let provider = Box::new(PlunkEmailProvider::new(plunk_api_key)); + return Some(Arc::new(EmailService::new(provider, user_service.clone()))); + } + + // Fall back to SMTP + if let (Some(smtp_username), Some(smtp_password), Some(smtp_email), Some(smtp_relay)) = + (c.smtp_username, c.smtp_password, c.smtp_email, c.smtp_relay) + { + let provider = + SmtpEmailProvider::new(smtp_username, smtp_password, smtp_email, smtp_relay) + .ok()?; + return Some(Arc::new(EmailService::new( + Box::new(provider), + user_service.clone(), + ))); + } + + None + }); + let billing_service = config.clone().and_then(|c| { if let Some(strip_secret) = c.stripe_secret && let Some(webhook_secret) = c.stripe_webhook_secret @@ -115,23 +139,12 @@ impl ServiceFactory { organization_service.clone(), user_service.clone(), network_service.clone(), + event_bus.clone(), ))); } None }); - let email_service = config.clone().and_then(|c| { - if let (Some(smtp_username), Some(smtp_password), Some(smtp_email), Some(smtp_relay)) = - (c.smtp_username, c.smtp_password, c.smtp_email, c.smtp_relay) - { - return Some(Arc::new( - EmailService::new(smtp_username, smtp_password, smtp_email, smtp_relay) - .unwrap(), - )); - } - None - }); - let auth_service = Arc::new(AuthService::new( user_service.clone(), organization_service.clone(), @@ -140,29 +153,14 @@ impl ServiceFactory { )); let oidc_service = config.and_then(|c| { - if let ( - Some(issuer_url), - Some(redirect_url), - Some(client_id), - Some(client_secret), - Some(provider_name), - ) = ( - &c.oidc_issuer_url, - &c.oidc_redirect_url, - &c.oidc_client_id, - &c.oidc_client_secret, - &c.oidc_provider_name, - ) { - return Some(Arc::new(OidcService::new(OidcService { - issuer_url: issuer_url.to_owned(), - client_id: client_id.to_owned(), - client_secret: client_secret.to_owned(), - redirect_url: redirect_url.to_owned(), - provider_name: provider_name.to_owned(), - auth_service: auth_service.clone(), - user_service: user_service.clone(), - event_bus: event_bus.clone(), - }))); + if let Some(oidc_providers) = c.oidc_providers { + return Some(Arc::new(OidcService::new( + oidc_providers, + &c.public_url, + auth_service.clone(), + user_service.clone(), + event_bus.clone(), + ))); } None }); @@ -173,13 +171,19 @@ impl ServiceFactory { .await; event_bus.register_subscriber(logging_service.clone()).await; - event_bus.register_subscriber(host_service.clone()).await; + event_bus + .register_subscriber(organization_service.clone()) + .await; if let Some(billing_service) = billing_service.clone() { event_bus.register_subscriber(billing_service).await; } + if let Some(email_service) = email_service.clone() { + event_bus.register_subscriber(email_service).await; + } + Ok(Self { user_service, auth_service, diff --git a/backend/src/server/shared/storage/generic.rs b/backend/src/server/shared/storage/generic.rs index e842a612..e6e75b15 100644 --- a/backend/src/server/shared/storage/generic.rs +++ b/backend/src/server/shared/storage/generic.rs @@ -96,6 +96,7 @@ where SqlValue::Subnets(v) => query.bind(serde_json::to_value(v)?), SqlValue::Services(v) => query.bind(serde_json::to_value(v)?), SqlValue::Groups(v) => query.bind(serde_json::to_value(v)?), + SqlValue::TelemetryOperation(v) => query.bind(serde_json::to_value(v)?), }; Ok(value) @@ -117,7 +118,7 @@ where } query.execute(&self.pool).await?; - tracing::debug!("Created {}: {}", T::table_name(), entity); + tracing::trace!("Created {}: {}", T::table_name(), entity); Ok(entity.clone()) } @@ -173,7 +174,7 @@ where query = Self::bind_value(query, value)?; } - tracing::debug!("Updated {}", entity); + tracing::trace!("Updated {}", entity); query.execute(&self.pool).await?; Ok(entity.clone()) @@ -184,7 +185,7 @@ where sqlx::query(&query_str).bind(id).execute(&self.pool).await?; - tracing::debug!("Deleted {} with id: {}", T::table_name(), id); + tracing::trace!("Deleted {} with id: {}", T::table_name(), id); Ok(()) } @@ -203,7 +204,7 @@ where let deleted_count = result.rows_affected() as usize; - tracing::debug!( + tracing::trace!( "Bulk deleted {} {}s (requested: {}, deleted: {})", deleted_count, T::table_name(), diff --git a/backend/src/server/shared/storage/traits.rs b/backend/src/server/shared/storage/traits.rs index 9cac8d01..724f1a98 100644 --- a/backend/src/server/shared/storage/traits.rs +++ b/backend/src/server/shared/storage/traits.rs @@ -2,6 +2,7 @@ use std::net::IpAddr; use crate::server::groups::r#impl::base::Group; use crate::server::services::r#impl::base::Service; +use crate::server::shared::events::types::TelemetryOperation; use crate::server::subnets::r#impl::base::Subnet; use crate::server::{ billing::types::base::BillingPlan, @@ -108,4 +109,5 @@ pub enum SqlValue { Subnets(Vec), Services(Vec), Groups(Vec), + TelemetryOperation(Vec), } diff --git a/backend/src/server/subnets/handlers.rs b/backend/src/server/subnets/handlers.rs index ee10ff99..2e94fee3 100644 --- a/backend/src/server/subnets/handlers.rs +++ b/backend/src/server/subnets/handlers.rs @@ -2,6 +2,7 @@ use crate::server::auth::middleware::{AuthenticatedEntity, MemberOrDaemon}; use crate::server::shared::handlers::traits::{ CrudHandlers, bulk_delete_handler, delete_handler, get_by_id_handler, update_handler, }; +use crate::server::shared::storage::traits::StorableEntity; use crate::server::shared::types::api::ApiError; use crate::server::{ config::AppState, @@ -80,14 +81,14 @@ async fn get_all_subnets( match &entity { AuthenticatedEntity::User { user_id, .. } => { tracing::debug!( - entity_type = "subnet", + entity_type = Subnet::table_name(), user_id = %user_id, "Get all request received" ); } AuthenticatedEntity::Daemon { .. } => { tracing::debug!( - entity_type = "subnet", + entity_type = Subnet::table_name(), daemon_id = %entity.entity_id(), "Get all request received" ); diff --git a/backend/src/server/topology/handlers.rs b/backend/src/server/topology/handlers.rs index af9f210c..7f1d189b 100644 --- a/backend/src/server/topology/handlers.rs +++ b/backend/src/server/topology/handlers.rs @@ -2,6 +2,7 @@ use crate::server::{ auth::middleware::{AuthenticatedUser, RequireMember}, config::AppState, shared::{ + events::types::{TelemetryEvent, TelemetryOperation}, handlers::traits::{ CrudHandlers, delete_handler, get_all_handler, get_by_id_handler, update_handler, }, @@ -20,8 +21,10 @@ use axum::{ }, routing::{delete, get, post, put}, }; +use chrono::Utc; use futures::{Stream, stream}; use std::{convert::Infallible, sync::Arc}; +use uuid::Uuid; pub fn create_router() -> Router> { Router::new() @@ -169,7 +172,32 @@ async fn rebuild( topology.base.nodes = nodes; topology.clear_stale(); - service.update(&mut topology, user.into()).await?; + service.update(&mut topology, user.clone().into()).await?; + + let organization = state + .services + .organization_service + .get_by_id(&user.organization_id) + .await?; + + if let Some(organization) = organization + && organization.not_onboarded(&TelemetryOperation::FirstTopologyRebuild) + { + state + .services + .event_bus + .publish_telemetry(TelemetryEvent { + id: Uuid::new_v4(), + organization_id: user.organization_id, + operation: TelemetryOperation::FirstTopologyRebuild, + timestamp: Utc::now(), + authentication: user.into(), + metadata: serde_json::json!({ + "is_onboarding_step": true + }), + }) + .await?; + } // Return will be handled through event subscriber which triggers SSE diff --git a/backend/src/server/topology/service/main.rs b/backend/src/server/topology/service/main.rs index 2c0b5eae..6a2b73bd 100644 --- a/backend/src/server/topology/service/main.rs +++ b/backend/src/server/topology/service/main.rs @@ -2,18 +2,27 @@ use std::{collections::HashMap, sync::Arc}; use anyhow::Error; use async_trait::async_trait; +use chrono::Utc; use petgraph::{Graph, graph::NodeIndex, visit::EdgeRef}; use tokio::sync::broadcast; use uuid::Uuid; use crate::server::{ + auth::middleware::AuthenticatedEntity, groups::{r#impl::base::Group, service::GroupService}, hosts::{r#impl::base::Host, service::HostService}, services::{r#impl::base::Service, service::ServiceService}, shared::{ - events::bus::EventBus, + events::{ + bus::EventBus, + types::{EntityEvent, EntityOperation}, + }, services::traits::{CrudService, EventBusService}, - storage::{filter::EntityFilter, generic::GenericPostgresStorage}, + storage::{ + filter::EntityFilter, + generic::GenericPostgresStorage, + traits::{StorableEntity, Storage}, + }, }, subnets::{r#impl::base::Subnet, service::SubnetService}, topology::{ @@ -58,6 +67,65 @@ impl CrudService for TopologyService { fn storage(&self) -> &Arc> { &self.storage } + + /// Create entity + async fn create( + &self, + entity: Topology, + authentication: AuthenticatedEntity, + ) -> Result { + let mut topology = if entity.id() == Uuid::nil() { + Topology::new(entity.get_base()) + } else { + entity + }; + + let (hosts, subnets, groups) = self.get_entity_data(topology.base.network_id).await?; + + let services = self + .get_service_data(topology.base.network_id, &topology.base.options) + .await?; + + let params = BuildGraphParams { + hosts: &hosts, + services: &services, + subnets: &subnets, + groups: &groups, + old_edges: &[], + old_nodes: &[], + options: &topology.base.options, + }; + + let (nodes, edges) = self.build_graph(params); + + topology.base.edges = edges; + topology.base.nodes = nodes; + topology.base.hosts = hosts; + topology.base.services = services; + topology.base.subnets = subnets; + topology.base.groups = groups; + topology.clear_stale(); + + let created = self.storage().create(&topology).await?; + + self.event_bus() + .publish_entity(EntityEvent { + id: Uuid::new_v4(), + entity_id: created.id(), + network_id: self.get_network_id(&created), + organization_id: self.get_organization_id(&created), + entity_type: created.clone().into(), + operation: EntityOperation::Created, + timestamp: Utc::now(), + metadata: serde_json::json!({ + "clear_stale": true + }), + authentication, + }) + .await?; + + Ok(created) + } } pub struct BuildGraphParams<'a> { diff --git a/backend/src/server/topology/service/subscriber.rs b/backend/src/server/topology/service/subscriber.rs index 8f19a47c..904bf2a9 100644 --- a/backend/src/server/topology/service/subscriber.rs +++ b/backend/src/server/topology/service/subscriber.rs @@ -29,6 +29,7 @@ struct TopologyChanges { removed_subnets: std::collections::HashSet, removed_groups: std::collections::HashSet, should_mark_stale: bool, + clear_stale: bool, } #[async_trait] @@ -41,7 +42,7 @@ impl EventSubscriber for TopologyService { (EntityDiscriminants::Group, None), ( EntityDiscriminants::Topology, - Some(vec![EntityOperation::Updated]), + Some(vec![EntityOperation::Created, EntityOperation::Updated]), ), ])) } @@ -68,10 +69,19 @@ impl EventSubscriber for TopologyService { .and_then(|v| serde_json::from_value::(v.clone()).ok()) .unwrap_or(false); + // Check if any event clears staleness (only set on topology create to avoid showing topology as stale on first load) + let clear_stale = entity_event + .metadata + .get("clear_stale") + .and_then(|v| serde_json::from_value::(v.clone()).ok()) + .unwrap_or(false); + // Topology updates from changes to options should be applied immediately and not processed alongside // other changes, otherwise another call to topology_service.update will be made which will trigger // an infinite loop - if let Entity::Topology(mut topology) = entity_event.entity_type { + if let Entity::Topology(mut topology) = entity_event.entity_type.clone() + && entity_event.operation == EntityOperation::Updated + { if trigger_stale { topology.base.is_stale = true; } @@ -106,6 +116,8 @@ impl EventSubscriber for TopologyService { if trigger_stale { // User will be prompted to update entities changes.should_mark_stale = true; + } else if clear_stale { + changes.clear_stale = true; } else { // It's safe to automatically update entities match entity_event.entity_type { @@ -155,10 +167,15 @@ impl EventSubscriber for TopologyService { } // Mark stale if needed - if changes.should_mark_stale { + if changes.should_mark_stale && !changes.clear_stale { topology.base.is_stale = true; } + // Clear stale - this only happens on topology create to avoid a stale state when loading app for the first time + if changes.clear_stale { + topology.base.is_stale = false; + } + if changes.updated_hosts { topology.base.hosts = hosts.clone() } @@ -196,6 +213,6 @@ impl EventSubscriber for TopologyService { } fn name(&self) -> &str { - "topology_validation" + "topology_stale" } } diff --git a/backend/tests/integration.rs b/backend/tests/integration.rs index cff050ff..4e333479 100644 --- a/backend/tests/integration.rs +++ b/backend/tests/integration.rs @@ -117,6 +117,7 @@ impl TestClient { let register_request = RegisterRequest { email: email.clone(), password: password.to_string(), + subscribed: false, }; let response = self diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index 69a1de79..1283c535 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -142,15 +142,11 @@ environment: | **Secure Cookies** | `--use-secure-session-cookies` | `NETVISOR_USE_SECURE_SESSION_COOKIES` | `false` | Enable HTTPS-only cookies | | **Integrated Daemon URL** | `--integrated-daemon-url` | `NETVISOR_INTEGRATED_DAEMON_URL` | `http://172.17.0.1:60073` | URL to reach daemon in default docker compose | | **Disable Registration** | `--disable-registration` | `NETVISOR_DISABLE_REGISTRATION` | `false` | Disable new user registration | -| **OIDC Issuer URL** | `--oidc-issuer-url` | `NETVISOR_OIDC_ISSUER_URL` | - | OIDC provider's issuer URL (must end with `/`) | -| **OIDC Client ID** | `--oidc-client-id` | `NETVISOR_OIDC_CLIENT_ID` | - | OAuth2 client ID from provider | -| **OIDC Client Secret** | `--oidc-client-secret` | `NETVISOR_OIDC_CLIENT_SECRET` | - | OAuth2 client secret from provider | -| **OIDC Provider Name** | `--oidc-provider-name` | `NETVISOR_OIDC_PROVIDER_NAME` | - | Display name shown in UI (e.g., "Authentik", "Keycloak") | -| **OIDC Redirect URL** | `--oidc-redirect-url` | `NETVISOR_OIDC_REDIRECT_URL` | - | URL from OIDC provider for authentication redirect | | **SMTP Username** | `--smtp-username` | `NETVISOR_SMTP_USERNAME` | - | SMTP username for email features (password reset, notifications) | | **SMTP Password** | `--smtp-password` | `NETVISOR_SMTP_PASSWORD` | - | SMTP password for email authentication | | **SMTP Relay** | `--smtp-relay` | `NETVISOR_SMTP_RELAY` | - | SMTP server address (e.g., `smtp.gmail.com`) | | **SMTP Email** | `--smtp-email` | `NETVISOR_SMTP_EMAIL` | - | Sender email address for outgoing emails | +| **Client IP Source** | `--client-ip-source` | `NETVISOR_CLIENT_IP_SOURCE` | - | Source of IP address from request headers, used to log accurate IP address in auth logs while using a reverse proxy. Refer to [axum-client-ip](https://github.com/imbolc/axum-client-ip?tab=readme-ov-file#configurable-vs-specific-extractors) docs for values you can set. | ### Integrated Daemon URL @@ -215,29 +211,7 @@ netvisor-server: NetVisor supports OpenID Connect (OIDC) for enterprise authentication with providers like Authentik, Keycloak, Auth0, Okta, and others. -### Server Configuration - -Add these environment variables to your server configuration: - -```yaml -environment: - # Required OIDC settings - - NETVISOR_OIDC_ISSUER_URL=https://your-provider.com/application/o/netvisor/ - - NETVISOR_OIDC_CLIENT_ID=your-client-id - - NETVISOR_OIDC_CLIENT_SECRET=your-client-secret - - NETVISOR_OIDC_REDIRECT_URL=https://auth.example.com/callback - - NETVISOR_OIDC_PROVIDER_NAME=Authentik -``` - -### Parameter Details - -| Parameter | Environment Variable | Description | -|-----------|---------------------|-------------| -| **Issuer URL** | `NETVISOR_OIDC_ISSUER_URL` | Your OIDC provider's issuer URL (ends in `/`) | -| **Client ID** | `NETVISOR_OIDC_CLIENT_ID` | OAuth2 client ID from your provider | -| **Client Secret** | `NETVISOR_OIDC_CLIENT_SECRET` | OAuth2 client secret from your provider | -| **Redirect URL** | `NETVISOR_OIDC_REDIRECT_URL` | URL provider redirects to after auth | -| **Provider Name** | `NETVISOR_OIDC_PROVIDER_NAME` | Display name shown in UI (e.g., "Authentik", "Keycloak") | +To get started, refer to oidc.toml.example. You can set up multiple OIDC providers by adding entries with a `[[oidc_providers]]` header and the listd fields. Create a copy of the file named oidc.toml and fill the fields for your provider(s). ### Provider Configuration @@ -263,20 +237,13 @@ https://your-netvisor-domain/api/auth/oidc/callback - Provider: OAuth2/OpenID Provider 2. **Configure Provider**: - - Redirect URI: `http://netvisor.local:60072/api/auth/oidc/callback` + - Redirect URI: `http://netvisor.local:60072/api/auth/oidc/authentik/callback` + - Note: the value you use in place of `authentik` in this url for your provider needs to match the `slug` field in oidc.toml. - Scopes: `openid email profile` - Client Type: Confidential - Copy Client ID and Client Secret -3. **Set NetVisor Environment Variables**: -```yaml -environment: - - NETVISOR_OIDC_ISSUER_URL=https://authentik.company.com/application/o/netvisor/ - - NETVISOR_OIDC_CLIENT_ID=ABC123DEF456 - - NETVISOR_OIDC_CLIENT_SECRET=xyz789uvw012 - - NETVISOR_OIDC_REDIRECT_URL=https://auth.example.com/callback - - NETVISOR_OIDC_PROVIDER_NAME=Authentik -``` +3. **Set Variables in oidc.toml** 4. **Restart server** and test login diff --git a/oidc.toml.example b/oidc.toml.example new file mode 100644 index 00000000..ddfef3fe --- /dev/null +++ b/oidc.toml.example @@ -0,0 +1,18 @@ +[[oidc_providers]] +# Display name shown in UI (e.g., "Authentik", "Keycloak") +name = "Authentik" + +# Unique lowercase slug to use in callback URLs +slug = "authentik" + +# Optional logo to display next to provider +logo = "https://cdn.jsdelivr.net/gh/homarr-labs/dashboard-icons/svg/authentik.svg" + +# OIDC provider's issuer URL (must end with `/`) +issuer_url = "YOUR_ISSUER_URL" + +# OAuth2 client ID from provider +client_id = "YOUR_CLIENT_ID" + +# OAuth2 client secret from provider +client_secret = "YOUR_CLIENT_SECRET" \ No newline at end of file diff --git a/package-lock.json b/package-lock.json index 45a0bf75..4458c8ff 100644 --- a/package-lock.json +++ b/package-lock.json @@ -5,6 +5,8 @@ "packages": { "": { "dependencies": { + "-": "^0.0.1", + "baseline-browser-mapping": "^2.8.31", "html-to-image": "^1.11.13" }, "devDependencies": { @@ -13,6 +15,12 @@ "prismjs": "^1.30.0" } }, + "node_modules/-": { + "version": "0.0.1", + "resolved": "https://registry.npmjs.org/-/-/--0.0.1.tgz", + "integrity": "sha512-3HfneK3DGAm05fpyj20sT3apkNcvPpCuccOThOPdzz8sY7GgQGe0l93XH9bt+YzibcTIgUAIMoyVJI740RtgyQ==", + "license": "UNLICENSED" + }, "node_modules/@jridgewell/gen-mapping": { "version": "0.3.13", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.13.tgz", @@ -135,6 +143,15 @@ "node": ">= 0.4" } }, + "node_modules/baseline-browser-mapping": { + "version": "2.8.31", + "resolved": "https://registry.npmjs.org/baseline-browser-mapping/-/baseline-browser-mapping-2.8.31.tgz", + "integrity": "sha512-a28v2eWrrRWPpJSzxc+mKwm0ZtVx/G8SepdQZDArnXYU/XS+IF6mp8aB/4E+hH1tyGCoDo3KlUCdlSxGDsRkAw==", + "license": "Apache-2.0", + "bin": { + "baseline-browser-mapping": "dist/cli.js" + } + }, "node_modules/clsx": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz", diff --git a/package.json b/package.json index 8a02e31f..2a5b3f15 100644 --- a/package.json +++ b/package.json @@ -1,5 +1,7 @@ { "dependencies": { + "-": "^0.0.1", + "baseline-browser-mapping": "^2.8.31", "html-to-image": "^1.11.13" }, "devDependencies": { diff --git a/ui/package-lock.json b/ui/package-lock.json index c701d761..d618886d 100644 --- a/ui/package-lock.json +++ b/ui/package-lock.json @@ -13,6 +13,7 @@ "@xyflow/svelte": "^1.2.4", "deepmerge": "^4.3.1", "elkjs": "^0.10.0", + "free-email-domains-list": "^1.0.16", "html-to-image": "^1.11.13", "ipaddr.js": "^2.2.0", "jquery": "^3.7.1", @@ -2692,6 +2693,15 @@ "url": "https://github.com/sponsors/rawify" } }, + "node_modules/free-email-domains-list": { + "version": "1.0.16", + "resolved": "https://registry.npmjs.org/free-email-domains-list/-/free-email-domains-list-1.0.16.tgz", + "integrity": "sha512-HTHSFzp1zSqRCT7HFefOsoOetHNu7s9fuLdsNb4FduGQx445dT/F5eIHwuNKkFJvIV8FrMKElWlBCFvrh/EEFw==", + "license": "MIT", + "peerDependencies": { + "validator": "^13.12.0" + } + }, "node_modules/fsevents": { "version": "2.3.3", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", @@ -4664,6 +4674,16 @@ "uuid": "dist-node/bin/uuid" } }, + "node_modules/validator": { + "version": "13.15.23", + "resolved": "https://registry.npmjs.org/validator/-/validator-13.15.23.tgz", + "integrity": "sha512-4yoz1kEWqUjzi5zsPbAS/903QXSYp0UOtHsPpp7p9rHAw/W+dkInskAE386Fat3oKRROwO98d9ZB0G4cObgUyw==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">= 0.10" + } + }, "node_modules/vite": { "version": "7.2.4", "resolved": "https://registry.npmjs.org/vite/-/vite-7.2.4.tgz", diff --git a/ui/package.json b/ui/package.json index a8543f2a..b973f268 100644 --- a/ui/package.json +++ b/ui/package.json @@ -42,6 +42,7 @@ "@xyflow/svelte": "^1.2.4", "deepmerge": "^4.3.1", "elkjs": "^0.10.0", + "free-email-domains-list": "^1.0.16", "html-to-image": "^1.11.13", "ipaddr.js": "^2.2.0", "jquery": "^3.7.1", diff --git a/ui/src/app.d.ts b/ui/src/app.d.ts index 9634bfd0..6a0e6540 100644 --- a/ui/src/app.d.ts +++ b/ui/src/app.d.ts @@ -11,4 +11,9 @@ declare global { } } +declare module 'freemail' { + export function isFree(email: string): boolean; + export function disposable(email: string): boolean; +} + export {}; diff --git a/ui/src/lib/features/api_keys/components/ApiKeyCard.svelte b/ui/src/lib/features/api_keys/components/ApiKeyCard.svelte index 8fc5404a..613c59b3 100644 --- a/ui/src/lib/features/api_keys/components/ApiKeyCard.svelte +++ b/ui/src/lib/features/api_keys/components/ApiKeyCard.svelte @@ -1,7 +1,6 @@ @@ -58,7 +87,7 @@ - +
- + + {#if hasOidcProviders} +
+
+
+
+
+ or +
+
+ +
+ {#each oidcProviders as provider (provider.slug)} + + {/each} +
+ {/if} + + {#if enableEmailOptIn} +
+ +
+ {/if} + {#if onSwitchToLogin}
@@ -117,7 +186,7 @@ Already have an account? - - - - {#if !collapsedCategories[category]} - {#each categoryFeatures as featureKey (featureKey)} - {@const featureDescription = features.getDescription(featureKey)} - {@const comingSoon = isComingSoon(featureKey)} - - -
- {features.getName(featureKey)} -
- {#if featureDescription} -
- {featureDescription} -
- {/if} - - - {#each filteredPlans as plan (plan.type)} - {@const value = getFeatureValue(plan.type, featureKey)} - - {#if comingSoon && value} - - {:else if typeof value === 'boolean'} - {#if value} - - {:else} - - {/if} - {:else if value === null} - - {:else} - {value} - {/if} - - {/each} - - {/each} - {/if} - {/each} - - -
-
-
-
-
- {#each filteredPlans as plan (plan.type)} -
- -
- {/each} +
+ {/each} +
diff --git a/ui/src/lib/features/organizations/store.ts b/ui/src/lib/features/organizations/store.ts index e1df490e..09fbdf31 100644 --- a/ui/src/lib/features/organizations/store.ts +++ b/ui/src/lib/features/organizations/store.ts @@ -44,12 +44,14 @@ export async function updateOrganization(org: Organization) { export async function createInvite( permissions: UserOrgPermissions, - network_ids: string[] + network_ids: string[], + email: string ): Promise { const request: CreateInviteRequest = { expiration_hours: null, permissions, - network_ids + network_ids, + send_to: email?.length == 0 ? null : email }; const result = await api.request( diff --git a/ui/src/lib/features/organizations/types.ts b/ui/src/lib/features/organizations/types.ts index 67064368..5c0c494d 100644 --- a/ui/src/lib/features/organizations/types.ts +++ b/ui/src/lib/features/organizations/types.ts @@ -9,19 +9,21 @@ export interface Organization { name: string; plan: BillingPlan; plan_status: string; - is_onboarded: boolean; + onboarding: string[]; } export interface CreateInviteRequest { expiration_hours: number | null; permissions: UserOrgPermissions; network_ids: string[]; + send_to: string | null; } export interface OrganizationInvite { id: string; permissions: UserOrgPermissions; url: string; + send_to: string | null; expires_at: string; created_at: string; created_by: string; diff --git a/ui/src/lib/features/users/components/InviteCard.svelte b/ui/src/lib/features/users/components/InviteCard.svelte index 11502b28..f3bd21e1 100644 --- a/ui/src/lib/features/users/components/InviteCard.svelte +++ b/ui/src/lib/features/users/components/InviteCard.svelte @@ -39,6 +39,10 @@ label: 'Created By', value: $users.find((u) => u.id == invite.created_by)?.email || 'Unknown User' }, + { + label: 'Sent To', + value: invite.send_to ? invite.send_to : 'N/A' + }, { label: 'Expires', value: formatTimestamp(invite.expires_at) diff --git a/ui/src/lib/features/users/components/InviteModal.svelte b/ui/src/lib/features/users/components/InviteModal.svelte index 12dff9c0..0f41fa31 100644 --- a/ui/src/lib/features/users/components/InviteModal.svelte +++ b/ui/src/lib/features/users/components/InviteModal.svelte @@ -1,7 +1,15 @@