### Creation of the environment

In [None]:
import os
os.environ['USE_AUTH_EPHEM'] = '0'

In [None]:
%tensorflow_version 2.x
!pip3 install --upgrade pip
#!pip install -qU t5
!pip3 install git+https://github.com/google-research/text-to-text-transfer-transformer.git #extra_id_x support

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

#Set the base dir(Google cloud bucket)
BASE_DIR = "gs://code-generation"

#C.B.: SEQ_LENGTH for number of tokens
SEQ_LENGTH = 512

if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
ON_CLOUD = True


if ON_CLOUD:
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

Collecting pip
  Downloading pip-22.0.4-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 5.3 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.0.4
Collecting git+https://github.com/google-research/text-to-text-transfer-transformer.git
  Cloning https://github.com/google-research/text-to-text-transfer-transformer.git to /tmp/pip-req-build-af9z7rqw
  Running command git clone --filter=blob:none --quiet https://github.com/google-research/text-to-text-transfer-transformer.git /tmp/pip-req-build-af9z7rqw
  Resolved https://github.com/google-research/text-to-text-transfer-transformer.git to commit 9bad27b7dfc0d1de27630bc98d5ad837da0b9170
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mesh-tensorflow[transformer]>=0.1.13
  Downloading mesh_tensorflow-0.1.19-py3-none-any.whl (366 kB)

Instructions for updating:
non-resource variables are not supported in the long term


### Loading of tsv files
With this script you can load each tsv file for finetuning.
Please be sure that the path to all tsv files are correct

In [None]:
train_java_construct_length = 732832
test__java_construct_length = 103766
train_android_construct_length = 681318
test__android_construct_length = 91317

train_java_block_length = 291949
test__java_block_length = 39101
train_android_block_length = 185960
test__android_block_length = 24482

train_java_token_length = 733004
test__java_token_length = 214451
train_android_token_length = 682918
test__android_token_length = 182440

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_construct = {
    "train":      BASE_DIR + '/T5_extension/ft_datasets/train_java_construct.tsv',
    "validation": BASE_DIR + '/T5_extension/ft_datasets/test_java_construct.tsv',
}

num_nq_examples_java_construct = dict(train=train_java_construct_length, validation=test__java_construct_length)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_construct = {
    "train":      BASE_DIR + '/T5_extension/ft_datasets/train_android_construct.tsv',
    "validation": BASE_DIR + '/T5_extension/ft_datasets/test_android_construct.tsv',
}

num_nq_examples_android_construct = dict(train=train_android_construct_length, validation=test__android_construct_length)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_block = {
    "train":      BASE_DIR + '/T5_extension/ft_datasets/train_java_block.tsv',
    "validation": BASE_DIR + '/T5_extension/ft_datasets/test_java_block.tsv',
}

num_nq_examples_java_block = dict(train=train_java_block_length, validation=test__java_block_length)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_block = {
    "train":      BASE_DIR + '/T5_extension/ft_datasets/train_android_block.tsv',
    "validation": BASE_DIR + '/T5_extension/ft_datasets/test_android_block.tsv',
}

num_nq_examples_android_block = dict(train=train_android_block_length, validation=test__android_block_length)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_token = {
    "train":      BASE_DIR + '/T5_extension/ft_datasets/train_java_token.tsv',
    "validation": BASE_DIR + '/T5_extension/ft_datasets/test_java_token.tsv',
}

num_nq_examples_java_token = dict(train=train_java_token_length, validation=test__java_token_length)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_token = {
    "train":      BASE_DIR + '/T5_extension/ft_datasets/train_android_token.tsv',
    "validation": BASE_DIR + '/T5_extension/ft_datasets/test_android_token.tsv',
}

num_nq_examples_android_token = dict(train=train_android_token_length, validation=test__android_token_length)

### Preprocess of the dataset
In this step we preprocess the dataset.  
You have to change the path to vocab files (*vocab_model_path* and *vocab_path*)
We're going to preprocess all the tsv file so that T5 can use them for finetuning.

In [None]:
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary


# # Set the path of sentencepiece model and vocab files
# # Must be the same used for the pre-trained phase
vocab_model_path = BASE_DIR + '/T5_extension/code.model'
vocab_path = BASE_DIR + '/T5_extension/code.vocab'


TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask


def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True, required=False),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True)
}

JAVA CONSTRUCT

In [None]:
def nq_java_construct(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

   # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path_java_construct[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_java_construct("train").take(5)):
  print(ex)

A few raw train examples...
{'input': b'public void updateLockdownExceptions(ManagedObjectReference _this, String[] users) throws AuthMinimumAdminPermission, RemoteException, RuntimeFault, UserNotFound { Argument[] params = new Argument[2]; params[0] = new Argument("_this", "ManagedObjectReference", _this); params[1] = new Argument("users", "String[]", users); getWsc().invoke( <extra_id_0>); } <CONST> VimStub(String, boolean), VimStub(String, TrustManager), VimStub(Client) <INV> Client getWsc() <OTH> void destroyPropertyFilter(ManagedObjectReference), ManagedObjectReference createFilter(ManagedObjectReference, PropertyFilterSpec, boolean), ObjectContent retrieveProperties(ManagedObjectReference, PropertyFilterSpec), UpdateSet checkForUpdates(ManagedObjectReference, String), UpdateSet waitForUpdates(ManagedObjectReference, String), void cancelWaitForUpdates(ManagedObjectReference), UpdateSet waitForUpdatesEx(ManagedObjectReference, String, WaitOptions), RetrieveResult retrieveProperties

In [None]:
def java_construct_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_CONSTRUCT:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('java_construct')
t5.data.TaskRegistry.add(
    "java_construct",
    dataset_fn=nq_java_construct,
    splits=["train", "validation"],
    text_preprocessor=[java_construct_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_construct
)

<t5.data.dataset_providers.FunctionTask at 0x7f0bf11f3fd0>

In [None]:
nq_task = t5.data.TaskRegistry.get("java_construct")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...
{'inputs_pretokenized': b'JAVA_CONSTRUCT:@Override public <K, V> TransactionalMapBuilder<K, V> mapBuilder(String name) { checkState( <extra_id_0>); return new DefaultTransactionalMapBuilder<>(name, new TransactionalMapConfig(), managementService, this); } <CONST> <INV> <OTH>', 'inputs': array([    3,     2,  8699,    16,  6157, 21656,    61,  1868,    22,
          10,   118,   423,     9,   800,    31,     3,  4466,   110,
         153,    29,   423,     9,   800,    31,   409,   153,     5,
          28,    85,     8,     7,  9778,     5, 32099,    11,    14,
          26,  1013,  4466,   110,   153,   881,    91,     9,    26,
           3,  4466,   110,   184,   101, 15058,   114,     9,    23,
          11,     6,   118, 14872,    31,   118, 23933,    31,   118,
         889,  7610,    31,     1], dtype=int32), 'targets_pretokenized': b'isOpen(), "transaction not open"', 'targets': array([13093,   101,    33,  3704,   174,   992,   105,     

JAVA TOKEN

In [None]:
def nq_java_token(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_java_token[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_java_token("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'private void emitSelectById(JavaWriter <extra_id_0> logger.d("emitSelectById"); javaWriter.beginMethod(getTargetClass(), $$GET_OBJECT_BY_ID, EnumSet.of(PUBLIC, STATIC), "long", "id", "SQLiteDatabase", "db") .emitStatement("Cursor cursor = db.rawQuery(\\"SELECT * FROM %s WHERE %s = id\\", null)", getTableName(), idColumn.getColumnName()) .emitStatement("%s value = %s(cursor, db).get(0)", getTargetClass(), $$MAP_OBJECT_FUNCTION) .emitStatement("cursor.close()") .emitStatement("return value") .endMethod(); } <CONST> TableObject(Element, String, String, ShillelaghLogger, String) <INV> String getTableName(), String getTargetClass(), void brewJava(Writer) <OTH> void setIdColumn(TableColumn), TableColumn getIdColumn(), void setIsChildTable(boolean), Element getOriginatingElement(), void addColumn(TableColumn), String getSchema(), String getFqcn(), void emitGetId(JavaWriter), void emitCreateTable(JavaWriter), void emitDropTable(JavaWriter), void emitInse

In [None]:
def java_token_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_TOKEN:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('java_token')
t5.data.TaskRegistry.add(
    "java_token",
    dataset_fn=nq_java_token,
    splits=["train", "validation"],
    text_preprocessor=[java_token_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_token
)

<t5.data.dataset_providers.FunctionTask at 0x7f0bdce52a50>

In [None]:
nq_task = t5.data.TaskRegistry.get("java_token")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...
{'inputs_pretokenized': b'JAVA_TOKEN:public void describeTo(Description description) { description.appendText(getMessage() <extra_id_0> if (invocation != null) { description.appendText(": "); invocation.describeTo(description); } if (expectations != null) { description.appendText(" "); expectations.describeTo(description); } } <CONST> ExpectationError(String, SelfDescribing, Invocation) <INV> <OTH> ExpectationError unexpected(String, Invocation), ExpectationError notAllSatisfied(SelfDescribing), String toString()', 'inputs': array([    3,     2,  8699,    16,  3279,    61,  5371,    17, 18452,
           5,   569,   573,     8,     7,   573,     4, 10633,     5,
         502,    15, 32099,    21,    20, 15190,    46,    25,     8,
           7,   573,     4, 10633,    32,    61,     3,    51,  3593,
           4,  9685,   135,     5,   980,    11,     6,    21,    20,
        4919, 11639,    46,    25,     8,     7,   573,     4, 10633,
         

JAVA BLOCK

In [None]:
def nq_java_block(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_java_block[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_java_block("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public String capitalizeWords(final Object target) { if (target == null) <extra_id_0> return StringUtils.capitalizeWords(target); } <CONST> Strings(Locale) <INV> String arrayCapitalizeWords(Object), List listCapitalizeWords(List), Set setCapitalizeWords(Set) <OTH> String toString(Object), String arrayToString(Object), List listToString(List), Set setToString(Set), String abbreviate(Object, int), String arrayAbbreviate(Object, int), List listAbbreviate(List, int), Set setAbbreviate(Set, int), Boolean equals(Object, Object), Boolean equalsIgnoreCase(Object, Object), Boolean contains(Object, String), Boolean arrayContains(Object, String), List listContains(List, String), Set setContains(Set, String), Boolean containsIgnoreCase(Object, String), Boolean arrayContainsIgnoreCase(Object, String), List listContainsIgnoreCase(List, String), Set setContainsIgnoreCase(Set, String), Boolean startsWith(Object, String), Boolean arrayStartsWith(Object, String), 

In [None]:
def java_block_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_BLOCK:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('java_block')
t5.data.TaskRegistry.add(
    "java_block",
    dataset_fn=nq_java_block,
    splits=["train", "validation"],
    text_preprocessor=[java_block_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_block
)

<t5.data.dataset_providers.FunctionTask at 0x7f0bdcd24810>

In [None]:
nq_task = t5.data.TaskRegistry.get("java_block")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...
{'inputs_pretokenized': b'JAVA_BLOCK:@MemoryChunkType private static int getMemoryChunkType( final Builder builder, final ImagePipelineExperiments imagePipelineExperiments) { if (builder.mMemoryChunkType != null) <extra_id_0> else if (imagePipelineExperiments.isNativeCodeDisabled()) { return MemoryChunkType.BUFFER_MEMORY; } else { return MemoryChunkType.NATIVE_MEMORY; } } <CONST> ImagePipelineConfig(Builder) <INV> <OTH> void setWebpBitmapFactory(WebpBitmapFactory, ImagePipelineExperiments, BitmapCreator), DiskCacheConfig getDefaultMainDiskCacheConfig(Context), void resetDefaultRequestConfig(), Bitmap getBitmapConfig(), Supplier getBitmapMemoryCacheParamsSupplier(), CountingMemoryCache getBitmapMemoryCacheTrimStrategy(), CacheKeyFactory getCacheKeyFactory(), Context getContext(), DefaultImageRequestConfig getDefaultImageRequestConfig(), FileCacheFactory getFileCacheFactory(), boolean isDownsampleEnabled(), boolean isDiskCacheEnabled(), Supplier ge

ANDROID CONSTRUCT

In [None]:
def nq_android_construct(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

   # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path_android_construct[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_android_construct("train").take(5)):
  print(ex)

A few raw train examples...
{'input': b'private void writeToFile(final String content) { if ( <extra_id_0>) { return; } try { System.err.println(content); File f = new File(OUTPUT_FILE); f.delete(); FileWriter fstream = new FileWriter(OUTPUT_FILE); BufferedWriter out = new BufferedWriter(fstream); out.write(content); out.close(); } catch (Exception e) { e.printStackTrace(); } } <CONST> <INV> <OTH>', 'output': b'!WRITE_TO_FILE'}
{'input': b'private void writeToFile(final String content) { if (!WRITE_TO_FILE) { return; } try { System.err.println(content); File f = new File(OUTPUT_FILE); f.delete(); FileWriter fstream = new FileWriter(OUTPUT_FILE); BufferedWriter out = new BufferedWriter(fstream); out.write(content); out.close(); } catch ( <extra_id_0>) { e.printStackTrace(); } } <CONST> <INV> <OTH>', 'output': b'Exception e'}
{'input': b'private void writeToFile(final String content) { if (!WRITE_TO_FILE) { return; } try { System.err.println(content); File f = new File( <extra_id_0>); f.

In [None]:
def android_construct_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_CONSTRUCT:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('android_construct')
t5.data.TaskRegistry.add(
    "android_construct",
    dataset_fn=nq_android_construct,
    splits=["train", "validation"],
    text_preprocessor=[android_construct_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_construct
)

<t5.data.dataset_providers.FunctionTask at 0x7f0bdd137990>

In [None]:
nq_task = t5.data.TaskRegistry.get("android_construct")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...
{'inputs_pretokenized': b'ANDROID_CONSTRUCT:public void onViewCreated(View view, Bundle savedInstanceState) { view.findViewById(R.id.submit_color).setOnClickListener(new OnClickListener() { @Override public void onClick( <extra_id_0>) { onColorChosen(mColor); dismiss(); } }); view.findViewById(R.id.cancel).setOnClickListener(new OnClickListener() { @Override public void onClick(View v) { onColorChosen(mOriginalColor); dismiss(); } }); } <CONST> ColorDialog() <INV> void onColorChosen(int) <OTH> ColorDialog getInstance(String, int), void onColorChanged(int), View onCreateView(LayoutInflater, ViewGroup, Bundle), void onSaveInstanceState(Bundle), void onCancel(DialogInterface), void onDismiss(DialogInterface), void finishHostActivity()', 'inputs': array([    3, 17198,    16,  6157, 21656,    61,  5371,    17,     3,
       12631,     5,   102,   255,     9,   824,  1159,   113,     8,
           7,   255,     4,  1241,     5,   117,     4,   103,    

ANDROID TOKEN

In [None]:
def nq_android_token(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_android_token[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_android_token("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public boolean initialize <extra_id_0> if (bluetoothManager == null) { bluetoothManager = (BluetoothManager) getSystemService(Context.BLUETOOTH_SERVICE); if (bluetoothManager == null) { Log.e(TAG, "Unable to initialize BluetoothManager."); return false; } } adapter = bluetoothManager.getAdapter(); if (adapter == null) { Log.e(TAG, "Unable to obtain a BluetoothAdapter."); return false; } return true; } <CONST> <INV> <OTH> void broadcastUpdate(String), void broadcastUpdate(String, BluetoothGattCharacteristic), IBinder onBind(Intent), boolean onUnbind(Intent), void enableSensor(BleSensor, boolean), boolean connect(String), void disconnect(), void close(), void readCharacteristic(BluetoothGattCharacteristic), void updateSensor(BleSensor), List getSupportedGattServices()', 'output': b'() {'}
{'input': b'public boolean initialize() { if (bluetoothManager == <extra_id_0> bluetoothManager = (BluetoothManager) getSystemService(Context.BLUETOOTH_SERVICE); 

In [None]:
def android_token_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_TOKEN:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('android_token')
t5.data.TaskRegistry.add(
    "android_token",
    dataset_fn=nq_android_token,
    splits=["train", "validation"],
    text_preprocessor=[android_token_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_token
)

<t5.data.dataset_providers.FunctionTask at 0x7f0bdd0cd8d0>

In [None]:
nq_task = t5.data.TaskRegistry.get("android_token")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...
{'inputs_pretokenized': b'ANDROID_TOKEN:public void onPageSelected(int position) { showButtonsContainer(); switch (position) { case 0: browseButton.setSelected(true); favoritesButton.setSelected(false); break <extra_id_0> case 1: browseButton.setSelected(false); favoritesButton.setSelected(true); break; default: } } <CONST> <INV> void showButtonsContainer() <OTH> void onCreate(Bundle), void onClick(View), void onPageScrolled(int, float, int), void onPageScrollStateChanged(int), void onRecyclerScrolled(int, int), void hideButtonsContainer()', 'inputs': array([    3, 17198,    16,  3279,    61,  5371,    17,     3, 19322,
           5,    44,   363,     8,     7,   528,  4474,   470,    19,
         489,    20,   459,     8,     7,   199,  5629,     3, 10979,
         427,     4,  4444,     5,   212,    11,     3, 21045,   427,
           4,  4444,     5,   282,    11,   476, 32099,   199,  4976,
           3, 10979,   427,     4,  4444,     5,   2

ANDROID BLOCK

In [None]:
def nq_android_block(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_android_block[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_android_block("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'private static List<HeroAndAdvantages> loadHeroes(SQLiteDatabase db) { List<HeroAndAdvantages> heroes = new ArrayList<>(); Cursor c = db.rawQuery("SELECT * FROM Heroes", null); c.moveToFirst(); while (!c.isAfterLast()) <extra_id_0> c.close(); return heroes; } <CONST> SqlLoader(Context) <INV> <OTH> List deepCopyOfHeroes(List), List calculateAdvantages(List), void loadAllAdvantages(List, SQLiteDatabase), void updateOneAdvantage(String, int, SQLiteDatabase), double findAdvantage(HeroAndAdvantages, String, SQLiteDatabase)', 'output': b'{ heroes.add(new HeroAndAdvantages(c)); c.moveToNext(); }'}
{'input': b'public void onCreate(@Nullable Bundle savedInstanceState) { super.onCreate(savedInstanceState); setRetainInstance(true); Bundle bundle = getArguments(); if (bundle != null) <extra_id_0> } <CONST> <INV> <OTH> void onViewClicked(View), void onAttach(Context), View onCreateView(LayoutInflater, ViewGroup, Bundle), void onStart(), void onDetach()', 'out

In [None]:
def android_block_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_BLOCK:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('android_block')
t5.data.TaskRegistry.add(
    "android_block",
    dataset_fn=nq_android_block,
    splits=["train", "validation"],
    text_preprocessor=[android_block_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_block
)

<t5.data.dataset_providers.FunctionTask at 0x7f0bdcf2a990>

In [None]:
nq_task = t5.data.TaskRegistry.get("android_block")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...
{'inputs_pretokenized': b'ANDROID_BLOCK:public synchronized List<FeedItem> getQueue() throws InterruptedException { try { return queueFuture.get(); } catch (ExecutionException e) <extra_id_0> } <CONST> PlaybackServiceTaskManager(Context, PSTMCallback) <INV> <OTH> void onEvent(QueueEvent), boolean isQueueLoaderActive(), void cancelQueueLoader(), void loadQueue(), void onEvent(FeedItemEvent), boolean isItemInQueue(long), List getQueueIfLoaded(), void startPositionSaver(), boolean isPositionSaverActive(), void cancelPositionSaver(), void startWidgetUpdater(), void setSleepTimer(long, boolean, boolean), boolean isSleepTimerActive(), void disableSleepTimer(), void restartSleepTimer(), long getSleepTimerTimeLeft(), boolean isWidgetUpdaterActive(), void cancelWidgetUpdater(), void cancelChapterLoader(), boolean isChapterLoaderActive(), void startChapterLoader(Playable), void cancelAllTasks(), void shutdown(), Runnable useMainThreadIfNecessary(Runnable)'

### Finetuning
You can run the finetuning using the following cells.  
Please set the correct path of the variable *MODEL_DIR* (the path to save the pretrained model in), *PATH_GIN_FILE* (the gin file configuration for this finetuning) and *PRETRAINED_DIR* (the folder that contains the pretrained model).  
**Keep attention** to change the *pretrained_model_dir* in finetune step (if you are starting the finetuning from scratch you have to set the value *PRETRAINED_DIR*, if you are restarting the finetuning from a previous saved checkpoint you have to set the value *MODEL_DIR*)

In [None]:
def _rate_num_input_examples(task):
  if "train" in task.splits:
    return float(task.num_input_examples("train"))
  elif "validation" in task.splits:
    return float(task.num_input_examples("validation"))
  else:
    raise ValueError("Task %s does not have a train or validation split." % (task.name))


t5.data.MixtureRegistry.remove("all_tasks")
t5.data.MixtureRegistry.add(
    "all_tasks",
    ["java_construct", "java_token", "java_block", "android_construct", "android_token", "android_block"],
    default_rate=_rate_num_input_examples
     #default_rate=1.0
)

<seqio.dataset_providers.Mixture at 0x7f0bdd0b0b10>

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import slanted_triangular
# C.B.: Added import
import t5.models

MODEL_SIZE = "small" 

# Set the folder where the checkpoints and all the others information will be writed
MODEL_DIR = BASE_DIR + '/T5_extension/finetuning'

# Specify the pre-trained dir which must contain the pre-trained models, the operative_config.gin file and the checkpoint file as well
PRETRAINED_DIR = BASE_DIR + '/T5_extension/pretrained_model'


model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16000),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)


# C.B.: Change to 100 for last phase, default 5000
SAVE_CHECKPOINTS_STEPS = 5000


model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = slanted_triangular,
    sequence_length={"inputs": SEQ_LENGTH, "targets": SEQ_LENGTH},
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

In [None]:
PATH_GIN_FILE = BASE_DIR + '/T5_extension/configuration_file/operative_config.gin'
import gin

with gin.unlock_config():
    gin.parse_config_file(PATH_GIN_FILE)
    #RUN FINE-TUNING
    FINETUNE_STEPS = 400000
    model.finetune(
        mixture_or_task_name="all_tasks",
        # pretrained_model_dir=PRETRAINED_DIR,
        pretrained_model_dir=MODEL_DIR,
        finetune_steps=FINETUNE_STEPS
    )

INFO:root:system_path_file_exists:gs://code-generation/T5_extension/configuration_file/operative_config.gin
ERROR:root:Path not found: gs://code-generation/T5_extension/configuration_file/operative_config.gin
INFO:root:system_path_file_exists:gs://code-generation/T5_extension/finetuning/operative_config.gin
ERROR:root:Path not found: gs://code-generation/T5_extension/finetuning/operative_config.gin


INFO:tensorflow:Using config: {'_model_dir': 'gs://code-generation/T5_extension/finetuning', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 100, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.31.160.10:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.31.160.10:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.31.160.10:8470', '_evaluation_master': 'grpc://10.31.160.10:8470'