-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
executable file
·79 lines (61 loc) · 2.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
from pathlib import Path
import torch
from torch.backends import cudnn as cudnn
import orchestrator
from helpers import logger
from helpers.argparser_util import agg_argparser
from helpers.experiment import ExperimentInitializer
from algos.compression.compressor import Compressor
def run(args):
# Initialize and configure experiment
experiment = ExperimentInitializer(args)
experiment.configure_logging()
# Create experiment name
experiment_name = experiment.get_name()
# Set device-related knobs
assert not args.fp16 or args.cuda, "fp16 ==> cuda"
if args.cuda:
# Use cuda
assert torch.cuda.is_available()
cudnn.benchmark = False
cudnn.deterministic = True
device = torch.device("cuda:0")
else:
if torch.has_mps:
# Use Apple's Metal Performance Shaders (MPS)
device = torch.device("mps")
else:
# Default case: just use plain old cpu, no cuda or m-chip gpu
device = torch.device("cpu")
os.environ["CUDA_VISIBLE_DEVICES"] = "" # kill any possibility of usage
args.device = device # add the device to hps for convenience
logger.info(f"device in use: {device}")
# Seedify
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
if args.dataset_handle == 'bigearthnet':
pass
else:
raise NotImplementedError("dataset not covered")
if args.algo_handle in ['vqae', 'residualvqae']:
algo_class_handle = Compressor
else:
raise NotImplementedError("algorithm not covered")
def algo_wrapper():
return algo_class_handle(
hps=args,
)
# Train
orchestrator.learn(
args=args,
algo_wrapper=algo_wrapper,
experiment_name=experiment_name,
)
if __name__ == '__main__':
_args = agg_argparser().parse_args()
_args.root = Path(__file__).resolve().parent # make the paths absolute
for k in ['checkpoints', 'logs']:
new_k = f"{k[:-1]}_dir"
vars(_args)[new_k] = Path(_args.root) / 'data' / k
run(_args)