# Text Summarization: 03a. modeling with pretrained model

## Import libraries

In [1]:
import os
from datasets import DatasetDict, load_from_disk
from matplotlib import pyplot as plt
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig
import torch

In [2]:
# set variables
SEED = 1234
curr_path = os.getcwd()
path_raw = os.path.join(curr_path, '../data/raw', 'bill_summary_us')
# dataset_name = 'vgoldberg/longform_article_summarization'
dataset_name = 'dreamproit/bill_summary_us'
path_preprocessed = os.path.join(curr_path, '../data/preprocessed', 'bill_summary_us_single_section')

path_predictions = os.path.join(
    curr_path, '../models/', 
    'flan_t5_base_pretrained', 
    'predictions', 'bill_summary_us_single_section'
)

max_input_length = 1024
max_target_length = 512

num_beams = 5
max_summary_ratio = 1.0

## Helper Functions

In [13]:
def print_example(example: dict) -> None:
    """Print example.
    :param example: dict, example.
    """
    print(' '.join(['=' * 10, 'text', '=' * 10]))
    print(example['text'])
    print(' '.join(['=' * 10, 'summary', '=' * 10]))
    print(example['summary'])    


def remove_missing_data(example: dict, col_list: list[str]) -> bool:
    """Remove missing data.
    :param example: dict, example.
    :param col_list, list[str], list of columns.
    :return: bool, whether valid (True) or missing data (False). 
    """
    flag_value = True
    for col in col_list:
        flag_value = flag_value and example[col] is not None
    return flag_value


def count_words(example: dict, col_list: list[str]) -> dict:    
    """Count words.
    :param example: dict, example.
    :param col_list, list[str], list of columns.
    :return: dict, count dictionary with prefix `count_`.
    """
    count_dict = {}
    for col_name in col_list:
        col_count = '_'.join(['count', col_name])
        word_count = len(example[col_name].split())
        count_dict[col_count] = word_count
    return count_dict


def process_tab(example: dict, col: str) -> dict:    
    """Process tab. Remove triple tabs then replace a tab with space.
    :param example: dict, example.
    :param col: str, column name to process.
    :return, dict, processed example.
    """
    example[col] = example[col].replace('\t\t\t', '').replace('\t', ' ')
    return example


def plot_text_length(df: pd.DataFrame, col_list: list[str]) -> tuple:
    """Plot historgram of text length.
    :param df, pd.DataFrame, input dataframe.
    :param col_list: list[str], a list of column names.
    :return: tuple, (fig, axis_tuple).
        - fig: matplotlib.figure, figure handle.
        - axis_tuple: tuple[matplotlib.axes.Axes], a tuple of axes.
    """
    num_col = len(col_list)
    fig, axis_tuple = plt.subplots(nrows=1, ncols=num_col, figsize=[6.4 * num_col, 4.8])
    for idx, col in enumerate(col_list):
        ax = axis_tuple[idx]
        _ = df.plot(y=col, kind='hist', bins=100, ax=ax)  

    return fig, axis_tuple


def find_outliers_cutoff(df: pd.DataFrame, scale: float = 1.5) -> dict[tuple]:
    """Find outliers cutoff.
    :param df, pd.DataFrame, input dataframe.
    :param scale: float, scale factor for IQR to define whiskter. Default: 1.5.
    :return: dict[tuple], a dictionary of lower and upper cutoff for each column.
    """
    # upper, lower whisker: pd.Series
    q3 = df.quantile(q=0.75)
    q1 = df.quantile(q=0.25)
    iqr = q3 - q1
    upper = q3 + scale * iqr
    lower = q1 - scale * iqr
    
    upper_dict = upper.to_dict()
    lower_dict = lower.to_dict()
    cutoff_dict = {}
    for key in upper_dict:
        cutoff_dict[key] = (lower_dict[key], upper_dict[key])
    return cutoff_dict    


def remove_outliers(example: dict, col_list: list, cutoff_dict: dict[tuple]) -> bool:
    """Remove outliers.
    :param example: dict, example.
    :param col_list, list[str], list of columns.
    :param cutoff_dict: dict[tuple], a dictionary of lower and upper cutoff for each column.
    :return: bool, whether valid (True) or outliers (False).
    """
    flag_val = True
    for col in col_list:
        cutoff_low, cutoff_high = cutoff_dict[col]
        flag_val = flag_val and cutoff_low <= example[col] <= cutoff_high
    return flag_val


def create_prompt(example: dict, col_text: str, prompt_prefix: str) -> dict:
    """Creat prompt with prompt prefix.
    :param example: dict, example.
    :param col_text: str, column for text.
    :param prompt_prefix: str, prompt prefix for instruction.
    :return: dict, example with prompt column.
    """
    example['prompt'] = f'{prompt_prefix}: {example[col_text]}'
    return example


#  build system to tokenize
def tokenize_example(
    example: dict, tokenizer: AutoTokenizer, col_prompt: str, col_summary: str, 
    max_input_length: int, max_target_length: int
) -> dict: 
    """Tokenize example.
    :param example: dict, example.
    :param tokenizer: transformers.AutoTokenizer, tokenizer.
    :param col_prompt: str, column name for prompt.
    :param col_summary: str, column name for summary (target)
    :param max_input_length: int, maximum input token length. 
    :param max_target_length: int, maximum target token length. 
    :return: dict, tokens for input ('input_ids') and target ('label').
    """
    model_input = tokenizer(
        example[col_prompt], 
        max_length=max_input_length, 
        truncation=True,
        # return_tensors='pt',
    )
    label = tokenizer(
        # example[col_summary], 
        text_target=example[col_summary],  
        max_length=max_target_length, 
        truncation=True,
        # return_tensors='pt',
    )
    model_input['label'] = label['input_ids']
    return model_input


def count_tokens(example: dict, col_list: list[str]) -> dict:    
    """Count tokens.
    :param example: dict, example.
    :param col_list: list[str], a list of column names.
    :return: dict, token counts with prefix 'count_tokens_'
    """
    count_dict = {}
    for col_name in col_list:
        col_count = '_'.join(['count_tokens', col_name])
        word_count = len(example[col_name])
        # print(type(example['text']))  # str
        # example[col_count] = len(example['text'].split())
        count_dict[col_count] = word_count
    return count_dict


def preprocessing(ds, col_names, cutoff_dict, col_for_count, tokenizer, max_input_length, max_target_length):   
    ds_preprocessed = ( 
        # 1. remove missing
        ds.filter(remove_missing_data, fn_kwargs={'col_list': col_names})
        
        # 2. filtering
        # select only 1 section
        .filter(lambda example: example['sections_length'] == 1, num_proc=12)
        # remove any zero text length, summary_length
        .filter(lambda example: example['text_length'] > 0, num_proc=12)
        .filter(lambda example: example['summary_length'] > 0, num_proc=12)
        # text_length_outliers, summary_length_outliers
        .filter(remove_outliers, fn_kwargs={'col_list': list(cutoff_dict.keys()), 'cutoff_dict': cutoff_dict}, num_proc=12)

        # count words
        .map(count_words, fn_kwargs={'col_list': col_for_count}, num_proc=12)
        
        # remove \t\t\t, and replace \t with ' '
        .map(process_tab, fn_kwargs={'col': 'text'}, num_proc=12)
        
        # create prompt
        .map(create_prompt, fn_kwargs={'col_text': 'text', 'prompt_prefix': 'summarize'}, num_proc=12)
        
        # tokenize: need to specify max_input_length, when number of tokens is greater than default limit (512)
        .map(
            tokenize_example, 
            fn_kwargs={
                'tokenizer': tokenizer, 'col_prompt': 'prompt', 'col_summary': 'summary', 
                'max_input_length': max_input_length, 'max_target_length': max_target_length, 
            },
            batched=True, 
            batch_size=500,
            num_proc=12,
        )

        # count tokens
        .map(count_tokens, fn_kwargs={'col_list': ['input_ids', 'label']}, num_proc=12)
    )
    return ds_preprocessed


def inference(
    example: dict, tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM, 
    num_beams: int = 5, max_summary_ratio: float = 1.0, device: str = 'cpu'
) -> dict:
    """Make inference.
    :param example: dict, example.
    :param tokenizer: transformers.AutoTokenizer, tokenizer.
    :param model: transformers.AutoModelForSeq2SeqLM, model.
    :param num_beams: int, the number of beams for beam search. Default: 5.
    :param max_summary_ratio: float, the maximum summary to text ratio. Default: 1.
    :param device: str, device to make an inference.
    :return: dict, prediction with 'predicted_tokens' and 'prediction' columns.
    """
    # convert dataset into tensor, need to put another dimension
    # type torch.LongTensor or torch.int64
    # torch batch does not work because length of the sequence varies with examples
    input_ids = ( 
        torch.tensor(example['input_ids'])
        .reshape((1, -1))
        .to(device)
    )

    # generate config
    generation_config = GenerationConfig(
        max_new_tokens=min(
            int(max_summary_ratio * example['count_tokens_input_ids']),
            max_target_length
        ),
        num_beams=num_beams, 
        skip_special_tokens=True, early_stopping=False
    )
    
    # output type: torch.Tensor
    with torch.no_grad():
        outputs = model.generate(input_ids, generation_config=generation_config)
    # bacl to cpu
    if device != 'cpu':
        outputs = outputs.to(device)
    # decode and remove the initial <pad> and the ending </s>
    decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # assign
    example['predicted_tokens'] = outputs[0].tolist()
    example['prediction'] = decoded_outputs
    # clear cache
    torch.cuda.empty_cache()
    
    return example

## Load dataset

In [4]:
# load dataset
ds = load_from_disk(dataset_path=path_preprocessed)
print(ds)
ds.cleanup_cache_files()

DatasetDict({
    train: Dataset({
        features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title', 'count_text', 'count_summary', 'prompt', 'input_ids', 'attention_mask', 'label', 'count_tokens_input_ids', 'count_tokens_label'],
        num_rows: 24861
    })
    dev: Dataset({
        features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title', 'count_text', 'count_summary', 'prompt', 'input_ids', 'attention_mask', 'label', 'count_tokens_input_ids', 'count_tokens_label'],
        num_rows: 3116
    })
    test: Dataset({
        features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title', 'count_text', 'count_summary', 'prompt', 'input_ids', 'attention_mask', 'label', 'coun

{'train': 0, 'dev': 0, 'test': 0}

In [5]:
# access record
print(ds['train'][0])
print_example(example=ds['train'][0])

{'id': '110s1149is', 'congress': 110, 'bill_type': 's', 'bill_number': 1149, 'bill_version': 'is', 'sections': [{'text': '1. Interstate distribution of State inspected\t\t\t meat and poultry products \n(a) Meat products \nSection 301(a)(1) of the Federal Meat\t\t\t Inspection Act (21 U.S.C. 661(a)(1)) is amended— (1) by striking (1) The\t\t\t Secretary and inserting (1)(A) The Secretary ; (2) in subparagraph (A) (as designated by\t\t\t paragraph (1)), by striking solely for distribution within such\t\t\t State ; and (3) by adding at the end the following: (B) The Secretary shall reimburse a State\t\t\t\tagency administering inspections pursuant to subparagraph (A) for not less than\t\t\t\t50 percent and not more than 60 percent of the costs of administering the\t\t\t\tinspections.. (b) Poultry products \nSection 5(a)(1) of the Poultry Products\t\t\t Inspection Act (21 U.S.C. 454(a)(1)) is amended— (1) by striking (1) The\t\t\t Secretary and inserting (1)(A) The Secretary ; (2) in subpa

In [6]:
print(ds['train'].features)
print(type(ds['train']))

{'id': Value(dtype='string', id=None), 'congress': Value(dtype='int64', id=None), 'bill_type': Value(dtype='string', id=None), 'bill_number': Value(dtype='int64', id=None), 'bill_version': Value(dtype='string', id=None), 'sections': [{'text': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'header': Value(dtype='string', id=None)}], 'sections_length': Value(dtype='int64', id=None), 'text': Value(dtype='string', id=None), 'text_length': Value(dtype='int64', id=None), 'summary': Value(dtype='string', id=None), 'summary_length': Value(dtype='int64', id=None), 'title': Value(dtype='string', id=None), 'count_text': Value(dtype='int64', id=None), 'count_summary': Value(dtype='int64', id=None), 'prompt': Value(dtype='string', id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'label': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=

## Load tokenizer and model

In [7]:
torch.cuda.is_available()

True

In [8]:
# model_checkpoint = "google/mt5-small"
model_checkpoint = 'google/flan-t5-base'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model.to(device)
# set model to eval mode
model.eval()
# to train the model: model.train()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=768, out_features=2048, bias=False)
              (wi_1): Linear(in_features=768, out_features=2048, bias=False)
              (wo):

### Sample test

In [9]:
# convert dataset into tensor, need to put another dimension
# input_ids = ds['train'].select_columns('input_ids').with_format('torch')['input_ids'][0].to('cpu')
# device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
idx = 0
input_ids = ds['train'].select_columns('input_ids').with_format('torch')['input_ids'][idx] 
input_ids = torch.reshape(input_ids, (1, -1)).to(device)
print(type(input_ids))
print(input_ids)
print(ds['train']['prompt'][idx])


<class 'torch.Tensor'>
tensor([[21603,    10,  1300, 28281,  3438,    13,  1015,     3, 26300,  3604,
            11, 27254,   494,    41,     9,    61,  1212,   144,   494,  5568,
             3, 25626,   599,     9,    61, 14296,    13,     8,  5034,  1212,
           144, 24996,  1983,    41,  2658,   412,     5,   134,     5,   254,
             5,   431,  4241,   599,     9,    61, 14296,    61,    19, 21012,
           318,  5637,    57, 11214,  5637,    37,  7471,    11,  8722,    53,
          5637,   599,   188,    61,    37,  7471,     3,   117,  6499,    16,
           769,  6583,  9413,    41,   188,    61,    41,     9,     7,  9943,
            57,  8986,  5637,   201,    57, 11214,  4199,   120,    21,  3438,
           441,   224,  1015,     3,   117,    11, 10153,    57,  2651,    44,
             8,   414,     8,   826,    10,    41,   279,    61,    37,  7471,
          1522, 29560,     3,     9,  1015,  3193, 24235,    53,  6082,     7,
         19890,   288,    12,

In [10]:
generation_config = GenerationConfig(
    max_new_tokens=int(1.0 * ds['train']['count_tokens_input_ids'][idx]), num_beams=5, skip_special_tokens=True, 
    early_stopping=False
)
# generation_config = GenerationConfig(max_new_tokens=100, num_beams=5, skip_special_tokens=True)
print(generation_config.max_new_tokens)

# model to device
model.to(device)
# output type: torch.Tensor
outputs = model.generate(input_ids, generation_config=generation_config)
# bacl to cpu
if device != 'cpu':
    outputs = outputs.to(device)
    # decoded_outputs = decoded_outputs.to(device)
# remove the initial <pad> and the ending </s>
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
torch.cuda.empty_cache()

# convert token ids to tokens
token = tokenizer.convert_ids_to_tokens(input_ids[0])
print(token)
# convert tokens to text
print(tokenizer.convert_tokens_to_string(token))

293
1. Interstate distribution of State inspected meat and poultry products (a) Meat products Section 301(a)(1) of the Federal Meat Inspection Act (21 U.S.C. 661(a)(1)) is amended— (1) by striking (1) The Secretary and inserting (1)(A) The Secretary ; (2) in subparagraph (A) (a) (a) (a) (a) (a) (a) (a) (a) (a) (a) (b) Poultry products Section 5(a)(1) of the Poultry Products Inspection Act (21 U.S.C. 454(a)(1)) is amended— (1) by striking (1) The Secretary and inserting (1)(A) The Secretary ; (2) in subparagraph (A) (a) (a) (a) (a) (b) (a) (a) (a) (b) (a) (a) (a) (b) (a) (a) (a) (b) (a) (a) (a) (b) (a) (a) (a) (b) (a) (a) (b) (a) (a) (b) (a) (a) (b) (a) (a) (b) (a) (a) (b) (a) (a
['▁summarize', ':', '▁1.', '▁Interstate', '▁distribution', '▁of', '▁State', '▁', 'inspected', '▁meat', '▁and', '▁poultry', '▁products', '▁(', 'a', ')', '▁Me', 'at', '▁products', '▁Section', '▁', '301', '(', 'a', ')', '(1)', '▁of', '▁the', '▁Federal', '▁Me', 'at', '▁Inspection', '▁Act', '▁(', '21', '▁U', '.', 'S

## Inference on entire data
* This will take a while (hours)

In [11]:
ds['dev']

Dataset({
    features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title', 'count_text', 'count_summary', 'prompt', 'input_ids', 'attention_mask', 'label', 'count_tokens_input_ids', 'count_tokens_label'],
    num_rows: 3116
})

In [14]:
# ds_output = ds['train'].select(range(3)).map(
ds_predict = ( 
    # ds['train']
    # ds
    ds['dev']
    .map(
        inference, 
        fn_kwargs={
            'tokenizer': tokenizer, 'model': model, 'num_beams': num_beams, 
            'max_summary_ratio': max_summary_ratio, 'device': device,
        }, 
        # multiprocessing not working for gpu
        num_proc=1,
    )
)

Map:   0%|          | 0/3116 [00:00<?, ? examples/s]

In [15]:
# combined with tokenized_train
ds_predict_all = DatasetDict(
    {
        # 'train': ds_predict_train, 
        'dev': ds_predict, 
        # 'test': ds_predict_test
    }
)
print(ds_predict_all)

DatasetDict({
    dev: Dataset({
        features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title', 'count_text', 'count_summary', 'prompt', 'input_ids', 'attention_mask', 'label', 'count_tokens_input_ids', 'count_tokens_label', 'predicted_tokens', 'prediction'],
        num_rows: 3116
    })
})


## Save predictions

In [16]:
ds_predict_all.save_to_disk(path_predictions)
!ls {path_predictions}

Saving the dataset (0/1 shards):   0%|          | 0/3116 [00:00<?, ? examples/s]

dataset_dict.json  dev


In [29]:
# model back to device
if device != 'cpu':
    model.to('cpu')
print(model.device)

cpu


## Verify predictions

In [30]:
# idx = 0
for idx in range(3):
    print(f'idx: {idx}')
    print(f"prompt\n{ds_predict['prompt'][idx]}\n")
    print(f"prediction\n{ds_predict['prediction'][idx]}\n")
    print(f"summary\n{ds_predict['summary'][idx]}\n")

idx: 0
prompt
summarize: That the House of Representatives supports the goals and ideals of National Community Gardening Awareness Month, including— (1) raising awareness about the importance of community gardens and urban agriculture; (2) improving access to public land for the creation of sustainable food projects; (3) encouraging further growth of community gardens and other opportunities that increase food self-reliance, improve fitness, contribute to a cleaner environment, and enhance community development; and (4) supporting cooperative efforts among Federal, State, and local governments and nonprofit organizations to promote the development and expansion of community gardens and to increase their accessibility to disadvantaged population groups.

prediction
To support the goals and ideals of National Community Gardening Awareness Month, including— (1) raising awareness about the importance of community gardens and urban agriculture; (2) improving access to public land for the cr