Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for storing models from N last epochs and decoding with those models. #61

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion cfg/TIMIT_baselines/TIMIT_MLP_fmllr.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ use_cuda = True
multi_gpu = False
save_gpumem = False
n_epochs_tr = 24
# Last n_mdls_store models will be stored. Leave empty to store only the final model.
n_mdls_store = 5

[dataset1]
data_name = TIMIT_tr
Expand Down Expand Up @@ -235,4 +237,7 @@ skip_scoring = false
scoring_script = local/score.sh
scoring_opts = "--min-lmwt 1 --max-lmwt 10"
norm_vars = False

# Decode with model from ep_to_decode epoch. Note that epoch indexing starts from 0,
# so e.g. decoding with ep_to_decode=3 will decode with model stored after the 4th epoch.
# Leave empty to decode with the final model.
ep_to_decode =
74 changes: 38 additions & 36 deletions run_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
config = configparser.ConfigParser()
config.read(cfg_file)


# Reading and parsing optional arguments from command line (e.g.,--optimization,lr=0.002)
[section_args,field_args,value_args]=read_args_command_line(sys.argv,config)

Expand Down Expand Up @@ -87,17 +86,15 @@
create_lists(config)

# Writing the config files
create_configs(config)
create_configs(config)

print("- Chunk creation......OK!\n")

# create res_file
res_file_path=out_folder+'/res.res'
res_file = open(res_file_path, "w")
res_file = open(res_file_path, "a")
res_file.close()



# Learning rates and architecture-specific optimization parameters
arch_lst=get_all_archs(config)
lr={}
Expand Down Expand Up @@ -144,7 +141,6 @@
lab_dict=[]
arch_dict=[]


# --------TRAINING LOOP--------#
for ep in range(N_ep):

Expand All @@ -157,7 +153,7 @@
for tr_data in tr_data_lst:

# Compute the total number of chunks for each training epoch
N_ck_tr=compute_n_chunks(out_folder,tr_data,ep,N_ep_str_format,'train')
N_ck_tr=compute_n_chunks(out_folder,tr_data,format(ep, N_ep_str_format),'train')
N_ck_str_format='0'+str(max(math.ceil(np.log10(N_ck_tr)),1))+'d'

# ***Epoch training***
Expand All @@ -180,7 +176,6 @@

# update learning rate in the cfg file (if needed)
change_lr_cfg(config_chunk_file,lr,ep)


# if this chunk has not already been processed, do training...
if not(os.path.exists(info_file)):
Expand Down Expand Up @@ -210,14 +205,18 @@
for pt_arch in pt_files.keys():
pt_files[pt_arch]=out_folder+'/exp_files/train_'+tr_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'_'+pt_arch+'.pkl'

# remove previous pkl files
# remove previous pkl files but store last n_mdls_store models
if config['exp']['n_mdls_store']:
n_mdls_store = int(config['exp']['n_mdls_store'])
else:
n_mdls_store = 0

if len(model_files_past.keys())>0:
for pt_arch in pt_files.keys():
if os.path.exists(model_files_past[pt_arch]):
if os.path.exists(model_files_past[pt_arch]) and (ep <= N_ep-n_mdls_store or ck != 0):
os.remove(model_files_past[pt_arch])


# Training Loss and Error
# Training Loss and Error
tr_info_lst=sorted(glob.glob(out_folder+'/exp_files/train_'+tr_data+'_ep'+format(ep, N_ep_str_format)+'*.info'))
[tr_loss,tr_error,tr_time]=compute_avg_performance(tr_info_lst)

Expand All @@ -238,7 +237,7 @@
for valid_data in valid_data_lst:

# Compute the number of chunks for each validation dataset
N_ck_valid=compute_n_chunks(out_folder,valid_data,ep,N_ep_str_format,'valid')
N_ck_valid=compute_n_chunks(out_folder,valid_data,format(ep, N_ep_str_format),'valid')
N_ck_str_format='0'+str(max(math.ceil(np.log10(N_ck_valid)),1))+'d'

for ck in range(N_ck_valid):
Expand Down Expand Up @@ -276,11 +275,10 @@
valid_peformance_dict[valid_data]=[valid_loss,valid_error,valid_time]
tot_time=tot_time+valid_time

# Print results in both res_file and stdout
dump_epoch_results(res_file_path, ep, tr_data_lst, tr_loss_tot, tr_error_tot, tot_time, valid_data_lst, valid_peformance_dict, lr, N_ep)
# Print results in both res_file and stdout, do not overwrite res.res file when reruning decoding
if not(os.path.exists(out_folder+'/exp_files/final_'+pt_arch+'.pkl')):
dump_epoch_results(res_file_path, ep, tr_data_lst, tr_loss_tot, tr_error_tot, tot_time, valid_data_lst, valid_peformance_dict, lr, N_ep)


# Check for learning rate annealing
if ep>0:
# computing average validation error (on all the dataset specified)
Expand All @@ -302,21 +300,26 @@
# --------FORWARD--------#
for forward_data in forward_data_lst:

if config['decoding']['ep_to_decode']:
decode_epoch = config['decoding']['ep_to_decode']
else:
decode_epoch = format(ep, N_ep_str_format)

# Compute the number of chunks
N_ck_forward=compute_n_chunks(out_folder,forward_data,ep,N_ep_str_format,'forward')
N_ck_forward=compute_n_chunks(out_folder,forward_data,decode_epoch,'forward')
N_ck_str_format='0'+str(max(math.ceil(np.log10(N_ck_forward)),1))+'d'

for ck in range(N_ck_forward):


if not is_production:
print('Testing %s chunk = %i / %i' %(forward_data,ck+1, N_ck_forward))
print('Testing %s chunk = %i / %i with model stored in epoch %s' %(forward_data,ck+1, N_ck_forward, decode_epoch))
else:
print('Forwarding %s chunk = %i / %i' %(forward_data,ck+1, N_ck_forward))
print('Forwarding %s chunk = %i / %i with model stored in epoch %s' %(forward_data,ck+1, N_ck_forward, decode_epoch))

# output file
info_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'.info'
config_chunk_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+format(ep, N_ep_str_format)+'_ck'+format(ck, N_ck_str_format)+'.cfg'

info_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+ decode_epoch +'_ck'+format(ck, N_ck_str_format) + '.info'
config_chunk_file=out_folder+'/exp_files/forward_'+forward_data+'_ep'+ decode_epoch +'_ck'+format(ck, N_ck_str_format) + '.cfg'

# Do forward if the chunk was not already processed
if not(os.path.exists(info_file)):
Expand All @@ -329,15 +332,13 @@
# run chunk processing
[data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict]=run_nn(data_name,data_set,data_end_index,fea_dict,lab_dict,arch_dict,config_chunk_file,processed_first,next_config_file)


# update the first_processed variable
processed_first=False

if not(os.path.exists(info_file)):
sys.stderr.write("ERROR: forward chunk %i of dataset %s not done! File %s does not exist.\nSee %s \n" % (ck,forward_data,info_file,log_file))
sys.exit(0)



# update the operation counter
op_counter+=1

Expand All @@ -350,18 +351,15 @@
forward_outs=config['forward']['forward_out'].split(',')
forward_dec_outs=list(map(strtobool,config['forward']['require_decoding'].split(',')))


for data in forward_data_lst:
for k in range(len(forward_outs)):
if forward_dec_outs[k]:

print('Decoding %s output %s' %(data,forward_outs[k]))

info_file=out_folder+'/exp_files/decoding_'+data+'_'+forward_outs[k]+'.info'



print('Decoding %s output %s for model stored in epoch %s' %(data,forward_outs[k],decode_epoch))

info_file=out_folder + '/exp_files/decoding_' + data + '_' + forward_outs[k] + '_e' + decode_epoch + '.info'
# create decode config file
config_dec_file=out_folder+'/decoding_'+data+'_'+forward_outs[k]+'.conf'
config_dec_file=out_folder + '/decoding_' + data + '_' + forward_outs[k] + '_e' + decode_epoch + '.conf'
config_dec = configparser.ConfigParser()
config_dec.add_section('decoding')

Expand Down Expand Up @@ -402,14 +400,18 @@

out_folder=os.path.abspath(out_folder)
files_dec=out_folder+'/exp_files/forward_'+data+'_ep*_ck*_'+forward_outs[k]+'_to_decode.ark'
out_dec_folder=out_folder+'/decode_'+data+'_'+forward_outs[k]
out_dec_folder=out_folder+'/decode_' + data + '_' + forward_outs[k] + '_e' + decode_epoch

if not(os.path.exists(info_file)):

# Run the decoder
cmd_decode=cmd+config['decoding']['decoding_script_folder'] +'/'+ config['decoding']['decoding_script']+ ' '+os.path.abspath(config_dec_file)+' '+ out_dec_folder + ' \"'+ files_dec + '\"'
run_shell(cmd_decode,log_file)


# Create deocding info file
with open(info_file, 'a'):
os.utime(info_file, None)

# remove ark files if needed
if not forward_save_files[k]:
list_rem=glob.glob(files_dec)
Expand Down