### Creation of the environment

In [None]:
import os
os.environ['USE_AUTH_EPHEM'] = '0'

In [None]:
%tensorflow_version 2.x
!pip3 install --upgrade pip
#!pip install -qU t5
!pip3 install git+https://github.com/google-research/text-to-text-transfer-transformer.git #extra_id_x support

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

#Set the base dir(Google cloud bucket)
BASE_DIR = "gs://code-generation" 

#C.B.: SEQ_LENGTH for number of tokens
SEQ_LENGTH = 512

if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
ON_CLOUD = True


if ON_CLOUD:
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

Collecting pip
  Downloading pip-22.0.4-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 5.2 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.0.4
Collecting git+https://github.com/google-research/text-to-text-transfer-transformer.git
  Cloning https://github.com/google-research/text-to-text-transfer-transformer.git to /tmp/pip-req-build-bl_z769r
  Running command git clone --filter=blob:none --quiet https://github.com/google-research/text-to-text-transfer-transformer.git /tmp/pip-req-build-bl_z769r
  Resolved https://github.com/google-research/text-to-text-transfer-transformer.git to commit c070da4626d936bab4039b007a5202f039d55f0a
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mesh-tensorflow[transformer]>=0.1.13
  Downloading mesh_tensorflow-0.1.19-py3-none-any.whl (366 kB)

Instructions for updating:
non-resource variables are not supported in the long term


### Loading of tsv files
With this script you can load each tsv file for finetuning.
Please be sure that the path to all tsv files are correct

In [None]:
train_construct_length = 251271
test__construct_length = 37722
eval__construct_length = 36734

train_block_length = 101646
test__block_length = 12706
eval__block_length = 12705

train_token_length = 307779
test__token_length = 38480
eval__token_length = 38475

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_construct = {
    "train":      BASE_DIR + '/T5_extension/ft_datasets/construct_train.tsv',
    "validation": BASE_DIR + '/T5_extension/ft_datasets/construct_eval.tsv',
    # "validation": BASE_DIR + '/T5_extension/ft_datasets/construct_test.tsv',
}

num_nq_examples_construct = dict(train=train_construct_length, validation=eval__construct_length)
# num_nq_examples_construct = dict(train=train_construct_length, validation=test__construct_length)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_block = {
    "train":      BASE_DIR + '/T5_extension/ft_datasets/block_train.tsv',
    "validation": BASE_DIR + '/T5_extension/ft_datasets/block_eval.tsv',
    # "validation": BASE_DIR + '/T5_extension/ft_datasets/block_test.tsv',
}

num_nq_examples_block = dict(train=train_block_length, validation=eval__block_length)
# num_nq_examples_block = dict(train=train_block_length, validation=test__block_length)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_token = {
    "train":      BASE_DIR + '/T5_extension/ft_datasets/token_train.tsv',
    "validation": BASE_DIR + '/T5_extension/ft_datasets/token_eval.tsv',
    # "validation": BASE_DIR + '/T5_extension/ft_datasets/token_test.tsv',
}

num_nq_examples_token = dict(train=train_token_length, validation=eval__token_length)
# num_nq_examples_token = dict(train=train_token_length, validation=test__token_length)

### Preprocess of the dataset
In this step we preprocess the dataset.  
You have to change the path to vocab files (*vocab_model_path* and *vocab_path*)
We're going to preprocess all the tsv file so that T5 can use them for evaluation.

In [None]:
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary


# # Set the path of sentencepiece model and vocab files
# # Must be the same used for the pre-trained phase
vocab_model_path = BASE_DIR + '/T5_extension/code.model'
vocab_path = BASE_DIR + '/T5_extension/code.vocab'


TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask


def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True, required=False),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True)
}

JAVA CONSTRUCT

In [None]:
def nq_construct(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

   # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path_construct[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_construct("train").take(5)):
  print(ex)

A few raw train examples...
{'input': b'private String formatEventDateRange(Date beginDate, Date endDate) { if ( <extra_id_0>) { if (isEndOfDay(endDate)) { return formatEventDate(beginDate); } else if (isMidnight(beginDate)) { return formatEventDate(beginDate) + " until " + eventOutDayOnlyDf.format(endDate); } else { return formatEventDate(beginDate) + " - " + eventOutTimeOnlyDf.format(endDate); } } else { return formatEventDate(beginDate) + " - " + formatEventDate(endDate); } } <SEP> /** Get the string representation of an event date range. */', 'output': b'DateUtils.isSameDay(beginDate, endDate)'}
{'input': b'private String formatEventDateRange(Date beginDate, Date endDate) { if (DateUtils.isSameDay(beginDate, endDate)) { if ( <extra_id_0>) { return formatEventDate(beginDate); } else if (isMidnight(beginDate)) { return formatEventDate(beginDate) + " until " + eventOutDayOnlyDf.format(endDate); } else { return formatEventDate(beginDate) + " - " + eventOutTimeOnlyDf.format(endDate); } 

In [None]:
def construct_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['CONSTRUCT:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('construct')
t5.data.TaskRegistry.add(
    "construct",
    dataset_fn=nq_construct,
    splits=["train", "validation"],
    text_preprocessor=[construct_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_construct
)

<t5.data.dataset_providers.FunctionTask at 0x7f097f9884d0>

In [None]:
nq_task = t5.data.TaskRegistry.get("construct")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...
{'inputs_pretokenized': b'CONSTRUCT:private void invalidatePageTransformer() { if (mViewPager.getAdapter().getCount() > 0) { new Handler().post(new Runnable() { @Override public void run() { if (mViewPager.beginFakeDrag()) { mViewPager.fakeDragBy( <extra_id_0>); mViewPager.endFakeDrag(); } } }); } } <SEP> /** * Trick to notify the pageTransformer of a data set change. */', 'inputs': array([12094, 30813,    78, 14566,    65, 14635,  1068,  5880,  1927,
          20,    56,     4,   104,  1351, 23775,    12,  1323,  2690,
        1927,    12,  1323,   531,  1927,    57,    88,    91,    20,
          39, 14748,  1927,    12,  6254,   451,  1973,  6854,  1927,
          20,    55,  8203,    38,    65,   506,  1927,    20,    56,
           4,   104,  1351, 23775,    12, 22734, 19378, 17581, 23620,
          20,   101,  1351, 23775,    12, 19913, 17581,   950,   451,
       32099,  4767,   101,  1351, 23775,    12,   904, 19378, 17581,
       12329, 

JAVA TOKEN

In [None]:
def nq_token(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_token[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_token("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public boolean initialize <extra_id_0> if (bluetoothManager == null) { bluetoothManager = (BluetoothManager) getSystemService(Context.BLUETOOTH_SERVICE); if (bluetoothManager == null) { Log.e(TAG, "Unable to initialize BluetoothManager."); return false; } } adapter = bluetoothManager.getAdapter(); if (adapter == null) { Log.e(TAG, "Unable to obtain a BluetoothAdapter."); return false; } return true; } <SEP> /** * Initializes a reference to the local Bluetooth adapter. * * @return Return true if the initialization is successful. */', 'output': b'() {'}
{'input': b'public boolean initialize() { if (bluetoothManager == <extra_id_0> bluetoothManager = (BluetoothManager) getSystemService(Context.BLUETOOTH_SERVICE); if (bluetoothManager == null) { Log.e(TAG, "Unable to initialize BluetoothManager."); return false; } } adapter = bluetoothManager.getAdapter(); if (adapter == null) { Log.e(TAG, "Unable to obtain a BluetoothAdapter."); return false; } retu

In [None]:
def token_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['TOKEN:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('token')
t5.data.TaskRegistry.add(
    "token",
    dataset_fn=nq_token,
    splits=["train", "validation"],
    text_preprocessor=[token_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_token
)

<t5.data.dataset_providers.FunctionTask at 0x7f096842d050>

In [None]:
nq_task = t5.data.TaskRegistry.get("token")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...
{'inputs_pretokenized': b'TOKEN:public BlockDeviceObserver(Shell rootShell, PartitionListener listener) { super("/dev/block/", FileObserver.CREATE | FileObserver.DELETE); mVolumes = new HashMap<String, Volume>(); mListener = listener; mHandler = new Handler(Looper.getMainLooper()); mRootShell = rootShell <extra_id_0> detectDevices(); } <SEP> /** * Creates a new block device observer. * Does not start observing until {@link #startWatching()} is called. * @param rootShell The shell to execute mount commands in. * @param listener A listener to receive block device events. * Calls are received on the main thread. */', 'inputs': array([18964,    78,  2825,  3901,  3790,  8655,   451, 13259,  1475,
       13259,     9,  7277,  1128,  2173,    91,    20,   436,   451,
          26,    98, 11361,    98,  9062,    98,    26,     9,   629,
        8655,    12,  8543,  1338,   629,  8655,    12, 13112,  4767,
         101,  7491,     8,    24,    39,  1410,

JAVA BLOCK

In [None]:
def nq_block(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_block[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_block("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b"public String getChannelTitle(String homeCampus) { if (channelTitle == null) <extra_id_0> else return channelTitle.getTitle(homeCampus); } <SEP> /** * Get channel display title based on home campus. * @param homeCampus User's home campus * @return Campus-localized channel display title */", 'output': b'return getTitle(homeCampus);'}
{'input': b"public String getChannelTitle(String homeCampus) { if (channelTitle == null) return getTitle(homeCampus); else <extra_id_0> } <SEP> /** * Get channel display title based on home campus. * @param homeCampus User's home campus * @return Campus-localized channel display title */", 'output': b'return channelTitle.getTitle(homeCampus);'}
{'input': b'public int getColor() { String hex = options.optString("color", null); if (hex == null) <extra_id_0> int aRGB = Integer.parseInt(hex, 16); return aRGB + 0xFF000000; } <SEP> /** * @return * The notification background color for the small icon * Returns null, if no co

In [None]:
def block_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['BLOCK:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('block')
t5.data.TaskRegistry.add(
    "block",
    dataset_fn=nq_block,
    splits=["train", "validation"],
    text_preprocessor=[block_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_block
)

<t5.data.dataset_providers.FunctionTask at 0x7f095265e710>

In [None]:
nq_task = t5.data.TaskRegistry.get("block")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...
{'inputs_pretokenized': b'BLOCK:private void updateFromMatrix(boolean updateZoom, boolean updateRotation) { matrix.getValues(matrixValues); x = matrixValues[2]; y = matrixValues[5]; if (updateZoom) { zoom = (float) Math.hypot(matrixValues[1], matrixValues[4]); } if (updateRotation) <extra_id_0> } <SEP> /** * Applying state from current matrix. * <p> * Having matrix: * <pre> * | a b tx | * A = | c d ty | * | 0 0 1 | * * x = tx * y = ty * scale = sqrt(b^2+d^2) * rotation = atan(c/d) = atan(-b/a) * </pre> * See <a href="http://stackoverflow.com/questions/4361242">here</a>. * * @param updateZoom Whether to extract zoom from matrix * @param updateRotation Whether to extract rotation from matrix */', 'inputs': array([17090,    78, 14566,    65,   611,   784,  6340,   451, 14653,
         611, 21657,     9,   177,   611, 22210,    91,    20,  7901,
          12,  1323,  1574,   451, 29841,  1574,  4767,     6,    15,
          24,  7901,  1574, 19667,  

### Evaluation
You can run the evaluation using the following cells.  
Please set the correct path of the variable *MODEL_DIR* (the path to save the pretrained model in)


In [None]:
def _rate_num_input_examples(task):
  if "train" in task.splits:
    return float(task.num_input_examples("train"))
  elif "validation" in task.splits:
    return float(task.num_input_examples("validation"))
  else:
    raise ValueError("Task %s does not have a train or validation split." % (task.name))


t5.data.MixtureRegistry.remove("all_tasks")
t5.data.MixtureRegistry.add(
    "all_tasks",
    ["construct", "token", "block"],
    default_rate=_rate_num_input_examples
     #default_rate=1.0
)

<seqio.dataset_providers.Mixture at 0x7f0968716210>

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular
import t5.models

MODEL_SIZE = "small" 

# Set the folder where the checkpoints and all the others information will be writed
MODEL_DIR = BASE_DIR + '/T5_extension/finetuning'

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16000),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = slanted_triangular,
    sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

In [None]:
# ## C.B.: Run only to create files.

model.batch_size = 512
model.eval(
    mixture_or_task_name="all_tasks",
    checkpoint_steps=-1 #evaluate only last checkpoint
)

In [None]:
first_checkpoint = 500000
last_checkpoint = 900000
checkpoint_interval = 20000
checkpoints = [*[i for i in range(first_checkpoint, last_checkpoint, checkpoint_interval)], last_checkpoint]
checkpoints

[500000,
 520000,
 540000,
 560000,
 580000,
 600000,
 620000,
 640000,
 660000,
 680000,
 700000,
 720000,
 740000,
 760000,
 780000,
 800000,
 820000,
 840000,
 860000,
 880000,
 900000]

In [None]:
# we used model.predict function (setting beam_size)

vocabulary_predict=get_default_vocabulary()

input_file = BASE_DIR + '/T5_extension/finetuning/predict/inputs.txt'
output_file = BASE_DIR + '/T5_extension/finetuning/predict/predictions.txt'

model.predict(input_file=input_file, output_file=output_file,
              checkpoint_steps=checkpoints, beam_size=1, temperature=0.0, keep_top_k=-1, vocabulary=vocabulary_predict)

INFO:root:system_path_file_exists:gs://code-generation/T5_extension/finetuning/operative_config.gin
ERROR:root:Path not found: gs://code-generation/T5_extension/finetuning/operative_config.gin


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of 