<a href="https://colab.research.google.com/github/namratakatti10/COMP8240/blob/main/Replication_dataset_t5_trivia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


<a href="https://colab.research.google.com/github/google-research/text-to-text-transfer-transformer/blob/main/notebooks/t5-trivia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Set Up

In [14]:
print("Installing dependencies...")
%tensorflow_version 2.x
!pip install -q t5

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5
import t5.models
import seqio

BASE_DIR = "gs://project-comp8240" #@param { type: "string" }
if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")
ON_CLOUD = True


if ON_CLOUD:
  print("Setting up GCS access...")
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "v2-8"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tf.enable_eager_execution()
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

Installing dependencies...
Setting up GCS access...
Running on TPU: grpc://10.20.162.234:8470


ValueError: ignored

Next, we define a function to load the TSV data as a `tf.data.Dataset` in TensorFlow.

In [15]:
ds = tfds.load(
    "glue/sst2",
    data_dir=DATA_DIR,
    # Download data locally for preprocessing to avoid using GCS space.
    download_and_prepare_kwargs={"download_dir": "./glue"})
print("A few raw validation examples...")
for ex in tfds.as_numpy(ds["validation"].take(2)):
  print(ex)

INFO:absl:Load dataset info from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Reusing dataset glue (gs://project-comp8240/data/glue/sst2/2.0.0)
INFO:absl:Constructing tf.data.Dataset glue for split None, from gs://project-comp8240/data/glue/sst2/2.0.0


A few raw validation examples...
{'idx': 1, 'label': 0, 'sentence': b'oh no so sorry about your pets '}
{'idx': 2, 'label': 1, 'sentence': b'you should i love prison break i only have till ep 5 '}


In [None]:
# tfds.as_dataframe(ds["validation"])

In [16]:
# for b in tfds.text.glue.Glue.builder_configs.values():
#   print(b)

def get_config(name):
  for b in tfds.text.glue.Glue.builder_configs.values():
    if b.name == name:
      return b

config = get_config('sst2')

In [17]:
import seqio
import t5.data
from t5.data import postprocessors
from t5.data import preprocessors
from t5.data.glue_utils import get_glue_metric
from t5.data.glue_utils import get_glue_postprocess_fn
from t5.data.glue_utils import get_glue_text_preprocessor
from t5.data.glue_utils import get_super_glue_metric
from t5.evaluation import metrics
import tensorflow_datasets as tfds

DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(
        vocabulary=t5.data.get_default_vocabulary(), add_eos=True,
        required=False),
    "targets": seqio.Feature(
        vocabulary=t5.data.get_default_vocabulary(), add_eos=True)
}


seqio.TaskRegistry.add(
  "all_mix3",
  source=seqio.TfdsDataSource(
      tfds_name="glue/sst2:2.0.0",
      tfds_data_dir=DATA_DIR),
  preprocessors=[
  get_glue_text_preprocessor(config),
  seqio.preprocessors.tokenize,
  seqio.CacheDatasetPlaceholder(),
  seqio.preprocessors.append_eos_after_trim,
      ],
metric_fns=get_glue_metric(config.name),
output_features=DEFAULT_OUTPUT_FEATURES,
postprocess_fn=get_glue_postprocess_fn(config))


<seqio.dataset_providers.Task at 0x7f2b02e3eb50>

## Define Model

In [9]:
import os


MODEL_SIZE = "base" #@param["small", "base", "large", "3B", "11B"]
# Public GCS path for T5 pre-trained model checkpoints
BASE_PRETRAINED_DIR = "gs://t5-data/pretrained_models"
PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)
MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE)

if ON_CLOUD and MODEL_SIZE == "3B":
  tf.logging.warning(
      "The `3B` model is too large to use with the 5GB GCS free tier. "
      "Make sure you have at least 25GB on GCS before continuing."
  )
elif ON_CLOUD and MODEL_SIZE == "11B":
  raise ValueError(
      "The `11B` parameter is too large to fine-tune on the `v2-8` TPU "
      "provided by Colab. Please comment out this Error if you're running "
      "on a larger TPU."
  )

# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)
# The models from our paper are based on the Mesh Tensorflow Transformer.
model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 128, "targets": 32},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)

Before we continue, let's load a [TensorBoard](https://www.tensorflow.org/tensorboard) visualizer so that we can keep monitor our progress. The page should automatically update as fine-tuning and evaluation proceed.

In [None]:
if ON_CLOUD:
  %reload_ext tensorboard
%tensorboard --logdir="$MODEL_DIR" --port=0

## Fine-tune

We are now ready to fine-tune our model. This will take a while (~2 hours with default settings), so please be patient! The larger the model and more `FINETUNE_STEPS` you use, the longer it will take.

Don't worry, you can always come back later and increase the number of steps, and it will automatically pick up where you left off.

In [10]:
FINETUNE_STEPS =  1000#@param {type: "integer"}

model.finetune(
    mixture_or_task_name="all_mix",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)

INFO:root:system_path_file_exists:gs://t5-data/pretrained_models/base/operative_config.gin
ERROR:root:Path not found: gs://t5-data/pretrained_models/base/operative_config.gin
INFO:root:Skipping import of unknown module `t5.data.sentencepiece_vocabulary` (skip_unknown=True).


INFO:tensorflow:Using config: {'_model_dir': 'gs://project-comp8240/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.20.162.234:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.20.162.234:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.20.162.234:8470', '_evaluation_master': 'grpc://10.20.162.234:8470', '_is

## Evaluate

We now evaluate on the validation sets of the tasks in our mixture. Accuracy results will be logged and added to the TensorBoard above.

In [19]:
# Use a larger batch size for evaluation, which requires less memory.
model.batch_size = train_batch_size * 4
model.eval(
    mixture_or_task_name="all_mix3",
    checkpoint_steps="all"
)

INFO:root:system_path_file_exists:gs://project-comp8240/models/base/operative_config.gin
ERROR:root:Path not found: gs://project-comp8240/models/base/operative_config.gin
INFO:absl:Adding task 'all_mix3' with predict metric_fn(s).
INFO:absl:Load dataset info from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Reusing dataset glue (gs://project-comp8240/data/glue/sst2/2.0.0)
INFO:absl:Constructing tf.data.Dataset glue for split validation, from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Automatically caching small dataset in memory: 'all_mix3:validation'
INFO:absl:Skipping packing/padding for 'all_mix3' since sequence length is None.
INFO:absl:Setting sequence lengths to {'inputs': 54, 'targets': 2}
INFO:absl:Evaluating checkpoint step: 1000900


INFO:tensorflow:Using config: {'_model_dir': 'gs://project-comp8240/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.20.162.234:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.20.162.234:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.20.162.234:8470', '_evaluation_master': 'grpc://10.20.162.234:8470', '_is

INFO:absl:Load dataset info from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Reusing dataset glue (gs://project-comp8240/data/glue/sst2/2.0.0)
INFO:absl:Constructing tf.data.Dataset glue for split validation, from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Automatically caching small dataset in memory: 'all_mix3:validation'
INFO:absl:Padding 'all_mix3' with sequence lengths: {'inputs': 54, 'targets': 2}


INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0 0]
  [0 0 0 1]
  [1 0 0 0]
  [1 0 0 1]
  [0 1 0 0]
  [0 1 0 1]
  [1 1 0 0]
  [1 1 0 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0 0]]

 [[0 0 0 1]]

 [[1 0 0 0]]

 [[1 0 0 1]]

 [[0 1 0 0]]

 [[0 1 0 1]]

 [[1 1 0 0]]

 [[1 1 0 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[4, 2] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[2] physical_shape=[1, 1, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1)]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[batch=4, model=2] LayoutRules{('heads', 'model'), ('batch', 'batch'), ('vocab', 'model'),

INFO:absl:eval/all_mix3/accuracy at step 1000900: 52.920
INFO:absl:Evaluating checkpoint step: 1000900


INFO:tensorflow:Using config: {'_model_dir': 'gs://project-comp8240/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.20.162.234:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.20.162.234:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.20.162.234:8470', '_evaluation_master': 'grpc://10.20.162.234:8470', '_is

INFO:absl:Load dataset info from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Reusing dataset glue (gs://project-comp8240/data/glue/sst2/2.0.0)
INFO:absl:Constructing tf.data.Dataset glue for split validation, from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Automatically caching small dataset in memory: 'all_mix3:validation'
INFO:absl:Padding 'all_mix3' with sequence lengths: {'inputs': 54, 'targets': 2}


INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0 0]
  [0 0 0 1]
  [1 0 0 0]
  [1 0 0 1]
  [0 1 0 0]
  [0 1 0 1]
  [1 1 0 0]
  [1 1 0 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0 0]]

 [[0 0 0 1]]

 [[1 0 0 0]]

 [[1 0 0 1]]

 [[0 1 0 0]]

 [[0 1 0 1]]

 [[1 1 0 0]]

 [[1 1 0 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[4, 2] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[2] physical_shape=[1, 1, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1)]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[batch=4, model=2] LayoutRules{('heads', 'model'), ('batch', 'batch'), ('vocab', 'model'),

INFO:absl:eval/all_mix3/accuracy at step 1000900: 52.920
INFO:absl:Evaluating checkpoint step: 1000900


INFO:tensorflow:Using config: {'_model_dir': 'gs://project-comp8240/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.20.162.234:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.20.162.234:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.20.162.234:8470', '_evaluation_master': 'grpc://10.20.162.234:8470', '_is

INFO:absl:Load dataset info from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Reusing dataset glue (gs://project-comp8240/data/glue/sst2/2.0.0)
INFO:absl:Constructing tf.data.Dataset glue for split validation, from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Automatically caching small dataset in memory: 'all_mix3:validation'
INFO:absl:Padding 'all_mix3' with sequence lengths: {'inputs': 54, 'targets': 2}


INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0 0]
  [0 0 0 1]
  [1 0 0 0]
  [1 0 0 1]
  [0 1 0 0]
  [0 1 0 1]
  [1 1 0 0]
  [1 1 0 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0 0]]

 [[0 0 0 1]]

 [[1 0 0 0]]

 [[1 0 0 1]]

 [[0 1 0 0]]

 [[0 1 0 1]]

 [[1 1 0 0]]

 [[1 1 0 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[4, 2] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[2] physical_shape=[1, 1, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1)]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[batch=4, model=2] LayoutRules{('heads', 'model'), ('batch', 'batch'), ('vocab', 'model'),

INFO:absl:eval/all_mix3/accuracy at step 1000900: 52.920
INFO:absl:Evaluating checkpoint step: 1000900


INFO:tensorflow:Using config: {'_model_dir': 'gs://project-comp8240/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.20.162.234:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.20.162.234:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.20.162.234:8470', '_evaluation_master': 'grpc://10.20.162.234:8470', '_is

INFO:absl:Load dataset info from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Reusing dataset glue (gs://project-comp8240/data/glue/sst2/2.0.0)
INFO:absl:Constructing tf.data.Dataset glue for split validation, from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Automatically caching small dataset in memory: 'all_mix3:validation'
INFO:absl:Padding 'all_mix3' with sequence lengths: {'inputs': 54, 'targets': 2}


INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0 0]
  [0 0 0 1]
  [1 0 0 0]
  [1 0 0 1]
  [0 1 0 0]
  [0 1 0 1]
  [1 1 0 0]
  [1 1 0 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0 0]]

 [[0 0 0 1]]

 [[1 0 0 0]]

 [[1 0 0 1]]

 [[0 1 0 0]]

 [[0 1 0 1]]

 [[1 1 0 0]]

 [[1 1 0 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[4, 2] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[2] physical_shape=[1, 1, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1)]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[batch=4, model=2] LayoutRules{('heads', 'model'), ('batch', 'batch'), ('vocab', 'model'),

INFO:absl:eval/all_mix3/accuracy at step 1000900: 52.920
INFO:absl:Evaluating checkpoint step: 999900


INFO:tensorflow:Using config: {'_model_dir': 'gs://project-comp8240/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.20.162.234:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.20.162.234:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.20.162.234:8470', '_evaluation_master': 'grpc://10.20.162.234:8470', '_is

INFO:absl:Load dataset info from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Reusing dataset glue (gs://project-comp8240/data/glue/sst2/2.0.0)
INFO:absl:Constructing tf.data.Dataset glue for split validation, from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Automatically caching small dataset in memory: 'all_mix3:validation'
INFO:absl:Padding 'all_mix3' with sequence lengths: {'inputs': 54, 'targets': 2}


INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0 0]
  [0 0 0 1]
  [1 0 0 0]
  [1 0 0 1]
  [0 1 0 0]
  [0 1 0 1]
  [1 1 0 0]
  [1 1 0 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0 0]]

 [[0 0 0 1]]

 [[1 0 0 0]]

 [[1 0 0 1]]

 [[0 1 0 0]]

 [[0 1 0 1]]

 [[1 1 0 0]]

 [[1 1 0 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[4, 2] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[2] physical_shape=[1, 1, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1)]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[batch=4, model=2] LayoutRules{('heads', 'model'), ('batch', 'batch'), ('vocab', 'model'),

INFO:absl:eval/all_mix3/accuracy at step 999900: 53.285
INFO:absl:Evaluating checkpoint step: 999900


INFO:tensorflow:Using config: {'_model_dir': 'gs://project-comp8240/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.20.162.234:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.20.162.234:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.20.162.234:8470', '_evaluation_master': 'grpc://10.20.162.234:8470', '_is

INFO:absl:Load dataset info from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Reusing dataset glue (gs://project-comp8240/data/glue/sst2/2.0.0)
INFO:absl:Constructing tf.data.Dataset glue for split validation, from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Automatically caching small dataset in memory: 'all_mix3:validation'
INFO:absl:Padding 'all_mix3' with sequence lengths: {'inputs': 54, 'targets': 2}


INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0 0]
  [0 0 0 1]
  [1 0 0 0]
  [1 0 0 1]
  [0 1 0 0]
  [0 1 0 1]
  [1 1 0 0]
  [1 1 0 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0 0]]

 [[0 0 0 1]]

 [[1 0 0 0]]

 [[1 0 0 1]]

 [[0 1 0 0]]

 [[0 1 0 1]]

 [[1 1 0 0]]

 [[1 1 0 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[4, 2] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[2] physical_shape=[1, 1, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1)]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[batch=4, model=2] LayoutRules{('heads', 'model'), ('batch', 'batch'), ('vocab', 'model'),

INFO:absl:eval/all_mix3/accuracy at step 999900: 53.285
INFO:absl:Evaluating checkpoint step: 999900


INFO:tensorflow:Using config: {'_model_dir': 'gs://project-comp8240/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.20.162.234:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.20.162.234:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.20.162.234:8470', '_evaluation_master': 'grpc://10.20.162.234:8470', '_is

INFO:absl:Load dataset info from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Reusing dataset glue (gs://project-comp8240/data/glue/sst2/2.0.0)
INFO:absl:Constructing tf.data.Dataset glue for split validation, from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Automatically caching small dataset in memory: 'all_mix3:validation'
INFO:absl:Padding 'all_mix3' with sequence lengths: {'inputs': 54, 'targets': 2}


INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0 0]
  [0 0 0 1]
  [1 0 0 0]
  [1 0 0 1]
  [0 1 0 0]
  [0 1 0 1]
  [1 1 0 0]
  [1 1 0 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0 0]]

 [[0 0 0 1]]

 [[1 0 0 0]]

 [[1 0 0 1]]

 [[0 1 0 0]]

 [[0 1 0 1]]

 [[1 1 0 0]]

 [[1 1 0 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[4, 2] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[2] physical_shape=[1, 1, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1)]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[batch=4, model=2] LayoutRules{('heads', 'model'), ('batch', 'batch'), ('vocab', 'model'),

INFO:absl:eval/all_mix3/accuracy at step 999900: 53.285
INFO:absl:Evaluating checkpoint step: 999900


INFO:tensorflow:Using config: {'_model_dir': 'gs://project-comp8240/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': graph_options {
  rewrite_options {
    disable_meta_optimizer: true
  }
}
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.20.162.234:8470"
    }
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.20.162.234:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.20.162.234:8470', '_evaluation_master': 'grpc://10.20.162.234:8470', '_is

INFO:absl:Load dataset info from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Reusing dataset glue (gs://project-comp8240/data/glue/sst2/2.0.0)
INFO:absl:Constructing tf.data.Dataset glue for split validation, from gs://project-comp8240/data/glue/sst2/2.0.0
INFO:absl:Automatically caching small dataset in memory: 'all_mix3:validation'
INFO:absl:Padding 'all_mix3' with sequence lengths: {'inputs': 54, 'targets': 2}


INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0 0]
  [0 0 0 1]
  [1 0 0 0]
  [1 0 0 1]
  [0 1 0 0]
  [0 1 0 1]
  [1 1 0 0]
  [1 1 0 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0 0]]

 [[0 0 0 1]]

 [[1 0 0 0]]

 [[1 0 0 1]]

 [[0 1 0 0]]

 [[0 1 0 1]]

 [[1 1 0 0]]

 [[1 1 0 1]]]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[4, 2] physical_shape=[2, 2, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_shape=[2] physical_shape=[1, 1, 2]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1)]
INFO:tensorflow:auto_logical_to_physical_tpu logical_to_physical = [(0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), (1, 1, 0), (1, 1, 1), (1, 0, 0), (1, 0, 1)]
INFO:tensorflow:SimdMeshImpl init: Shape[batch=4, model=2] LayoutRules{('heads', 'model'), ('batch', 'batch'), ('vocab', 'model'),

INFO:absl:eval/all_mix3/accuracy at step 999900: 53.285


## Export SavedModel

We first export the SavedModel. We set a batch size of 1 for simplicity, but it may be more efficient to use a larger batch size if you want to handle multiple requests per call.

For 3B and 11B models the export will take approximately 30-45 minutes.

In [None]:
export_dir = os.path.join(MODEL_DIR, "export")

model.batch_size = 1 # make one prediction per call
saved_model_path = model.export(
    export_dir,
    checkpoint_step=-1,  # use most recent
    beam_size=1,  # no beam search
    temperature=1.0,  # sample according to predicted distribution
)
print("Model saved to:", saved_model_path)

INFO:root:system_path_file_exists:gs://project-comp8240/models/base/operative_config.gin
ERROR:root:Path not found: gs://project-comp8240/models/base/operative_config.gin


INFO:tensorflow:Using config: {'_model_dir': 'gs://project-comp8240/models/base', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 5000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.79.37.170:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.79.37.170:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.79.37.170:8470', '_evaluation_master': 'grpc://10.79.37.170:8470', '_is_chief': True, '_num_ps_