You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
classMnistTrainer(eve.app.trainer.BaseTrainer):
defreset(self) ->np.ndarray:
"""Evaluate current trainer, reload trainer and then return the initial obs. Returns: obs: np.ndarray, the initial observation of trainer. """# do a fast validself.steps+=1ifself.steps%self.eval_steps==0:
self.steps=0finetune_acc=self.valid()["acc"]
# eval modeliffinetune_acc>self.finetune_acc:
self.finetune_acc=finetune_acc# reset model to explore more posibilityself.load_from_RAM()
# save best model which achieve higher rewardifself.accumulate_reward>self.best_reward:
self.cache_to_RAM()
self.best_reward=self.accumulate_reward# clear accumulate rewardself.accumulate_reward=0.0# reset related last value# WRAN: don't forget to reset self._obs_gen and self._last_eve_obs to None.# somtimes, the episode may be interrupted, but the gen do not reset.self.last_eve=Noneself.obs_generator=Noneself.upgrader.zero_obs()
self.fit_step()
returnself.fetch_obs()
defclose(self):
"""Override close in your subclass to perform any necessary cleanup. Environments will automatically close() themselves when garbage collected or when the program exits. """# load best model firstself.load_from_RAM()
finetune_acc=self.test()["acc"]
bits=0bound=0forvinself.upgrader.eve_parameters():
bits=bits+th.floor(v.mean() *8)
bound=bound+8bits=bits.item()
print(
f"baseline: {self.baseline_acc}, ours: {finetune_acc}, bits: {bits} / {bound}"
)
ifself.tensorboard_logisnotNone:
save_path=self.kwargs.get(
"save_path", os.path.join(self.tensorboard_log, "model.ckpt"))
self.save_checkpoint(path=save_path)
print(f"save trained model to {save_path}")
defreward(self) ->float:
"""A simple reward function. You have to rewrite this function based on your tasks. """self.upgrader.zero_obs()
info=self.fit_step()
returninfo["acc"] -self.last_eve.mean().item() *0.4
# define a mnist classifierneuron_wise=Truesample_episode=Falsemnist_classifier=MnistClassifier(mnist(neuron_wise))
mnist_classifier.prepare_data(data_root="/home/densechen/dataset")
mnist_classifier.setup_train() # use default configuration# set mnist classifier to quantization modemnist_classifier.quantize()
# set neurons and states# if neuron wise, we just set neurons as the member of max neurons of the network# else set it to 1.mnist_classifier.set_neurons(16ifneuron_wiseelse1)
mnist_classifier.set_states(1)
# None will use a default casemnist_classifier.set_action_space(None)
mnist_classifier.set_observation_space(None)
# define a trainerMnistTrainer.assign_model(mnist_classifier)
# define a experiment managerexp_manager=eve.app.ExperimentManager(
algo="ddpg",
env_id="mnist_trainer",
env=MnistTrainer,
log_folder="examples/logs",
n_timesteps=100000,
save_freq=1000,
default_hyperparameter_yaml="hyperparams",
log_interval=100,
sample_episode=sample_episode,
)
model=exp_manager.setup_experiment()
exp_manager.learn(model)
exp_manager.save_trained_model(model)