Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/aws.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
//! AWS S3 upload module
//!
//! This was adapted from the
//! [official examples](https://github.com/awslabs/aws-sdk-rust/blob/main/examples/s3/README.md)
use std::{env, fs::File, path::PathBuf};

use anyhow::{anyhow, Result};
Expand All @@ -16,14 +20,21 @@ use crate::config::Params;
const CHUNK_SIZE: u64 = 1024 * 1024 * 5;
const MAX_CHUNKS: u64 = 10000;

/// Upload a file to AWS S3 using multipart.
///
/// The `_temp_dir` is not used but needs to be kept around until the upload is complete. It going out of scope will
/// delete the temp folder.
pub async fn upload_file(archive_path: PathBuf, _temp_dir: TempDir, params: &Params) -> Result<()> {
// we want to use `from_env` below, so make sure that environment variables are set properly, even if data comes
// from the command line args
env::set_var("AWS_ACCESS_KEY_ID", &params.aws_key_id);
env::set_var("AWS_SECRET_ACCESS_KEY", &params.aws_key);
let shared_config = aws_config::from_env()
.region(params.aws_region.region().await)
.region(params.aws_region.region().await) // set the region
.load()
.await;
let client = Client::new(&shared_config);
// if the desired filename was specified, append the file extension in case it was not already provided
let filename = params
.filename
.clone()
Expand All @@ -32,6 +43,7 @@ pub async fn upload_file(archive_path: PathBuf, _temp_dir: TempDir, params: &Par
f => f,
})
.unwrap_or_else(|| {
// default filename is awsbck_ + the folder name + .tar.gz
format!(
"awsbck_{}.tar.gz",
params
Expand All @@ -49,37 +61,44 @@ pub async fn upload_file(archive_path: PathBuf, _temp_dir: TempDir, params: &Par
.await?;
let upload_id = multipart_upload_res
.upload_id()
.ok_or_else(|| anyhow!("upload_id not found"))?;
.ok_or_else(|| anyhow!("upload_id not found"))?; // convert option to error if None
let file_size = get_file_size(&archive_path)?;
let mut chunk_count = (file_size / CHUNK_SIZE) + 1;
let mut size_of_last_chunk = file_size % CHUNK_SIZE;
// if the file size is exactly a multiple of CHUNK_SIZE, we don't need the extra chunk
if size_of_last_chunk == 0 {
size_of_last_chunk = CHUNK_SIZE;
chunk_count -= 1;
}

// something went very wrong if we get a size of zero here
if file_size == 0 {
return Err(anyhow!("file size is 0"));
}
// AWS will not accept an upload with too many chunks
if chunk_count > MAX_CHUNKS {
return Err(anyhow!("too many chunks, try increasing the chunk size"));
}

let mut upload_parts: Vec<CompletedPart> = Vec::new();

// upload all chunks
for chunk_index in 0..chunk_count {
let this_chunk = match chunk_index {
i if i == chunk_count - 1 => size_of_last_chunk,
_ => CHUNK_SIZE,
};
// take the relevant part of the file corresponding to this chunk
let stream = ByteStream::read_from()
.path(&archive_path)
.offset(chunk_index * CHUNK_SIZE)
.length(Length::Exact(this_chunk))
.build()
.await?;

// this should be a uint but somehow the API expects an i32 (which starts at 1)
let part_number = (chunk_index as i32) + 1;

// send chunk and record the ETag and part number
let upload_part_res = client
.upload_part()
.key(&filename)
Expand All @@ -89,13 +108,16 @@ pub async fn upload_file(archive_path: PathBuf, _temp_dir: TempDir, params: &Par
.part_number(part_number)
.send()
.await?;

// this vec of chunks is required to finalize the upload
upload_parts.push(
CompletedPart::builder()
.e_tag(upload_part_res.e_tag.unwrap_or_default())
.part_number(part_number)
.build(),
);
}
// complete the upload
let completed_multipart_upload: CompletedMultipartUpload = CompletedMultipartUpload::builder()
.set_parts(Some(upload_parts))
.build();
Expand All @@ -110,6 +132,7 @@ pub async fn upload_file(archive_path: PathBuf, _temp_dir: TempDir, params: &Par
Ok(())
}

/// Utility function to get the file size
fn get_file_size(archive_path: &PathBuf) -> Result<u64> {
let file = File::open(archive_path)?;
let metadata = file.metadata()?;
Expand Down
9 changes: 9 additions & 0 deletions src/backup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,24 @@ pub async fn backup(params: &Params) -> Result<()> {
Ok(())
}

/// Compress the folder into a randomly named tar.gz archive in a temp directory
fn compress_folder(folder: &Path) -> Result<(PathBuf, TempDir)> {
// create a temp directory, it will be deleted when the ref goes out of scope
let dir = TempDir::new()?;
// generate a random filename
let filename = format!("{}.tar.gz", Uuid::new_v4());
let file_path = dir.child(filename);
// create the file handle
let tar_gz: File = File::create(&file_path)?;
let enc = GzEncoder::new(tar_gz, Compression::default());
let mut tar = tar::Builder::new(enc);
// insert the contents of folder into the archive, recursively, at the root of the archive
// note that the folder itself is not present in the archive, only its contents
tar.append_dir_all(".", folder)?;
// make sure the tar layer data is written
let res = tar.into_inner()?;
// make sure the gz layer data is written
res.finish()?;
// we return the temp dir reference to keep it around until the file is uploaded
Ok((file_path, dir))
}
14 changes: 10 additions & 4 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ use log::*;

use crate::prelude::*;

/// CLI Parser uses `clap`.
///
/// command args take precedence over environment variables.
#[derive(Parser)]
#[command(version, about, long_about = None)]
struct Cli {
Expand Down Expand Up @@ -57,15 +60,15 @@ struct Cli {
aws_key: Option<String>,
}

/// Runtime parameters, parsed and ready to be used
/// Runtime parameters, parsed, validated and ready to be used
pub struct Params {
/// Which folder to backup
pub folder: PathBuf,
/// An optional interval duration in seconds
pub interval: Option<u64>,
/// The name of the archive that will be uploaded to S3 (without extension)
/// The optional name of the archive that will be uploaded to S3 (without extension)
pub filename: Option<String>,
/// The AWS S3 region
/// The AWS S3 region, defaults to us-east-1
pub aws_region: RegionProviderChain,
/// The AWS S3 bucket name
pub aws_bucket: String,
Expand All @@ -77,18 +80,21 @@ pub struct Params {

/// Parse the command-line arguments and environment variables into runtime params
pub async fn parse_config() -> Result<Params> {
// Read from the command-line args, and if not present, check environment variables
let params = Cli::parse();

// ensure we have all the required parameters
let Some(folder) = params.folder else {
return Err(anyhow!("No folder path was provided"));
};
// make sure folder exists
let folder = folder
.canonicalize()
.with_context(|| anyhow!("Could not resolve path {}", folder.to_string_lossy()))?;
if !folder.is_dir() {
return Err(anyhow!("'{}' is not a folder", folder.to_string_lossy()));
}

// all AWS parameters are required
let aws_region = RegionProviderChain::first_try(params.aws_region.map(Region::new))
.or_default_provider()
.or_else(Region::new("us-east-1"));
Expand Down