### Creation of the environment

In [None]:
## RUN THIS CELL

%tensorflow_version 2.x
!pip3 install --upgrade pip
#!pip install -qU t5
!pip3 install git+https://github.com/google-research/text-to-text-transfer-transformer.git #extra_id_x support

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

#Set the base dir(Google cloud bucket)
BASE_DIR = "gs://bucket_code_completion" 

if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
ON_CLOUD = True


if ON_CLOUD:
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

Collecting pip
[?25l  Downloading https://files.pythonhosted.org/packages/fe/ef/60d7ba03b5c442309ef42e7d69959f73aacccd0d86008362a681c4698e83/pip-21.0.1-py3-none-any.whl (1.5MB)
[K     |████████████████████████████████| 1.5MB 6.7MB/s 
[?25hInstalling collected packages: pip
  Found existing installation: pip 19.3.1
    Uninstalling pip-19.3.1:
      Successfully uninstalled pip-19.3.1
Successfully installed pip-21.0.1
Collecting git+https://github.com/google-research/text-to-text-transfer-transformer.git
  Cloning https://github.com/google-research/text-to-text-transfer-transformer.git to /tmp/pip-req-build-32oqr5ju
  Running command git clone -q https://github.com/google-research/text-to-text-transfer-transformer.git /tmp/pip-req-build-32oqr5ju
Collecting mesh-tensorflow[transformer]>=0.1.13
  Downloading mesh_tensorflow-0.1.18-py3-none-any.whl (361 kB)
[K     |████████████████████████████████| 361 kB 6.9 MB/s 
Collecting rouge-score
  Downloading rouge_score-0.0.4-py2.py3-none-any

Instructions for updating:
non-resource variables are not supported in the long term


### Loading of tsv files
With this script you can load each tsv file for finetuning.
Please be sure that the path to all tsv files are correct.
For running the evaluation on a **single model** on a specific dataset, load only the tsv file you're interested in (e.g. *java_construct*)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_construct = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_java_construct.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_java_construct.tsv',
}

num_nq_examples_java_construct = dict(train=750000, validation=106237)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_construct = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_android_construct.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_android_construct.tsv',
}

num_nq_examples_android_construct = dict(train=750000, validation=100536)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_block = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_java_block.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_java_block.tsv',
}

num_nq_examples_java_block = dict(train=298470, validation=40008)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_block = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_android_block.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_android_block.tsv',
}

num_nq_examples_android_block = dict(train=204580, validation=26978)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_token = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_java_token.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_java_token.tsv',
}

num_nq_examples_java_token = dict(train=750000, validation=219486)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_token = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_android_token.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_android_token.tsv',
}

num_nq_examples_android_token = dict(train=750000, validation=200504)

### Preprocess of the dataset
In this step we preprocess the dataset.  
You have to change the path to vocab files (*vocab_model_path* and *vocab_path*)
We're going to preprocess all the tsv file so that T5 can use them for finetuning.  
Please be sure to run **only the cells related to the specific model** you want to evaluate. Run the following cell and then only the group of cell related to the model (e.g. all the cell under JAVA_CONSTRUCT module)

In [None]:
## RUN THIS CELL
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary


# # Set the path of sentencepiece model and vocab files
# # Must be the same used for the pre-trained phase
vocab_model_path = 'gs://bucket_code_completion/T5_extension/code.model'
vocab_path = 'gs://bucket_code_completion/T5_extension/code.vocab'


TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask


def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True, required=False),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True)
}

JAVA CONSTRUCT

In [None]:
def nq_java_construct(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

   # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path_java_construct[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_java_construct("train").take(5)):
  print(ex)

A few raw train examples...
{'input': b'public void updateLockdownExceptions(ManagedObjectReference _this, String[] users) throws AuthMinimumAdminPermission, RemoteException, RuntimeFault, UserNotFound { Argument[] params = new Argument[2]; params[0] = new Argument("_this", "ManagedObjectReference", _this); params[1] = new Argument("users", "String[]", users); getWsc().invoke( <extra_id_0>); }', 'output': b'"UpdateLockdownExceptions", params, null'}
{'input': b'@Override public Collection<AgentProjectInfo> collectDependencies(String folder) { if ( <extra_id_0>){ devDependencies = findDevDependencies(folder); } File yarnLock = new File(folder + fileSeparator + YARN_LOCK); boolean yarnLockFound = yarnLock.isFile(); Collection<DependencyInfo> dependencies = new ArrayList<>(); if (yarnLockFound){ dependencies = parseYarnLock(yarnLock); } else { npmLsFailureStatus = true; } return getSingleProjectList(dependencies); }', 'output': b'!includeDevDependencies'}
{'input': b'@Override public Coll

In [None]:
def java_construct_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_CONSTRUCT:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('java_construct')
t5.data.TaskRegistry.add(
    "java_construct",
    dataset_fn=nq_java_construct,
    splits=["train", "validation"],
    text_preprocessor=[java_construct_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_construct
)

<t5.data.dataset_providers.FunctionTask at 0x7f007fb0b630>

In [None]:
nq_task = t5.data.TaskRegistry.get("java_construct")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b'JAVA_CONSTRUCT:public void appendItems(List<String> items) { int count = getItemCount(); mValues.addAll(items); notifyItemRangeInserted( <extra_id_0>); }', 'inputs': array([    3,  7641,    15,  5071, 17558,    56,  4569,    20,     3,
         109,   757,     5,    71,    25,    31,    29,  1258,     8,
           7,    35,   436,    11,     3,  9357,    18, 10295,     4,
         771,     5,  2495,    10,  1737,   169,   594, 12306,     5,
       32099,    10,     6,     1], dtype=int32), 'targets_pretokenized': b'count, items.size()', 'targets': array([ 436,    9, 1258,    4,  134,   16,    1], dtype=int32)}
{'inputs_pretokenized': b'JAVA_CONSTRUCT:@Override public int pHashinateBytes(byte[] bytes) { ByteBuffer buf = ByteBuffer.wrap( <extra_id_0>); final int token = MurmurHash3.hash3_x64_128(buf, 0, bytes.length, 0); return partitionForToken(token); }', 'inputs': array([    3,  7641,    15,  5071, 17558,    56,  2098,

JAVA TOKEN

In [None]:
def nq_java_token(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_java_token[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_java_token("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'private void emitSelectById(JavaWriter <extra_id_0> logger.d("emitSelectById"); javaWriter.beginMethod(getTargetClass(), $$GET_OBJECT_BY_ID, EnumSet.of(PUBLIC, STATIC), "long", "id", "SQLiteDatabase", "db") .emitStatement("Cursor cursor = db.rawQuery(\\"SELECT * FROM %s WHERE %s = id\\", null)", getTableName(), idColumn.getColumnName()) .emitStatement("%s value = %s(cursor, db).get(0)", getTargetClass(), $$MAP_OBJECT_FUNCTION) .emitStatement("cursor.close()") .emitStatement("return value") .endMethod(); }', 'output': b'javaWriter) throws IOException {'}
{'input': b'private void emitSelectById(JavaWriter javaWriter) throws IOException { logger.d("emitSelectById" <extra_id_0> javaWriter.beginMethod(getTargetClass(), $$GET_OBJECT_BY_ID, EnumSet.of(PUBLIC, STATIC), "long", "id", "SQLiteDatabase", "db") .emitStatement("Cursor cursor = db.rawQuery(\\"SELECT * FROM %s WHERE %s = id\\", null)", getTableName(), idColumn.getColumnName()) .emitStatement("%s

In [None]:
def java_token_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_TOKEN:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('java_token')
t5.data.TaskRegistry.add(
    "java_token",
    dataset_fn=nq_java_token,
    splits=["train", "validation"],
    text_preprocessor=[java_token_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_token
)

<t5.data.dataset_providers.FunctionTask at 0x7feee8081fd0>

In [None]:
nq_task = t5.data.TaskRegistry.get("java_token")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b"JAVA_TOKEN:private String scanBlockScalarBreaks (int indent) { StringBuilder chunks = new StringBuilder(); while (column < indent && peek() == ' ') forward(); while (FULL_LINEBR.indexOf(peek()) != -1) { chunks.append(scanLineBreak()); while (column < indent && peek() == ' ') forward(); } return <extra_id_0> }", 'inputs': array([    3,  7641,    15,  2591,    56,  8797,    26,  2594,   326,
        6388,  3932,    22,    17,    53,  4387,     8,     7,   375,
       11899,    11,    24,   375,    18,   317,    17,  1214,   136,
        4387,    91,  8587,    16,    40,     3,     2,     3,     2,
           8,  5395,    18,   317,    17,  5996,    15,  3104,  7597,
           4,  1104,     5,  3333,    60,    49,  1324,     7, 11899,
           4,   109,     5,  3516, 25650,    39,   317,    17,  1214,
         136,  4387,    91,  8587,    16,    40,     3,     2,     3,
           2,     8,  5395,    18,     6,    14, 32

JAVA BLOCK

In [None]:
def nq_java_block(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_java_block[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_java_block("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public String capitalizeWords(final Object target) { if (target == null) <extra_id_0> return StringUtils.capitalizeWords(target); }', 'output': b'{ return null; }'}
{'input': b'@Deprecated public static byte[] convertLongToVarInt(long value) { ByteBuffer longBB = ByteBuffer.allocate(EthereumUtil.LONG_SIZE); longBB.putLong(value); byte[] result = longBB.array(); int leadingZeros=0; for (int i=0;i<result.length;i++) { if (result[i]==0) <extra_id_0> else { break; } } return Arrays.copyOfRange(result, leadingZeros, result.length); }', 'output': b'{ leadingZeros++; }'}
{'input': b'@Deprecated public static byte[] convertLongToVarInt(long value) { ByteBuffer longBB = ByteBuffer.allocate(EthereumUtil.LONG_SIZE); longBB.putLong(value); byte[] result = longBB.array(); int leadingZeros=0; for (int i=0;i<result.length;i++) { if (result[i]==0) { leadingZeros++; } else <extra_id_0> } return Arrays.copyOfRange(result, leadingZeros, result.length); }', 'output'

In [None]:
def java_block_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_BLOCK:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('java_block')
t5.data.TaskRegistry.add(
    "java_block",
    dataset_fn=nq_java_block,
    splits=["train", "validation"],
    text_preprocessor=[java_block_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_block
)

<t5.data.dataset_providers.FunctionTask at 0x7feee80b2da0>

In [None]:
nq_task = t5.data.TaskRegistry.get("java_block")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b'JAVA_BLOCK:public BigMoney plusMinor(long amountToAdd) { if (amountToAdd == 0) <extra_id_0> BigDecimal newAmount = amount.add(BigDecimal.valueOf(amountToAdd, currency.getDecimalPlaces())); return BigMoney.of(currency, newAmount); }', 'inputs': array([    3,  7641,    15,  3517,    56,  4569, 30085,  8940,  8170,
           5,   288,     3, 28214,     8,     7,    21,    17, 28214,
          40,   178, 32099,  2322,    24,  1539,    11,  1453,     4,
          67,     5,  3191,     4,   510,     5, 28214,     9,  5599,
           4,    33,  4899, 15331,   366,    14, 30085,     4,   579,
           5,  7110,     9,    24,  1539,    10,     6,     1],
      dtype=int32), 'targets_pretokenized': b'{ return this; }', 'targets': array([ 7, 14, 23, 13,  6,  1], dtype=int32)}
{'inputs_pretokenized': b'JAVA_BLOCK:@Override protected ExceptionAction onException(Throwable t) { raftInvocationContext.updateKnownLeaderOnFailure(group

ANDROID CONSTRUCT

In [None]:
def nq_android_construct(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

   # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path_android_construct[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_android_construct("train").take(5)):
  print(ex)

A few raw train examples...
{'input': b'private void writeToFile(final String content) { if ( <extra_id_0>) { return; } try { System.err.println(content); File f = new File(OUTPUT_FILE); f.delete(); FileWriter fstream = new FileWriter(OUTPUT_FILE); BufferedWriter out = new BufferedWriter(fstream); out.write(content); out.close(); } catch (Exception e) { e.printStackTrace(); } }', 'output': b'!WRITE_TO_FILE'}
{'input': b'private void writeToFile(final String content) { if (!WRITE_TO_FILE) { return; } try { System.err.println(content); File f = new File(OUTPUT_FILE); f.delete(); FileWriter fstream = new FileWriter(OUTPUT_FILE); BufferedWriter out = new BufferedWriter(fstream); out.write(content); out.close(); } catch ( <extra_id_0>) { e.printStackTrace(); } }', 'output': b'Exception e'}
{'input': b'private void writeToFile(final String content) { if (!WRITE_TO_FILE) { return; } try { System.err.println(content); File f = new File( <extra_id_0>); f.delete(); FileWriter fstream = new FileW

In [None]:
def android_construct_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_CONSTRUCT:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('android_construct')
t5.data.TaskRegistry.add(
    "android_construct",
    dataset_fn=nq_android_construct,
    splits=["train", "validation"],
    text_preprocessor=[android_construct_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_construct
)

<t5.data.dataset_providers.FunctionTask at 0x7feee80c6dd8>

In [None]:
nq_task = t5.data.TaskRegistry.get("android_construct")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b'ANDROID_CONSTRUCT:public static String getCoordinateChars(int x) { int max = BOARD_LETTERS.length(); int firstCharIndex = Math.min(max, x / max) - 1; int secondCharIndex = x % max; if (firstCharIndex >= 0) return Character.toString(BOARD_LETTERS.charAt(firstCharIndex)) + BOARD_LETTERS.charAt( <extra_id_0>); else return Character.toString(BOARD_LETTERS.charAt(secondCharIndex)); }', 'inputs': array([    3, 16446,    15,  5071, 17558,    56,  4569,    48,    26,
           3, 19015,  3582,     5,    53,   205,     8,     7,    35,
         350,    11,     3, 13066,    15, 28100,   113,     4,   105,
          18,    35,   607,  1234,   163,    11,   608,     4,   769,
           5,   532,     9,   205,   260,   350,     8,   139,   498,
          35,  1959,  1234,   163,    11,   205,     3,     2,   350,
          13,    21,    17,  1089,  1234,   163,   453,   178,    14,
        3180,     4,   123,     5, 13066,    15, 2

ANDROID TOKEN

In [None]:
def nq_android_token(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_android_token[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_android_token("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public boolean initialize <extra_id_0> if (bluetoothManager == null) { bluetoothManager = (BluetoothManager) getSystemService(Context.BLUETOOTH_SERVICE); if (bluetoothManager == null) { Log.e(TAG, "Unable to initialize BluetoothManager."); return false; } } adapter = bluetoothManager.getAdapter(); if (adapter == null) { Log.e(TAG, "Unable to obtain a BluetoothAdapter."); return false; } return true; }', 'output': b'() {'}
{'input': b'public boolean initialize() { if (bluetoothManager == <extra_id_0> bluetoothManager = (BluetoothManager) getSystemService(Context.BLUETOOTH_SERVICE); if (bluetoothManager == null) { Log.e(TAG, "Unable to initialize BluetoothManager."); return false; } } adapter = bluetoothManager.getAdapter(); if (adapter == null) { Log.e(TAG, "Unable to obtain a BluetoothAdapter."); return false; } return true; }', 'output': b'null) {'}
{'input': b'public boolean initialize() { if (bluetoothManager == null) { bluetoothManager = (Blu

In [None]:
def android_token_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_TOKEN:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('android_token')
t5.data.TaskRegistry.add(
    "android_token",
    dataset_fn=nq_android_token,
    splits=["train", "validation"],
    text_preprocessor=[android_token_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_token
)

<t5.data.dataset_providers.FunctionTask at 0x7feee7feb390>

In [None]:
nq_task = t5.data.TaskRegistry.get("android_token")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b'ANDROID_TOKEN:protected void onResume() { super.onResume(); IntentFilter filter = new IntentFilter(); filter.addAction(Intent.ACTION_BATTERY_CHANGED); Log.d(TAG, "Register battery status receiver." <extra_id_0> registerReceiver(mBroadcastReceiver, filter); }', 'inputs': array([    3, 16446,    15,  2591,    56, 18728,    20,  4984,    16,
           7,    52,     4,  4409,    18,  8183,   531,    11,    24,
        8183,    18,   531,     4,  7108,     5,   527,     4,  1023,
          15,   285, 20849,     2,    15,  4226,    10,   319,     4,
         101,     5,   356,     9,    32,  2341, 18836,   585,  3560,
           4,    83, 32099,     3,  8280,     5,    87,  2737,  1550,
           9,   531,    10,     6,     1], dtype=int32), 'targets_pretokenized': b');', 'targets': array([ 3, 10,  1], dtype=int32)}
{'inputs_pretokenized': b'ANDROID_TOKEN:protected LocalTime getItemBeginTime(int position) { if (!today) retur

ANDROID BLOCK

In [None]:
def nq_android_block(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_android_block[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_android_block("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public void onStackEmpty() { cancelAutoLove(); if (progressDialog == null || !progressDialog.isShowing()) <extra_id_0> }', 'output': b'{ progressDialog = ProgressDialog.show(this, "Out of Cats", "Searching for more, Please wait..."); startLoading(); }'}
{'input': b'private static List<HeroAndAdvantages> loadHeroes(SQLiteDatabase db) { List<HeroAndAdvantages> heroes = new ArrayList<>(); Cursor c = db.rawQuery("SELECT * FROM Heroes", null); c.moveToFirst(); while (!c.isAfterLast()) <extra_id_0> c.close(); return heroes; }', 'output': b'{ heroes.add(new HeroAndAdvantages(c)); c.moveToNext(); }'}
{'input': b'private String getLinkByRelation(String relation) { for (Link l : link) { if (l.getRel().equals(relation)) <extra_id_0> } return null; }', 'output': b'{ return l.getUri(); }'}
{'input': b'public static boolean isBundleValid(final Bundle bundle) { if (null == bundle) <extra_id_0> if (bundle.getInt(BUNDLE_EXTRA_INT_VERSION_CODE, -1) == -1) { return

In [None]:
def android_block_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_BLOCK:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('android_block')
t5.data.TaskRegistry.add(
    "android_block",
    dataset_fn=nq_android_block,
    splits=["train", "validation"],
    text_preprocessor=[android_block_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_block
)

<t5.data.dataset_providers.FunctionTask at 0x7feee7f0b4e0>

In [None]:
nq_task = t5.data.TaskRegistry.get("android_block")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b'ANDROID_BLOCK:private void removeOld(final long now) { while (!mQueue.isEmpty()) { final Sample sample = mQueue.get(0); if (now - sample.getTimeStamp() > WINDOW_SIZE) { mQueue.remove(0); } else <extra_id_0> } }', 'inputs': array([    3, 16446,    15,  3517,    56,  8797,    20,   424,  2796,
           5,    64,   126,  1426,     8,     7,   317,   124,    87,
           2,   514,     4,   280,    60,     7,    44,  7856,  2057,
          11,    54,     2,   514,     4,    33,   659,    21,    17,
        2674,   139,  2057,     4, 23553,    16,     3,    29,     3,
       10356,    15,   991,     2,   146,     8,     7,    54,     2,
         514,     4,   252,   659,     6,    77, 32099,     6,     6,
           1], dtype=int32), 'targets_pretokenized': b'{ break; }', 'targets': array([  7, 591,  13,   6,   1], dtype=int32)}
{'inputs_pretokenized': b'ANDROID_BLOCK:public static boolean isEntered(Activity activity) { if

### Evaluation
You can run the evaluation using the following cells.  
Please set the correct path of the variable *MODEL_DIR* (the path to save the pretrained model in)

Change the mixture chosing the one you want to run (e.g. you can associate "all_tasks" to ["android token"] if you want to train android token)

Please be sure to run only the cell under the specific model you want to train (e.g. all cells under **ANDROID TOKEN** section)

In [None]:
## RUN THIS CELL

def _rate_num_input_examples(task):
  if "train" in task.splits:
    return float(task.num_input_examples("train"))
  elif "validation" in task.splits:
    return float(task.num_input_examples("validation"))
  else:
    raise ValueError("Task %s does not have a train or validation split." % (task.name))


t5.data.MixtureRegistry.remove("all_tasks")
t5.data.MixtureRegistry.add(
    "all_tasks",
    # ["java_construct", "java_token", "java_block", "android_construct", "android_token", "android_block"],
    ["java_construct"],
    default_rate=_rate_num_input_examples
     #default_rate=1.0
)

<t5.seqio.dataset_providers.Mixture at 0x7f007f875550>

JAVA CONSTRUCT

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular

MODEL_SIZE = "small" 

# Set the folder where the checkpoints and all the others information will be writed
MODEL_DIR = 'gs://bucket_code_completion/T5_extension/single_finetuning/java_construct/model'

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = slanted_triangular,
    sequence_length={"inputs": 256, "targets": 256},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

vocabulary_predict=get_default_vocabulary()

model.predict(input_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/java_construct_inputs.dms', 
              output_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/java_construct_predictions.dms',
              checkpoint_steps=-1, beam_size=1, temperature=0.0, keep_top_k=-1, vocabulary=vocabulary_predict)

INFO:root:system_path_file_exists:gs://bucket_comment_completion/Matteo/single_finetuning/java_construct/model/operative_config.gin
ERROR:root:Path not found: gs://bucket_comment_completion/Matteo/single_finetuning/java_construct/model/operative_config.gin


INFO:tensorflow:Using config: {'_model_dir': 'gs://bucket_comment_completion/Matteo/single_finetuning/java_construct/model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.121.166.66:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.121.166.66:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.121.166.66:8470', '_evaluation

ANDROID CONSTRUCT

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular

MODEL_SIZE = "small" 

# Set the folder where the checkpoints and all the others information will be writed
MODEL_DIR = 'gs://bucket_code_completion/T5_extension/single_finetuning/android_construct/model'

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = slanted_triangular,
    sequence_length={"inputs": 256, "targets": 256},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

vocabulary_predict=get_default_vocabulary()

model.predict(input_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/android_construct_inputs.dms', 
              output_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/android_construct_predictions.dms',
              checkpoint_steps=-1, beam_size=1, temperature=0.0, keep_top_k=-1, vocabulary=vocabulary_predict)

JAVA BLOCK

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular

MODEL_SIZE = "small" 

# Set the folder where the checkpoints and all the others information will be writed
MODEL_DIR = 'gs://bucket_code_completion/T5_extension/single_finetuning/java_block/model'

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = slanted_triangular,
    sequence_length={"inputs": 256, "targets": 256},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

vocabulary_predict=get_default_vocabulary()

model.predict(input_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/java_block_inputs.dms', 
              output_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/java_block_predictions.dms',
              checkpoint_steps=-1, beam_size=1, temperature=0.0, keep_top_k=-1, vocabulary=vocabulary_predict)

ANDROID BLOCK

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular

MODEL_SIZE = "small" 

# Set the folder where the checkpoints and all the others information will be writed
MODEL_DIR = 'gs://bucket_code_completion/T5_extension/single_finetuning/android_block/model'

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = slanted_triangular,
    sequence_length={"inputs": 256, "targets": 256},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

vocabulary_predict=get_default_vocabulary()

model.predict(input_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/android_block_inputs.dms', 
              output_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/android_block_predictions.dms',
              checkpoint_steps=-1, beam_size=1, temperature=0.0, keep_top_k=-1, vocabulary=vocabulary_predict)

JAVA TOKEN

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular

MODEL_SIZE = "small" 

# Set the folder where the checkpoints and all the others information will be writed
MODEL_DIR = 'gs://bucket_code_completion/T5_extension/single_finetuning/java_token/model'

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = slanted_triangular,
    sequence_length={"inputs": 256, "targets": 256},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

vocabulary_predict=get_default_vocabulary()

model.predict(input_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/java_token_inputs.dms', 
              output_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/java_token_predictions.dms',
              checkpoint_steps=-1, beam_size=1, temperature=0.0, keep_top_k=-1, vocabulary=vocabulary_predict)

ANDROID TOKEN

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular

MODEL_SIZE = "small" 

# Set the folder where the checkpoints and all the others information will be writed
MODEL_DIR = 'gs://bucket_code_completion/T5_extension/single_finetuning/android_token/model'

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = slanted_triangular,
    sequence_length={"inputs": 256, "targets": 256},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

vocabulary_predict=get_default_vocabulary()

model.predict(input_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/android_token_inputs.dms', 
              output_file='gs://bucket_code_completion/T5_extension/single_finetuning/predict/android_token_predictions.dms',
              checkpoint_steps=-1, beam_size=1, temperature=0.0, keep_top_k=-1, vocabulary=vocabulary_predict)

INFO:root:system_path_file_exists:gs://bucket_comment_completion/Matteo/single_finetuning/android_token/operative_config.gin
ERROR:root:Path not found: gs://bucket_comment_completion/Matteo/single_finetuning/android_token/operative_config.gin
INFO:root:system_path_file_exists:gs://bucket_comment_completion/Matteo/pretrained_with_masking/operative_config.gin
ERROR:root:Path not found: gs://bucket_comment_completion/Matteo/pretrained_with_masking/operative_config.gin


INFO:tensorflow:Using config: {'_model_dir': 'gs://bucket_comment_completion/Matteo/single_finetuning/android_token/model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.107.26.122:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.107.26.122:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.107.26.122:8470', '_evaluation_

  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0 0]
  [0 0 0 1]
  [1 0 0 0]
  [1 0 0 1]
  [0 1 0 0]
  [0 1 0 1]
  [1 1 0 0]
  [1 1 0 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0 0]]

 [[0 0 0 1]]

 [[1 0 0 0]]

 [[1 0 0 1]]

 [[0 1 0 0]]

 [[0 1 0 1]]

 [[1 1 0 0]]

 [[1 1 0 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[8] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[batch=8] LayoutRules{('d_ff', 'model'), ('batch', 'batch'), ('experts', 'batch'), ('vocab', 'model'), ('heads', 'model'), ('ensemble', 'ensemble')}
INFO:tensorflow:Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7f74efedfbe0>
INF