Skip to content

Commit

Permalink
Restore optimizer state
Browse files Browse the repository at this point in the history
  • Loading branch information
cswinter committed Oct 25, 2019
1 parent 857f583 commit 52bced1
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 26 deletions.
1 change: 1 addition & 0 deletions hyper_params.py
Expand Up @@ -22,6 +22,7 @@ def __init__(self):
self.fp16 = False # Whether to use half-precision floating point
self.zero_init_vf = True # Set all initial weights for value function head to zero
self.small_init_pi = False # Set initial weights for policy head to small values and biases to zero
self.resume_from = '' # Filepath to saved policy

# Eval
self.eval_envs = 256
Expand Down
79 changes: 55 additions & 24 deletions main.py
Expand Up @@ -14,6 +14,8 @@
from hyper_params import HyperParams
from policy import Policy

logger = logging.getLogger(__name__)

TEST_LOG_ROOT_DIR = '/home/clemens/Dropbox/artifacts/DeepCodeCraft_test'
LOG_ROOT_DIR = '/home/clemens/Dropbox/artifacts/DeepCodeCraft'
EVAL_MODELS_PATH = '/home/clemens/Dropbox/artifacts/DeepCodeCraft/golden-models'
Expand Down Expand Up @@ -59,34 +61,44 @@ def train(hps: HyperParams, out_dir: str) -> None:
print("Running on CPU")
device = "cpu"

policy = Policy(hps.depth, hps.width, hps.conv, hps.small_init_pi, hps.zero_init_vf, hps.fp16).to(device)
if hps.fp16:
policy = policy.half()
if hps.optimizer == 'SGD':
optimizer = optim.SGD(policy.parameters(), lr=hps.lr, momentum=hps.momentum, weight_decay=hps.weight_decay)
optimizer_fn = optim.SGD
optimizer_kwargs = dict(lr=hps.lr, momentum=hps.momentum, weight_decay=hps.weight_decay)
elif hps.optimizer == 'RMSProp':
optimizer = optim.RMSprop(policy.parameters(), lr=hps.lr, momentum=hps.momentum, weight_decay=hps.weight_decay)
optimizer_fn = optim.RMSprop
optimizer_kwargs = dict(lr=hps.lr, momentum=hps.momentum, weight_decay=hps.weight_decay)
elif hps.optimizer == 'Adam':
optimizer = optim.Adam(policy.parameters(), lr=hps.lr, weight_decay=hps.weight_decay, eps=1e-5)
optimizer_fn = optim.Adam
optimizer_kwargs = dict(lr=hps.lr, weight_decay=hps.weight_decay, eps=1e-5)
else:
raise Exception(f'Invalid optimizer name `{hps.optimizer}`')

resume_steps = 0
if hps.resume_from == '':
policy = Policy(hps.depth, hps.width, hps.conv, hps.small_init_pi, hps.zero_init_vf, hps.fp16).to(device)
optimizer = optimizer_fn(policy.parameters(), **optimizer_kwargs)
else:
policy, optimizer, resume_steps = load_policy(hps.resume_from, device, optimizer_fn, optimizer_kwargs)

if hps.fp16:
policy = policy.half()

wandb.watch(policy)

total_steps = 0
total_steps = resume_steps
epoch = 0
obs, action_masks = env.reset()
eprewmean = 0
eplenmean = 0
completed_episodes = 0
while total_steps < hps.steps:
while total_steps < hps.steps + resume_steps:
if total_steps >= next_eval and hps.eval_envs > 0:
eval(policy, hps, device, total_steps)
next_eval += hps.eval_frequency
next_model_save -= 1
if next_model_save == 0:
next_model_save = hps.model_save_frequency
save_policy(policy, out_dir, total_steps)
save_policy(policy, out_dir, total_steps, optimizer)

episode_start = time.time()
entropies = []
Expand Down Expand Up @@ -250,16 +262,7 @@ def train(hps: HyperParams, out_dir: str) -> None:

if hps.eval_envs > 0:
eval(policy, hps, device, total_steps)
save_policy(policy, out_dir, total_steps)


def save_policy(policy, out_dir, total_steps):
model_path = os.path.join(out_dir, f'model-{total_steps}.pt')
print(f'Saving policy to {model_path}')
torch.save({
'model_state_dict': policy.state_dict(),
'model_kwargs': policy.kwargs,
}, model_path)
save_policy(policy, out_dir, total_steps, optimizer)


def eval(policy, hps, device, total_steps):
Expand Down Expand Up @@ -296,7 +299,9 @@ def eval(policy, hps, device, total_steps):

i = 0
for name, opp in opponents.items():
opp['policy'] = load_policy(opp['model_file']).to(device)
policy, _, _ = load_policy(opp['model_file'], device)
policy.eval()
opp['policy'] = policy
opp['envs'] = odds[i * len(odds) // len(opponents):(i+1) * len(odds) // len(opponents)]
i += 1

Expand Down Expand Up @@ -335,20 +340,46 @@ def eval(policy, hps, device, total_steps):
'eval_max_score': scores.max(),
'eval_min_score': scores.min(),
}, step=total_steps)
print(f'Eval: {scores.mean()}')
for opp_name, scores in scores_by_opp.items():
scores = np.array(scores)
wandb.log({f'eval_mean_score_vs_{opp_name}': scores.mean()}, step=total_steps)
print(f'Eval: {scores.mean()}')

env.close()


def load_policy(name):
def save_policy(policy, out_dir, total_steps, optimizer=None):
model_path = os.path.join(out_dir, f'model-{total_steps}.pt')
print(f'Saving policy to {model_path}')
model = {
'model_state_dict': policy.state_dict(),
'model_kwargs': policy.kwargs,
'total_steps': total_steps,
}
if optimizer:
model['optimizer_state_dict'] = optimizer.state_dict()
torch.save(model, model_path)


def load_policy(name, device, optimizer_fn=None, optimizer_kwargs=None):
checkpoint = torch.load(os.path.join(EVAL_MODELS_PATH, name))
policy = Policy(**checkpoint['model_kwargs'])
policy.load_state_dict(checkpoint['model_state_dict'])
policy.eval()
return policy
policy.to(device)

optimizer = None
if optimizer_fn:
optimizer = optimizer_fn(policy.parameters(), **optimizer_kwargs)
if 'optimizer_state_dict' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(device)
else:
logger.warning(f'Failed to restore optimizer state: No `optimizer_state_dict` in saved model.')

return policy, optimizer, checkpoint.get('total_steps', 0)


def explained_variance(ypred,y):
Expand Down
4 changes: 2 additions & 2 deletions showmatch.py
Expand Up @@ -20,8 +20,8 @@ def showmatch(model1_path, model2_path, task, randomize):

nenv = 128

policy1 = load_policy(model1_path).to(device)
policy2 = load_policy(model2_path).to(device)
policy1, _, _ = load_policy(model1_path, device)
policy2, _, _ = load_policy(model2_path, device)

env = envs.CodeCraftVecEnv(nenv,
nenv // 2,
Expand Down

0 comments on commit 52bced1

Please sign in to comment.