# Testing Custom Transformer
https://huggingface.co/docs/datasets/quickstart

In [1]:
#@formatter:off
%load_ext autoreload
%autoreload 2
#@formatter:on

In [2]:
# this need to point to your env with hugging face package installed
!which python

/opt/homebrew/Caskroom/miniforge/base/envs/huggingface/bin/python


In [3]:
import sys
import pathlib
import os
sys.path.append(os.path.join(pathlib.Path('.').parent.resolve(),'..'))
from src import Sampling
from src import SamplingEnums as ENUMS

import tensorflow as tf
from transformers import GPT2Tokenizer, SmallTransformerConfig
from transformers.models.small_transformer.modeling_tf_small_transformer import (
    TF_SMALL_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
    TFSmallTransformerLMHeadModel,
)
from transformers.tf_utils import shape_list
from datasets import load_dataset

In [68]:
train_dataset_raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
test_dataset_raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

Reusing dataset wikitext (/Users/hp/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Reusing dataset wikitext (/Users/hp/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


In [69]:
train_dataset_raw

Dataset({
    features: ['text'],
    num_rows: 36718
})

# Tokenize the dataset

In [90]:
from transformers import DataCollatorWithPadding
BATCH_SIZE = 32
dataset = train_dataset_raw.map(lambda e: tokenizer(e['text'], truncation=True, padding='max_length'), batched=True)
edataset = test_dataset_raw.map(lambda e: tokenizer(e['text'], truncation=True, padding='max_length'), batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer, return_tensors="tf")
train_dataset = dataset.to_tf_dataset(
    columns=['input_ids', 'attention_mask'],
    shuffle=True,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator,
)
test_dataset = dataset.to_tf_dataset(
    columns=['input_ids', 'attention_mask'],
    shuffle=True,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator,
)

Loading cached processed dataset at /Users/hp/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-eb43a12640473c0e.arrow


  0%|          | 0/5 [00:00<?, ?ba/s]

In [79]:
for i in train_dataset.take(1).as_numpy_iterator():
    print(i['input_ids'])
    print(i['attention_mask'])

[[50256 50256 50256 ... 50256 50256 50256]
 [  796   796 45205 ... 50256 50256 50256]
 [50256 50256 50256 ... 50256 50256 50256]
 ...
 [  383 24933   319 ... 50256 50256 50256]
 [50256 50256 50256 ... 50256 50256 50256]
 [ 9415   358 23490 ... 50256 50256 50256]]
[[0 0 0 ... 0 0 0]
 [1 1 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [1 1 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [1 1 1 ... 0 0 0]]


2022-05-16 00:00:07.608248: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz


# Format the Dataset

In [80]:
# CONFIGURATION VARIABLES ( CHANGE THESE - THESE ARE USED IN THE MODEL TESTING FILE)
batch_size = 13
seq_length = 7
is_training = True
use_token_type_ids = True
use_input_mask = True
use_labels = True
use_mc_token_ids = True
vocab_size = 99
hidden_size = 32
num_hidden_layers = 5
num_attention_heads = 4
intermediate_size = 37
hidden_act = "gelu"
hidden_dropout_prob = 0.1
attention_probs_dropout_prob = 0.1
max_position_embeddings = 512
type_vocab_size = 16
type_sequence_label_size = 2
initializer_range = 0.02
num_labels = 3
num_choices = 4
scope = None
bos_token_id = vocab_size - 1
eos_token_id = vocab_size - 1
pad_token_id = vocab_size - 1

In [81]:
config = SmallTransformerConfig(
    vocab_size=vocab_size,
    n_embd=hidden_size,
    n_layer=num_hidden_layers,
    n_head=num_attention_heads,
    # intermediate_size=intermediate_size,
    # hidden_act=hidden_act,
    # hidden_dropout_prob=hidden_dropout_prob,
    # attention_probs_dropout_prob=attention_probs_dropout_prob,
    n_positions=max_position_embeddings,
    # type_vocab_size=type_vocab_size,
    # initializer_range=initializer_range
    bos_token_id=bos_token_id,
    eos_token_id=eos_token_id,
    pad_token_id=pad_token_id,
    return_dict=True,
)

In [None]:
model = TFSmallTransformerLMHeadModel(config=config)


In [103]:
import numpy as np
from datasets import load_metric
from transformers import Seq2SeqTrainingArguments, TFTrainer


In [104]:
training_args = Seq2SeqTrainingArguments(
    output_dir="small_transformer_trainer",
    evaluation_strategy="steps",
    # place_model_on_device=False,
)
metric = load_metric("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)
trainer = TFTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

using `logging_steps` to initialize `eval_steps` to 500
PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
You are instantiating a Trainer but W&B is not installed. To use wandb logging, run `pip install wandb && wandb login` see https://docs.wandb.com/huggingface.
To use comet_ml logging, run `pip/conda install comet_ml` see https://www.comet.ml/docs/python-sdk/huggingface/


In [106]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
model.fit(test_dataset)


ValueError: in user code:

    File "/opt/homebrew/Caskroom/miniforge/base/envs/huggingface/lib/python3.10/site-packages/keras/engine/training.py", line 878, in train_function  *
        return step_function(self, iterator)
    File "/opt/homebrew/Caskroom/miniforge/base/envs/huggingface/lib/python3.10/site-packages/keras/engine/training.py", line 867, in step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "/opt/homebrew/Caskroom/miniforge/base/envs/huggingface/lib/python3.10/site-packages/keras/engine/training.py", line 860, in run_step  **
        outputs = model.train_step(data)
    File "/Users/hp/github/classes/cosc525/Implementation-Of-A-Lightweight-Transformer-And-Analysis-Of-Text-Generation-Sampling-Techniques/submodules/transformers/src/transformers/modeling_tf_utils.py", line 1009, in train_step
        self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    File "/opt/homebrew/Caskroom/miniforge/base/envs/huggingface/lib/python3.10/site-packages/keras/optimizer_v2/optimizer_v2.py", line 532, in minimize
        return self.apply_gradients(grads_and_vars, name=name)
    File "/opt/homebrew/Caskroom/miniforge/base/envs/huggingface/lib/python3.10/site-packages/keras/optimizer_v2/optimizer_v2.py", line 633, in apply_gradients
        grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
    File "/opt/homebrew/Caskroom/miniforge/base/envs/huggingface/lib/python3.10/site-packages/keras/optimizer_v2/utils.py", line 73, in filter_empty_gradients
        raise ValueError(f"No gradients provided for any variable: {variable}. "

    ValueError: No gradients provided for any variable: (['tf_small_transformer_lm_head_model/transformer/wpe/embeddings:0', 'tf_small_transformer_lm_head_model/transformer/wte/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/ln_1/gamma:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/ln_1/beta:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/attn/c_attn/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/attn/c_attn/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/attn/c_proj/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/attn/c_proj/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/ln_2/gamma:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/ln_2/beta:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/mlp/c_fc/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/mlp/c_fc/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/mlp/c_proj/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._0/mlp/c_proj/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/ln_1/gamma:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/ln_1/beta:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/attn/c_attn/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/attn/c_attn/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/attn/c_proj/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/attn/c_proj/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/ln_2/gamma:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/ln_2/beta:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/mlp/c_fc/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/mlp/c_fc/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/mlp/c_proj/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._1/mlp/c_proj/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/ln_1/gamma:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/ln_1/beta:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/attn/c_attn/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/attn/c_attn/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/attn/c_proj/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/attn/c_proj/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/ln_2/gamma:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/ln_2/beta:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/mlp/c_fc/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/mlp/c_fc/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/mlp/c_proj/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._2/mlp/c_proj/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/ln_1/gamma:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/ln_1/beta:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/attn/c_attn/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/attn/c_attn/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/attn/c_proj/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/attn/c_proj/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/ln_2/gamma:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/ln_2/beta:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/mlp/c_fc/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/mlp/c_fc/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/mlp/c_proj/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._3/mlp/c_proj/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/ln_1/gamma:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/ln_1/beta:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/attn/c_attn/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/attn/c_attn/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/attn/c_proj/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/attn/c_proj/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/ln_2/gamma:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/ln_2/beta:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/mlp/c_fc/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/mlp/c_fc/bias:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/mlp/c_proj/weight:0', 'tf_small_transformer_lm_head_model/transformer/h_._4/mlp/c_proj/bias:0', 'tf_small_transformer_lm_head_model/transformer/ln_f/gamma:0', 'tf_small_transformer_lm_head_model/transformer/ln_f/beta:0'],). Provided `grads_and_vars` is ((None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/wpe/embeddings:0' shape=(512, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/wte/weight:0' shape=(99, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/ln_1/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/ln_1/beta:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/attn/c_attn/weight:0' shape=(32, 96) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/attn/c_attn/bias:0' shape=(1, 96) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/attn/c_proj/weight:0' shape=(32, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/attn/c_proj/bias:0' shape=(1, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/ln_2/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/ln_2/beta:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/mlp/c_fc/weight:0' shape=(32, 128) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/mlp/c_fc/bias:0' shape=(1, 128) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/mlp/c_proj/weight:0' shape=(128, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._0/mlp/c_proj/bias:0' shape=(1, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/ln_1/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/ln_1/beta:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/attn/c_attn/weight:0' shape=(32, 96) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/attn/c_attn/bias:0' shape=(1, 96) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/attn/c_proj/weight:0' shape=(32, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/attn/c_proj/bias:0' shape=(1, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/ln_2/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/ln_2/beta:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/mlp/c_fc/weight:0' shape=(32, 128) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/mlp/c_fc/bias:0' shape=(1, 128) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/mlp/c_proj/weight:0' shape=(128, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._1/mlp/c_proj/bias:0' shape=(1, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/ln_1/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/ln_1/beta:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/attn/c_attn/weight:0' shape=(32, 96) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/attn/c_attn/bias:0' shape=(1, 96) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/attn/c_proj/weight:0' shape=(32, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/attn/c_proj/bias:0' shape=(1, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/ln_2/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/ln_2/beta:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/mlp/c_fc/weight:0' shape=(32, 128) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/mlp/c_fc/bias:0' shape=(1, 128) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/mlp/c_proj/weight:0' shape=(128, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._2/mlp/c_proj/bias:0' shape=(1, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/ln_1/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/ln_1/beta:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/attn/c_attn/weight:0' shape=(32, 96) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/attn/c_attn/bias:0' shape=(1, 96) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/attn/c_proj/weight:0' shape=(32, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/attn/c_proj/bias:0' shape=(1, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/ln_2/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/ln_2/beta:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/mlp/c_fc/weight:0' shape=(32, 128) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/mlp/c_fc/bias:0' shape=(1, 128) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/mlp/c_proj/weight:0' shape=(128, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._3/mlp/c_proj/bias:0' shape=(1, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/ln_1/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/ln_1/beta:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/attn/c_attn/weight:0' shape=(32, 96) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/attn/c_attn/bias:0' shape=(1, 96) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/attn/c_proj/weight:0' shape=(32, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/attn/c_proj/bias:0' shape=(1, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/ln_2/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/ln_2/beta:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/mlp/c_fc/weight:0' shape=(32, 128) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/mlp/c_fc/bias:0' shape=(1, 128) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/mlp/c_proj/weight:0' shape=(128, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/h_._4/mlp/c_proj/bias:0' shape=(1, 32) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/ln_f/gamma:0' shape=(32,) dtype=float32>), (None, <tf.Variable 'tf_small_transformer_lm_head_model/transformer/ln_f/beta:0' shape=(32,) dtype=float32>)).


TypeError: 'PrefetchDataset' object is not subscriptable

----------------------------------------------------------------------------------------------------
0: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my dog.

I'm not sure if I'll
