-
Notifications
You must be signed in to change notification settings - Fork 1
/
config.py
71 lines (61 loc) · 1.56 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from net_arch.feature_extractor import PensieveFeatureExtractor
from net_arch.mlp import CartPoleNetwork, SB3MLPDQNNetwork
from net_arch.single_path_net import SinglePathPolicy
from utils import create_training_data, get_fcc_test_data, get_linear_exp_decay
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_ENV_CONF = {
"env_name": "SinglePath-v0",
"num_env": 1,
"params": {
"train": True,
"bitrate_list": create_training_data()
}
}
TEST_ENV_CONF = {
"env_name": "SinglePath-v0",
"num_env": 10,
"params": {
"train": False,
"bitrate_list": get_fcc_test_data()
}
}
MODEL_CONF = {
"model": SinglePathPolicy(PensieveFeatureExtractor(device=DEVICE),
SB3MLPDQNNetwork(256, 7, torch.nn.ReLU())),
"optimizer_class": torch.optim.Adam,
"optimizer_conf": {
"lr": 0.0005
},
"gamma": 0.99,
"n_step": 3,
"target_update_freq": 1,
"exploration_update": get_linear_exp_decay(),
"step_per_collect": 100
}
SERVER_CONF = {
"num_clients": 100,
"chosen_prob": 0.1
}
CLIENT_CONF = {
"local_batch_size": 32,
"local_epochs": 10,
"max_buffer_size": 2000,
"test_num": 10
}
TRAINING_CONF = {
"num_rounds": 200,
"seed": 44,
"model_dir": "tmp",
"model_name": "dqn_fl_44",
"result_dir": "results",
"result_fig": "result.png"
}
# # Uncomment this for Wandb usage
# # and provide API key
# WANDB_CONFIG = {
# "WANDB_API_KEY": "",
# "project": "fl_rl",
# "name": "fl_rl_dqn"
# }
WANDB_CONFIG = None