-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
train.py
178 lines (138 loc) · 6.64 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import argparse
import logging
import os
import sys
import tensorflow as tf
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
from transformers.file_utils import is_sagemaker_dp_enabled
if os.environ.get("SDP_ENABLED") or is_sagemaker_dp_enabled():
SDP_ENABLED = True
os.environ["SAGEMAKER_INSTANCE_TYPE"] = "p3dn.24xlarge"
import smdistributed.dataparallel.tensorflow as sdp
else:
SDP_ENABLED = False
def fit(model, loss, opt, train_dataset, epochs, train_batch_size, max_steps=None):
pbar = tqdm(train_dataset)
for i, batch in enumerate(pbar):
with tf.GradientTape() as tape:
inputs, targets = batch
outputs = model(batch)
loss_value = loss(targets, outputs.logits)
if SDP_ENABLED:
tape = sdp.DistributedGradientTape(tape, sparse_as_dense=True)
grads = tape.gradient(loss_value, model.trainable_variables)
opt.apply_gradients(zip(grads, model.trainable_variables))
pbar.set_description(f"Loss: {loss_value:.4f}")
if SDP_ENABLED:
if i == 0:
sdp.broadcast_variables(model.variables, root_rank=0)
sdp.broadcast_variables(opt.variables(), root_rank=0)
first_batch = False
if max_steps and i >= max_steps:
break
train_results = {"loss": loss_value.numpy()}
return train_results
def get_datasets():
# Load dataset
train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"])
# Preprocess train dataset
train_dataset = train_dataset.map(
lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True
)
train_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])
train_features = {
x: train_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
for x in ["input_ids", "attention_mask"]
}
tf_train_dataset = tf.data.Dataset.from_tensor_slices((train_features, train_dataset["label"]))
# Preprocess test dataset
test_dataset = test_dataset.map(
lambda e: tokenizer(e["text"], truncation=True, padding="max_length"), batched=True
)
test_dataset.set_format(type="tensorflow", columns=["input_ids", "attention_mask", "label"])
test_features = {
x: test_dataset[x].to_tensor(default_value=0, shape=[None, tokenizer.model_max_length])
for x in ["input_ids", "attention_mask"]
}
tf_test_dataset = tf.data.Dataset.from_tensor_slices((test_features, test_dataset["label"]))
if SDP_ENABLED:
tf_train_dataset = tf_train_dataset.shard(sdp.size(), sdp.rank())
tf_test_dataset = tf_test_dataset.shard(sdp.size(), sdp.rank())
tf_train_dataset = tf_train_dataset.batch(args.train_batch_size, drop_remainder=True)
tf_test_dataset = tf_test_dataset.batch(args.eval_batch_size, drop_remainder=True)
return tf_train_dataset, tf_test_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# Hyperparameters sent by the client are passed as command-line arguments to the script.
parser.add_argument("--epochs", type=int, default=3)
parser.add_argument("--train-batch-size", type=int, default=16)
parser.add_argument("--eval-batch-size", type=int, default=8)
parser.add_argument("--model_name", type=str)
parser.add_argument("--learning_rate", type=str, default=5e-5)
parser.add_argument("--do_train", type=bool, default=True)
parser.add_argument("--do_eval", type=bool, default=True)
# Data, model, and output directories
parser.add_argument("--output_data_dir", type=str, default=os.environ["SM_OUTPUT_DATA_DIR"])
parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])
parser.add_argument("--n_gpus", type=str, default=os.environ["SM_NUM_GPUS"])
args, _ = parser.parse_known_args()
# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.getLevelName("INFO"),
handlers=[logging.StreamHandler(sys.stdout)],
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
if SDP_ENABLED:
sdp.init()
gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
if gpus:
tf.config.experimental.set_visible_devices(gpus[sdp.local_rank()], "GPU")
# Load model and tokenizer
model = TFAutoModelForSequenceClassification.from_pretrained(args.model_name)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# get datasets
tf_train_dataset, tf_test_dataset = get_datasets()
# fine optimizer and loss
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
# Training
if args.do_train:
# train_results = model.fit(tf_train_dataset, epochs=args.epochs, batch_size=args.train_batch_size)
train_results = fit(
model, loss, optimizer, tf_train_dataset, args.epochs, args.train_batch_size, max_steps=None
)
logger.info("*** Train ***")
output_eval_file = os.path.join(args.output_data_dir, "train_results.txt")
if not SDP_ENABLED or sdp.rank() == 0:
with open(output_eval_file, "w") as writer:
logger.info("***** Train results *****")
logger.info(train_results)
for key, value in train_results.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
# Evaluation
if args.do_eval and (not SDP_ENABLED or sdp.rank() == 0):
result = model.evaluate(tf_test_dataset, batch_size=args.eval_batch_size, return_dict=True)
logger.info("*** Evaluate ***")
output_eval_file = os.path.join(args.output_data_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Eval results *****")
logger.info(result)
for key, value in result.items():
logger.info(" %s = %s", key, value)
writer.write("%s = %s\n" % (key, value))
# Save result
if SDP_ENABLED:
if sdp.rank() == 0:
model.save_pretrained(args.model_dir)
tokenizer.save_pretrained(args.model_dir)
else:
model.save_pretrained(args.model_dir)
tokenizer.save_pretrained(args.model_dir)