Skip to content

Commit 1ebf3ec

Browse files
committed
added verbose logging
1 parent f296354 commit 1ebf3ec

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def get_ds(config):
145145
ds_raw = load_dataset('opus_books', f"{config['lang_src']}-{config['lang_tgt']}", split='train')
146146

147147
# Build tokenizers
148+
print("Loading tokenizers...")
148149
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
149150
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])
150151

@@ -188,6 +189,8 @@ def train_model(config):
188189
# Make sure the weights folder exists
189190
Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
190191

192+
# Load the dataset
193+
print("Loading dataset...")
191194
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
192195
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
193196

@@ -284,7 +287,7 @@ def train_model(config):
284287
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step)
285288

286289
# Save the model at the end of every epoch
287-
model_filename = get_weights_file_path(config, f"{epoch:02d}")
290+
model_filename = get_weights_file_path(config, epoch)
288291
torch.save({
289292
'epoch': epoch,
290293
'model_state_dict': model.module.state_dict(), # Need to access module because we are using DDP

0 commit comments

Comments
 (0)