-
Notifications
You must be signed in to change notification settings - Fork 652
/
storage.py
50 lines (42 loc) · 1.58 KB
/
storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
from .hugging_face import HuggingFace, HuggingFaceDataset
from .s3 import S3
def model_factory(model_provider, model_provider_parameters):
match model_provider:
case "hf":
hf = HuggingFace()
hf.load_config(model_provider_parameters)
hf.download_model_and_tokenizer()
case _:
return "This is the default case"
def dataset_factory(dataset_provider, dataset_provider_parameters):
match dataset_provider:
case "s3":
s3 = S3()
s3.load_config(dataset_provider_parameters)
s3.download_dataset()
case "hf":
hf = HuggingFaceDataset()
hf.load_config(dataset_provider_parameters)
hf.download_dataset()
case _:
return "This is the default case"
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="script for downloading model and datasets to PVC."
)
parser.add_argument("--model_provider", type=str, help="name of model provider")
parser.add_argument(
"--model_provider_parameters",
type=str,
help="model provider serialised arguments",
)
parser.add_argument("--dataset_provider", type=str, help="name of dataset provider")
parser.add_argument(
"--dataset_provider_parameters",
type=str,
help="dataset provider serialized arguments",
)
args = parser.parse_args()
model_factory(args.model_provider, args.model_provider_parameters)
dataset_factory(args.dataset_provider, args.dataset_provider_parameters)