Skip to content

Commit

Permalink
refactor: cleanup, vt client takes string ownership
Browse files Browse the repository at this point in the history
Signed-off-by: Richard Zak <richard.j.zak@gmail.com>
  • Loading branch information
rjzak committed May 3, 2024
1 parent 9ef5674 commit ffd78c7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 20 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ anyhow = { version = "1.0", features = ["std"] }
chrono = { version = "0.4", features = ["clock", "serde"], default-features = false }
clap = { version = "4.5", features = ["derive", "env", "help", "std", "usage"], default-features = false }
hex = { version = "0.4.3", features = ["alloc", "std"], default-features = false }
lazy_static = { version = "1.4.0" }
lazy_static = { version = "1.4.0", default-features = false }
reqwest = { version = "0.12.4", features = ["http2", "multipart", "rustls-tls"], default-features = false }
serde = { version = "1.0", features = ["derive"] }
serde_json = { version = "1.0", features = ["alloc"], default-features = false }
sha2 = { version = "0.10.8", features = ["std"], default-features = false }
tokio = { version = "1.37", features = ["rt", "macros"], default-features = false }
zeroize = { version = "1.7.0", features = ["alloc", "derive"], default-features = false }
zeroize = { version = "1.7.0", features = ["derive"], default-features = false }

[dev-dependencies]
rstest = { version = "0.19", default-features = false }
37 changes: 19 additions & 18 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ use std::string::FromUtf8Error;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::multipart::Form;
use serde::{Deserialize, Serialize};
use zeroize::Zeroizing;
use zeroize::Zeroize;

/// Capture the error from VirusTotal, plus parsing or networking errors along the way
#[derive(Clone, Debug, Eq, Serialize, Deserialize)]
pub struct VirusTotalError {
/// Message describing the error
pub message: String,

/// Short version of the error
pub code: String,
}

Expand All @@ -39,13 +42,13 @@ impl std::error::Error for VirusTotalError {}
impl From<reqwest::Error> for VirusTotalError {
fn from(err: reqwest::Error) -> Self {
let url = if let Some(url) = err.url() {
format!(" loading {}", url.as_str())
format!(" loading {url}")
} else {
"".into()
};
Self {
message: "Http error".into(),
code: format!("Error {url} {}", err),
code: format!("Error{url} {err}"),
}
}
}
Expand All @@ -54,7 +57,7 @@ impl From<serde_json::Error> for VirusTotalError {
fn from(err: serde_json::Error) -> Self {
Self {
message: "Json error".into(),
code: format!("Json error at line {}: {}", err.line(), err),
code: format!("Json error at line {}: {err}", err.line()),
}
}
}
Expand All @@ -69,21 +72,19 @@ impl From<FromUtf8Error> for VirusTotalError {
}

/// VirusTotal client object
#[derive(Clone)]
#[derive(Clone, Zeroize)]
pub struct VirusTotalClient {
/// The API key used to interact with VirusTotal
key: Zeroizing<String>,
key: String,
}

impl VirusTotalClient {
const API_KEY: &'static str = "x-apikey";
const KEY_LEN: usize = 64;

/// New VirusTotal client given an API key, assuming it's valid
pub fn new(key: &str) -> Self {
Self {
key: Zeroizing::new(key.to_string()),
}
pub fn new(key: String) -> Self {
Self { key }
}

fn header(&self) -> HeaderMap {
Expand Down Expand Up @@ -200,15 +201,15 @@ impl VirusTotalClient {

// Just borrowing the `FileRescanResponseRequest` type get get it's error handling
let error: FileRescanRequestResponse = serde_json::from_str(&json_response)?;
if let FileRescanRequestResponse::Error(error) = error {
return Err(error);
return if let FileRescanRequestResponse::Error(error) = error {
Err(error)
} else {
// Should never happen, since we're only here if some error occurred.
return Err(VirusTotalError {
Err(VirusTotalError {
message: json_response,
code: "VTError".into(),
});
}
})
};
}

let body = response.bytes().await?;
Expand All @@ -226,15 +227,15 @@ impl FromStr for VirusTotalClient {
Err("Invalid API key length")
} else {
Ok(Self {
key: Zeroizing::new(key.to_string()),
key: key.to_string(),
})
}
}
}

impl From<String> for VirusTotalClient {
fn from(value: String) -> Self {
VirusTotalClient::new(&value)
VirusTotalClient::new(value)
}
}

Expand All @@ -248,7 +249,7 @@ mod test {
if let Ok(api_key) = std::env::var("VT_API_KEY") {
const HASH: &str = "fff40032c3dc062147c530e3a0a5c7e6acda4d1f1369fbc994cddd3c19a2de88";

let client = VirusTotalClient::new(&api_key);
let client = VirusTotalClient::new(api_key);

let report = client
.get_report(HASH)
Expand Down

0 comments on commit ffd78c7

Please sign in to comment.