In [None]:
%env GCE_PROJECT_NAME *insertProjectName*
%env TPU_ZONE *insert zone*
%env TPU_NAME *insert tpu name*

In [None]:
import os
import datetime
import tensorflow as tf
import xlnet
import run_classifier

### TASK DEFINITION

In [None]:
BUCKET = '*insert your bucket name*'
TASK_VERSION = '*task name*'
assert BUCKET, 'Must specify an existing GCS bucket name'
OUTPUT_DIR = 'gs://{}/xlnet_large/{}'.format(BUCKET, TASK_VERSION)
tf.gfile.MakeDirs(OUTPUT_DIR)
print('***** Model output directory: {} *****'.format(OUTPUT_DIR))

In [None]:
XLNET_PRETRAINED_DIR = 'gs://path/to/xlnet model/'

In [None]:
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver([os.environ['TPU_NAME']], zone=os.environ['TPU_ZONE'], project=os.environ['GCE_PROJECT_NAME'])
tpu_grpc_url = tpu_cluster_resolver.get_master()

run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=OUTPUT_DIR,  # define output_dir as the path where you want to store the fine-tuned model
    save_checkpoints_steps=1000,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=1000,
        num_shards=8,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))

xlnet_config = xlnet.XLNetConfig(json_path=os.path.join(XLNET_PRETRAINED_DIR, 'xlnet_config.json'))
xlnet_model = run_classifier.get_model_fn(num_classes)

In [None]:
# batch sizes can be defined in run_classifier's flags.
# i.e. change the default values according to your code as they are required flags
num_train_examples, num_dev_examples, num_test_examples = x,y,z   # number of records according to your .tfrecord file
num_train_steps = int(
    num_train_examples / TRAIN_BATCH_SIZE * NUM_TRAIN_EPOCHS)

estimator = tf.contrib.tpu.TPUEstimator(
        use_tpu=True,
        model_fn=xlnet_model,
        config=run_config,
        params=None,
        train_batch_size=TRAIN_BATCH_SIZE,
        predict_batch_size=PREDICT_BATCH_SIZE,
        eval_batch_size=EVAL_BATCH_SIZE)

# FINE TUNING

In [None]:
train_input_fn = run_classifier.file_based_input_fn_builder(
    input_file="input tfrecord file",
    seq_length=MAX_SEQ_LENGTH,  # set according the features that you have created
    is_training=True,
    drop_remainder=True)

In [None]:
%%time

# Train the model.
print('***** Started training at {} *****'.format(datetime.datetime.now()))
print('Num examples = {}'.format(num_train_examples))
print('Batch size = {}'.format(TRAIN_BATCH_SIZE))
tf.logging.info("  Num steps = %d", num_train_steps)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
print('***** Finished training at {} *****'.format(datetime.datetime.now()))

# EVALUATION

In [None]:
%%time
print('***** Started test eval at {} *****'.format(datetime.datetime.now()))
print('  Num examples = {}'.format(num_test_examples))
print('  Batch size = {}'.format(EVAL_BATCH_SIZE))
test_steps = int(num_test_examples / EVAL_BATCH_SIZE)
test_input_fn = run_classifier.file_based_input_fn_builder(
    input_file="gs://input evaluation tfrecord file",
    seq_length=MAX_SEQ_LENGTH,
    is_training=False,
    drop_remainder=True)
result = estimator.evaluate(input_fn=test_input_fn, steps=test_steps)
print('***** Finished evaluation at {} *****'.format(datetime.datetime.now()))
output_test_eval_file = os.path.join(OUTPUT_DIR, "test_results.txt")
with tf.gfile.GFile(output_test_eval_file, "w") as writer:
  print("***** Test Eval results *****")
  for key in sorted(result.keys()):
    print('  {} = {}'.format(key, str(result[key])))
    writer.write("%s = %s\n" % (key, str(result[key])))

# EXPORT MODEL

In [None]:
def serving_input_fn():
  with tf.variable_scope("sorting_hat_sa1_5"):
    feature_spec = {
        "input_ids": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
        "input_mask": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
        "segment_ids": tf.FixedLenFeature([MAX_SEQ_LENGTH], tf.int64),
        "label_ids": tf.FixedLenFeature([], tf.int64),
      }
    serialized_tf_example = tf.placeholder(dtype=tf.string,
                                           shape=[None],
                                           name='input_example_tensor')
    receiver_tensors = {'examples': serialized_tf_example}
    features = tf.parse_example(serialized_tf_example, feature_spec)
    return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

EXPORT_DIR = 'gs://path/to/export'
estimator._export_to_tpu = False  # this is important
path = estimator.export_savedmodel(EXPORT_DIR, serving_input_fn)
print(path)