-
Notifications
You must be signed in to change notification settings - Fork 203
/
resources.rs
204 lines (192 loc) · 6.68 KB
/
resources.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
//! # Resource definitions for model weights, vocabularies and configuration files
//!
//! This crate relies on the concept of Resources to access the files used by the models.
//! This includes:
//! - model weights
//! - configuration files
//! - vocabularies
//! - (optional) merges files for BPE-based tokenizers
//!
//! These are expected in the pipelines configurations or are used as utilities to reference to the
//! resource location. Two types of resources exist:
//! - LocalResource: points to a local file
//! - RemoteResource: points to a remote file via a URL and a local cached file
//!
//! For both types of resources, the local location of teh file can be retrieved using
//! `get_local_path`, allowing to reference the resource file location regardless if it is a remote
//! or local resource. Default implementations for a number of `RemoteResources` are available as
//! pre-trained models in each model module.
use lazy_static::lazy_static;
use std::path::PathBuf;
use reqwest::Client;
use std::{fs, env};
use tokio::prelude::*;
extern crate dirs;
/// # Resource Enum expected by the `download_resource` function
/// Can be of type:
/// - LocalResource
/// - RemoteResource
#[derive(PartialEq, Clone)]
pub enum Resource {
Local(LocalResource),
Remote(RemoteResource),
}
impl Resource {
/// Gets the local path for a given resource
///
/// # Returns
///
/// * `PathBuf` pointing to the resource file
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, LocalResource};
/// use std::path::PathBuf;
/// let config_resource = Resource::Local(LocalResource { local_path: PathBuf::from("path/to/config.json")});
/// let config_path = config_resource.get_local_path();
/// ```
///
pub fn get_local_path(&self) -> &PathBuf {
match self {
Resource::Local(resource) => &resource.local_path,
Resource::Remote(resource) => &resource.local_path,
}
}
}
/// # Local resource
#[derive(PartialEq, Clone)]
pub struct LocalResource {
/// Local path for the resource
pub local_path: PathBuf
}
/// # Remote resource
#[derive(PartialEq, Clone)]
pub struct RemoteResource {
/// Remote path/url for the resource
pub url: String,
/// Local path for the resource
pub local_path: PathBuf,
}
impl RemoteResource {
/// Creates a new RemoteResource from an URL and a custom local path. Note that this does not
/// download the resource (only declares the remote and local locations)
///
/// # Arguments
///
/// * `url` - `&str` Location of the remote resource
/// * `target` - `PathBuf` Local path to save teh resource to
///
/// # Returns
///
/// * `RemoteResource` RemoteResource object
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource};
/// use std::path::PathBuf;
/// let config_resource = Resource::Remote(RemoteResource::new("http://config_json_location", PathBuf::from("path/to/config.json")));
/// ```
///
pub fn new(url: &str, target: PathBuf) -> RemoteResource {
RemoteResource { url: url.to_string(), local_path: target }
}
/// Creates a new RemoteResource from an URL and local name. Will define a local path pointing to
/// ~/.cache/.rusbert/model_name. Note that this does not download the resource (only declares
/// the remote and local locations)
///
/// # Arguments
///
/// * `name_url_tuple` - `(&str, &str)` Location of the name of model and remote resource
///
/// # Returns
///
/// * `RemoteResource` RemoteResource object
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// ("distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"
/// )
/// ));
/// ```
///
pub fn from_pretrained(name_url_tuple: (&str, &str)) -> RemoteResource {
let name = name_url_tuple.0;
let url = name_url_tuple.1.to_string();
let mut local_path = CACHE_DIRECTORY.to_path_buf();
local_path.push(name);
RemoteResource { url, local_path }
}
}
lazy_static! {
#[derive(Copy, Clone, Debug)]
/// # Global cache directory
/// If the environment variable `RUSTBERT_CACHE` is set, will save the cache model files at that
/// location. Otherwise defaults to `~/.cache/.rustbert`.
pub static ref CACHE_DIRECTORY: PathBuf = _get_cache_directory();
}
fn _get_cache_directory() -> PathBuf {
let home = match env::var("RUSTBERT_CACHE") {
Ok(value) => PathBuf::from(value),
Err(_) => {
let mut home = dirs::home_dir().unwrap();
home.push(".cache");
home.push(".rustbert");
home
}
};
home
}
/// # (Download) the resource and return a path to its local path
/// This function will download remote resource to their local path if they do not exist yet.
/// Then for both `LocalResource` and `RemoteResource`, it will the local path to the resource.
/// For `LocalResource` only the resource path is returned.
///
/// # Arguments
///
/// * `resource` - Pointer to the `&Resource` to optionally download and get the local path.
///
/// # Returns
///
/// * `&PathBuf` Local path for the resource
///
/// # Example
///
/// ```no_run
/// use rust_bert::resources::{Resource, RemoteResource, download_resource};
/// let model_resource = Resource::Remote(RemoteResource::from_pretrained(
/// ("distilbert-sst2/model.ot",
/// "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-rust_model.ot"
/// )
/// ));
/// let local_path = download_resource(&model_resource);
/// ```
///
#[tokio::main]
pub async fn download_resource(resource: &Resource) -> failure::Fallible<&PathBuf> {
match resource {
Resource::Remote(remote_resource) => {
let target = &remote_resource.local_path;
let url = &remote_resource.url;
if !target.exists() {
println!("Downloading {} to {:?}", url, target);
fs::create_dir_all(target.parent().unwrap())?;
let client = Client::new();
let mut output_file = tokio::fs::File::create(target).await?;
let mut response = client.get(url.as_str()).send().await?;
while let Some(chunk) = response.chunk().await? {
output_file.write_all(&chunk).await?;
}
}
Ok(resource.get_local_path())
}
Resource::Local(_) => {
Ok(resource.get_local_path())
}
}
}