@@ -145,6 +145,7 @@ def get_ds(config):
145145 ds_raw = load_dataset ('opus_books' , f"{ config ['lang_src' ]} -{ config ['lang_tgt' ]} " , split = 'train' )
146146
147147 # Build tokenizers
148+ print ("Loading tokenizers..." )
148149 tokenizer_src = get_or_build_tokenizer (config , ds_raw , config ['lang_src' ])
149150 tokenizer_tgt = get_or_build_tokenizer (config , ds_raw , config ['lang_tgt' ])
150151
@@ -188,6 +189,8 @@ def train_model(config):
188189 # Make sure the weights folder exists
189190 Path (config ['model_folder' ]).mkdir (parents = True , exist_ok = True )
190191
192+ # Load the dataset
193+ print ("Loading dataset..." )
191194 train_dataloader , val_dataloader , tokenizer_src , tokenizer_tgt = get_ds (config )
192195 model = get_model (config , tokenizer_src .get_vocab_size (), tokenizer_tgt .get_vocab_size ()).to (device )
193196
@@ -284,7 +287,7 @@ def train_model(config):
284287 run_validation (model , val_dataloader , tokenizer_src , tokenizer_tgt , config ['seq_len' ], device , lambda msg : batch_iterator .write (msg ), global_step )
285288
286289 # Save the model at the end of every epoch
287- model_filename = get_weights_file_path (config , f" { epoch :02d } " )
290+ model_filename = get_weights_file_path (config , epoch )
288291 torch .save ({
289292 'epoch' : epoch ,
290293 'model_state_dict' : model .module .state_dict (), # Need to access module because we are using DDP
0 commit comments