<a href="https://colab.research.google.com/github/freddejn/summarization-transformer-cnn-dailymail/blob/master/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import os
!pip install -q -U tensor2tensor
!pip install -q -U tensorflow
#!pip install -q tensor2tensor==1.13.2 --force-reinstall

from google.colab import auth
import tensorflow as tf
from tensor2tensor.utils import registry

PROJECT_ID = 'transformer-233711'
!gcloud config set project {PROJECT_ID}
auth.authenticate_user()
!mkdir user_dir

# Main program

In [0]:
%%writefile user_dir/registrations.py
from tensor2tensor.utils import registry
from tensor2tensor.models import transformer
from tensor2tensor.models import lstm
from tensor2tensor.data_generators import problem

# transformer_base_v1
# Used for machine translation in original paper.
# Does not learn for summarization task due to layer_postprocess_sequence
# layer_preprocess_sequence set to dan and none respectively.
@registry.register_hparams
def transformer_base_v1_extra():
    hparams = transformer.transformer_base_v1()
    hparams.batch_size = 4096
    hparams.max_length = 2048
    hparams.max_input_seq_length = 512
    hparams.max_target_seq_length = 128
    return hparams

# transformer_base_v3
# Default for the transformer base model.
# Learns on summarization task but not as fast as transformer_tpu.
@registry.register_hparams
def transformer_base_v3_extra():
    hparams = transformer.transformer_base_v3()
    hparams.batch_size = 4096
    hparams.max_length = 2048      # Use 512 for hparam-testing
    hparams.max_input_seq_length = 512
    hparams.max_target_seq_length = 128
    return hparams


# transformer_tpu
# Optimized for running on tpu.
# Trains fastest of the hyperparameters tested.
@registry.register_hparams
def transformer_tpu_extra():
    hparams = transformer.transformer_tpu()
    hparams.batch_size = 4096
    hparams.max_length = 2048
    hparams.max_input_seq_length = 512
    hparams.max_target_seq_length = 128
    return hparams


# transformer_prepend
# Results in good ROUGE-scores but copies input and uses it as output.
@registry.register_hparams
def transformer_prepend_extra():
    hparams = transformer.transformer_prepend()
    hparams_batch_size=4096
    hparams.max_length = 2048
    return hparams

# lstm_bahdanau_attention
# Uses more memory than luong_attention
# 9GB with batch size 4096 and max_length 1024
# 8+ GB with batch size 1024 and max_length 1024
@registry.register_hparams
def lstm_bahdanau_extra():
    hparams = lstm.lstm_bahdanau_attention() # Uses 9GB memory for (4096, 1024)
    hparams.batch_size = 4096                # Uses 8+ memory for (1024, 1024)
    hparams.max_length = 2048
    hparams.hidden_size = 128
    hparams.max_input_seq_length = 512
    hparams.max_target_seq_length = 128
    return hparams

# lstm_luong_attention
# More memory efficient than bahdanau_attention with no apparent difference in 
# ROUGE-score.
# Works best with bidirectional encoder and hiddn size 256
@registry.register_hparams
def lstm_luong_extra():
    hparams = lstm.lstm_luong_attention()
    hparams.batch_size = 4096
    hparams.max_length = 2048
    hparams.hidden_size=256
    hparams.max_input_seq_length = 512
    hparams.max_target_seq_length = 128
    return hparams

@registry.register_ranged_hparams
def ranged_lstm(hparams):
    hparams.set_float("dropout", 0.2, 0.4)

In [0]:
%%writefile user_dir/__init__.py
from . import registrations as reg

In [0]:
from user_dir import registrations as reg
TPU_WORKER = 'grpc://' + os.environ['COLAB_TPU_ADDR']

# Change MODEL and OUTPUT_DIR
run = reg.transformer_tpu_extra
HPARAMS_SET = run.__name__
BATCH_SIZE = run().batch_size
MAX_LENGTH = run().batch_size
MAX_INPUT_SEQ_LENGTH = run().max_input_seq_length
MAX_TARGET_SEQ_LENGTH = run().max_target_seq_length
OUTPUT_DIR = f'gs://tensor2tensor-test-bucket/{HPARAMS_SET}-b{BATCH_SIZE}-ml{MAX_LENGTH}-mi{MAX_INPUT_SEQ_LENGTH}-mt{MAX_TARGET_SEQ_LENGTH}'
MODEL = 'transformer'


!t2t-trainer \
  --data_dir='gs://tensor2tensor-test-bucket/data'\
  --output_dir=$OUTPUT_DIR \
  --t2t_usr_dir='/content/user_dir' \
  --problem='summarize_cnn_dailymail32k' \
  --model=$MODEL \
  --hparams_set=$HPARAMS_SET \
  --train_steps=1000000 \
  --eval_steps=10 \
  --local_eval_frequency=10000 \
  --use_tpu \
  --keep_checkpoint_max=300 \
  --cloud_tpu_name=$TPU_WORKER \