# Requirements

In [2]:
!pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.8/67.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m38.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m413.8/413.8 kB[0m [31m40.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.9/138.9 kB[0m [31m17.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.7/49.7 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m14.2 MB/s[0m eta 

# Import Libraries

In [3]:
import py7zr
import time

import keras_nlp
import keras
import tensorflow as tf
import tensorflow_datasets as tfds

# Loading Dataset

In [4]:
# Download the dataset.
filename = keras.utils.get_file(
    "corpus.7z",
    origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z",
)

# Extract the `.7z` file.
with py7zr.SevenZipFile(filename, mode="r") as z:
    z.extractall(path="/root/tensorflow_datasets/downloads/manual")

# Load data using TFDS.
samsum_ds = tfds.load("samsum", split="train", as_supervised=True)

Downloading data from https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z
[1m2944100/2944100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading and preparing dataset Unknown size (download: Unknown size, generated: 10.71 MiB, total: 10.71 MiB) to /root/tensorflow_datasets/samsum/1.0.0...


Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/14732 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/incomplete.0THOPB_1.0.0/samsum-train.tfrecord*...:   0%|          |…

Generating validation examples...:   0%|          | 0/818 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/incomplete.0THOPB_1.0.0/samsum-validation.tfrecord*...:   0%|      …

Generating test examples...:   0%|          | 0/819 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/incomplete.0THOPB_1.0.0/samsum-test.tfrecord*...:   0%|          | …

Dataset samsum downloaded and prepared to /root/tensorflow_datasets/samsum/1.0.0. Subsequent calls will reuse this data.


In [5]:
for dialogue, summary in samsum_ds:
    print(dialogue.numpy())
    print(summary.numpy())
    break

b"Carter: Hey Alexis, I just wanted to let you know that I had a really nice time with you tonight. \r\nAlexis: Thanks Carter. Yeah, I really enjoyed myself as well. \r\nCarter: If you are up for it, I would really like to see you again soon.\r\nAlexis: Thanks Carter, I'm flattered. But I have a really busy week coming up.\r\nCarter: Yeah, no worries. I totally understand. But if you ever want to go grab dinner again, just let me know. \r\nAlexis: Yeah of course. Thanks again for tonight. \r\nCarter: Sure. Have a great night. "
b'Alexis and Carter met tonight. Carter would like to meet again, but Alexis is busy.'


In [6]:
for dialogue, summary in samsum_ds:
    print(dialogue.numpy())
    print(summary.numpy())
    break

b"Carter: Hey Alexis, I just wanted to let you know that I had a really nice time with you tonight. \r\nAlexis: Thanks Carter. Yeah, I really enjoyed myself as well. \r\nCarter: If you are up for it, I would really like to see you again soon.\r\nAlexis: Thanks Carter, I'm flattered. But I have a really busy week coming up.\r\nCarter: Yeah, no worries. I totally understand. But if you ever want to go grab dinner again, just let me know. \r\nAlexis: Yeah of course. Thanks again for tonight. \r\nCarter: Sure. Have a great night. "
b'Alexis and Carter met tonight. Carter would like to meet again, but Alexis is busy.'


# Define Hyperparameters

In [7]:
BATCH_SIZE = 8
NUM_BATCHES = 600
EPOCHS = 10  # Can be set to a higher value for better results
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 40

In [8]:
train_ds = (
    samsum_ds.map(
        lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
    )
    .batch(BATCH_SIZE)
    .cache()
)
train_ds = train_ds.take(NUM_BATCHES)

# Loading BART Model

In [9]:
preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset(
    "bart_base_en",
    encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
    decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset(
    "bart_base_en", preprocessor=preprocessor
)

bart_lm.summary()

Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/model.safetensors...
Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/model.safetensors.index.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/metadata.json...
100%|██████████| 141/141 [00:00<00:00, 398kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/preprocessor.json...
Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/tokenizer.json...
100%|██████████| 448/448 [00:00<00:00, 311kB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/assets/tokenizer/vocabulary.json...
100%|██████████| 0.99M/0.99M [00:00<00:00, 2.61MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/assets/tokenizer/merges.txt...
100%|█████

# Fine-tune BART

In [10]:
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
    epsilon=1e-6,
    global_clipnorm=1.0,  # Gradient clipping.
)
# Exclude layernorm and bias terms from weight decay.
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])

loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

In [11]:
bart_lm.compile(
    optimizer=optimizer,
    loss=loss,
    weighted_metrics=["accuracy"],
)

In [12]:
bart_lm.fit(train_ds, epochs=EPOCHS)

Epoch 1/10
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m618s[0m 617ms/step - accuracy: 0.5520 - loss: 0.4338
Epoch 2/10
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m373s[0m 617ms/step - accuracy: 0.6367 - loss: 0.3176
Epoch 3/10
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m373s[0m 618ms/step - accuracy: 0.6976 - loss: 0.2482
Epoch 4/10
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m382s[0m 617ms/step - accuracy: 0.7502 - loss: 0.1940
Epoch 5/10
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m373s[0m 617ms/step - accuracy: 0.7974 - loss: 0.1508
Epoch 6/10
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m378s[0m 626ms/step - accuracy: 0.8348 - loss: 0.1194
Epoch 7/10
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m372s[0m 616ms/step - accuracy: 0.8680 - loss: 0.0924
Epoch 8/10
[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m382s[0m 617ms/step - accuracy: 0.8950 - loss: 0.0737
Epoch 9/

<keras.src.callbacks.history.History at 0x7e1e00363100>

# Generate summaries and evaluate them!

In [13]:
def generate_text(model, input_text, max_length=200, print_time_taken=False):
    start = time.time()
    output = model.generate(input_text, max_length=max_length)
    end = time.time()
    print(f"Total Time Elapsed: {end - start:.2f}s")
    return output


# Load the dataset.
val_ds = tfds.load("samsum", split="validation", as_supervised=True)
val_ds = val_ds.take(100)

dialogues = []
ground_truth_summaries = []
for dialogue, summary in val_ds:
    dialogues.append(dialogue.numpy())
    ground_truth_summaries.append(summary.numpy())

# Let's make a dummy call - the first call to XLA generally takes a bit longer.
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)

# Generate summaries.
generated_summaries = generate_text(
    bart_lm,
    val_ds.map(lambda dialogue, _: dialogue).batch(8),
    max_length=MAX_GENERATION_LENGTH,
    print_time_taken=True,
)

Total Time Elapsed: 38.58s
Total Time Elapsed: 92.76s


In [14]:
for dialogue, generated_summary, ground_truth_summary in zip(
    dialogues[:10], generated_summaries[:10], ground_truth_summaries[:10]
):
    print("Dialogue:", dialogue)
    print("Generated Summary:", generated_summary)
    print("Ground Truth Summary:", ground_truth_summary)
    print("=============================")

Dialogue: b'Tony: Is the boss in?\r\nClaire: Not yet.\r\nTony: Could let me know when he comes, please? \r\nClaire: Of course.\r\nTony: Thank you.'
Generated Summary: Tony will let Claire know when his boss comes.
Ground Truth Summary: b"The boss isn't in yet. Claire will let Tony know when he comes."
Dialogue: b"James: What shouldl I get her?\r\nTim: who?\r\nJames: gees Mary my girlfirend\r\nTim: Am I really the person you should be asking?\r\nJames: oh come on it's her birthday on Sat\r\nTim: ask Sandy\r\nTim: I honestly am not the right person to ask this\r\nJames: ugh fine!"
Generated Summary: James is waiting  for Mary's birthday on Saturday or for Tim to ask Sandy.
Ground Truth Summary: b"Mary's birthday is on Saturday. Her boyfriend, James, is looking for gift ideas. Tim suggests that he ask Sandy."
Dialogue: b"Mary: So, how's Israel? Have you been on the beach?\r\nKate: It's so expensive! But they say, it's Tel Aviv... Tomorrow we are going to Jerusalem.\r\nMary: I've heard Isr