Skip to content

Commit

Permalink
feat: add more checks to prevent mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomio committed May 5, 2022
1 parent 7d1e165 commit 1a11545
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 53 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "website-screenshot"
version = "1.0.0"
version = "1.1.0"
authors = ["Tomio <mail@tomio.fun>"]
license = "MIT/Apache-2.0"
edition = "2021"
Expand Down
3 changes: 3 additions & 0 deletions src/error.rs
Expand Up @@ -16,6 +16,8 @@ pub enum Error {
MissingAuthToken,
#[display(fmt = "Invalid token provided.")]
Unauthorized,
#[display(fmt = "The screenshot with that slug can't be found.")]
ScreenshotNotFound,
}

impl ResponseError for Error {
Expand All @@ -29,6 +31,7 @@ impl ResponseError for Error {
match self.deref() {
Error::InvalidUrl | Error::MissingAuthToken => StatusCode::BAD_REQUEST,
Error::Unauthorized => StatusCode::UNAUTHORIZED,
Error::ScreenshotNotFound => StatusCode::NOT_FOUND,
}
}
}
2 changes: 1 addition & 1 deletion src/main.rs
Expand Up @@ -99,7 +99,7 @@ async fn main() -> anyhow::Result<()> {
.wrap(Governor::new(&governor_config))
.app_data(state.clone())
.service(routes::screenshot_route)
.service(routes::get_route)
.service(routes::get_screenshot)
.service(routes::index_route)
})
.bind(("0.0.0.0", port))?
Expand Down
47 changes: 35 additions & 12 deletions src/providers/cloudinary.rs
Expand Up @@ -12,18 +12,23 @@ use serde_json::Value;
use super::Provider;

#[derive(Debug)]
pub struct CloudinaryProvider(Arc<RedisClient>, Client);
pub struct CloudinaryProvider {
redis: Arc<RedisClient>,
reqwest: Client,
}

#[async_trait]
impl Provider for CloudinaryProvider {
fn new() -> Self {
Self(
Arc::new(
RedisClient::open(std::env::var("REDIS_URL").expect("Failed to get redis url"))
.expect("Failed to open redis client"),
),
Client::new(),
)
let redis = Arc::new(
RedisClient::open(std::env::var("REDIS_URL").expect("Failed to get redis url"))
.expect("Failed to open redis client"),
);

Self {
redis,
reqwest: Client::new(),
}
}

#[inline]
Expand All @@ -32,15 +37,15 @@ impl Provider for CloudinaryProvider {
}

async fn get(&self, slug: String) -> Result<Vec<u8>> {
let mut con = self.0.get_async_connection().await?;
let mut con = self.redis.get_async_connection().await?;
let url: String = con.get(format!("{}:{slug}", CloudinaryProvider::prefix())).await?;
let data = self.1.get(url).send().await?.bytes().await?;
let data = self.reqwest.get(url).send().await?.bytes().await?;

Ok(data.as_ref().to_vec())
}

async fn set(&self, slug: String, data: Vec<u8>) -> Result<()> {
let mut con = self.0.get_async_connection().await?;
let mut con = self.redis.get_async_connection().await?;
let base_64_img = format!("data:image/png;base64,{}", encode(data));
let mut params: HashMap<&'static str, String> = HashMap::new();

Expand All @@ -50,7 +55,7 @@ impl Provider for CloudinaryProvider {
params.insert("file", base_64_img);

let res = self
.1
.reqwest
.post(format!(
"https://api.cloudinary.com/v1_1/{}/image/upload",
env::var("CLOUDINARY_CLOUD_NAME")?
Expand All @@ -69,4 +74,22 @@ impl Provider for CloudinaryProvider {

Ok(())
}

async fn check(&self, slug: String) -> Result<bool> {
let mut con = self.redis.get_async_connection().await?;

match con.get::<String, String>(format!("{}:{slug}", CloudinaryProvider::prefix())).await {
Ok(url) => {
let req = self.reqwest.head(url).send().await?;
let status = req.status();

if status.is_client_error() && status.is_server_error() {
return Ok(false);
} else {
return Ok(true);
}
},
Err(_) => Ok(false),
}
}
}
25 changes: 21 additions & 4 deletions src/providers/fs.rs
Expand Up @@ -13,7 +13,9 @@ use tokio::io::AsyncWriteExt;
use super::Provider;

#[derive(Debug)]
pub struct FsProvider(Arc<Client>);
pub struct FsProvider {
redis: Arc<Client>,
}

#[async_trait]
impl Provider for FsProvider {
Expand All @@ -29,7 +31,9 @@ impl Provider for FsProvider {
.expect("Failed to open redis client"),
);

Self(redis)
Self {
redis,
}
}

#[inline]
Expand All @@ -38,7 +42,7 @@ impl Provider for FsProvider {
}

async fn get(&self, slug: String) -> Result<Vec<u8>> {
let mut con = self.0.get_async_connection().await?;
let mut con = self.redis.get_async_connection().await?;
let path: String = con.get(format!("{}:{slug}", FsProvider::prefix())).await?;
let contents = read(path).await?;

Expand All @@ -49,11 +53,24 @@ impl Provider for FsProvider {
let file_name = format!("{}.png", slug);
let file_path = format!("screenshots/{}", file_name);
let mut file = File::create(&file_path).await?;
let mut con = self.0.get_async_connection().await?;
let mut con = self.redis.get_async_connection().await?;

file.write_all(&data).await?;
con.set(format!("{}:{slug}", FsProvider::prefix()), file_path).await?;

Ok(())
}

async fn check(&self, slug: String) -> Result<bool> {
let mut con = self.redis.get_async_connection().await?;

match con.get::<String, String>(format!("{}:{slug}", FsProvider::prefix())).await {
Ok(path) => {
let path = Path::new(&path);

Ok(path.exists() && path.is_file())
},
Err(_) => Ok(false),
}
}
}
1 change: 1 addition & 0 deletions src/providers/mod.rs
Expand Up @@ -9,6 +9,7 @@ pub trait Provider {

async fn get(&self, slug: String) -> Result<Vec<u8>>;
async fn set(&self, slug: String, data: Vec<u8>) -> Result<()>;
async fn check(&self, slug: String) -> Result<bool>;
}

cfg_if! {
Expand Down
40 changes: 34 additions & 6 deletions src/providers/s3.rs
Expand Up @@ -10,7 +10,10 @@ use s3::{Bucket, Region};
use super::Provider;

#[derive(Debug)]
pub struct S3Provider(Arc<Client>, Arc<Bucket>);
pub struct S3Provider {
redis: Arc<Client>,
bucket: Arc<Bucket>,
}

#[async_trait]
impl Provider for S3Provider {
Expand Down Expand Up @@ -54,7 +57,10 @@ impl Provider for S3Provider {
.expect("Failed to initialize s3 bucket"),
);

Self(redis, bucket)
Self {
redis,
bucket,
}
}

#[inline]
Expand All @@ -63,20 +69,42 @@ impl Provider for S3Provider {
}

async fn get(&self, slug: String) -> Result<Vec<u8>> {
let mut con = self.0.get_async_connection().await?;
let mut con = self.redis.get_async_connection().await?;
let path: String = con.get(format!("{}:{slug}", S3Provider::prefix())).await?;
let (data, _) = self.1.get_object(path).await?;
let (data, _) = self.bucket.get_object(path).await?;

Ok(data)
}

async fn set(&self, slug: String, data: Vec<u8>) -> Result<()> {
let mut con = self.0.get_async_connection().await?;
let mut con = self.redis.get_async_connection().await?;
let path = format!("{}.png", slug.clone());

self.1.put_object(&path, &data).await?;
self.bucket.put_object(&path, &data).await?;
con.set(format!("{}:{slug}", S3Provider::prefix()), path).await?;

Ok(())
}

async fn check(&self, slug: String) -> Result<bool> {
let mut con = self.redis.get_async_connection().await?;

match con.get::<String, String>(format!("{}:{slug}", S3Provider::prefix())).await {
Ok(path) => match self.bucket.head_object(path).await {
Ok((res, code)) => {
if code >= 400 || code <= 599 {
return Ok(false);
}

if res.content_type.is_none() || res.content_length.is_none() {
return Ok(false);
}

Ok(true)
},
Err(_) => Ok(false),
},
Err(_) => Ok(false),
}
}
}
27 changes: 21 additions & 6 deletions src/providers/sled.rs
Expand Up @@ -10,7 +10,10 @@ use sled::Db;
use super::Provider;

#[derive(Debug)]
pub struct SledProvider(Arc<Client>, Arc<Db>);
pub struct SledProvider {
redis: Arc<Client>,
db: Arc<Db>,
}

#[async_trait]
impl Provider for SledProvider {
Expand All @@ -25,7 +28,10 @@ impl Provider for SledProvider {
.expect("Failed to open sled database"),
);

Self(redis, db)
Self {
redis,
db,
}
}

#[inline]
Expand All @@ -34,20 +40,29 @@ impl Provider for SledProvider {
}

async fn get(&self, slug: String) -> Result<Vec<u8>> {
let mut con = self.0.get_async_connection().await?;
let mut con = self.redis.get_async_connection().await?;
let key: String = con.get(format!("{}:{slug}", SledProvider::prefix())).await?;
let data = self.1.get(key)?.expect("Failed to get data").as_ref().to_vec();
let data = self.db.get(key)?.expect("Failed to get data").as_ref().to_vec();

Ok(data)
}

async fn set(&self, slug: String, data: Vec<u8>) -> Result<()> {
let mut con = self.0.get_async_connection().await?;
let mut con = self.redis.get_async_connection().await?;
let key = cuid()?;

con.set(format!("{}:{slug}", SledProvider::prefix()), &key).await?;
self.1.insert(key, data)?;
self.db.insert(key, data)?;

Ok(())
}

async fn check(&self, slug: String) -> Result<bool> {
let mut con = self.redis.get_async_connection().await?;

match con.get::<String, String>(format!("{}:{slug}", SledProvider::prefix())).await {
Ok(key) => Ok(self.db.contains_key(key)?),
Err(_) => Ok(false),
}
}
}

0 comments on commit 1a11545

Please sign in to comment.