### Creation of the environment

In [1]:
%tensorflow_version 2.x
!pip3 install --upgrade pip
#!pip install -qU t5
!pip3 install git+https://github.com/google-research/text-to-text-transfer-transformer.git #extra_id_x support

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

#Set the base dir(Google cloud bucket)
BASE_DIR = "gs://bucket_code_completion" 

if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
ON_CLOUD = True


if ON_CLOUD:
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

Collecting pip
[?25l  Downloading https://files.pythonhosted.org/packages/fe/ef/60d7ba03b5c442309ef42e7d69959f73aacccd0d86008362a681c4698e83/pip-21.0.1-py3-none-any.whl (1.5MB)
[K     |████████████████████████████████| 1.5MB 6.9MB/s 
[?25hInstalling collected packages: pip
  Found existing installation: pip 19.3.1
    Uninstalling pip-19.3.1:
      Successfully uninstalled pip-19.3.1
Successfully installed pip-21.0.1
Collecting git+https://github.com/google-research/text-to-text-transfer-transformer.git
  Cloning https://github.com/google-research/text-to-text-transfer-transformer.git to /tmp/pip-req-build-4qw6o801
  Running command git clone -q https://github.com/google-research/text-to-text-transfer-transformer.git /tmp/pip-req-build-4qw6o801
Collecting mesh-tensorflow[transformer]>=0.1.13
  Downloading mesh_tensorflow-0.1.18-py3-none-any.whl (361 kB)
[K     |████████████████████████████████| 361 kB 6.1 MB/s 
Collecting rouge-score
  Downloading rouge_score-0.0.4-py2.py3-none-any

Instructions for updating:
non-resource variables are not supported in the long term


### Loading of tsv files
With this script you can load each tsv file for finetuning.
Please be sure that the path to all tsv files are correct

In [2]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_construct = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_java_construct.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_java_construct.tsv',
}

num_nq_examples_java_construct = dict(train=750000, validation=106237)

In [3]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_construct = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_android_construct.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_android_construct.tsv',
}

num_nq_examples_android_construct = dict(train=750000, validation=100536)

In [4]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_block = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_java_block.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_java_block.tsv',
}

num_nq_examples_java_block = dict(train=298470, validation=40008)

In [5]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_block = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_android_block.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_android_block.tsv',
}

num_nq_examples_android_block = dict(train=204580, validation=26978)

In [6]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_token = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_java_token.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_java_token.tsv',
}

num_nq_examples_java_token = dict(train=750000, validation=219486)

In [7]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_token = {
    "train":      'gs://bucket_code_completion/T5_extension/ft_datasets/train_android_token.tsv',
    "validation": 'gs://bucket_code_completion/T5_extension/ft_datasets/test_android_token.tsv',
}

num_nq_examples_android_token = dict(train=750000, validation=200504)

### Preprocess of the dataset
In this step we preprocess the dataset.  
You have to change the path to vocab files (*vocab_model_path* and *vocab_path*)
We're going to preprocess all the tsv file so that T5 can use them for evaluation.

In [8]:
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary


# # Set the path of sentencepiece model and vocab files
# # Must be the same used for the pre-trained phase
vocab_model_path = 'gs://bucket_code_completion/T5_extension/code.model'
vocab_path = 'gs://bucket_code_completion/T5_extension/code.vocab'


TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask


def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True, required=False),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True)
}

JAVA CONSTRUCT

In [9]:
def nq_java_construct(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

   # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path_java_construct[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_java_construct("train").take(5)):
  print(ex)

A few raw train examples...
{'input': b'public void updateLockdownExceptions(ManagedObjectReference _this, String[] users) throws AuthMinimumAdminPermission, RemoteException, RuntimeFault, UserNotFound { Argument[] params = new Argument[2]; params[0] = new Argument("_this", "ManagedObjectReference", _this); params[1] = new Argument("users", "String[]", users); getWsc().invoke( <extra_id_0>); }', 'output': b'"UpdateLockdownExceptions", params, null'}
{'input': b'@Override public Collection<AgentProjectInfo> collectDependencies(String folder) { if ( <extra_id_0>){ devDependencies = findDevDependencies(folder); } File yarnLock = new File(folder + fileSeparator + YARN_LOCK); boolean yarnLockFound = yarnLock.isFile(); Collection<DependencyInfo> dependencies = new ArrayList<>(); if (yarnLockFound){ dependencies = parseYarnLock(yarnLock); } else { npmLsFailureStatus = true; } return getSingleProjectList(dependencies); }', 'output': b'!includeDevDependencies'}
{'input': b'@Override public Coll

In [10]:
def java_construct_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_CONSTRUCT:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [11]:
t5.data.TaskRegistry.remove('java_construct')
t5.data.TaskRegistry.add(
    "java_construct",
    dataset_fn=nq_java_construct,
    splits=["train", "validation"],
    text_preprocessor=[java_construct_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_construct
)

<t5.data.dataset_providers.FunctionTask at 0x7f3147a41ac8>

In [12]:
nq_task = t5.data.TaskRegistry.get("java_construct")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b'JAVA_CONSTRUCT:private static List<String> getNames(List<Charset> supportedCharsets) { Builder<String> builder=ImmutableList.<String>builder(); for(Charset supportedCharset:supportedCharsets) { builder.add( <extra_id_0>); } return builder.build(); }', 'inputs': array([    3,  7641,    15,  5071, 17558,    56,  8797,    48,    85,
          25,    31,    29,    41,   572,     5,    71,    25,  3662,
          29,  1269,  3662,    22,     8,     7,  1018,    25,    31,
          29,   259,   161, 10785,     4,    25,    31,    29,   534,
          18,    50,     5,  3662,  1269,  3662,    56,  8126,  3662,
          22,     8,     7,   259,     4,    67,     5, 32099,    10,
           6,    14,   259,     4,   352,    18,     6,     1],
      dtype=int32), 'targets_pretokenized': b'supportedCharset.name()', 'targets': array([1269, 3662,    4,   98,   16,    1], dtype=int32)}
{'inputs_pretokenized': b'JAVA_CONSTRUCT:public

JAVA TOKEN

In [13]:
def nq_java_token(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_java_token[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_java_token("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'private void emitSelectById(JavaWriter <extra_id_0> logger.d("emitSelectById"); javaWriter.beginMethod(getTargetClass(), $$GET_OBJECT_BY_ID, EnumSet.of(PUBLIC, STATIC), "long", "id", "SQLiteDatabase", "db") .emitStatement("Cursor cursor = db.rawQuery(\\"SELECT * FROM %s WHERE %s = id\\", null)", getTableName(), idColumn.getColumnName()) .emitStatement("%s value = %s(cursor, db).get(0)", getTargetClass(), $$MAP_OBJECT_FUNCTION) .emitStatement("cursor.close()") .emitStatement("return value") .endMethod(); }', 'output': b'javaWriter) throws IOException {'}
{'input': b'private void emitSelectById(JavaWriter javaWriter) throws IOException { logger.d("emitSelectById" <extra_id_0> javaWriter.beginMethod(getTargetClass(), $$GET_OBJECT_BY_ID, EnumSet.of(PUBLIC, STATIC), "long", "id", "SQLiteDatabase", "db") .emitStatement("Cursor cursor = db.rawQuery(\\"SELECT * FROM %s WHERE %s = id\\", null)", getTableName(), idColumn.getColumnName()) .emitStatement("%s

In [14]:
def java_token_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_TOKEN:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [15]:
t5.data.TaskRegistry.remove('java_token')
t5.data.TaskRegistry.add(
    "java_token",
    dataset_fn=nq_java_token,
    splits=["train", "validation"],
    text_preprocessor=[java_token_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_token
)

<t5.data.dataset_providers.FunctionTask at 0x7f30060f95c0>

In [16]:
nq_task = t5.data.TaskRegistry.get("java_token")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


{'inputs_pretokenized': b'JAVA_TOKEN:@Override public boolean setResponseProperty( HttpServletRequest portletRequest, IPortletWindow portletWindow, String property, String value) { if (IPortletRenderer.EXTERNAL_PORTLET_LINK_PROPERTY.equals(property) && StringUtils.isNotBlank(value)) { portletRequest.setAttribute <extra_id_0> return true; } return false; }', 'inputs': array([    3,  7641,    15,  2591,    56,  2098,    27,    12,    45,
          55,   164,   220,     5,  1197,   125, 12175,   125,     9,
         266, 28300, 12175,  1114,     9,    26,   698,     9,    26,
          82,     8,     7,    21,    17,   183,  7592,  1107,     4,
         146,     2,  7256,    15,  2606,  2104,    70,    15,  5618,
          15,  1527,     2,     4,   117,     5,  1238,     8,    91,
        1756,     4,  4994,     5,   122,     8,     8,     7, 12175,
         125,     4,  2413, 32099,    14,    89,    13,     6,    14,
          76,    13,     6,     1], dtype=int32), 'targets_pretokenize

JAVA BLOCK

In [17]:
def nq_java_block(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_java_block[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_java_block("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public String capitalizeWords(final Object target) { if (target == null) <extra_id_0> return StringUtils.capitalizeWords(target); }', 'output': b'{ return null; }'}
{'input': b'@Deprecated public static byte[] convertLongToVarInt(long value) { ByteBuffer longBB = ByteBuffer.allocate(EthereumUtil.LONG_SIZE); longBB.putLong(value); byte[] result = longBB.array(); int leadingZeros=0; for (int i=0;i<result.length;i++) { if (result[i]==0) <extra_id_0> else { break; } } return Arrays.copyOfRange(result, leadingZeros, result.length); }', 'output': b'{ leadingZeros++; }'}
{'input': b'@Deprecated public static byte[] convertLongToVarInt(long value) { ByteBuffer longBB = ByteBuffer.allocate(EthereumUtil.LONG_SIZE); longBB.putLong(value); byte[] result = longBB.array(); int leadingZeros=0; for (int i=0;i<result.length;i++) { if (result[i]==0) { leadingZeros++; } else <extra_id_0> } return Arrays.copyOfRange(result, leadingZeros, result.length); }', 'output'

In [18]:
def java_block_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_BLOCK:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [19]:
t5.data.TaskRegistry.remove('java_block')
t5.data.TaskRegistry.add(
    "java_block",
    dataset_fn=nq_java_block,
    splits=["train", "validation"],
    text_preprocessor=[java_block_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_block
)

<t5.data.dataset_providers.FunctionTask at 0x7f3006121eb8>

In [20]:
nq_task = t5.data.TaskRegistry.get("java_block")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


{'inputs_pretokenized': b'JAVA_BLOCK:private Set<String> getGroups(DirContext context, String callerDn) { Set<String> groups = null; String groupSearchBase = idStoreDefinition.getGroupSearchBase(); String groupSearchFilter = idStoreDefinition.getGroupSearchFilter(); if (groupSearchBase.isEmpty() || groupSearchFilter.isEmpty()) { groups = getGroupsByMembership(context, callerDn); } else <extra_id_0> return groups; }', 'inputs': array([    3,  7641,    15,  3517,    56,  8797,   300,    25,    31,
          29,     3,  9471,     5,   490,    92,   130,     9,    26,
        5048,  6327,     8,     7,   300,    25,    31,    29,  3008,
          11,    30,    13,    26,   718,   758,   431,    11,   176,
         475,   530,     4,  4957,   758,   431,    18,    26,   718,
       11890,    11,   176,   475,   530,     4,  4957, 11890,    18,
          21,    17,   780,   758,   431,     4,   280,    16,     3,
           2,   718, 11890,     4,   280,    60,     7,  3008,    11,
        6

ANDROID CONSTRUCT

In [21]:
def nq_android_construct(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

   # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path_android_construct[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_android_construct("train").take(5)):
  print(ex)

A few raw train examples...
{'input': b'private void writeToFile(final String content) { if ( <extra_id_0>) { return; } try { System.err.println(content); File f = new File(OUTPUT_FILE); f.delete(); FileWriter fstream = new FileWriter(OUTPUT_FILE); BufferedWriter out = new BufferedWriter(fstream); out.write(content); out.close(); } catch (Exception e) { e.printStackTrace(); } }', 'output': b'!WRITE_TO_FILE'}
{'input': b'private void writeToFile(final String content) { if (!WRITE_TO_FILE) { return; } try { System.err.println(content); File f = new File(OUTPUT_FILE); f.delete(); FileWriter fstream = new FileWriter(OUTPUT_FILE); BufferedWriter out = new BufferedWriter(fstream); out.write(content); out.close(); } catch ( <extra_id_0>) { e.printStackTrace(); } }', 'output': b'Exception e'}
{'input': b'private void writeToFile(final String content) { if (!WRITE_TO_FILE) { return; } try { System.err.println(content); File f = new File( <extra_id_0>); f.delete(); FileWriter fstream = new FileW

In [22]:
def android_construct_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_CONSTRUCT:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [23]:
t5.data.TaskRegistry.remove('android_construct')
t5.data.TaskRegistry.add(
    "android_construct",
    dataset_fn=nq_android_construct,
    splits=["train", "validation"],
    text_preprocessor=[android_construct_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_construct
)

<t5.data.dataset_providers.FunctionTask at 0x7f300608bf28>

In [24]:
nq_task = t5.data.TaskRegistry.get("android_construct")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


{'inputs_pretokenized': b'ANDROID_CONSTRUCT:private void arrancaPrimeraVez() { mDrawerLayout = (DrawerLayout) findViewById(R.id.drawer_layout); mDrawerLayout.setDrawerLockMode(DrawerLayout.LOCK_MODE_LOCKED_CLOSED); getSupportActionBar().setDisplayHomeAsUpEnabled( <extra_id_0>); FragmentTransaction transaction = getSupportFragmentManager().beginTransaction(); transaction.add(R.id.main_container, new InitialFragment()); transaction.commit(); doUpdateTitle(getString(R.string.titulo_inicial)); }', 'inputs': array([    3, 16446,    15,  5071, 17558,    56,  8797,    20,  3971,
        9624,   184, 10010,  3474,   380,   110,   652,    16,     7,
          54, 18275,    11,    17, 18275,     8,  3510,     5,   144,
           4,   111,     4, 17964,    15,  1179,    10,    54, 18275,
           4,    63,  5208, 21152,     5, 18275,     4,  4776,    15,
        1493,    15, 12403,    15, 12664,    10,    41, 11422,    37,
        8090, 18353,   456,     5, 32099,    10,  7272,   737,  1678,
 

ANDROID TOKEN

In [25]:
def nq_android_token(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_android_token[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_android_token("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public boolean initialize <extra_id_0> if (bluetoothManager == null) { bluetoothManager = (BluetoothManager) getSystemService(Context.BLUETOOTH_SERVICE); if (bluetoothManager == null) { Log.e(TAG, "Unable to initialize BluetoothManager."); return false; } } adapter = bluetoothManager.getAdapter(); if (adapter == null) { Log.e(TAG, "Unable to obtain a BluetoothAdapter."); return false; } return true; }', 'output': b'() {'}
{'input': b'public boolean initialize() { if (bluetoothManager == <extra_id_0> bluetoothManager = (BluetoothManager) getSystemService(Context.BLUETOOTH_SERVICE); if (bluetoothManager == null) { Log.e(TAG, "Unable to initialize BluetoothManager."); return false; } } adapter = bluetoothManager.getAdapter(); if (adapter == null) { Log.e(TAG, "Unable to obtain a BluetoothAdapter."); return false; } return true; }', 'output': b'null) {'}
{'input': b'public boolean initialize() { if (bluetoothManager == null) { bluetoothManager = (Blu

In [26]:
def android_token_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_TOKEN:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [27]:
t5.data.TaskRegistry.remove('android_token')
t5.data.TaskRegistry.add(
    "android_token",
    dataset_fn=nq_android_token,
    splits=["train", "validation"],
    text_preprocessor=[android_token_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_token
)

<t5.data.dataset_providers.FunctionTask at 0x7f30060a7518>

In [28]:
nq_task = t5.data.TaskRegistry.get("android_token")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


{'inputs_pretokenized': b'ANDROID_TOKEN:public Builder setCongestionLevel(com.echo5bravo.gtfs.GtfsRealtime.VehiclePosition.CongestionLevel value) { if (value == null) { throw new NullPointerException <extra_id_0> } bitField0_ |= 0x00000080; congestionLevel_ = value; onChanged(); return this; }', 'inputs': array([    3, 16446,    15,  2591,    56,  4569,  1018,    55, 31663,
         377,     5,   653,     4, 15279,     2, 21290,  6549,     4,
       27697,     4,  7435,     4,  5421,   392,     4, 31663,   377,
          82,     8,     7,    21,    17,   122,    40,    30,     8,
           7,    78,    24,  2024,    38, 32099,     6,  4787,  2022,
           3,     2,   161,   157,   138,  7025,     2,  2986,  2059,
        5963,    22,  9474,   377,    15,    11,    82,    13,  3533,
          18,    14,    23,    13,     6,     1], dtype=int32), 'targets_pretokenized': b'();', 'targets': array([ 3, 18,  1], dtype=int32)}
{'inputs_pretokenized': b'ANDROID_TOKEN:private void startNewI

ANDROID BLOCK

In [29]:
def nq_android_block(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_android_block[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_android_block("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public void onStackEmpty() { cancelAutoLove(); if (progressDialog == null || !progressDialog.isShowing()) <extra_id_0> }', 'output': b'{ progressDialog = ProgressDialog.show(this, "Out of Cats", "Searching for more, Please wait..."); startLoading(); }'}
{'input': b'private static List<HeroAndAdvantages> loadHeroes(SQLiteDatabase db) { List<HeroAndAdvantages> heroes = new ArrayList<>(); Cursor c = db.rawQuery("SELECT * FROM Heroes", null); c.moveToFirst(); while (!c.isAfterLast()) <extra_id_0> c.close(); return heroes; }', 'output': b'{ heroes.add(new HeroAndAdvantages(c)); c.moveToNext(); }'}
{'input': b'private String getLinkByRelation(String relation) { for (Link l : link) { if (l.getRel().equals(relation)) <extra_id_0> } return null; }', 'output': b'{ return l.getUri(); }'}
{'input': b'public static boolean isBundleValid(final Bundle bundle) { if (null == bundle) <extra_id_0> if (bundle.getInt(BUNDLE_EXTRA_INT_VERSION_CODE, -1) == -1) { return

In [30]:
def android_block_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_BLOCK:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [31]:
t5.data.TaskRegistry.remove('android_block')
t5.data.TaskRegistry.add(
    "android_block",
    dataset_fn=nq_android_block,
    splits=["train", "validation"],
    text_preprocessor=[android_block_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_block
)

<t5.data.dataset_providers.FunctionTask at 0x7f300612e2e8>

In [32]:
nq_task = t5.data.TaskRegistry.get("android_block")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


{'inputs_pretokenized': b'ANDROID_BLOCK:private void updateAttrs(Context context, AttributeSet attrs){ TypedArray styledAttrs = context.getTheme().obtainStyledAttributes( attrs, R.styleable.ValidableTextInputLayout, 0, 0); try <extra_id_0> finally { styledAttrs.recycle(); } initialize(); }', 'inputs': array([    3, 16446,    15,  3517,    56,  8797,    20,   233,  6046,
           5,    92,   130,     9,  2763,  1085,   212, 18925,  2277,
         101,  6046,    11,   130,     4, 16461,    37,  6425, 20118,
           5,  1085,     9,   544,     4,  9599,     4,  1538,   367,
       30553,     9,   312,   522,    93, 32099,   658,     7,  2277,
         101,  6046,     4,  8202,    18,     6,  1196,    18,     6,
           1], dtype=int32), 'targets_pretokenized': b'{ required = styledAttrs.getBoolean(R.styleable.ValidableTextInputLayout_required, false); cantContainSpaces = styledAttrs.getBoolean(R.styleable.ValidableTextInputLayout_cantContainSpaces, false); }', 'targets': array([  

### Evaluation of the score value
You can run the evaluation using the following cells.  
Please set the correct path of the variable *MODEL_DIR* (the path to save the pretrained model in)


In [33]:
def _rate_num_input_examples(task):
  if "train" in task.splits:
    return float(task.num_input_examples("train"))
  elif "validation" in task.splits:
    return float(task.num_input_examples("validation"))
  else:
    raise ValueError("Task %s does not have a train or validation split." % (task.name))


t5.data.MixtureRegistry.remove("all_tasks")
t5.data.MixtureRegistry.add(
    "all_tasks",
    ["java_construct", "java_token", "java_block", "android_construct", "android_token", "android_block"],
    default_rate=_rate_num_input_examples
     #default_rate=1.0
)

<t5.seqio.dataset_providers.Mixture at 0x7f3005fe4320>

In [34]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular

MODEL_SIZE = "small" 

# Set the folder where the checkpoints and all the others information will be writed
MODEL_DIR = 'gs://bucket_code_completion/T5_extension/finetuning'

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = slanted_triangular,
    sequence_length={"inputs": 256, "targets": 256},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

In [35]:
# we used model.predict function (setting beam_size)

vocabulary_predict=get_default_vocabulary()

model.score(inputs='gs://bucket_code_completion/T5_extension/finetuning/score/inputs.txt', 
            targets='gs://bucket_code_completion/T5_extension/finetuning/score/targets.txt',
              scores_file='gs://bucket_code_completion/T5_extension/finetuning/score/scores.txt',
              checkpoint_steps=-1, vocabulary=vocabulary_predict)

INFO:root:system_path_file_exists:gs://bucket_code_completion/T5_extension/finetuning/operative_config.gin
ERROR:root:Path not found: gs://bucket_code_completion/T5_extension/finetuning/operative_config.gin


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of data from outfeed.
INFO:tensorflow:Enqueue next (1) batch(es) of data to infeed.
INFO:tensorflow:Dequeue next (1) batch(es) of 

([-92.469124,
  -1.0895462,
  -1.9226506,
  -10.211275,
  -5.6445794,
  -33.480736,
  -19.27973,
  -11.465764,
  -11.361916,
  -13.477501,
  -2.7581952,
  -26.8485,
  -12.315093,
  -5.2809324,
  -2.1096609,
  -8.237902,
  -0.8282796,
  -29.518497,
  -0.04651332,
  -0.056794167,
  -21.163338,
  -10.332344,
  -40.618916,
  -41.889816,
  -0.75122786,
  -0.0033997297,
  -7.2741914,
  -31.898897,
  -31.158138,
  -14.288792,
  -17.184597,
  -6.234426,
  -0.4696663,
  -21.011288,
  -55.478657,
  -25.661575,
  -0.006794691,
  -0.37967145,
  -0.5290384,
  -28.02265,
  -34.9584,
  -12.8206215,
  -10.285417,
  -0.28567445,
  -1.4702772,
  -33.25132,
  -0.92343485,
  -62.135723,
  -6.74185,
  -36.533012,
  -6.8095565,
  -3.3433206,
  -10.4408655,
  -9.679232,
  -0.10669875,
  -0.5959325,
  -2.4551036,
  -9.083405,
  -0.17639273,
  -5.8118696,
  -66.314026,
  -26.273964,
  -104.179596,
  -88.14816,
  -6.404502,
  -16.298004,
  -34.936466,
  -31.892698,
  -2.256744,
  -33.48571,
  -0.2229507,
  -2.0