In [None]:
import tensorflow.compat.v2 as tf
tf.enable_v2_behavior()
import multiprocessing.dummy as multiprocessing
from multiprocessing import Lock
from tqdm import tqdm
from nltk import pos_tag
import json


import numpy as np

import tensorflow as tf
import tensorflow_hub as hub

In [None]:
rng = np.random.default_rng()

llm_model = None
# TO-DO: Load the model using the instructions in the link below.
# Instructions for using PaLM 2 Bison Chat: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/chat-bison?inv=1&invt=AbytLg

num_replicas = 1 # Parallelization recommended 
batch_size_per_replica = 16
num_threads = batch_size_per_replica * num_replicas
MAX_RETRIES = 4

pool = multiprocessing.Pool(num_threads)

data_dir = '/data/preprocessed_wikicatsum' # Add the correct path (eg: path of the file '{data_dir}/{domain}/fake_sentence_replacement_{split}.tgt.json')

In [None]:
PROMPT_SENTENCE_REPLACEMENT = '\n Given this paragraph, predict the next sentence and only print the predicted sentence. Print only the predicted sentence.'
def prompt_sentence_replacement(left_context):
  return 'You are provided with the following paragraph: \n' + left_context + PROMPT_SENTENCE_REPLACEMENT

PROMPT_WORD_REPLACEMENT = """\n In this paragraph, certain words have been masked. They are indicated as '[MASK]'.
Please predict the masked words to complete the paragraph. Each '[MASK]' should be replaced with a single word.
Print the paragraph after replacing all the "[MASK]" with their respective predicted words. No other text should be printed. Print only the completed paragraph.
Make sure no '[MASK]' remains in the output."""
def prompt_word_replacement(masked_summary):
  return 'You are provided with the following paragraph: \n' + masked_summary + PROMPT_WORD_REPLACEMENT

In [None]:
# TO-DO: Add function to generate text after passing the prompt to the llm_model

def generate_word_replacement_summary(reference_summary_sentence_split):
  # Reference summary is already split into sentences
  if isinstance(reference_summary_sentence_split, str):
    reference_summary_sentence_split = reference_summary_sentence_split.split(' <SNT> ')
  reference_summary = ' <SNT> '.join(reference_summary_sentence_split)
  reference_summary_word_split = reference_summary.split()
  tokens_tag = pos_tag(reference_summary_word_split)
  replace_pos = ['VERB','NOUN', 'PROPN', 'NUM', 'VB', 'VBG', 'VBD', 'VBN', 'VBP', 'VBZ', 'NN', 'NNS', 'NNP', 'NNPS', 'CB']
  token_idx = []
  snt_token_id = []
  for t_id in range(len(tokens_tag)):
    # print(f'({reference_summary_word_split[t_id]}, {tokens_tag[t_id][1]})')
    if reference_summary_word_split[t_id] == '<SNT>':
      snt_token_id.append(t_id)
    elif tokens_tag[t_id][1] in replace_pos:
      token_idx.append(t_id)
  snt_token_id.append(len(reference_summary_word_split))
  if len(token_idx) >= 5:
    num_words_to_replace = rng.integers(1,len(token_idx)//5, endpoint = True)
  else:
    num_words_to_replace = 1
  word_idx_to_replace = np.sort(rng.choice(token_idx, num_words_to_replace, replace = False))
  for word_id in word_idx_to_replace:
    reference_summary_word_split[word_id] = '[MASK]'
  masked_summary = (' '.join(reference_summary_word_split)).split('<SNT>')
  predicted_summary = masked_summary

  # One call to LLM
  sent_idx_replaced = []
  predicted = ' <SNT> '.join(predicted_summary)
  retries = 0
  replaced_word_snt_flag = 0
  while ('[MASK]' in predicted or len(sent_idx_replaced) == 0 or replaced_word_snt_flag == 1) and retries < MAX_RETRIES:
    retries += 1
    sent_idx_replaced = []
    prompt = prompt_word_replacement(predicted)
    predicted, _ = llm_model.Generate(prompt)[0] # Add the function for generation using LLM
    predicted_summary = predicted.split(' <SNT> ')
    if (len(predicted_summary) > len(reference_summary_sentence_split) or len(predicted_summary) < len(reference_summary_sentence_split)):
      replaced_word_snt_flag = 1
    else:
      for p in range(len(predicted_summary)):
        predicted_sentence_word_split = predicted_summary[p].lower().split()
        reference_sentence_word_split = reference_summary_sentence_split[p].lower().split()
        if reference_sentence_word_split != predicted_sentence_word_split:
          sent_idx_replaced.append(p)
  if retries == MAX_RETRIES:
    predicted.replace('[MASK]', '')

  return predicted_summary, list(set(sent_idx_replaced))

# TO-DO: Add function to generate text after passing the prompt to the llm_model

def generate_sentence_replacement_summary(reference_summary):
  if isinstance(reference_summary, str):
    reference_summary = reference_summary.split(' <SNT> ')
  # Reference summary is already split into sentences
  num_sentences = len(reference_summary)
  fake_summary = reference_summary
  sent_idx_replaced = np.array([])
  if num_sentences >= 2:
    num_sentences_replace = rng.integers(1,num_sentences//2, endpoint = True)
    sent_idx_replaced = np.sort(rng.choice(np.arange(1, num_sentences), num_sentences_replace, replace = False))
    for s in sent_idx_replaced:
      left_context = ' '.join(fake_summary[:s])
      prompt = prompt_sentence_replacement(left_context)
      predicted_sentence, _ = llm_model.Generate(prompt)[0] # Add the function for generation using LLM
      fake_summary[s] = predicted_sentence
  return fake_summary, sent_idx_replaced.tolist()

In [None]:
def sample_idx(len_ds, num_split_sample_list):
  ds_idx = np.arange(len_ds)
  pos_sample_idx = np.sort(rng.choice(ds_idx, num_split_sample_list[0], replace = False))
  ds_idx = np.setdiff1d(ds_idx, pos_sample_idx)
  sentence_replacement_idx = np.sort(rng.choice(ds_idx, num_split_sample_list[1], replace = False))
  ds_idx = np.setdiff1d(ds_idx, sentence_replacement_idx)
  word_replacement_idx = np.sort(rng.choice(ds_idx, num_split_sample_list[2], replace = False))
  return pos_sample_idx.tolist(), sentence_replacement_idx.tolist(), word_replacement_idx.tolist()

def chunker(iterable, chunk_size, fill=None):
  return (iterable[pos:pos + chunk_size] for pos in range(0, len(iterable), chunk_size))

In [None]:
import json
from tqdm import tqdm

domains = ['film', 'company', 'animal']
splits = ['test','valid','train']

# num_split_sample = [num_split_pos_sample, num_split_sentence_replacement_sample, num_split_word_replacement_sample]
num_sample = {
  'train':[7500, 5500, 2000],
  'test': [400, 250, 150],
  'valid': [400, 250, 150]
}

for split in splits:
  for domain in domains:
    sentence_promises = []
    word_promises = []
    sentence_summaries = []
    word_summaries = []
    pos_summaries = []
    print('domain:', domain, 'split:', split)
    sentence_output_file = f'{data_dir}/{domain}/fake_sentence_replacement_{split}.tgt.json'
    word_output_file = f'{data_dir}/{domain}/fake_word_replacement_{split}.tgt.json'
    pos_output_file = f'{data_dir}/{domain}/pos_{split}.tgt.json'
    with open(f'{data_dir}/{domain}/{split}.tgt' , 'r') as f, open(sentence_output_file, 'w') as sw, open(word_output_file, 'w') as ww, open(pos_output_file, 'w') as pw:
      f_readlines = f.readlines()
      pos_sample_idx, sentence_replacement_idx, word_replacement_idx = sample_idx(len(f_readlines), num_sample[split])
      pos_indexed_summaries = [f_readlines[i] for i in pos_sample_idx]
      sentence_replacement_indexed_summaries = [f_readlines[i] for i in sentence_replacement_idx]
      word_replacement_indexed_summaries = [f_readlines[i] for i in word_replacement_idx]

      pos_summaries = [{'ds_sample_id': ind, 'summary':reference_summary, 'sent_idx_replaced': []}
                       for ind, reference_summary in zip(pos_sample_idx, pos_indexed_summaries)]

      sent_results = []
      for reference_summaries in tqdm(chunker(sentence_replacement_indexed_summaries, chunk_size=num_threads), total=len(sentence_replacement_indexed_summaries)/num_threads):
        sent_results += pool.map(generate_sentence_replacement_summary, reference_summaries)
      sentence_summaries = [{'ds_sample_id': int(ind), 'summary':' <SNT> '.join(fake_summary), 'sent_idx_replaced': sent_idx_replaced}
                        for ind, (fake_summary, sent_idx_replaced) in zip(sentence_replacement_idx, sent_results)]
      
      word_results = []
      for reference_summaries in tqdm(chunker(word_replacement_indexed_summaries, chunk_size=num_threads), total=len(word_replacement_indexed_summaries)/num_threads):
        word_results += pool.map(generate_word_replacement_summary, reference_summaries)
      word_summaries = [{'ds_sample_id': int(ind), 'summary':' <SNT> '.join(fake_summary), 'sent_idx_replaced': sent_idx_replaced}
                        for ind, (fake_summary, sent_idx_replaced) in zip(word_replacement_idx, word_results)]
      json.dump(sentence_summaries, sw)
      json.dump(word_summaries, ww)
      json.dump(pos_summaries, pw)


In [None]:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _float_array_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

@tf.py_function(Tout=tf.string)
def serialize_example(idx, documents, summary, aggregate_label, instance_labels):
  """
  Creates a tf.train.Example message ready to be written to a file.
  """
  feature = {
      'id': _int64_feature(idx),
      'documents': _bytes_feature(documents),
      'summary': _bytes_feature(summary),
      'aggregate_label': _float_feature(aggregate_label),
      'instance_labels': _float_array_feature(instance_labels),
  }

  # Create a Features message using tf.train.Example.

  example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
  return example_proto.SerializeToString()

In [None]:
# Defining SeqIO Task for Synthetic Dataset
domains = ['animal', 'company', 'film']
splits = ['test','valid','train']

# Writing TFRecord file
tfrecord_filenames = {'train':'train.tfrecord', 'test':'test.tfrecord', 'valid':'valid.tfrecord'}

# Write the `tf.train.Example` observations to the file.
for split in tqdm(splits):
  filename_output = f'{data_dir}/{tfrecord_filenames[split]}'
  with tf.io.TFRecordWriter(filename_output) as writer:
    for domain in tqdm(domains):
      print('domain:', domain, 'split:', split)
      sentence_samples_file = f'{data_dir}/{domain}/fake_sentence_replacement_{split}.tgt.json'
      word_samples_file = f'{data_dir}/{domain}/fake_word_replacement_{split}.tgt.json'
      pos_samples_file = f'{data_dir}/{domain}/pos_{split}.tgt.json'
      document_file = f'{data_dir}/{domain}/{split}.src'
      with open(document_file , 'r') as df:
        df_readlines = df.readlines()
        sf = np.array(json.load(open(sentence_samples_file , 'r')))
        wf = np.array(json.load(open(word_samples_file , 'r')))
        pf = np.array(json.load(open(pos_samples_file , 'r')))
        summary_types_dict = {'sentence': sf, 'word': wf, 'pos': pf}
        for summary_type in summary_types_dict.keys():
          agg_label = 1
          for sample in summary_types_dict[summary_type]:
            if summary_type == 'sentence' or summary_type == 'word':
              agg_label = 0
            else:
              agg_label = 1
            instance_labels = np.ones(len(sample['summary'].split('<SNT>')))
            if len(sample['sent_idx_replaced']) > 0:
              instance_labels[np.array(sample['sent_idx_replaced'])] = 0
            else:
              agg_label = 1
            example = serialize_example(sample['ds_sample_id'], df_readlines[sample['ds_sample_id']], sample['summary'], agg_label, instance_labels)
            writer.write(example.numpy())


# sentence_output_file = f'{data_dir}/{domain}/fake_sentence_replacement_{split}.tgt.json'
# word_output_file = f'{data_dir}/{domain}/fake_word_replacement_{split}.tgt.json'
# pos_output_file = f'{data_dir}/{domain}/pos_{split}.tgt.json'

In [None]:
feature_description = {
  'id': tf.io.FixedLenFeature([], tf.int64, default_value=0),
  'documents': tf.io.FixedLenFeature([], tf.string, default_value=''),
  'summary': tf.io.FixedLenFeature([], tf.string, default_value=''),
  'aggregate_label': tf.io.FixedLenFeature([], tf.float32, default_value=0.0),
  'instance_labels': tf.io.RaggedFeature(tf.float32),
}

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _parse_function(example_proto):
  # Parse the input `tf.train.Example` proto using the dictionary above.
  return tf.io.parse_single_example(example_proto, feature_description)

def process_dataset(dataset_path, encoder_type = 'sentence-t5', aggregation_type = 'min'):
  encoder = None
  if encoder_type == 'sentence-t5':
    sentence_t5_url = "https://tfhub.dev/google/sentence-t5/st5-large/1"
    encoder = hub.KerasLayer(sentence_t5_url)
  # tfrecord_filenames = {'train':f'{dataset_path}/train.tfrecord', 'test':f'{dataset_path}/test.tfrecord', 'valid':f'{dataset_path}/valid.tfrecord'}
  tfrecord_filenames = {'train':f'{dataset_path}/train.tfrecord'}
  raw_dataset = {}
  parsed_dataset = {}
  for filename in tfrecord_filenames.keys():
    raw_dataset[filename] = tf.data.TFRecordDataset(tfrecord_filenames[filename])
    parsed_dataset[filename] = raw_dataset[filename].map(_parse_function)
    num_document_splits = [] # Number of splits for a document
    num_summary_sentences = [] # Number of instances in a bag
    document_splits = [] # Document splits for the entire dataset
    summary_sentences = [] # Summary sentences for the entire dataset
    target_agg_score = []
    target_instance_score = []

    for p in tqdm(parsed_dataset[filename]):
      sample_summary_sentences = p['summary'].numpy().decode("utf-8").split('<SNT>')
      num_summary_sentences.append(len(sample_summary_sentences))
      summary_sentences += sample_summary_sentences

      document = p['documents'].numpy().decode("utf-8").split()
      document_split_sample = [' '.join(document[x:x+1000]) for x in range(0, len(document), 1000)]
      num_document_splits.append(len(document_split_sample))
      document_splits += document_split_sample

      sample_instance_score = list(p['instance_labels'].numpy())
      assert len(sample_instance_score) == len(sample_summary_sentences)
      target_instance_score += sample_instance_score

      sample_agg_score = p['aggregate_label'].numpy()
      target_agg_score.append(sample_agg_score)

    document_splits_embedding = tf.concat([encoder(tf.constant(document_splits[x:x+500]))[0] for x in range(0, len(document_splits), 500)], axis=0)
    summary_sentences_embedding = tf.concat([encoder(tf.constant(summary_sentences[x:x+500]))[0] for x in range(0, len(summary_sentences), 500)], axis=0)
    summary_sentences_embedding = tf.RaggedTensor.from_row_lengths(summary_sentences_embedding, num_summary_sentences)
    document_splits_embedding = tf.RaggedTensor.from_row_lengths(document_splits_embedding, num_document_splits)
    target_instance_score = tf.RaggedTensor.from_row_lengths(target_instance_score, num_summary_sentences)
    ds = tf.data.Dataset.from_tensor_slices((document_splits_embedding, summary_sentences_embedding, num_document_splits, num_summary_sentences, target_agg_score, target_instance_score))
    ds = ds.shuffle(len(ds))
    ds.save(f'{data_dir}/{encoder_type}_{filename}')
    


process_dataset(data_dir, encoder_type = 'sentence-t5', aggregation_type = 'min')