Skip to content

Commit

Permalink
Merge pull request #2 from rhaskia/main
Browse files Browse the repository at this point in the history
Ability to request with a certain scope
  • Loading branch information
jakewilkins committed Jan 26, 2024
2 parents 4a3732e + 6ca0cb1 commit a2298cf
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 51 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ OPTIONS:
-h, --host <HOST> The host to authenticate with
--help Print help information
-r, --refresh <REFRESH> A Refresh Token to exchange
-s, --scope <SCOPE> The scope required for the auth app
-V, --version Print version information
```
164 changes: 115 additions & 49 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@

use std::{result::Result, thread, time, fmt};
use std::collections::HashMap;
use std::{fmt, result::Result, thread, time};

use chrono::{DateTime, Duration};
use chrono::offset::Utc;
use chrono::{DateTime, Duration};

mod util;

#[derive(Debug, Default, Clone, serde_derive::Serialize)]
#[derive(Debug, Default, Clone, serde_derive::Serialize, serde_derive::Deserialize)]
pub struct Credential {
pub token: String,
pub expiry: String,
Expand All @@ -26,25 +25,24 @@ impl Credential {
pub fn is_expired(&self) -> bool {
let exp = match DateTime::parse_from_rfc3339(self.expiry.as_str()) {
Ok(time) => time,
Err(_) => return false
Err(_) => return false,
};
let now = Utc::now();
now > exp
}
}


#[derive(Debug, Clone)]
pub enum DeviceFlowError {
HttpError(String),
GitHubError(String),
HttpError(String),
GitHubError(String),
}

impl fmt::Display for DeviceFlowError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
DeviceFlowError::HttpError(string) => write!(f, "DeviceFlowError: {}", string),
DeviceFlowError::GitHubError(string) => write!(f, "DeviceFlowError: {}", string)
DeviceFlowError::GitHubError(string) => write!(f, "DeviceFlowError: {}", string),
}
}
}
Expand All @@ -57,52 +55,83 @@ impl From<reqwest::Error> for DeviceFlowError {
}
}

pub fn authorize(client_id: String, host: Option<String>) -> Result<Credential, DeviceFlowError> {
pub fn authorize(
client_id: String,
host: Option<String>,
scope: Option<String>,
) -> Result<Credential, DeviceFlowError> {
let my_string: String;
let thost = match host {
Some(string) => {
my_string = string;
Some(my_string.as_str())
},
None => None
}
None => None,
};

let mut flow = DeviceFlow::start(client_id.as_str(), thost)?;
let binding: String;
let tscope = match scope {
Some(string) => {
binding = string;
Some(binding.as_str())
}
None => None,
};

let mut flow = DeviceFlow::start(client_id.as_str(), thost, tscope)?;

// eprintln!("res is {:?}", res);
eprintln!("Please visit {} in your browser", flow.verification_uri.clone().unwrap());
eprintln!(
"Please visit {} in your browser",
flow.verification_uri.clone().unwrap()
);
eprintln!("And enter code: {}", flow.user_code.clone().unwrap());

thread::sleep(FIVE_SECONDS);

flow.poll(20)
}

pub fn refresh(client_id: &str, refresh_token: &str, host: Option<String>) -> Result<Credential, DeviceFlowError> {
pub fn refresh(
client_id: &str,
refresh_token: &str,
host: Option<String>,
scope: Option<String>,
) -> Result<Credential, DeviceFlowError> {
let my_string: String;
let thost = match host {
Some(string) => {
my_string = string;
Some(my_string.as_str())
},
None => None
}
None => None,
};

refresh_access_token(client_id, refresh_token, thost)
let scope_binding;
let tscope = match scope {
Some(string) => {
scope_binding = string;
Some(scope_binding.as_str())
}
None => None,
};

refresh_access_token(client_id, refresh_token, thost, tscope)
}

#[derive(Debug, Clone)]
pub enum DeviceFlowState {
Pending,
Processing(time::Duration),
Success(Credential),
Failure(DeviceFlowError)
Failure(DeviceFlowError),
}

#[derive(Clone)]
pub struct DeviceFlow {
pub host: String,
pub client_id: String,
pub scope: String,
pub user_code: Option<String>,
pub device_code: Option<String>,
pub verification_uri: Option<String>,
Expand All @@ -112,12 +141,16 @@ pub struct DeviceFlow {
const FIVE_SECONDS: time::Duration = time::Duration::new(5, 0);

impl DeviceFlow {
pub fn new(client_id: &str, maybe_host: Option<&str>) -> Self {
Self{
pub fn new(client_id: &str, maybe_host: Option<&str>, scope: Option<&str>) -> Self {
Self {
client_id: String::from(client_id),
scope: match scope {
Some(string) => String::from(string),
None => String::new(),
},
host: match maybe_host {
Some(string) => String::from(string),
None => String::from("github.com")
None => String::from("github.com"),
},
user_code: None,
device_code: None,
Expand All @@ -126,31 +159,43 @@ impl DeviceFlow {
}
}

pub fn start(client_id: &str, maybe_host: Option<&str>) -> Result<DeviceFlow, DeviceFlowError> {
let mut flow = DeviceFlow::new(client_id, maybe_host);
pub fn start(
client_id: &str,
maybe_host: Option<&str>,
scope: Option<&str>,
) -> Result<DeviceFlow, DeviceFlowError> {
let mut flow = DeviceFlow::new(client_id, maybe_host, scope);

flow.setup();

match flow.state {
DeviceFlowState::Processing(_) => Ok(flow.to_owned()),
DeviceFlowState::Failure(err) => Err(err),
_ => Err(util::credential_error("Something truly unexpected happened".into()))
_ => Err(util::credential_error(
"Something truly unexpected happened".into(),
)),
}
}

pub fn setup(&mut self) {
let body = format!("client_id={}", &self.client_id);
let body = format!("client_id={}&scope={}", &self.client_id, &self.scope);
let entry_url = format!("https://{}/login/device/code", &self.host);

if let Some(res) = util::send_request(self, entry_url, body) {
if res.contains_key("error") && res.contains_key("error_description"){
self.state = DeviceFlowState::Failure(util::credential_error(res["error_description"].as_str().unwrap().into()))
if res.contains_key("error") && res.contains_key("error_description") {
self.state = DeviceFlowState::Failure(util::credential_error(
res["error_description"].as_str().unwrap().into(),
))
} else if res.contains_key("error") {
self.state = DeviceFlowState::Failure(util::credential_error(format!("Error response: {:?}", res["error"].as_str().unwrap())))
self.state = DeviceFlowState::Failure(util::credential_error(format!(
"Error response: {:?}",
res["error"].as_str().unwrap()
)))
} else {
self.user_code = Some(String::from(res["user_code"].as_str().unwrap()));
self.device_code = Some(String::from(res["device_code"].as_str().unwrap()));
self.verification_uri = Some(String::from(res["verification_uri"].as_str().unwrap()));
self.verification_uri =
Some(String::from(res["verification_uri"].as_str().unwrap()));
self.state = DeviceFlowState::Processing(FIVE_SECONDS);
}
};
Expand All @@ -162,51 +207,57 @@ impl DeviceFlow {

if let DeviceFlowState::Processing(interval) = self.state {
if count == iterations {
return Err(util::credential_error("Max poll iterations reached".into()))
return Err(util::credential_error("Max poll iterations reached".into()));
}

thread::sleep(interval);
} else {
break
break;
}
};
}

match &self.state {
DeviceFlowState::Success(cred) => Ok(cred.to_owned()),
DeviceFlowState::Failure(err) => Err(err.to_owned()),
_ => Err(util::credential_error("Unable to fetch credential, sorry :/".into()))
_ => Err(util::credential_error(
"Unable to fetch credential, sorry :/".into(),
)),
}
}

pub fn update(&mut self) {
let poll_url = format!("https://{}/login/oauth/access_token", self.host);
let poll_payload = format!("client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
let poll_payload = format!(
"client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
self.client_id,
&self.device_code.clone().unwrap()
);

if let Some(res) = util::send_request(self, poll_url, poll_payload) {
if res.contains_key("error") {
match res["error"].as_str().unwrap() {
"authorization_pending" => {},
"authorization_pending" => {}
"slow_down" => {
if let DeviceFlowState::Processing(current_interval) = self.state {
self.state = DeviceFlowState::Processing(current_interval + FIVE_SECONDS);
self.state =
DeviceFlowState::Processing(current_interval + FIVE_SECONDS);
};
},
}
other_reason => {
self.state = DeviceFlowState::Failure(
util::credential_error(format!("Error checking for token: {}", other_reason))
);
},
self.state = DeviceFlowState::Failure(util::credential_error(format!(
"Error checking for token: {}",
other_reason
)));
}
}
} else {
let mut this_credential = Credential::empty();
this_credential.token = res["access_token"].as_str().unwrap().to_string();

if let Some(expires_in) = res.get("expires_in") {
this_credential.expiry = calculate_expiry(expires_in.as_i64().unwrap());
this_credential.refresh_token = res["refresh_token"].as_str().unwrap().to_string();
this_credential.refresh_token =
res["refresh_token"].as_str().unwrap().to_string();
}

self.state = DeviceFlowState::Success(this_credential);
Expand All @@ -222,25 +273,40 @@ fn calculate_expiry(expires_in: i64) -> String {
expiry.to_rfc3339()
}

fn refresh_access_token(client_id: &str, refresh_token: &str, maybe_host: Option<&str>) -> Result<Credential, DeviceFlowError> {
fn refresh_access_token(
client_id: &str,
refresh_token: &str,
maybe_host: Option<&str>,
maybe_scope: Option<&str>,
) -> Result<Credential, DeviceFlowError> {
let host = match maybe_host {
Some(string) => string,
None => "github.com"
None => "github.com",
};

let scope = match maybe_scope {
Some(string) => string,
None => "",
};

let client = reqwest::blocking::Client::new();
let entry_url = format!("https://{}/login/oauth/access_token", &host);
let request_body = format!("client_id={}&refresh_token={}&client_secret=&grant_type=refresh_token",
&client_id, &refresh_token);
let request_body = format!(
"client_id={}&refresh_token={}&client_secret=&grant_type=refresh_token&scope={}",
&client_id, &refresh_token, &scope
);

let res = client.post(&entry_url)
let res = client
.post(&entry_url)
.header("Accept", "application/json")
.body(request_body)
.send()?
.json::<HashMap<String, serde_json::Value>>()?;

if res.contains_key("error") {
Err(util::credential_error(res["error"].as_str().unwrap().into()))
Err(util::credential_error(
res["error"].as_str().unwrap().into(),
))
} else {
let mut credential = Credential::empty();
// eprintln!("res: {:?}", &res);
Expand Down
8 changes: 6 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ struct Args {
#[clap(short, long, value_parser)]
host: Option<String>,

/// The scope the user wants to have
#[clap(short, long, value_parser)]
scope: Option<String>,

/// A Refresh Token to exchange
#[clap(short, long, value_parser)]
refresh: Option<String>,
Expand All @@ -29,10 +33,10 @@ fn main() {

match args.refresh {
None => {
cred = authorize(args.client_id, args.host);
cred = authorize(args.client_id, args.host, args.scope);
},
Some(rt) => {
cred = refresh(args.client_id.as_str(), rt.as_str(), args.host);
cred = refresh(args.client_id.as_str(), rt.as_str(), args.host, args.scope);
}
}
match cred {
Expand Down

0 comments on commit a2298cf

Please sign in to comment.