In [115]:
# ! pip install transformers datasets

In [116]:
from datasets import load_dataset

dataset = load_dataset("../cnn_dailymail_suenes/cnn_dailymail_suenes.py")
dataset

No config specified, defaulting to: cnn_dailymail_suenes/1.0.0
Found cached dataset cnn_dailymail_suenes (/home/jobayer/.cache/huggingface/datasets/cnn_dailymail_suenes/1.0.0/1.0.0/066f2a33d28679b436a34f9ac68a5860f9a0a9ebfe420002c59c849e6dde9337)
100%|██████████| 3/3 [00:00<00:00, 269.39it/s]


DatasetDict({
    train: Dataset({
        features: ['text', 'summary', 'score'],
        num_rows: 138502
    })
    validation: Dataset({
        features: ['text', 'summary', 'score'],
        num_rows: 17917
    })
    test: Dataset({
        features: ['text', 'summary', 'score'],
        num_rows: 17722
    })
})

In [117]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-tiny", model_max_length=512)

def tokenize_function(examples):
    return tokenizer(examples["text"], examples["summary"], padding=True, truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

Loading cached processed dataset at /home/jobayer/.cache/huggingface/datasets/cnn_dailymail_suenes/1.0.0/1.0.0/066f2a33d28679b436a34f9ac68a5860f9a0a9ebfe420002c59c849e6dde9337/cache-ed6da946a25afc4e.arrow
Loading cached processed dataset at /home/jobayer/.cache/huggingface/datasets/cnn_dailymail_suenes/1.0.0/1.0.0/066f2a33d28679b436a34f9ac68a5860f9a0a9ebfe420002c59c849e6dde9337/cache-1c6ed230569349fc.arrow
Loading cached processed dataset at /home/jobayer/.cache/huggingface/datasets/cnn_dailymail_suenes/1.0.0/1.0.0/066f2a33d28679b436a34f9ac68a5860f9a0a9ebfe420002c59c849e6dde9337/cache-7911509133935c24.arrow


In [118]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["validation"].shuffle(seed=42).select(range(1000))

Loading cached shuffled indices for dataset at /home/jobayer/.cache/huggingface/datasets/cnn_dailymail_suenes/1.0.0/1.0.0/066f2a33d28679b436a34f9ac68a5860f9a0a9ebfe420002c59c849e6dde9337/cache-88432b43caf41e32.arrow
Loading cached shuffled indices for dataset at /home/jobayer/.cache/huggingface/datasets/cnn_dailymail_suenes/1.0.0/1.0.0/066f2a33d28679b436a34f9ac68a5860f9a0a9ebfe420002c59c849e6dde9337/cache-62b101985c52c6c7.arrow


In [119]:
tf_train_dataset = tokenized_datasets["train"].to_tf_dataset(
# tf_train_dataset = small_train_dataset.to_tf_dataset(
    columns=["input_ids", "token_type_ids", "attention_mask"],
    label_cols=["score"],
    shuffle=True,
    batch_size=8)

tf_validation_dataset = tokenized_datasets["validation"].to_tf_dataset(
# tf_validation_dataset = small_eval_dataset.to_tf_dataset(
    columns=["input_ids", "token_type_ids", "attention_mask"],
    label_cols=["score"],
    shuffle=False,
    batch_size=8)

In [120]:
from transformers import TFAutoModelForSequenceClassification

model = TFAutoModelForSequenceClassification.from_pretrained("prajjwal1/bert-tiny", num_labels=1, from_pt=True)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertForSequenceClassification: ['bert.embeddings.position_ids']
- This IS expected if you are initializing TFBertForSequenceClassification from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForSequenceClassification from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFBertForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [121]:
import tensorflow as tf

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
    loss=tf.keras.losses.MeanSquaredError())

In [122]:
model_checkpoint = "./tf_model_checkpoint"
checkpoint_path = model_checkpoint + "/checkpoints/cp-{epoch}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

callbacks = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq="epoch")

In [123]:
latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
if latest_checkpoint:
    model.load_weights(latest_checkpoint)

In [124]:
model.fit(tf_train_dataset, validation_data=tf_validation_dataset, epochs=3, callbacks=callbacks)

tokenizer.save_pretrained(model_checkpoint)
model.save_pretrained(model_checkpoint)

Epoch 1/3
Epoch 1: saving model to ./tf_model_checkpoint/checkpoints/cp-1.ckpt
Epoch 2/3
Epoch 2: saving model to ./tf_model_checkpoint/checkpoints/cp-2.ckpt
Epoch 3/3
Epoch 3: saving model to ./tf_model_checkpoint/checkpoints/cp-3.ckpt


In [125]:
model.evaluate(tf_validation_dataset)



0.018575478345155716

In [126]:
model.predict(tf_validation_dataset)



TFSequenceClassifierOutput(loss=None, logits=array([[0.95998573],
       [0.89672196],
       [0.92698634],
       ...,
       [0.24205008],
       [0.5289012 ],
       [0.24337558]], dtype=float32), hidden_states=None, attentions=None)