Skip to content

Commit

Permalink
Merge pull request #28 from nikoshet/chore/add-flags-to-allow-invalid…
Browse files Browse the repository at this point in the history
…-certs-on-db-connections

Add flags to allow invalid certs on db connection
  • Loading branch information
nikoshet committed Jun 11, 2024
2 parents 2d50589 + a6457cb commit e881de2
Show file tree
Hide file tree
Showing 10 changed files with 213 additions and 491 deletions.
582 changes: 104 additions & 478 deletions Cargo.lock

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dms-cdc-operator"
version = "0.1.12"
version = "0.1.13"
edition = "2021"
license = "MIT"
description = "The dms-cdc-operator is a Rust-based utility for comparing the state of a list of tables in an Amazon RDS database with data stored in Parquet files on Amazon S3, particularly useful for change data capture (CDC) scenarios"
Expand All @@ -20,7 +20,7 @@ tokio = { version = "1", features = ["full"] }
anyhow = "1.0"
log = "0.4.21"
colored = "2.1.0"
polars = { version = "0.39.2", features = [
polars = { version = "0.40.0", features = [
"timezones",
"json",
"lazy",
Expand All @@ -32,9 +32,9 @@ polars = { version = "0.39.2", features = [
] }
chrono = "0.4.37"
async-trait = "0.1.79"
rust-pgdatadiff = "0.1.4"
rust-pgdatadiff = "0.1.6"
indexmap = { version = "2.2.6", features = ["serde"] }
polars-core = "0.39.2"
polars-core = "0.40.0"
rust_decimal = "1.35.0"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
Expand All @@ -44,6 +44,8 @@ clap = "4.5.4"
mockall = "0.12.1"
cargo-nextest = "0.9.72"
dms-cdc-operator = { path = ".", version = "0.1.12" }
native-tls = "0.2.12"
postgres-native-tls = "0.5.0"

[dependencies]
indexmap.workspace = true
Expand All @@ -63,6 +65,8 @@ deadpool-postgres.workspace = true
futures.workspace = true
clap.workspace = true
tracing-subscriber.workspace = true
postgres-native-tls.workspace = true
native-tls.workspace = true

[dev-dependencies]
mockall.workspace = true
Expand Down
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ In order to use the tool as a client, you can use `cargo`.
The tool provides two features for running it, which are `Inquire` and `Clap`.

### Using Clap
```shell
```
Usage: dms-cdc-operator-client validate [OPTIONS] --bucket-name <BUCKET_NAME> --s3-prefix <S3_PREFIX> --source-postgres-url <SOURCE_POSTGRES_URL> --target-postgres-url <TARGET_POSTGRES_URL>
Options:
Expand Down Expand Up @@ -88,6 +88,10 @@ Options:
Run only the datadiff
--only-snapshot
Take only a snapshot from S3 to target DB
--accept-invalid-certs-first-db
Accept invalid TLS certificates for the first database
--accept-invalid-certs-second-db
Accept invalid TLS certificates for the second database
-h, --help
Print help
-V, --version
Expand Down
2 changes: 1 addition & 1 deletion dms-cdc-operator-client/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dms-cdc-operator-client"
version = "0.1.12"
version = "0.1.13"
edition = "2021"
license = "MIT"
description = "The dms-cdc-operator-client is a Rust-based client for comparing the state of a list of tables in an Amazon RDS database with data stored in Parquet files on Amazon S3, particularly useful for change data capture (CDC) scenarios"
Expand Down
34 changes: 32 additions & 2 deletions dms-cdc-operator-client/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ enum Commands {
conflicts_with("only_datadiff")
)]
only_snapshot: bool,
/// Accept invalid TLS certificates for the first database
#[arg(long, default_value_t = false, required = false)]
accept_invalid_certs_first_db: bool,
/// Accept invalid TLS certificates for the second database
#[arg(long, default_value_t = false, required = false)]
accept_invalid_certs_second_db: bool,
},
}

Expand All @@ -120,6 +126,8 @@ fn main_clap() -> Result<CDCOperatorPayload> {
start_position,
only_datadiff,
only_snapshot,
accept_invalid_certs_first_db,
accept_invalid_certs_second_db,
} => {
let payload = CDCOperatorPayload::new(
bucket_name,
Expand All @@ -137,6 +145,8 @@ fn main_clap() -> Result<CDCOperatorPayload> {
start_position,
only_datadiff,
only_snapshot,
accept_invalid_certs_first_db,
accept_invalid_certs_second_db,
);

Ok(payload)
Expand Down Expand Up @@ -231,6 +241,18 @@ fn main_inquire() -> Result<CDCOperatorPayload> {
.with_help_message("Take only a snapshot from S3 to target DB (no data comparison)")
.prompt()?;

let accept_invalid_certs_first_db =
Confirm::new("Accept invalid TLS certificates for the first database")
.with_default(false)
.with_help_message("Accept invalid TLS certificates for the first database")
.prompt()?;

let accept_invalid_certs_second_db =
Confirm::new("Accept invalid TLS certificates for the second database")
.with_default(false)
.with_help_message("Accept invalid TLS certificates for the second database")
.prompt()?;

let payload = CDCOperatorPayload::new(
bucket_name,
s3_prefix,
Expand All @@ -255,6 +277,8 @@ fn main_inquire() -> Result<CDCOperatorPayload> {
start_position.parse::<i64>().unwrap(),
only_datadiff,
only_snapshot,
accept_invalid_certs_first_db,
accept_invalid_certs_second_db,
);

Ok(payload)
Expand Down Expand Up @@ -282,7 +306,9 @@ async fn main() -> Result<()> {
cdc_operator_payload.database_name(),
cdc_operator_payload.max_connections(),
);
let pg_pool = db_client.connect_to_postgres().await;
let pg_pool = db_client
.connect_to_postgres(cdc_operator_payload.accept_invalid_certs_first_db())
.await;
// Create a PostgresOperatorImpl instance
let postgres_operator = PostgresOperatorImpl::new(pg_pool);

Expand All @@ -292,7 +318,9 @@ async fn main() -> Result<()> {
"public",
cdc_operator_payload.max_connections(),
);
let target_pg_pool = target_db_client.connect_to_postgres().await;
let target_pg_pool = target_db_client
.connect_to_postgres(cdc_operator_payload.accept_invalid_certs_second_db())
.await;
// Create a PostgresOperatorImpl instance for the target database
let target_postgres_operator = PostgresOperatorImpl::new(target_pg_pool);

Expand Down Expand Up @@ -338,6 +366,8 @@ async fn main() -> Result<()> {
cdc_operator_payload.schema_name(),
cdc_operator_payload.chunk_size(),
cdc_operator_payload.start_position(),
cdc_operator_payload.accept_invalid_certs_first_db(),
cdc_operator_payload.accept_invalid_certs_second_db(),
);

let _ = CDCOperator::validate(cdc_operator_validate_payload).await;
Expand Down
2 changes: 2 additions & 0 deletions src/cdc/cdc_operator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ impl CDCOperator {
cdc_operator_validate_payload.included_tables().to_vec(),
cdc_operator_validate_payload.excluded_tables().to_vec(),
cdc_operator_validate_payload.schema_name(),
cdc_operator_validate_payload.accept_invalid_certs_first_db(),
cdc_operator_validate_payload.accept_invalid_certs_second_db(),
);
let diff_result = Differ::diff_dbs(payload).await;
if diff_result.is_err() {
Expand Down
20 changes: 20 additions & 0 deletions src/cdc/cdc_operator_payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ pub struct CDCOperatorPayload {
start_position: i64,
only_datadiff: bool,
only_snapshot: bool,
accept_invalid_certs_first_db: bool,
accept_invalid_certs_second_db: bool,
}

impl CDCOperatorPayload {
Expand All @@ -39,6 +41,8 @@ impl CDCOperatorPayload {
/// * `start_position` - The start position for pgdatadiff validation.
/// * `only_datadiff` - Whether to only validate the data difference.
/// * `only_snapshot` - Whether to only take a snapshot and skip validation.
/// * `accept_invalid_certs_first_db` - Whether to accept invalid certificates for the first database.
/// * `accept_invalid_certs_second_db` - Whether to accept invalid certificates for the second database.
///
/// # Returns
///
Expand All @@ -60,6 +64,8 @@ impl CDCOperatorPayload {
start_position: i64,
only_datadiff: bool,
only_snapshot: bool,
accept_invalid_certs_first_db: bool,
accept_invalid_certs_second_db: bool,
) -> Self {
if only_datadiff && only_snapshot {
panic!("Cannot run both only_datadiff and only_snapshot at the same time");
Expand All @@ -81,6 +87,8 @@ impl CDCOperatorPayload {
start_position,
only_datadiff,
only_snapshot,
accept_invalid_certs_first_db,
accept_invalid_certs_second_db,
}
}

Expand Down Expand Up @@ -151,6 +159,14 @@ impl CDCOperatorPayload {
pub fn only_snapshot(&self) -> bool {
self.only_snapshot
}

pub fn accept_invalid_certs_first_db(&self) -> bool {
self.accept_invalid_certs_first_db
}

pub fn accept_invalid_certs_second_db(&self) -> bool {
self.accept_invalid_certs_second_db
}
}

#[cfg(test)]
Expand All @@ -175,6 +191,8 @@ mod tests {
let start_position = 0;
let only_datadiff = true;
let only_snapshot = true;
let accept_invalid_certs_first_db = false;
let accept_invalid_certs_second_db = false;

let _validator = CDCOperatorPayload::new(
bucket_name,
Expand All @@ -192,6 +210,8 @@ mod tests {
start_position,
only_datadiff,
only_snapshot,
accept_invalid_certs_first_db,
accept_invalid_certs_second_db,
);
}
}
15 changes: 15 additions & 0 deletions src/cdc/validate_payload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ pub struct CDCOperatorValidatePayload {
pub schema_name: String,
pub chunk_size: i64,
pub start_position: i64,
pub accept_invalid_certs_first_db: bool,
pub accept_invalid_certs_second_db: bool,
}

impl CDCOperatorValidatePayload {
#[allow(clippy::too_many_arguments)]
pub fn new(
source_postgres_url: impl Into<String>,
target_postgres_url: impl Into<String>,
Expand All @@ -17,6 +20,8 @@ impl CDCOperatorValidatePayload {
schema_name: impl Into<String>,
chunk_size: i64,
start_position: i64,
accept_invalid_certs_first_db: bool,
accept_invalid_certs_second_db: bool,
) -> Self {
CDCOperatorValidatePayload {
source_postgres_url: source_postgres_url.into(),
Expand All @@ -26,6 +31,8 @@ impl CDCOperatorValidatePayload {
schema_name: schema_name.into(),
chunk_size,
start_position,
accept_invalid_certs_first_db,
accept_invalid_certs_second_db,
}
}

Expand Down Expand Up @@ -56,4 +63,12 @@ impl CDCOperatorValidatePayload {
pub fn start_position(&self) -> i64 {
self.start_position
}

pub fn accept_invalid_certs_first_db(&self) -> bool {
self.accept_invalid_certs_first_db
}

pub fn accept_invalid_certs_second_db(&self) -> bool {
self.accept_invalid_certs_second_db
}
}
6 changes: 3 additions & 3 deletions src/dataframe/dataframe_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub trait DataframeOperator {
async fn create_dataframe_from_parquet_file(
&self,
payload: &CreateDataframePayload,
) -> Result<Option<DataFrame>>;
) -> Result<Option<polars::prelude::DataFrame>>;
}

pub struct DataframeOperatorImpl<'a> {
Expand All @@ -48,7 +48,7 @@ impl DataframeOperator for DataframeOperatorImpl<'_> {
async fn create_dataframe_from_parquet_file(
&self,
payload: &CreateDataframePayload,
) -> Result<Option<DataFrame>> {
) -> Result<Option<polars::prelude::DataFrame>> {
// If we used LazyFrame, we would have an issue with tokio, since we should have to block on the tokio runtime untill the
// result is ready with .collect(). To avoid this, we use the ParquetReader, which is a synchronous reader.
// For LazyFrame, we would have to use the following code:
Expand Down Expand Up @@ -89,7 +89,7 @@ impl DataframeOperator for DataframeOperatorImpl<'_> {

#[cfg(test)]
mod tests {
use polars::frame::DataFrame;
use polars::prelude::DataFrame;

use crate::dataframe::dataframe_ops::{
CreateDataframePayload, DataframeOperator, MockDataframeOperator,
Expand Down
25 changes: 23 additions & 2 deletions src/postgres/postgres_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,34 @@ impl PostgresConfig {
/// # Returns
///
/// A connection pool to the Postgres database.
pub async fn connect_to_postgres(&self) -> Pool {
pub async fn connect_to_postgres(&self, accept_invalid_certs: bool) -> Pool {
let connection_string = self.postgres_url.to_string();
let max_connections: usize = self.max_connections as usize;
let mut cfg = Config::new();
cfg.url = Some(connection_string);
cfg.pool = Some(deadpool_postgres::PoolConfig::new(max_connections));
cfg.create_pool(Some(Runtime::Tokio1), NoTls).unwrap()

let tls_connector = if accept_invalid_certs {
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;

let tls_connector = TlsConnector::builder()
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true)
.build()
.unwrap();

Some(MakeTlsConnector::new(tls_connector))
} else {
None
};

if accept_invalid_certs {
cfg.create_pool(Some(Runtime::Tokio1), tls_connector.clone().unwrap())
.unwrap()
} else {
cfg.create_pool(Some(Runtime::Tokio1), NoTls).unwrap()
}
}

/// Returns the connection string.
Expand Down

0 comments on commit e881de2

Please sign in to comment.