In [1]:
%cd ..

/Users/Egor.Spirin/JetBrains/code-t5


In [2]:
import os
import gin
import seqio
import t5
import tensorflow as tf
import tensorflow_datasets as tfds
from contextlib import contextmanager
from tensorflow.compat.v1 import logging

from code_t5.constants import NEWLINE
from code_t5.tasks import register_dev_task

In [3]:
# Improve logging.
@contextmanager
def tf_verbosity_level(level):
    log_level = logging.get_verbosity()
    logging.set_verbosity(level)
    yield
    logging.set_verbosity(log_level)

In [4]:
def decode(vocab, s):
    return vocab.decode(s.tolist()).replace(NEWLINE, "\n")

# Define properties

In [5]:
BASE_DIR = "../data/mlp-8"
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")

VOCAB_PATH = os.path.join(DATA_DIR, "dataset-dev", "dev.model")

MODEL_SIZE = "arch-lm_v1-lm" #@param["small", "base", "base-t5.1.1", "base_shared", "base_shared_1k", "base-top5k", "base-top5k", "lm_ifa_1k", "arch-lm_v1-lm", "large", "3B", "11B"]
MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE)
CACHE_DIR = os.path.join(BASE_DIR, "cache")

TRAIN_STEPS = 200

In [6]:
with gin.unlock_config():
    gin.bind_parameter(
        "preprocessors.unsupervised.preprocessors",
        [
            t5.data.preprocessors.select_random_chunk,
            # t5.data.preprocessors.reduce_concat_tokens,
            # t5.data.preprocessors.split_tokens_to_targets_length,
            t5.data.preprocessors.split_tokens_to_inputs_length,
            t5.data.preprocessors.denoise
        ],
    )
    gin.bind_parameter("preprocessors.select_random_chunk.max_length", 65536)
    gin.bind_parameter("preprocessors.denoise.inputs_fn", t5.data.preprocessors.drop_noise_tokens)
    gin.bind_parameter("preprocessors.denoise.noise_density", 0.5)
    gin.bind_parameter("preprocessors.denoise.noise_mask_fn", t5.data.preprocessors.random_prefix_noise_mask)
    gin.bind_parameter("preprocessors.denoise.targets_fn", t5.data.preprocessors.drop_nonnoise_tokens)

In [7]:
vocab = seqio.SentencePieceVocabulary(VOCAB_PATH, t5.data.DEFAULT_EXTRA_IDS)

In [8]:
seqio.TaskRegistry.reset()
register_dev_task(DATA_DIR, vocab)
seqio.TaskRegistry.names()

Registering dev task...


dict_keys(['dev'])

In [9]:
dev_task = seqio.TaskRegistry.get("dev")
dataset = dev_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 128}, use_cached=False)

2021-12-17 14:32:17.310409: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Instructions for updating:
Use `tf.data.Dataset.random(...)`.


Instructions for updating:
Use `tf.data.Dataset.random(...)`.


In [10]:
for ex in tfds.as_numpy(dataset.take(5)):
    print("----")
    print(ex)
    if "inputs" in ex:
        print(f"Inputs: {tf.size(ex['inputs'])}\n'{decode(vocab, ex['inputs'])}'")
    print(f"Targets: {tf.size(ex['targets'])}\n'{decode(vocab, ex['targets'])}'")
    print("----")
    print()

----
{'inputs': array([   14, 25217,  2702,    14, 25217,   877,    14,  2331,  3650,
          14,  2331,  1403,    14, 25217,  3650,    14, 25217,  1403,
         183,     3,    11,    10,    17,     6,    33,   553,    37,
       18667,  8591,    85,  2685,   504,  1298,    38,   549,    37,
        6917,  4231,    51,  5802,     1], dtype=int32), 'targets': array([25217,   360,   388,  3535, 18921,     4,     3,    11,    16,
           4,   251,    30, 25217,  2685,  1276,    68,   276,   186,
        1566,   183,    68,    70,   186, 25217,   208,  3291,  1175,
          26,    15,     4,   249,     5,  1444,   256,    17,   320,
        2376,  2276,     7,    31,     5,   560,    45,   103,     5,
         131,     5,  9414,     5,   588,     5,  2421,    35,     3,
          25,    72,     5,  2331,     5,  9414,     5,   588,     5,
        2421,     7,    15,     8,   466,     5,   171,    13,    17,
           6,    16,     4,   135,     4,   341,     7,     3,    11,
      

2021-12-17 14:32:20.825357: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
