In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# !pip install sentencepiece
!pip install transformers==4.17 --quiet

In [None]:
from transformers import (
    T5Tokenizer,
    TFT5ForConditionalGeneration,
    PegasusTokenizer,
    TFPegasusForConditionalGeneration,
    logging,
    T5ForConditionalGeneration
)
logging.set_verbosity_error()

import pickle

drive = '/content/drive/MyDrive'

In [None]:
with open(f'{drive}/train_test_split.pkl', 'rb') as file:
    X_train, X_val, X_test, y_train, y_val, y_test = pickle.load(file)

In [None]:
def generate_summaries(model, tokenizer, X, y, batch_size, max_length=1024, max_length_output=150, min_length=0, num_beams=4):
  all_data = []
  # Use subset of the data
  for i in range(0, len(X), batch_size):
      batch = X[i:i+batch_size]

      # Keep the labels synchronized
      batch_labels = y[i:i+batch_size]

      # Generate Inputs. Keep max_length to 1024, since the input text is large.
      inputs = tokenizer(batch, max_length=max_length, padding=True, truncation=True, return_tensors="tf")

      # Generate Summary IDs
      summary_ids = model.generate(
          inputs['input_ids'],
          attention_mask=inputs['attention_mask'],
          max_length=max_length,
          min_length=min_length,
          max_new_tokens=1024,
          num_beams=num_beams,
          early_stopping=True
      )

      # Decode summaries
      summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

      # Store inputs, summary IDs, decoded summaries, and original labels in a synchronized manner
      for input_text, input_ids, attention_mask, summary_id, summary, label in zip(batch, inputs['input_ids'], inputs['attention_mask'], summary_ids, summaries, batch_labels):
          all_data.append({
              "input_text": input_text,
              "input_ids": input_ids.numpy().tolist(),
              "attention_mask": attention_mask.numpy().tolist(),
              "summary_ids": summary_id.numpy().tolist(),
              "summaries": summary,
              "label": label
          })

  return all_data

In [None]:
# For full models only, since they were built with a different tensorflow version.
!pip install tensorflow==2.13.1

In [None]:
# Generate Summary IDs
model_type = 't5'
# model_type = 'pegasus'

# model_version = "base"
# model_version = '10k'
# model_version = '100k'
model_version = 'full'
# model_version = '20Epochs'

model = None
tokenizer = None

if model_type == "t5":
  t5_model_name = 't5-base'
  tokenizer = T5Tokenizer.from_pretrained(t5_model_name)

  if model_version == "10k":
    model = TFT5ForConditionalGeneration.from_pretrained(f'{drive}/Training_10k/t5_fine_tuned_10k')
  elif model_version == "100k":
    model = TFT5ForConditionalGeneration.from_pretrained(f'{drive}/Training_100k/t5_fine_tuned_100k')
  elif model_version == "base":
    model = TFT5ForConditionalGeneration.from_pretrained(t5_model_name)
  elif model_version == "full":
    model = TFT5ForConditionalGeneration.from_pretrained(f'{drive}/Training_full/t5_fine_tuned')
  elif model_version == "20Epochs":
    model = TFT5ForConditionalGeneration.from_pretrained(f'{drive}/Training_20Epochs/t5_fine_tuned_20Epochs')

elif model_type == "pegasus":
  pegasus_model_name = 'google/pegasus-xsum'
  tokenizer = PegasusTokenizer.from_pretrained(pegasus_model_name)
  if model_version == "10k":
    model = TFPegasusForConditionalGeneration.from_pretrained(f'{drive}/Training_10k/pegasus_fine_tuned_10k')
  elif model_version == "100k":
    model = TFPegasusForConditionalGeneration.from_pretrained(f'{drive}/Training_100k/pegasus_fine_tuned_100k')
  elif model_version == "base":
    model = TFPegasusForConditionalGeneration.from_pretrained(pegasus_model_name)
  elif model_version == "full":
    model = TFPegasusForConditionalGeneration.from_pretrained(f'{drive}/Training_full/pegasus_fine_tuned')
  elif model_version == "20Epochs":
    model = TFPegasusForConditionalGeneration.from_pretrained(f'{drive}/Training_20Epochs/pegasus_fine_tuned_20Epochs')

print(f"Tokenizer = {tokenizer}")
print(f"Model = {model}")
print(f"Model Version = {model_version}")
print(f"Model Type = {model_type}")

In [None]:
batch_size = 16
subset_size = 500

inputs = X_test[:subset_size]

# Add prefix for T5 models.
if model_type == "t5":
  prefix = "summarize: "
  inputs = [f"{prefix}{content}" for content in inputs]

summaries = generate_summaries(model, tokenizer, inputs, y_test, batch_size)
filename = f'{drive}/Summaries/{model_type}_fine_tuned_{model_version}_summaries_{subset_size}.pkl'

if model_version == 'base':
  filename = f'{drive}/Summaries/{model_type}_{model_version}_summaries_{subset_size}_att.pkl'

with open(f'{filename}', 'wb') as f:
    pickle.dump(summaries, f)