diff --git a/Cargo.lock b/Cargo.lock index 15e0106..b3ad0b8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2385,7 +2385,7 @@ dependencies = [ [[package]] name = "website-screenshot" -version = "1.0.0" +version = "1.1.0" dependencies = [ "actix-governor", "actix-web", diff --git a/Cargo.toml b/Cargo.toml index 7120e4c..1d54f4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "website-screenshot" -version = "1.0.0" +version = "1.1.0" authors = ["Tomio "] license = "MIT/Apache-2.0" edition = "2021" diff --git a/src/error.rs b/src/error.rs index 14924f9..f306c08 100644 --- a/src/error.rs +++ b/src/error.rs @@ -16,6 +16,8 @@ pub enum Error { MissingAuthToken, #[display(fmt = "Invalid token provided.")] Unauthorized, + #[display(fmt = "The screenshot with that slug can't be found.")] + ScreenshotNotFound, } impl ResponseError for Error { @@ -29,6 +31,7 @@ impl ResponseError for Error { match self.deref() { Error::InvalidUrl | Error::MissingAuthToken => StatusCode::BAD_REQUEST, Error::Unauthorized => StatusCode::UNAUTHORIZED, + Error::ScreenshotNotFound => StatusCode::NOT_FOUND, } } } diff --git a/src/main.rs b/src/main.rs index c999444..95ed853 100644 --- a/src/main.rs +++ b/src/main.rs @@ -99,7 +99,7 @@ async fn main() -> anyhow::Result<()> { .wrap(Governor::new(&governor_config)) .app_data(state.clone()) .service(routes::screenshot_route) - .service(routes::get_route) + .service(routes::get_screenshot) .service(routes::index_route) }) .bind(("0.0.0.0", port))? diff --git a/src/providers/cloudinary.rs b/src/providers/cloudinary.rs index 9301b3d..da95c68 100644 --- a/src/providers/cloudinary.rs +++ b/src/providers/cloudinary.rs @@ -12,18 +12,23 @@ use serde_json::Value; use super::Provider; #[derive(Debug)] -pub struct CloudinaryProvider(Arc, Client); +pub struct CloudinaryProvider { + redis: Arc, + reqwest: Client, +} #[async_trait] impl Provider for CloudinaryProvider { fn new() -> Self { - Self( - Arc::new( - RedisClient::open(std::env::var("REDIS_URL").expect("Failed to get redis url")) - .expect("Failed to open redis client"), - ), - Client::new(), - ) + let redis = Arc::new( + RedisClient::open(std::env::var("REDIS_URL").expect("Failed to get redis url")) + .expect("Failed to open redis client"), + ); + + Self { + redis, + reqwest: Client::new(), + } } #[inline] @@ -32,15 +37,15 @@ impl Provider for CloudinaryProvider { } async fn get(&self, slug: String) -> Result> { - let mut con = self.0.get_async_connection().await?; + let mut con = self.redis.get_async_connection().await?; let url: String = con.get(format!("{}:{slug}", CloudinaryProvider::prefix())).await?; - let data = self.1.get(url).send().await?.bytes().await?; + let data = self.reqwest.get(url).send().await?.bytes().await?; Ok(data.as_ref().to_vec()) } async fn set(&self, slug: String, data: Vec) -> Result<()> { - let mut con = self.0.get_async_connection().await?; + let mut con = self.redis.get_async_connection().await?; let base_64_img = format!("data:image/png;base64,{}", encode(data)); let mut params: HashMap<&'static str, String> = HashMap::new(); @@ -50,7 +55,7 @@ impl Provider for CloudinaryProvider { params.insert("file", base_64_img); let res = self - .1 + .reqwest .post(format!( "https://api.cloudinary.com/v1_1/{}/image/upload", env::var("CLOUDINARY_CLOUD_NAME")? @@ -69,4 +74,22 @@ impl Provider for CloudinaryProvider { Ok(()) } + + async fn check(&self, slug: String) -> Result { + let mut con = self.redis.get_async_connection().await?; + + match con.get::(format!("{}:{slug}", CloudinaryProvider::prefix())).await { + Ok(url) => { + let req = self.reqwest.head(url).send().await?; + let status = req.status(); + + if status.is_client_error() && status.is_server_error() { + return Ok(false); + } else { + return Ok(true); + } + }, + Err(_) => Ok(false), + } + } } diff --git a/src/providers/fs.rs b/src/providers/fs.rs index 5a14716..ad20326 100644 --- a/src/providers/fs.rs +++ b/src/providers/fs.rs @@ -13,7 +13,9 @@ use tokio::io::AsyncWriteExt; use super::Provider; #[derive(Debug)] -pub struct FsProvider(Arc); +pub struct FsProvider { + redis: Arc, +} #[async_trait] impl Provider for FsProvider { @@ -29,7 +31,9 @@ impl Provider for FsProvider { .expect("Failed to open redis client"), ); - Self(redis) + Self { + redis, + } } #[inline] @@ -38,7 +42,7 @@ impl Provider for FsProvider { } async fn get(&self, slug: String) -> Result> { - let mut con = self.0.get_async_connection().await?; + let mut con = self.redis.get_async_connection().await?; let path: String = con.get(format!("{}:{slug}", FsProvider::prefix())).await?; let contents = read(path).await?; @@ -49,11 +53,24 @@ impl Provider for FsProvider { let file_name = format!("{}.png", slug); let file_path = format!("screenshots/{}", file_name); let mut file = File::create(&file_path).await?; - let mut con = self.0.get_async_connection().await?; + let mut con = self.redis.get_async_connection().await?; file.write_all(&data).await?; con.set(format!("{}:{slug}", FsProvider::prefix()), file_path).await?; Ok(()) } + + async fn check(&self, slug: String) -> Result { + let mut con = self.redis.get_async_connection().await?; + + match con.get::(format!("{}:{slug}", FsProvider::prefix())).await { + Ok(path) => { + let path = Path::new(&path); + + Ok(path.exists() && path.is_file()) + }, + Err(_) => Ok(false), + } + } } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 4ddc80d..1bc56b9 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -9,6 +9,7 @@ pub trait Provider { async fn get(&self, slug: String) -> Result>; async fn set(&self, slug: String, data: Vec) -> Result<()>; + async fn check(&self, slug: String) -> Result; } cfg_if! { diff --git a/src/providers/s3.rs b/src/providers/s3.rs index c3d93b9..22a2531 100644 --- a/src/providers/s3.rs +++ b/src/providers/s3.rs @@ -10,7 +10,10 @@ use s3::{Bucket, Region}; use super::Provider; #[derive(Debug)] -pub struct S3Provider(Arc, Arc); +pub struct S3Provider { + redis: Arc, + bucket: Arc, +} #[async_trait] impl Provider for S3Provider { @@ -54,7 +57,10 @@ impl Provider for S3Provider { .expect("Failed to initialize s3 bucket"), ); - Self(redis, bucket) + Self { + redis, + bucket, + } } #[inline] @@ -63,20 +69,42 @@ impl Provider for S3Provider { } async fn get(&self, slug: String) -> Result> { - let mut con = self.0.get_async_connection().await?; + let mut con = self.redis.get_async_connection().await?; let path: String = con.get(format!("{}:{slug}", S3Provider::prefix())).await?; - let (data, _) = self.1.get_object(path).await?; + let (data, _) = self.bucket.get_object(path).await?; Ok(data) } async fn set(&self, slug: String, data: Vec) -> Result<()> { - let mut con = self.0.get_async_connection().await?; + let mut con = self.redis.get_async_connection().await?; let path = format!("{}.png", slug.clone()); - self.1.put_object(&path, &data).await?; + self.bucket.put_object(&path, &data).await?; con.set(format!("{}:{slug}", S3Provider::prefix()), path).await?; Ok(()) } + + async fn check(&self, slug: String) -> Result { + let mut con = self.redis.get_async_connection().await?; + + match con.get::(format!("{}:{slug}", S3Provider::prefix())).await { + Ok(path) => match self.bucket.head_object(path).await { + Ok((res, code)) => { + if code >= 400 || code <= 599 { + return Ok(false); + } + + if res.content_type.is_none() || res.content_length.is_none() { + return Ok(false); + } + + Ok(true) + }, + Err(_) => Ok(false), + }, + Err(_) => Ok(false), + } + } } diff --git a/src/providers/sled.rs b/src/providers/sled.rs index 95721ee..21abecf 100644 --- a/src/providers/sled.rs +++ b/src/providers/sled.rs @@ -10,7 +10,10 @@ use sled::Db; use super::Provider; #[derive(Debug)] -pub struct SledProvider(Arc, Arc); +pub struct SledProvider { + redis: Arc, + db: Arc, +} #[async_trait] impl Provider for SledProvider { @@ -25,7 +28,10 @@ impl Provider for SledProvider { .expect("Failed to open sled database"), ); - Self(redis, db) + Self { + redis, + db, + } } #[inline] @@ -34,20 +40,29 @@ impl Provider for SledProvider { } async fn get(&self, slug: String) -> Result> { - let mut con = self.0.get_async_connection().await?; + let mut con = self.redis.get_async_connection().await?; let key: String = con.get(format!("{}:{slug}", SledProvider::prefix())).await?; - let data = self.1.get(key)?.expect("Failed to get data").as_ref().to_vec(); + let data = self.db.get(key)?.expect("Failed to get data").as_ref().to_vec(); Ok(data) } async fn set(&self, slug: String, data: Vec) -> Result<()> { - let mut con = self.0.get_async_connection().await?; + let mut con = self.redis.get_async_connection().await?; let key = cuid()?; con.set(format!("{}:{slug}", SledProvider::prefix()), &key).await?; - self.1.insert(key, data)?; + self.db.insert(key, data)?; Ok(()) } + + async fn check(&self, slug: String) -> Result { + let mut con = self.redis.get_async_connection().await?; + + match con.get::(format!("{}:{slug}", SledProvider::prefix())).await { + Ok(key) => Ok(self.db.contains_key(key)?), + Err(_) => Ok(false), + } + } } diff --git a/src/providers/tixte.rs b/src/providers/tixte.rs index d482b26..0363359 100644 --- a/src/providers/tixte.rs +++ b/src/providers/tixte.rs @@ -13,7 +13,10 @@ use serde_json::Value; use super::Provider; #[derive(Debug)] -pub struct TixteProvider(Arc, Client); +pub struct TixteProvider { + redis: Arc, + reqwest: Client, +} #[async_trait] impl Provider for TixteProvider { @@ -23,9 +26,12 @@ impl Provider for TixteProvider { .expect("Failed to open redis client"), ); - let client = Client::new(); + let reqwest = Client::new(); - Self(redis, client) + Self { + redis, + reqwest, + } } #[inline] @@ -34,15 +40,15 @@ impl Provider for TixteProvider { } async fn get(&self, slug: String) -> Result> { - let mut con = self.0.get_async_connection().await?; + let mut con = self.redis.get_async_connection().await?; let url: String = con.get(format!("{}:{slug}", TixteProvider::prefix())).await?; - let data = self.1.get(url).send().await?.bytes().await?; + let data = self.reqwest.get(url).send().await?.bytes().await?; Ok(data.as_ref().to_vec()) } async fn set(&self, slug: String, data: Vec) -> Result<()> { - let mut con = self.0.get_async_connection().await?; + let mut con = self.redis.get_async_connection().await?; let file = Part::bytes(data).mime_str("image/png")?.file_name(format!("{slug}.png")); let form = Form::new().part("file", file); let domain_conf = match &env::var("TIXTE_DOMAIN_CONFIG") @@ -62,7 +68,7 @@ impl Provider for TixteProvider { params.insert("random_name", false); let res = self - .1 + .reqwest .post("https://api.tixte.com/v1/upload") .multipart(form) .query(¶ms) @@ -88,7 +94,7 @@ impl Provider for TixteProvider { params.insert("random", true); let res = self - .1 + .reqwest .post("https://api.tixte.com/v1/upload") .multipart(form) .query(¶ms) @@ -112,8 +118,27 @@ impl Provider for TixteProvider { Ok(()) } + + async fn check(&self, slug: String) -> Result { + let mut con = self.redis.get_async_connection().await?; + + match con.get::(format!("{}:{slug}", TixteProvider::prefix())).await { + Ok(url) => { + let req = self.reqwest.head(url).send().await?; + let status = req.status(); + + if status.is_client_error() && status.is_server_error() { + return Ok(false); + } else { + return Ok(true); + } + }, + Err(_) => Ok(false), + } + } } +#[derive(Debug)] enum DomainConfig { Standard(String), Random, diff --git a/src/routes/get.rs b/src/routes/get.rs index 0551a9b..e1b95da 100644 --- a/src/routes/get.rs +++ b/src/routes/get.rs @@ -1,22 +1,24 @@ use actix_web::http::header; use actix_web::{get, web, HttpResponse}; -use serde_json::json; +use crate::error::Error; use crate::providers::Provider; use crate::{Result, State}; #[get("/s/{slug}")] -pub async fn get(data: web::Data, slug: web::Path) -> Result { - let screenshot = data.storage.get(slug.into_inner()).await; - - match screenshot { - Ok(screenshot) => Ok(HttpResponse::Ok() - .content_type("image/png") - .append_header(header::CacheControl(vec![header::CacheDirective::MaxAge(31536000)])) - .body(screenshot)), - Err(_) => Ok(HttpResponse::NotFound().json(json!({ - "error": 404, - "message": "The screenshot could not be found." - }))), +pub async fn get_screenshot( + data: web::Data, + slug: web::Path, +) -> Result { + if let false = data.storage.check(slug.clone()).await.expect("Failed checking slug") { + return Err(Error::ScreenshotNotFound); } + + // Safe to unwrap now + let screenshot = data.storage.get(slug.into_inner()).await.unwrap(); + + Ok(HttpResponse::Ok() + .content_type("image/png") + .append_header(header::CacheControl(vec![header::CacheDirective::MaxAge(31536000)])) + .body(screenshot)) } diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 32fe81a..24c3f02 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -2,6 +2,6 @@ mod get; mod index; mod screenshot; -pub use get::get as get_route; +pub use get::get_screenshot; pub use index::index as index_route; pub use screenshot::screenshot as screenshot_route;