Skip to content

Commit

Permalink
fixed args for patch size
Browse files Browse the repository at this point in the history
  • Loading branch information
philschmid committed Jul 28, 2021
1 parent c3fa5b5 commit 9df51d5
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 12 deletions.
4 changes: 2 additions & 2 deletions sagemaker/02_getting_started_tensorflow/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@

# Hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--train-batch-size", type=int, default=16)
parser.add_argument("--eval-batch-size", type=int, default=8)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=8)
parser.add_argument("--model_name", type=str)
parser.add_argument("--learning_rate", type=str, default=5e-5)
parser.add_argument("--do_train", type=bool, default=True)
Expand Down
8 changes: 4 additions & 4 deletions sagemaker/05_spot_instances/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@

# hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--train-batch-size", type=int, default=32)
parser.add_argument("--eval-batch-size", type=int, default=64)
parser.add_argument("--train_batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=64)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--model_name", type=str)
parser.add_argument("--learning_rate", type=str, default=5e-5)
parser.add_argument("--output_dir", type=str)

# Data, model, and output directories
parser.add_argument("--output-data-dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])
parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])
parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])
parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])
parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])
Expand Down
8 changes: 4 additions & 4 deletions sagemaker/06_sagemaker_metrics/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

# hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--train-batch-size", type=int, default=32)
parser.add_argument("--eval-batch-size", type=int, default=64)
parser.add_argument("--train_batch_size", type=int, default=32)
parser.add_argument("--eval_batch_size", type=int, default=64)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument("--model_name", type=str)
parser.add_argument("--learning_rate", type=float, default=5e-5)

# Data, model, and output directories
parser.add_argument("--checkpoints", type=str, default="/opt/ml/checkpoints/")
parser.add_argument("--model-dir", type=str, default=os.environ["SM_MODEL_DIR"])
parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])
parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])
parser.add_argument("--training_dir", type=str, default=os.environ["SM_CHANNEL_TRAIN"])
parser.add_argument("--test_dir", type=str, default=os.environ["SM_CHANNEL_TEST"])
Expand Down Expand Up @@ -94,4 +94,4 @@ def compute_metrics(pred):
writer.write(f"{key} = {value}\n")

# Saves the model locally. In SageMaker, writing in /opt/ml/model sends it to S3
trainer.save_model(args.model_dir)
trainer.save_model(args.model_dir)
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def get_datasets():

# Hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--train-batch-size", type=int, default=16)
parser.add_argument("--eval-batch-size", type=int, default=8)
parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=8)
parser.add_argument("--model_name", type=str)
parser.add_argument("--learning_rate", type=str, default=5e-5)
parser.add_argument("--do_train", type=bool, default=True)
Expand Down

0 comments on commit 9df51d5

Please sign in to comment.