Skip to content

Commit 402d058

Browse files
committed
only printing on local_rank 0
1 parent 480b818 commit 402d058

File tree

1 file changed

+17
-10
lines changed

1 file changed

+17
-10
lines changed

train.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,8 @@ def get_ds(config):
145145
ds_raw = load_dataset('opus_books', f"{config['lang_src']}-{config['lang_tgt']}", split='train')
146146

147147
# Build tokenizers
148-
print("Loading tokenizers...")
148+
if config['local_rank'] == 0:
149+
print("Loading tokenizers...")
149150
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
150151
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])
151152

@@ -167,8 +168,9 @@ def get_ds(config):
167168
max_len_src = max(max_len_src, len(src_ids))
168169
max_len_tgt = max(max_len_tgt, len(tgt_ids))
169170

170-
print(f'Max length of source sentence: {max_len_src}')
171-
print(f'Max length of target sentence: {max_len_tgt}')
171+
if config['local_rank'] == 0:
172+
print(f'Max length of source sentence: {max_len_src}')
173+
print(f'Max length of target sentence: {max_len_tgt}')
172174

173175

174176
train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=False, sampler=DistributedSampler(train_ds, shuffle=True))
@@ -184,13 +186,15 @@ def train_model(config):
184186
# Define the device
185187
assert torch.cuda.is_available(), "Training on CPU is not supported"
186188
device = torch.device("cuda")
187-
print("Using device:", device)
189+
if config['local_rank'] == 0:
190+
print("Using device:", device)
188191

189192
# Make sure the weights folder exists
190193
Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
191194

192195
# Load the dataset
193-
print("Loading dataset...")
196+
if config['local_rank'] == 0:
197+
print("Loading dataset...")
194198
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
195199
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
196200

@@ -209,7 +213,8 @@ def train_model(config):
209213

210214
# If we couldn't find a model to preload, just start from scratch
211215
if model_filename is not None:
212-
print(f'Preloading model {model_filename}')
216+
if config['local_rank'] == 0:
217+
print(f'Preloading model {model_filename}')
213218
state = torch.load(model_filename)
214219
model.load_state_dict(state['model_state_dict'])
215220
initial_epoch = state['epoch'] + 1
@@ -218,7 +223,8 @@ def train_model(config):
218223
wandb_run_id = state['wandb_run_id']
219224
del state
220225
else:
221-
print(f'Could not find model to preload: {config["preload"]}. Starting from scratch')
226+
if config['local_rank'] == 0:
227+
print(f'Could not find model to preload: {config["preload"]}. Starting from scratch')
222228

223229
# Only initialize W&B on the rank 0 node
224230
if config['global_rank'] == 0:
@@ -324,9 +330,10 @@ def train_model(config):
324330
config['global_rank'] = int(os.environ['RANK'])
325331

326332
# Print configuration
327-
print("Configuration:")
328-
for key, value in config.items():
329-
print(f"{key:>20}: {value}")
333+
if config['local_rank'] == 0:
334+
print("Configuration:")
335+
for key, value in config.items():
336+
print(f"{key:>20}: {value}")
330337

331338
# Setup distributed training
332339
init_process_group(backend='nccl')

0 commit comments

Comments
 (0)