### Configuration of the environment

In [None]:
%tensorflow_version 2.x
!pip3 install --upgrade pip
#!pip install -qU t5
!pip3 install git+https://github.com/google-research/text-to-text-transfer-transformer.git #extra_id_x support

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

#Set the base dir(Google cloud bucket)
BASE_DIR = "gs://bucket_comment_completion" 

if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
ON_CLOUD = True


if ON_CLOUD:
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

Collecting pip
[?25l  Downloading https://files.pythonhosted.org/packages/fe/ef/60d7ba03b5c442309ef42e7d69959f73aacccd0d86008362a681c4698e83/pip-21.0.1-py3-none-any.whl (1.5MB)
[K     |████████████████████████████████| 1.5MB 6.0MB/s 
[?25hInstalling collected packages: pip
  Found existing installation: pip 19.3.1
    Uninstalling pip-19.3.1:
      Successfully uninstalled pip-19.3.1
Successfully installed pip-21.0.1
Collecting git+https://github.com/google-research/text-to-text-transfer-transformer.git
  Cloning https://github.com/google-research/text-to-text-transfer-transformer.git to /tmp/pip-req-build-2rlzfg82
  Running command git clone -q https://github.com/google-research/text-to-text-transfer-transformer.git /tmp/pip-req-build-2rlzfg82
Collecting mesh-tensorflow[transformer]>=0.1.13
  Downloading mesh_tensorflow-0.1.18-py3-none-any.whl (361 kB)
[K     |████████████████████████████████| 361 kB 7.0 MB/s 
Collecting rouge-score
  Downloading rouge_score-0.0.4-py2.py3-none-any

Instructions for updating:
non-resource variables are not supported in the long term


### Loading of the tsv files
We loaded the 6 tsv files, please be sure to upload them on the bucket and to copy the correct path

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_construct = {
    "train":      'gs://bucket_comment_completion/Matteo/ft_datasets/train_java_construct.tsv',
    "validation": 'gs://bucket_comment_completion/Matteo/ft_datasets/eval_java_construct.tsv',
}

num_nq_examples_java_construct = dict(train=750000, validation=104037)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_construct = {
    "train":      'gs://bucket_comment_completion/Matteo/ft_datasets/train_android_construct.tsv',
    "validation": 'gs://bucket_comment_completion/Matteo/ft_datasets/eval_android_construct.tsv',
}

num_nq_examples_android_construct = dict(train=750000, validation=98793)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_block = {
    "train":      'gs://bucket_comment_completion/Matteo/ft_datasets/train_java_block.tsv',
    "validation": 'gs://bucket_comment_completion/Matteo/ft_datasets/eval_java_block.tsv',
}

num_nq_examples_java_block = dict(train=298470, validation=38840)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_block = {
    "train":      'gs://bucket_comment_completion/Matteo/ft_datasets/train_android_block.tsv',
    "validation": 'gs://bucket_comment_completion/Matteo/ft_datasets/eval_android_block.tsv',
}

num_nq_examples_android_block = dict(train=204580, validation=26503)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_java_token = {
    "train":      'gs://bucket_comment_completion/Matteo/ft_datasets/train_java_token.tsv',
    "validation": 'gs://bucket_comment_completion/Matteo/ft_datasets/eval_java_token.tsv',
}

num_nq_examples_java_token = dict(train=750000, validation=214682)

In [None]:
#Validation(train and test on the same dataset)

nq_tsv_path_android_token = {
    "train":      'gs://bucket_comment_completion/Matteo/ft_datasets/train_android_token.tsv',
    "validation": 'gs://bucket_comment_completion/Matteo/ft_datasets/eval_android_token.tsv',
}

num_nq_examples_android_token = dict(train=750000, validation=198281)

### Preprocess of the dataset
In this step we preprocess the dataset.  
You have to change the path to vocab files (*vocab_model_path* and *vocab_path*)
We're going to preprocess all the tsv file so that T5 can use them for HP tuning.

In [None]:
from t5.data import postprocessors as t5_postprocessors
from t5.seqio import Feature,SentencePieceVocabulary


# # Set the path of sentencepiece model and vocab files
# # Must be the same used for the pre-trained phase
vocab_model_path = 'gs://bucket_comment_completion/Matteo/code.model'
vocab_path = 'gs://bucket_comment_completion/Matteo/code.vocab'


TaskRegistry = t5.data.TaskRegistry
TfdsTask = t5.data.TfdsTask


def get_default_vocabulary():
  return SentencePieceVocabulary(vocab_model_path, 100)

DEFAULT_OUTPUT_FEATURES = {
    "inputs": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True, required=False),

    "targets": Feature(
        vocabulary=get_default_vocabulary(), add_eos=True)
}

JAVA CONSTRUCT

In [None]:
def nq_java_construct(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

   # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path_java_construct[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_java_construct("train").take(5)):
  print(ex)

A few raw train examples...
{'input': b'public void updateLockdownExceptions(ManagedObjectReference _this, String[] users) throws AuthMinimumAdminPermission, RemoteException, RuntimeFault, UserNotFound { Argument[] params = new Argument[2]; params[0] = new Argument("_this", "ManagedObjectReference", _this); params[1] = new Argument("users", "String[]", users); getWsc().invoke( <extra_id_0>); }', 'output': b'"UpdateLockdownExceptions", params, null'}
{'input': b'@Override public Collection<AgentProjectInfo> collectDependencies(String folder) { if ( <extra_id_0>){ devDependencies = findDevDependencies(folder); } File yarnLock = new File(folder + fileSeparator + YARN_LOCK); boolean yarnLockFound = yarnLock.isFile(); Collection<DependencyInfo> dependencies = new ArrayList<>(); if (yarnLockFound){ dependencies = parseYarnLock(yarnLock); } else { npmLsFailureStatus = true; } return getSingleProjectList(dependencies); }', 'output': b'!includeDevDependencies'}
{'input': b'@Override public Coll

In [None]:
def java_construct_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_CONSTRUCT:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('java_construct')
t5.data.TaskRegistry.add(
    "java_construct",
    dataset_fn=nq_java_construct,
    splits=["train", "validation"],
    text_preprocessor=[java_construct_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_construct
)

<t5.data.dataset_providers.FunctionTask at 0x7f7881812b70>

In [None]:
nq_task = t5.data.TaskRegistry.get("java_construct")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b'JAVA_CONSTRUCT:private void addJavaDoc(JDocCommentable docCommentable) { JDocComment javadoc = docCommentable.javadoc(); javadoc.append( <extra_id_0>); }', 'inputs': array([    3,  7641,    15,  5071, 17558,    56,  8797,    20,   162,
       28631,     5,   808,  1361,  1341,   367,  1009,  1341,   367,
           8,     7,  1570,  1361,  1341,     3, 14480,    11,  1009,
        1341,   367,     4, 14480,    18,     3, 14480,     4,   109,
           5, 32099,    10,     6,     1], dtype=int32), 'targets_pretokenized': b'REQUIRED_COMMENT_TEXT', 'targets': array([ 7572,     2,  8424,    15, 10867,    15,   989,     2,    70,
           1], dtype=int32)}
{'inputs_pretokenized': b'JAVA_CONSTRUCT:@Bean public NativeEnvironmentRepository nativeEnvironmentRepository( NativeEnvironmentRepositoryFactory factory, NativeEnvironmentProperties environmentProperties) { return factory.build( <extra_id_0>); }', 'inputs': array([    3

JAVA TOKEN

In [None]:
def nq_java_token(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_java_token[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_java_token("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public void addThirdPartyPaymentWorkflow(com.mozu.api.contracts.sitesettings.order.ExternalPaymentWorkflowDefinition definition) throws Exception { MozuClient client = com.mozu.api.clients.commerce.settings.checkout.PaymentSettingsClient <extra_id_0> client.setContext(_apiContext); client.executeRequest(); client.cleanupHttpConnection(); }', 'output': b'.addThirdPartyPaymentWorkflowClient( definition);'}
{'input': b'public void addThirdPartyPaymentWorkflow(com.mozu.api.contracts.sitesettings.order.ExternalPaymentWorkflowDefinition definition) throws Exception { MozuClient client = com.mozu.api.clients.commerce.settings.checkout.PaymentSettingsClient.addThirdPartyPaymentWorkflowClient( definition); client.setContext(_apiContext <extra_id_0> client.executeRequest(); client.cleanupHttpConnection(); }', 'output': b');'}
{'input': b'public void addThirdPartyPaymentWorkflow(com.mozu.api.contracts.sitesettings.order.ExternalPaymentWorkflowDefinition def

In [None]:
def java_token_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_TOKEN:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('java_token')
t5.data.TaskRegistry.add(
    "java_token",
    dataset_fn=nq_java_token,
    splits=["train", "validation"],
    text_preprocessor=[java_token_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_token
)

<t5.data.dataset_providers.FunctionTask at 0x7f7761206048>

In [None]:
nq_task = t5.data.TaskRegistry.get("java_token")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


{'inputs_pretokenized': b'JAVA_TOKEN:public T withMatcher(final String matcherName, final Matcher matcher) { this.configuration = configuration.withMatcher(matcherName, matcher); return (T) this <extra_id_0> }', 'inputs': array([    3,  7641,    15,  2591,    56,  4569,   299,   273,  1588,
           5,    64,    26,  2116,    66,     9,    44,  2908,  2116,
           8,     7,    23,     4,  1382,    11,   739,     4,   616,
        1588,     5,  1311,    66,     9,  2116,    10,    14,    17,
          70,     8,    23, 32099,     6,     1], dtype=int32), 'targets_pretokenized': b';', 'targets': array([ 3, 13,  1], dtype=int32)}
{'inputs_pretokenized': b'JAVA_TOKEN:@Override public <extra_id_0> if (this.rendition != null) { return ModificationDate.get(this.rendition.getRendition().adaptTo(Resource.class)); } else { return null; } }', 'inputs': array([    3,  7641,    15,  2591,    56,  2098,    27,    12, 32099,
          21,    17,    75,     4,   185,   852, 18717,    49,    30,


JAVA BLOCK

In [None]:
def nq_java_block(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_java_block[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_java_block("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public Object createField(String name) { Object reader = _fieldMap.get(name); if (reader == null) <extra_id_0> return reader; }', 'output': b'reader = NullFieldDeserializer.DESER;'}
{'input': b'@Nonnull public static JSInvocation invoke (@Nonnull final JQueryInvocation aJQueryInvocation, @Nonnull final JSAssocArray aOptions) <extra_id_0>', 'output': b'{ return invoke (aJQueryInvocation).arg (aOptions); }'}
{'input': b'public static String presentMinMaxCount(long minmax) { if (minmax == Long.MAX_VALUE || minmax == Long.MIN_VALUE) <extra_id_0> return String.valueOf(minmax); }', 'output': b'{ return UNDEF_STRING; }'}
{'input': b'@Override public NonBottomTypeNode<ElkClass, ElkNamedIndividual> getCreateNode( final Collection<? extends ElkClass> members) <extra_id_0>', 'output': b'{ return getCreateUpdateableTypeNode( classTaxonomy_.getCreateNode(members)); }'}
{'input': b'public static JsonException typeMismatch(Object indexOrName, Object actual, Str

In [None]:
def java_block_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['JAVA_BLOCK:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('java_block')
t5.data.TaskRegistry.add(
    "java_block",
    dataset_fn=nq_java_block,
    splits=["train", "validation"],
    text_preprocessor=[java_block_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_java_block
)

<t5.data.dataset_providers.FunctionTask at 0x7f776116db00>

In [None]:
nq_task = t5.data.TaskRegistry.get("java_block")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


A few preprocessed training examples...


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


{'inputs_pretokenized': b'JAVA_BLOCK:@Override public Optional<TimeSeriesCollection> getPreviousCollection(int n) { if (n < 0) throw new IllegalArgumentException("cannot look into the future"); if (n == 0) return Optional.of(getCurrentCollection()); if (n - 1 >= previous_.size()) <extra_id_0> return Optional.of(previous_.get(n - 1)); }', 'inputs': array([    3,  7641,    15,  3517,    56,  2098,    27,    12,   730,
          25, 16445,   387,    29, 15550,   387,     5,    53,   446,
           8,     7,    21,    17,   127,   136,   178,    78,    24,
         381,    38,    28, 10153,  5972,  2378,    62,  2639,    46,
          21,    17,   127,    40,   178,    14,   730,     4,   579,
           5,  1134,   387,    39,    21,    17,   127,   139,   279,
         453,  1805,    15,     4,   134,    60, 32099,    14,   730,
           4,   579,     5,  3500,    15,     4,    33,     5,   127,
         139,  3792,     6,     1], dtype=int32), 'targets_pretokenized': b'return Optiona

ANDROID CONSTRUCT

In [None]:
def nq_android_construct(split, shuffle_files=True):
  # We only have one file for each split.
  del shuffle_files

   # Load lines from the text file as examples.

  ds = tf.data.TextLineDataset(nq_tsv_path_android_construct[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw train examples...")
for ex in tfds.as_numpy(nq_android_construct("train").take(5)):
  print(ex)

A few raw train examples...
{'input': b'private void writeToFile(final String content) { if ( <extra_id_0>) { return; } try { System.err.println(content); File f = new File(OUTPUT_FILE); f.delete(); FileWriter fstream = new FileWriter(OUTPUT_FILE); BufferedWriter out = new BufferedWriter(fstream); out.write(content); out.close(); } catch (Exception e) { e.printStackTrace(); } }', 'output': b'!WRITE_TO_FILE'}
{'input': b'private void writeToFile(final String content) { if (!WRITE_TO_FILE) { return; } try { System.err.println(content); File f = new File(OUTPUT_FILE); f.delete(); FileWriter fstream = new FileWriter(OUTPUT_FILE); BufferedWriter out = new BufferedWriter(fstream); out.write(content); out.close(); } catch ( <extra_id_0>) { e.printStackTrace(); } }', 'output': b'Exception e'}
{'input': b'private void writeToFile(final String content) { if (!WRITE_TO_FILE) { return; } try { System.err.println(content); File f = new File( <extra_id_0>); f.delete(); FileWriter fstream = new FileW

In [None]:
def android_construct_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_CONSTRUCT:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('android_construct')
t5.data.TaskRegistry.add(
    "android_construct",
    dataset_fn=nq_android_construct,
    splits=["train", "validation"],
    text_preprocessor=[android_construct_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_construct
)

<t5.data.dataset_providers.FunctionTask at 0x7f776119c390>

In [None]:
nq_task = t5.data.TaskRegistry.get("android_construct")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b'ANDROID_CONSTRUCT:public void onReceive(Context context, Intent intent) { if(GameActivity.ACTION_PAGE_SELECTED.equals(intent.getAction())) if(action != null && action.isActive()) action.finish(); if(LocalEvents.ACTION_PLAYER_ADD.equals(intent.getAction()) || LocalEvents.ACTION_PLAYER_EDIT.equals( <extra_id_0>) || LocalEvents.ACTION_PLAYER_REMOVE.equals(intent.getAction())) playersAdapter.notifyDataSetChanged(); }', 'inputs': array([    3, 16446,    15,  5071, 17558,    56,  4569,    20,  6431,
           5,    92,   130,     9,   604,   576,     8,     7,    21,
           5,  2040,   435,     4,  1023,    15,  3497,    15,  9738,
           4,   117,     5,  1482,     4,  2677,   459,    21,     5,
         915,    49,    30,    91,   647,     4,  5649,    60,   647,
           4,  3835,    18,    21,     5,   959,  1314,     4,  1023,
          15,  3882,     2,  1387,    15,  4735,     4,   117,     5,
        1482,  

ANDROID TOKEN

In [None]:
def nq_android_token(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_android_token[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_android_token("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public void register(Context context <extra_id_0> IntentFilter intentFilter = buildFilter(); try { register(context, intentFilter); }catch (IllegalArgumentException e){ Log.w(TAG, "Error registering receiver. Receiver maybe already registered, unregistering to register again."); unregister(context); register(context, intentFilter); } }', 'output': b') {'}
{'input': b'public void register(Context context) { IntentFilter intentFilter <extra_id_0> try { register(context, intentFilter); }catch (IllegalArgumentException e){ Log.w(TAG, "Error registering receiver. Receiver maybe already registered, unregistering to register again."); unregister(context); register(context, intentFilter); } }', 'output': b'= buildFilter();'}
{'input': b'public void register(Context context) { IntentFilter intentFilter = buildFilter(); try <extra_id_0> register(context, intentFilter); }catch (IllegalArgumentException e){ Log.w(TAG, "Error registering receiver. Receiver ma

In [None]:
def android_token_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_TOKEN:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('android_token')
t5.data.TaskRegistry.add(
    "android_token",
    dataset_fn=nq_android_token,
    splits=["train", "validation"],
    text_preprocessor=[android_token_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_token
)

<t5.data.dataset_providers.FunctionTask at 0x7f776117df98>

In [None]:
nq_task = t5.data.TaskRegistry.get("android_token")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b'ANDROID_TOKEN:private void onDrag(MotionEvent event) { final int pointerIndex = event.findPointerIndex(_activePointerId); if (pointerIndex == -1) return; final float x = event.getX <extra_id_0> final float y = event.getY(pointerIndex); _translateX += (x - _lastTouchX); _translateY += (y - _lastTouchY); invalidate(); _lastTouchX = x; _lastTouchY = y; }', 'inputs': array([    3, 16446,    15,  2591,    56,  8797,    20,   170,  2943,
           5,  6119,   209,     8,     7,    44,    35,  4635,   163,
          11,   209,     4,   714,  1521,   163,     5,    15,  2231,
        1521,    68,    10,    21,    17,  7358,   163,    40,  1324,
          14,    13,    44,   245,   205,    11,   209,     4,    33,
           2, 32099,    44,   245,   240,    11,   209,     4,    33,
           2,     5,  7358,   163,    10,     3,    15,  3625,     2,
         470,    17,   138,   139,     3,    15,   944,  3447,     2,
        

ANDROID BLOCK

In [None]:
def nq_android_block(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path_android_block[split])
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["string","string"],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  
  ds = ds.map(lambda *ex: dict(zip(["input", "output"], ex)))
  return ds

print("A few raw valid examples...")
for ex in tfds.as_numpy(nq_android_block("validation").take(5)):
  print(ex)

A few raw valid examples...
{'input': b'public void update() { if(Gdx.input.isKeyPressed(Keys.LEFT))<extra_id_0> if(Gdx.input.isKeyPressed(Keys.RIGHT)){ bucket.x += 200 * Gdx.graphics.getDeltaTime(); } if (bucket.x < 0) bucket.x = 0; if (bucket.x > 320 - 64) bucket.x = 320 - 64; }', 'output': b'{ bucket.x -= 200 * Gdx.graphics.getDeltaTime(); }'}
{'input': b'public void update() { if(Gdx.input.isKeyPressed(Keys.LEFT)){ bucket.x -= 200 * Gdx.graphics.getDeltaTime(); } if(Gdx.input.isKeyPressed(Keys.RIGHT))<extra_id_0> if (bucket.x < 0) bucket.x = 0; if (bucket.x > 320 - 64) bucket.x = 320 - 64; }', 'output': b'{ bucket.x += 200 * Gdx.graphics.getDeltaTime(); }'}
{'input': b'public void update() { if(Gdx.input.isKeyPressed(Keys.LEFT)){ bucket.x -= 200 * Gdx.graphics.getDeltaTime(); } if(Gdx.input.isKeyPressed(Keys.RIGHT)){ bucket.x += 200 * Gdx.graphics.getDeltaTime(); } if (bucket.x < 0) <extra_id_0> if (bucket.x > 320 - 64) bucket.x = 320 - 64; }', 'output': b'bucket.x = 0;'}
{'input':

In [None]:
def android_block_preprocessing(ds):
  
  def to_inputs_and_targets(ex):

        inputs = tf.strings.join(['ANDROID_BLOCK:' + ex['input']], separator=' ')
        class_label = tf.strings.join([ex['output']], separator=' ')
        return {'inputs': inputs, 'targets': class_label }
    
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [None]:
t5.data.TaskRegistry.remove('android_block')
t5.data.TaskRegistry.add(
    "android_block",
    dataset_fn=nq_android_block,
    splits=["train", "validation"],
    text_preprocessor=[android_block_preprocessing],
    output_features = DEFAULT_OUTPUT_FEATURES,
    metric_fns=[t5.evaluation.metrics.accuracy],
    num_input_examples=num_nq_examples_android_block
)

<t5.data.dataset_providers.FunctionTask at 0x7f7761206c18>

In [None]:
nq_task = t5.data.TaskRegistry.get("android_block")
ds = nq_task.get_dataset(split="train", sequence_length={"inputs": 256, "targets": 256})
print("A few preprocessed training examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)


  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


A few preprocessed training examples...
{'inputs_pretokenized': b'ANDROID_BLOCK:@Override public void onPrepared(MediaPlayer mp) { progressDialog.dismiss(); if (-1 != mCurrentPosition) <extra_id_0> mVideoView.start(); }', 'inputs': array([    3, 16446,    15,  3517,    56,  2098,    27,    12,    20,
         170,  5609,     5,  6238,  8212,     8,     7, 12898,     4,
        3875,    18,    21,    17,   802,    49,  2982,   392,     8,
       32099,    54, 17649,     4,   373,    18,     6,     1],
      dtype=int32), 'targets_pretokenized': b'{ mVideoView.seekTo(mCurrentPosition); }', 'targets': array([    7,    54, 17649,     4, 14206,     5,  4399,   392,    10,
           6,     1], dtype=int32)}
{'inputs_pretokenized': b'ANDROID_BLOCK:private void spinnerSelect(LibOpenConnect.FormOpt opt, int index) { LibOpenConnect.FormChoice fc = opt.choices.get((int)index); String s = fc.name != null ? fc.name : ""; if (opt.userData == null) <extra_id_0> else if (!s.equals(opt.userData)) { op

### Hyper Parameter tuning
You can run the HP tuning using the following cells.  
Please set the correct path of the variable *MODEL_DIR* (the path to save the pretrained model in), *PATH_GIN_FILE* (the gin file configuration for this HP tuning) and *PRETRAINED_DIR* (the folder that contains the pretrained model).  
**Keep attention** to change the *pretrained_model_dir* in finetune step (if you are starting the HP tuning from scratch you have to set the value *PRETRAINED_DIR*, if you are restarting the HP tuning from a previous saved checkpoint you have to set the value *MODEL_DIR*)

In [None]:
def _rate_num_input_examples(task):
  if "train" in task.splits:
    return float(task.num_input_examples("train"))
  elif "validation" in task.splits:
    return float(task.num_input_examples("validation"))
  else:
    raise ValueError("Task %s does not have a train or validation split." % (task.name))


t5.data.MixtureRegistry.remove("all_tasks")
t5.data.MixtureRegistry.add(
    "all_tasks",
    ["java_construct", "java_token", "java_block", "android_construct", "android_token", "android_block"],
    default_rate=_rate_num_input_examples
     #default_rate=1.0
)

<t5.seqio.dataset_providers.Mixture at 0x7f7761190748>

In [None]:
from mesh_tensorflow.transformer.learning_rate_schedules import truncated_rsqrt

MODEL_SIZE = "small" 

# Set the folder where the checkpoints and all the others information will be writed
MODEL_DIR = 'gs://bucket_comment_completion/Matteo/HP_TUNING/ISR/model'

# Specify the pre-trained dir which must contain the pre-trained models, the operative_config.gin file and the checkpoint file as well
PRETRAINED_DIR='gs://bucket_comment_completion/Matteo/pretrained_with_masking'


model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    learning_rate_schedule = truncated_rsqrt,
    sequence_length={"inputs": 256, "targets": 256},
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

In [None]:
PATH_GIN_FILE = 'gs://bucket_comment_completion/Matteo/HP_TUNING/ISR/operative_config.gin'
import gin

with gin.unlock_config():
    gin.parse_config_file(PATH_GIN_FILE)
    #RUN FINE-TUNING
    FINETUNE_STEPS = 100000
    model.finetune(
        mixture_or_task_name="all_tasks",
        pretrained_model_dir=MODEL_DIR,
        finetune_steps=FINETUNE_STEPS
    )

INFO:root:system_path_file_exists:gs://bucket_comment_completion/Matteo/HP_TUNING/ISR/operative_config.gin
ERROR:root:Path not found: gs://bucket_comment_completion/Matteo/HP_TUNING/ISR/operative_config.gin
INFO:root:system_path_file_exists:gs://bucket_comment_completion/Matteo/HP_TUNING/ISR/model/operative_config.gin
ERROR:root:Path not found: gs://bucket_comment_completion/Matteo/HP_TUNING/ISR/model/operative_config.gin


INFO:tensorflow:Using config: {'_model_dir': 'gs://bucket_comment_completion/Matteo/HP_TUNING/ISR/model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.116.164.242:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.116.164.242:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.116.164.242:8470', '_evaluation_master': 'grpc:

  _tokenize, num_parallel_calls=tf.data.experimental.AUTOTUNE)


INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0 0]
  [0 0 0 1]
  [1 0 0 0]
  [1 0 0 1]
  [0 1 0 0]
  [0 1 0 1]
  [1 1 0 0]
  [1 1 0 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0 0]]

 [[0 0 0 1]]

 [[1 0 0 0]]

 [[1 0 0 1]]

 [[0 1 0 0]]

 [[0 1 0 1]]

 [[1 1 0 0]]

 [[1 1 0 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[8] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[batch=8] LayoutRules{('heads', 'model'), ('batch', 'batch'), ('ensemble', 'ensemble'), ('d_ff', 'model'), ('experts', 'batch'), ('vocab', 'model')}
INFO:tensorflow:Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7f7762f91c18>
INF

### Evaluate the performances

In [None]:
# Use a larger batch size for evaluation, which requires less memory.
model.batch_size = 512
model.eval(
    mixture_or_task_name="all_tasks",
    # mixture_or_task_name="all_tasks",
    checkpoint_steps=-1 #evaluate only last checkpoint
)