diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 7910c341..5eabb384 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -18,6 +18,7 @@ # pylint: disable=no-name-in-module from instructlab.dolomite.hf_models import GPTDolomiteForCausalLM +from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoModelForCausalLM, get_scheduler import torch @@ -325,7 +326,7 @@ def train( lr_scheduler, accelerator: Accelerator, tokenizer, - train_loader, + train_loader: DataLoader, grad_accum, metric_logger, ): @@ -457,6 +458,7 @@ def train( "total_loss": float(log_loss / num_loss_counted_tokens), "samples_seen": samples_seen, "gradnorm": global_grad_norm, + "total_samples": len(train_loader.dataset), # "weight_norm": weight_norm, } ) @@ -620,6 +622,7 @@ def main(args): "num_batches": len(train_loader), "avg_samples_per_batch": len(dataset) / len(train_loader), "samples_per_gpu": args.samples_per_gpu, + "total_samples": len(dataset), # emit the total number of samples } )