@@ -145,7 +145,8 @@ def get_ds(config):
145145 ds_raw = load_dataset ('opus_books' , f"{ config ['lang_src' ]} -{ config ['lang_tgt' ]} " , split = 'train' )
146146
147147 # Build tokenizers
148- print ("Loading tokenizers..." )
148+ if config ['local_rank' ] == 0 :
149+ print ("Loading tokenizers..." )
149150 tokenizer_src = get_or_build_tokenizer (config , ds_raw , config ['lang_src' ])
150151 tokenizer_tgt = get_or_build_tokenizer (config , ds_raw , config ['lang_tgt' ])
151152
@@ -167,8 +168,9 @@ def get_ds(config):
167168 max_len_src = max (max_len_src , len (src_ids ))
168169 max_len_tgt = max (max_len_tgt , len (tgt_ids ))
169170
170- print (f'Max length of source sentence: { max_len_src } ' )
171- print (f'Max length of target sentence: { max_len_tgt } ' )
171+ if config ['local_rank' ] == 0 :
172+ print (f'Max length of source sentence: { max_len_src } ' )
173+ print (f'Max length of target sentence: { max_len_tgt } ' )
172174
173175
174176 train_dataloader = DataLoader (train_ds , batch_size = config ['batch_size' ], shuffle = False , sampler = DistributedSampler (train_ds , shuffle = True ))
@@ -184,13 +186,15 @@ def train_model(config):
184186 # Define the device
185187 assert torch .cuda .is_available (), "Training on CPU is not supported"
186188 device = torch .device ("cuda" )
187- print ("Using device:" , device )
189+ if config ['local_rank' ] == 0 :
190+ print ("Using device:" , device )
188191
189192 # Make sure the weights folder exists
190193 Path (config ['model_folder' ]).mkdir (parents = True , exist_ok = True )
191194
192195 # Load the dataset
193- print ("Loading dataset..." )
196+ if config ['local_rank' ] == 0 :
197+ print ("Loading dataset..." )
194198 train_dataloader , val_dataloader , tokenizer_src , tokenizer_tgt = get_ds (config )
195199 model = get_model (config , tokenizer_src .get_vocab_size (), tokenizer_tgt .get_vocab_size ()).to (device )
196200
@@ -209,7 +213,8 @@ def train_model(config):
209213
210214 # If we couldn't find a model to preload, just start from scratch
211215 if model_filename is not None :
212- print (f'Preloading model { model_filename } ' )
216+ if config ['local_rank' ] == 0 :
217+ print (f'Preloading model { model_filename } ' )
213218 state = torch .load (model_filename )
214219 model .load_state_dict (state ['model_state_dict' ])
215220 initial_epoch = state ['epoch' ] + 1
@@ -218,7 +223,8 @@ def train_model(config):
218223 wandb_run_id = state ['wandb_run_id' ]
219224 del state
220225 else :
221- print (f'Could not find model to preload: { config ["preload" ]} . Starting from scratch' )
226+ if config ['local_rank' ] == 0 :
227+ print (f'Could not find model to preload: { config ["preload" ]} . Starting from scratch' )
222228
223229 # Only initialize W&B on the rank 0 node
224230 if config ['global_rank' ] == 0 :
@@ -324,9 +330,10 @@ def train_model(config):
324330 config ['global_rank' ] = int (os .environ ['RANK' ])
325331
326332 # Print configuration
327- print ("Configuration:" )
328- for key , value in config .items ():
329- print (f"{ key :>20} : { value } " )
333+ if config ['local_rank' ] == 0 :
334+ print ("Configuration:" )
335+ for key , value in config .items ():
336+ print (f"{ key :>20} : { value } " )
330337
331338 # Setup distributed training
332339 init_process_group (backend = 'nccl' )
0 commit comments