# Summarisation notebook

This is the notebook for summarisation

(More descriptions)

### 1. Import packages

In [None]:
import os
os.getcwd()
import random
import numpy as np

import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import TestTubeLogger
from pytorch_lightning.callbacks import ModelCheckpoint

from src.summarisation_lightning_model import LmForSummarisation

### 2. Define parameters

In [None]:
args ={
    'max_input_len': 512,  # Maximum number of tokens in the source documents, 512 for BART-base, 8192 for LED-base
    'max_output_len': 256,  # Maximum number of tokens in the summary
    'save_dir': '../models/summarisation_bart',  # Path to save the model and logs, 'models/summarisation_bart' for BART, 'models/summarisation_led' for LED
    'tokenizer': '../pretrained_lms/facebook-bart-base',  # Pretrained tokenizer
    'model': '../pretrained_lms/facebook-bart-base',  # Pretrained model (facebook-bart-base for BART, allenai-led-base-16384)
    'label_smoothing': 0.0, # Label smoothing (not required)
    'epochs': 1,  # Number of epochs during training
    'batch_size': 4,  # Batch size (1 for LED, 4 for BART)
    'grad_accum': 1,  # Gradient accumulation (4 for LED for effective batch size, 1 for BART to keep consistent)
    'lr': 0.00003,  # Training learning rate
    'warmup': 1000,  # Number of warmup steps
    'gpus': 1,  # Number of gpus. 0 for CPU
    'precision': 32,  # Double precision (64), full precision (32) 
                      # or half precision (16). Can be used on CPU, GPU or TPUs.
    'cache_dir': '../datasets/cache/' # Path to dataset cache where dataset is converted
}

### 3. Initialize Lightning module

In [None]:
# Initialize with a seed
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
    
# dataset size. Needed to compute number of steps for the lr scheduler
args['dataset_size'] = 50594. # manually entered

# Define PyTorch Lightning model
model = LmForSummarisation(args)
# Include datasets
model.hf_datasets = nlp.load_dataset('multi_news', cache_dir=args['cache_dir'])

# Define logger
logger = TestTubeLogger(
    save_dir=args['save_dir'],
    name='training',
    version=0  # always use version=0
)

# Define checkpoint saver
checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(args['save_dir'], "training", "checkpoints"),  # Dir path
    save_top_k=1,  # Maximum number of checkpoints to be saved
    verbose=True,  # Verbose
    monitor='avg_val_loss',  # Checkpointing measurement (BLEU validation)
    mode='min',      # Maximize measurement over the validation
    period=1         # Save every epoch
)

print(args)


# Define lightning trainer
trainer = pl.Trainer(gpus=args['gpus'], distributed_backend='dp' if torch.cuda.is_available() else None,
                     track_grad_norm=-1,
                     max_epochs=args['epochs'],
                     max_steps=None,
                     replace_sampler_ddp=False,
                     accumulate_grad_batches=args['grad_accum'],
                     gradient_clip_val=1.0,  # Max grad_norm
                     val_check_interval=1.0,  # Num steps between validation
                     num_sanity_val_steps=2,  # Validation steps for sanity check
                     check_val_every_n_epoch=1,  # Check validation every N
                     logger=logger,
                     callbacks=checkpoint_callback,
                     progress_bar_refresh_rate=10,  # Progress bar for printing (updates every N)
                     precision=args['precision'],
                     amp_backend='native', amp_level='O2',
                     )

#### 4. Train model

In [None]:
# Train model
trainer.fit(model)

### 5. Test model

In [None]:
# Test model
trainer.test(model)

### 6. Inference

In [None]:
# Define PyTorch Lightning model
model = LmForSummarisation.load_from_checkpoint('../models/<path_to_model>.ckpt')

document = '<ADD A DOCUMENT TO SUMMARISE>'
summary = model.summarise_example(sentence, args['max_input_len'], args['max_output_len'])
summary