Skip to content

Commit

Permalink
Enable greedy sampling (#70)
Browse files Browse the repository at this point in the history
* Enable greedy sampling

* Refactored based on comments in revision 1

* Supporting decoder model sampling

* Addressed comments in rev2

* Formatted files with make style

* Reformatting after rebase

---------

Co-authored-by: Aashiq Muhamed <muhaaash@amazon.com>
  • Loading branch information
aashiqmuhamed and Aashiq Muhamed committed Jun 5, 2023
1 parent 872a38e commit 7bc8e9b
Show file tree
Hide file tree
Showing 5 changed files with 990 additions and 49 deletions.
30 changes: 15 additions & 15 deletions notebooks/text-classification/scripts/train.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import os
import argparse
import logging
import os

import evaluate
import numpy as np
from datasets import load_from_disk
from huggingface_hub import HfFolder
from transformers import (
AutoModelForSequenceClassification,
DataCollatorForSeq2Seq,
AutoTokenizer,
TrainingArguments,
set_seed,
)
from datasets import load_from_disk
import torch
import evaluate
import numpy as np
import logging
from huggingface_hub import HfFolder
from transformers import TrainingArguments

from optimum.neuron import TrainiumTrainer as Trainer


Expand All @@ -22,15 +21,14 @@

print(f"is precompilation: {os.environ.get('NEURON_PARALLEL_COMPILE')}")


def parse_args():
"""Parse the arguments."""
parser = argparse.ArgumentParser()
# add model id and dataset path argument
parser.add_argument("--model_id", type=str, default="bert-large-uncased", help="Model id to use for training.")
parser.add_argument("--dataset_path", type=str, default="dataset", help="Path to the already processed dataset.")
parser.add_argument(
"--output_dir", type=str, default=None, help="Hugging Face Repository id for uploading models"
)
parser.add_argument("--output_dir", type=str, default=None, help="Hugging Face Repository id for uploading models")
parser.add_argument(
"--repository_id", type=str, default=None, help="Hugging Face Repository id for uploading models"
)
Expand All @@ -55,9 +53,11 @@ def parse_args():
args = parser.parse_known_args()
return args


# Metric Id
metric = evaluate.load("f1")


# Metric helper method
def compute_metrics(eval_pred):
predictions, labels = eval_pred
Expand All @@ -73,11 +73,11 @@ def training_function(args):
train_dataset = load_from_disk(os.path.join(args.dataset_path, "train"))
eval_dataset = load_from_disk(os.path.join(args.dataset_path, "eval"))
tokenizer = AutoTokenizer.from_pretrained(args.model_id)

# Prepare model labels - useful for inference
labels = train_dataset.features["labels"].names
num_labels = len(labels)
label2id, id2label = dict(), dict()
label2id, id2label = {}, {}
for i, label in enumerate(labels):
label2id[label] = str(i)
id2label[str(i)] = label
Expand Down Expand Up @@ -143,4 +143,4 @@ def main():


if __name__ == "__main__":
main()
main()

0 comments on commit 7bc8e9b

Please sign in to comment.