Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
192 lines (149 sloc) 4.96 KB
import glob
import cPickle as pickle
import os
class Checkpointer:
def __init__(self, model_nm, cell_nm, attention_type):
"""
:param model_nm:
:param cell_nm:
:param attention_type:
"""
self.model_nm = model_nm
self.cell_nm = cell_nm
self.attention_type = attention_type
self.last_ckpt = None
self.last_id = 0
self.step_save_location = 'steps.p'
self.data_save_location = 'data'
self.mapper_save_location = 'mapper.p'
self.steps_per_ckpt = None
self.num_steps_per_prediction = None
self.present_checkpoints = None
self.outfile = None
# initialize the steps if not initialized
if self.step_save_location not in os.listdir(self.get_checkpoint_location()):
pickle.dump(0,open(self.get_step_file(), 'wb'))
def steps_per_checkpoint(self, num_steps):
"""
:param num_steps:
:return:
"""
self.steps_per_ckpt = num_steps
def get_checkpoint_steps(self):
"""
:return:
"""
return self.steps_per_ckpt
def steps_per_prediction(self, num_steps):
"""
:param num_steps:
:return:
"""
self.num_steps_per_prediction = num_steps
def get_prediction_checkpoint_steps(self):
"""
:return:
"""
return self.num_steps_per_prediction
def get_checkpoint_location(self):
"""
:return:
"""
return 'checkpoint/' + self.model_nm + '/' + self.cell_nm + '/' + self.attention_type
def get_last_checkpoint(self):
"""
Assumes that the last checpoint has a higher checkpoint id. Checkpoint will be saved in this exact format
model_<checkpint_id>.ckpt Eg - model_100.ckpt
:return:
"""
'''
'''
self.present_checkpoints = glob.glob(self.get_checkpoint_location() + '/*.ckpt')
if len(self.present_checkpoints) != 0:
present_ids = [self.__get_id(ckpt) for ckpt in self.present_checkpoints]
# sort the ID's and return the model for the last ID
present_ids.sort()
self.last_id = present_ids[-1]
self.last_ckpt = self.get_checkpoint_location() + '/model_' +\
str(self.last_id) + '.ckpt'
return self.last_ckpt
def __get_id(self, ckpt_file):
"""
:param ckpt_file:
:return:
"""
return int(ckpt_file.split('.')[0].split('_')[1])
def delete_previous_checkpoints(self, num_previous=5):
"""
Deletes all previous checkpoints that are <num_previous> before the present checkpoint.
This is done to prevent blowing out of memory due to too many checkpoints
:param num_previous:
:return:
"""
self.present_checkpoints = glob.glob(self.get_checkpoint_location() + '/*.ckpt')
if len(self.present_checkpoints) > num_previous:
present_ids = [self.__get_id(ckpt) for ckpt in self.present_checkpoints]
present_ids.sort()
ids_2_delete = present_ids[0:len(present_ids) - num_previous]
for ckpt_id in ids_2_delete:
ckpt_file_nm = self.get_checkpoint_location() + '/model_' + str(ckpt_id) + '.ckpt'
os.remove(ckpt_file_nm)
def get_save_address(self):
"""
:return:
"""
_ = self.get_last_checkpoint()
next_id = self.last_id + 1
return self.get_checkpoint_location() + '/model_' + str(next_id) + '.ckpt'
def is_checkpointed(self):
"""
:return:
"""
return self.last_id > 0
def get_data_file_location(self):
"""
:return:
"""
return 'checkpoint/' + self.data_save_location
def get_mapper_file_location(self):
"""
:return:
"""
return 'checkpoint/' + self.data_save_location + '/' + self.mapper_save_location
def get_mapper_folder_location(self):
"""
:return:
"""
return 'checkpoint/' + self.data_save_location
def get_step_file(self):
"""
:return:
"""
return self.get_checkpoint_location() + '/' + self.step_save_location
def is_mapper_checkpointed(self):
"""
:return:
"""
if self.mapper_save_location in os.listdir(self.get_mapper_folder_location()):
return True
else:
return False
def is_output_file_present(self):
"""
:return:
"""
out_loc = self.outfile.split('/')
file_nm = out_loc[3]
dir_nm = out_loc[0] + '/' + out_loc[1] + '/' + out_loc[2] + '/'
return file_nm in os.listdir(dir_nm)
def set_result_location(self, outfile):
"""
:param outfile:
:return:
"""
self.outfile = outfile
def get_result_location(self):
"""
:return:
"""
return self.outfile
You can’t perform that action at this time.