<a href="https://colab.research.google.com/github/freddejn/summarization-transformer-cnn-dailymail/blob/master/generate_targets_and_inputs.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import warnings
!pip install -q -U tensor2tensor
import tensorflow as tf
from tensor2tensor.utils import trainer_lib, registry
from tensor2tensor import problems
from tensor2tensor.data_generators import problem
import numpy as np
import datetime as dt
from google.colab import auth
import os
tfe = tf.contrib.eager
tfe.enable_eager_execution()

auth.authenticate_user()
Modes = tf.estimator.ModeKeys

PROJECT_ID = 'transformer-233711'
!gcloud config set project {PROJECT_ID}
BUCKET = 'tensor2tensor-test-bucket'
DATA_DIR = f'gs://{BUCKET}/data'
PROBLEM_NAME = 'summarize_cnn_dailymail32k'

In [0]:
%%time
# Generates input and targets for later evaluation
import re
DATASET = 'eval'        # Select to generate from test or eval set.
MAX_LENGTH = 512        # Input will be truncated to this length.
MAX_TARGET_LENGTH = 128 # Targets longer than this in sub-words will be dropped.
MAX_TESTS = 10          # How many samples to generate.
trunk = True            # Select if to truncate or drop.
all_inputs = []
all_targets = []

# Picks random samples from array, uses num as seed
def get_random_samples(arr, num):
    np.random.seed(num)
    return np.random.choice(arr, num, replace=False)

# Decoding input of integers to text for the t2t-decoder
def decode(integers, encoders):
    integers = list(np.squeeze(integers))
    if 1 in integers:
        integers = integers[:integers.index(1)]
    return encoders["inputs"].decode(np.squeeze(integers))

# Save array to file, line by line
def save_arr(arr, filename):
    with open(filename, 'w') as file:
        for txt in arr:
            file.write(f'{txt}\n')
            
# Will generate examples one by one
summarize_problem = problems.problem(PROBLEM_NAME)
encoders = summarize_problem.feature_encoders(DATA_DIR)
if DATASET == 'test':
    eval_data_dir = f'gs://{BUCKET}/data_for_test'
    predict_data_examples = tfe.Iterator(summarize_problem.dataset(problem.DatasetSplit.TEST, DATA_DIR))
if DATASET == 'eval':
    eval_data_dir = f'gs://{BUCKET}/data_for_evaluation'
    predict_data_examples = tfe.Iterator(summarize_problem.dataset(problem.DatasetSplit.EVAL, DATA_DIR))

count = 0
if trunk:
    for _, data in enumerate(predict_data_examples):
        if  data['targets'].shape[0] <= MAX_TARGET_LENGTH:
            len_inputs = data['inputs'].shape[0]
            data['inputs'] = data['inputs'][:MAX_LENGTH]
            count+=1
            all_inputs.append(decode(data['inputs'], encoders))
            all_targets.append(decode(data['targets'], encoders))
            print(f'{count}.({all_inputs[count-1][0:20]}... ,{all_targets[count-1][0:10]}...)', end="\t")
else:
    for _, data in enumerate(predict_data_examples):
    # Here inputs could be manually filtered on max length
        if data['inputs'].shape[0] <= MAX_LENGTH and data['targets'].shape[0] <= MAX_TARGET_LENGTH :
            count+=1
            all_inputs.append(decode(data['inputs'], encoders))
            all_targets.append(decode(data['targets'], encoders))
            print(f'{count}.({all_inputs[count-1][0:20]}... ,{all_targets[count-1][0:10]}...)', end="\t")

        
# Randomize samples
all_inputs = get_random_samples(all_inputs, MAX_TESTS)
all_targets = get_random_samples(all_targets, MAX_TESTS)

# Print out how many samples generated
print(f'\nNum Inputs: {len(all_inputs)}\nNum Targets: {len(all_targets)}')

# Store locally before copying to gcs bucket
save_arr(all_inputs, 'all_inputs.txt')
save_arr(all_targets, 'all_targets.txt')

# Copy targets and inputs to gs bucket, label by max_length and max_test, one example per line
if trunk:
    !gsutil cp 'all_inputs.txt' '{eval_data_dir}/trunk_{MAX_LENGTH}_num_{MAX_TESTS}_inputs.txt'
    !gsutil cp 'all_targets.txt' '{eval_data_dir}/trunk_{MAX_LENGTH}_num_{MAX_TESTS}_targets.txt'
    
else:
    !gsutil cp 'all_inputs.txt' '{eval_data_dir}/len_{MAX_LENGTH}_num_{MAX_TESTS}_inputs.txt'
    !gsutil cp 'all_targets.txt' '{eval_data_dir}/len_{MAX_LENGTH}_num_{MAX_TESTS}_targets.txt'