# Text Summarization: 03b. modeling with instruction fine tuning and LoRA

* QLORA paper: https://arxiv.org/abs/2305.14314
* example notebook: https://blog.ovhcloud.com/fine-tuning-llama-2-models-using-a-single-gpu-qlora-and-ai-notebooks/
* from bits and bytes https://huggingface.co/blog/4bit-transformers-bitsandbytes
    * colab notebook: https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing#scrollTo=gkIcwsSU01EB 
    * Old blog plost: https://huggingface.co/blog/hf-bitsandbytes-integration
    * Old LLM.int8(): https://arxiv.org/abs/2208.07339
 
* When cuda error after suspended: https://discuss.pytorch.org/t/userwarning-cuda-initialization-cuda-unknown-error-this-may-be-due-to-an-incorrectly-set-up-environment-e-g-changing-env-variable-cuda-visible-devices-after-program-start-setting-the-available-devices-to-be-zero/129335/4

```sudo rmmod nvidia_uvm```

```sudo modprobe nvidia_uvm```

* **steps**
1. Load dataset
2. create bnb config
3. Load model with bnb config on quantization info
4. find modules for LORA
5. Create LORA config and wrap the model
6. Config train
7. Train model
8. Save model
9. clear memory

## Import libraries

In [1]:
import os
from datasets import DatasetDict, load_from_disk
from matplotlib import pyplot as plt
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig, PreTrainedModel
import torch
import evaluate
import re
import numpy as np
import time
from nltk.tokenize import sent_tokenize
import nltk 
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/stephen/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
# Need to install: bitsandbytes, peft, scipy (bitsandbytes needs it)

import bitsandbytes as bnb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, AutoPeftModelForSeq2SeqLM, TaskType
from transformers import ( 
    set_seed, Seq2SeqTrainingArguments, BitsAndBytesConfig, 
    DataCollatorForLanguageModeling, Seq2SeqTrainer, DataCollatorForSeq2Seq
)
import accelerate

In [3]:
# set variables
SEED = 1234
curr_path = os.getcwd()
path_raw = os.path.join(curr_path, '../data/raw', 'bill_summary_us')
# dataset_name = 'vgoldberg/longform_article_summarization'
dataset_name = 'dreamproit/bill_summary_us'
path_preprocessed = os.path.join(curr_path, '../data/preprocessed', 'bill_summary_us_single_section')

path_predictions = os.path.join(
    curr_path, '../models/', 
    'flan_t5_base_pretrained', 
    'predictions', 'bill_summary_us_single_section'
)

path_predictions_lora = os.path.join(
    curr_path, '../models/', 
    'flan_t5_finetuned_lora', 
    'predictions', 'bill_summary_us_single_section'
)

max_input_length = 1024
max_target_length = 512

num_beams = 5
max_summary_ratio = 1.0

model_checkpoint = 'google/flan-t5-base'

## Helper Functions

In [4]:
def print_example(example: dict) -> None:
    """Print example.
    :param example: dict, example.
    """
    print(' '.join(['=' * 10, 'text', '=' * 10]))
    print(example['text'])
    print(' '.join(['=' * 10, 'summary', '=' * 10]))
    print(example['summary'])    


def remove_missing_data(example: dict, col_list: list[str]) -> bool:
    """Remove missing data.
    :param example: dict, example.
    :param col_list, list[str], list of columns.
    :return: bool, whether valid (True) or missing data (False). 
    """
    flag_value = True
    for col in col_list:
        flag_value = flag_value and example[col] is not None
    return flag_value


def count_words(example: dict, col_list: list[str]) -> dict:    
    """Count words.
    :param example: dict, example.
    :param col_list, list[str], list of columns.
    :return: dict, count dictionary with prefix `count_`.
    """
    count_dict = {}
    for col_name in col_list:
        col_count = '_'.join(['count', col_name])
        word_count = len(example[col_name].split())
        count_dict[col_count] = word_count
    return count_dict


def process_tab(example: dict, col: str) -> dict:    
    """Process tab. Remove triple tabs then replace a tab with space.
    :param example: dict, example.
    :param col: str, column name to process.
    :return, dict, processed example.
    """
    example[col] = example[col].replace('\t\t\t', '').replace('\t', ' ')
    return example


def plot_text_length(df: pd.DataFrame, col_list: list[str]) -> tuple:
    """Plot historgram of text length.
    :param df, pd.DataFrame, input dataframe.
    :param col_list: list[str], a list of column names.
    :return: tuple, (fig, axis_tuple).
        - fig: matplotlib.figure, figure handle.
        - axis_tuple: tuple[matplotlib.axes.Axes], a tuple of axes.
    """
    num_col = len(col_list)
    fig, axis_tuple = plt.subplots(nrows=1, ncols=num_col, figsize=[6.4 * num_col, 4.8])
    for idx, col in enumerate(col_list):
        ax = axis_tuple[idx]
        _ = df.plot(y=col, kind='hist', bins=100, ax=ax)  

    return fig, axis_tuple


def find_outliers_cutoff(df: pd.DataFrame, scale: float = 1.5) -> dict[tuple]:
    """Find outliers cutoff.
    :param df, pd.DataFrame, input dataframe.
    :param scale: float, scale factor for IQR to define whiskter. Default: 1.5.
    :return: dict[tuple], a dictionary of lower and upper cutoff for each column.
    """
    # upper, lower whisker: pd.Series
    q3 = df.quantile(q=0.75)
    q1 = df.quantile(q=0.25)
    iqr = q3 - q1
    upper = q3 + scale * iqr
    lower = q1 - scale * iqr
    
    upper_dict = upper.to_dict()
    lower_dict = lower.to_dict()
    cutoff_dict = {}
    for key in upper_dict:
        cutoff_dict[key] = (lower_dict[key], upper_dict[key])
    return cutoff_dict    


def remove_outliers(example: dict, col_list: list, cutoff_dict: dict[tuple]) -> bool:
    """Remove outliers.
    :param example: dict, example.
    :param col_list, list[str], list of columns.
    :param cutoff_dict: dict[tuple], a dictionary of lower and upper cutoff for each column.
    :return: bool, whether valid (True) or outliers (False).
    """
    flag_val = True
    for col in col_list:
        cutoff_low, cutoff_high = cutoff_dict[col]
        flag_val = flag_val and cutoff_low <= example[col] <= cutoff_high
    return flag_val


def create_prompt(example: dict, col_text: str, prompt_prefix: str) -> dict:
    """Creat prompt with prompt prefix.
    :param example: dict, example.
    :param col_text: str, column for text.
    :param prompt_prefix: str, prompt prefix for instruction.
    :return: dict, example with prompt column.
    """
    example['prompt'] = f'{prompt_prefix}: {example[col_text]}'
    return example


#  build system to tokenize
def tokenize_example(
    example: dict, tokenizer: AutoTokenizer, col_prompt: str, col_summary: str, 
    max_input_length: int, max_target_length: int
) -> dict: 
    """Tokenize example.
    :param example: dict, example.
    :param tokenizer: transformers.AutoTokenizer, tokenizer.
    :param col_prompt: str, column name for prompt.
    :param col_summary: str, column name for summary (target)
    :param max_input_length: int, maximum input token length. 
    :param max_target_length: int, maximum target token length. 
    :return: dict, tokens for input ('input_ids') and target ('label').
    """
    model_input = tokenizer(
        example[col_prompt], 
        max_length=max_input_length, 
        truncation=True,
        # return_tensors='pt',
    )
    label = tokenizer(
        # example[col_summary], 
        text_target=example[col_summary],  
        max_length=max_target_length, 
        truncation=True,
        # return_tensors='pt',
    )
    model_input['label'] = label['input_ids']
    return model_input


def count_tokens(example: dict, col_list: list[str]) -> dict:    
    """Count tokens.
    :param example: dict, example.
    :param col_list: list[str], a list of column names.
    :return: dict, token counts with prefix 'count_tokens_'
    """
    count_dict = {}
    for col_name in col_list:
        col_count = '_'.join(['count_tokens', col_name])
        word_count = len(example[col_name])
        # print(type(example['text']))  # str
        # example[col_count] = len(example['text'].split())
        count_dict[col_count] = word_count
    return count_dict


def preprocessing(ds, col_names, cutoff_dict, col_for_count, tokenizer, max_input_length, max_target_length):   
    ds_preprocessed = ( 
        # 1. remove missing
        ds.filter(remove_missing_data, fn_kwargs={'col_list': col_names})
        
        # 2. filtering
        # select only 1 section
        .filter(lambda example: example['sections_length'] == 1, num_proc=12)
        # remove any zero text length, summary_length
        .filter(lambda example: example['text_length'] > 0, num_proc=12)
        .filter(lambda example: example['summary_length'] > 0, num_proc=12)
        # text_length_outliers, summary_length_outliers
        .filter(remove_outliers, fn_kwargs={'col_list': list(cutoff_dict.keys()), 'cutoff_dict': cutoff_dict}, num_proc=12)

        # count words
        .map(count_words, fn_kwargs={'col_list': col_for_count}, num_proc=12)
        
        # remove \t\t\t, and replace \t with ' '
        .map(process_tab, fn_kwargs={'col': 'text'}, num_proc=12)
        
        # create prompt
        .map(create_prompt, fn_kwargs={'col_text': 'text', 'prompt_prefix': 'summarize'}, num_proc=12)
        
        # tokenize: need to specify max_input_length, when number of tokens is greater than default limit (512)
        .map(
            tokenize_example, 
            fn_kwargs={
                'tokenizer': tokenizer, 'col_prompt': 'prompt', 'col_summary': 'summary', 
                'max_input_length': max_input_length, 'max_target_length': max_target_length, 
            },
            batched=True, 
            batch_size=500,
            num_proc=12,
        )

        # count tokens
        .map(count_tokens, fn_kwargs={'col_list': ['input_ids', 'label']}, num_proc=12)
    )
    return ds_preprocessed


def inference(
    example: dict, tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM, 
    num_beams: int = 5, max_summary_ratio: float = 1.0, device: str = 'cpu'
) -> dict:
    """Make inference.
    :param example: dict, example.
    :param tokenizer: transformers.AutoTokenizer, tokenizer.
    :param model: transformers.AutoModelForSeq2SeqLM, model.
    :param num_beams: int, the number of beams for beam search. Default: 5.
    :param max_summary_ratio: float, the maximum summary to text ratio. Default: 1.
    :param device: str, device to make an inference.
    :return: dict, prediction with 'predicted_tokens' and 'prediction' columns.
    """
    # convert dataset into tensor, need to put another dimension
    # type torch.LongTensor or torch.int64
    # torch batch does not work because length of the sequence varies with examples
    input_ids = ( 
        torch.tensor(example['input_ids'])
        .reshape((1, -1))
        .to(device)
    )

    # generate config
    generation_config = GenerationConfig(
        max_new_tokens=min(
            int(max_summary_ratio * example['count_tokens_input_ids']),
            max_target_length
        ),
        num_beams=num_beams, 
        skip_special_tokens=True, early_stopping=False
    )
    
    # output type: torch.Tensor
    with torch.no_grad():
        outputs = model.generate(input_ids, generation_config=generation_config)
    # bacl to cpu
    if device != 'cpu':
        outputs = outputs.to(device)
    # decode and remove the initial <pad> and the ending </s>
    decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # assign
    example['predicted_tokens'] = outputs[0].tolist()
    example['prediction'] = decoded_outputs
    # clear cache
    torch.cuda.empty_cache()
    
    return example


def compute_metrics(eval_pred):
    # text
    predictions, labels = eval_pred
    if isinstance(predictions, str):
        result = rouge.compute(predictions=[predictions], references=[labels], use_stemmer=True)
    else:
        result = rouge.compute(predictions=predictions, references=labels, use_stemmer=True)
    # round to the 4th decimal
    result = {key: round(val, 4) for key, val in result.items()}
    return result


def compute_metrics_for_training(eval_pred):
    # tokens_id
    predictions, labels = eval_pred
    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # ROUGE expects a newline after each sentence. So, add \n for new sentence 
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    # Compute ROUGE scores
    rouge_score = evaluate.load('rouge')
    result = rouge_score.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # round to the 4th decimal
    result = {key: round(val, 4) for key, val in result.items()}
    return result
    
    # # Extract the median scores
    # result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    # return {k: round(v, 4) for k, v in result.items()}


## Load dataset

In [5]:
# load dataset
ds = load_from_disk(dataset_path=path_preprocessed)
print(ds)
ds.cleanup_cache_files()

DatasetDict({
    train: Dataset({
        features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title', 'count_text', 'count_summary', 'prompt', 'input_ids', 'attention_mask', 'labels', 'count_tokens_input_ids', 'count_tokens_labels'],
        num_rows: 24861
    })
    dev: Dataset({
        features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title', 'count_text', 'count_summary', 'prompt', 'input_ids', 'attention_mask', 'labels', 'count_tokens_input_ids', 'count_tokens_labels'],
        num_rows: 3116
    })
    test: Dataset({
        features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title', 'count_text', 'count_summary', 'prompt', 'input_ids', 'attention_mask', 'labels', 

{'train': 0, 'dev': 0, 'test': 3}

In [6]:
# small sample
# ds_train = ds['train'].select_columns(['input_ids', 'attention_mask', 'labels']).select(range(20))
# ds_dev = ds['dev'].select_columns(['input_ids', 'attention_mask', 'labels']).select(range(20))

# ds_train = ds['train'].select_columns(['input_ids', 'attention_mask', 'labels']).select(range(500))
# ds_dev = ds['dev'].select_columns(['input_ids', 'attention_mask', 'labels']).select(range(100))

# ds_train = ds['train'].select_columns(['input_ids', 'attention_mask', 'labels']).select(range(5000))
# ds_dev = ds['dev'].select_columns(['input_ids', 'attention_mask', 'labels']).select(range(1000))

ds_train = ds['train'].select_columns(['input_ids', 'attention_mask', 'labels'])
ds_dev = ds['dev'].select_columns(['input_ids', 'attention_mask', 'labels'])

In [7]:
print(ds_train)
print(ds_dev)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 24861
})
Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 3116
})


In [8]:
# access record
print(ds_train[0])
print_example(example=ds['train'][0])

{'input_ids': [21603, 10, 1300, 28281, 3438, 13, 1015, 3, 26300, 3604, 11, 27254, 494, 41, 9, 61, 1212, 144, 494, 5568, 3, 25626, 599, 9, 61, 14296, 13, 8, 5034, 1212, 144, 24996, 1983, 41, 2658, 412, 5, 134, 5, 254, 5, 431, 4241, 599, 9, 61, 14296, 61, 19, 21012, 318, 5637, 57, 11214, 5637, 37, 7471, 11, 8722, 53, 5637, 599, 188, 61, 37, 7471, 3, 117, 6499, 16, 769, 6583, 9413, 41, 188, 61, 41, 9, 7, 9943, 57, 8986, 5637, 201, 57, 11214, 4199, 120, 21, 3438, 441, 224, 1015, 3, 117, 11, 10153, 57, 2651, 44, 8, 414, 8, 826, 10, 41, 279, 61, 37, 7471, 1522, 29560, 3, 9, 1015, 3193, 24235, 53, 6082, 7, 19890, 288, 12, 769, 6583, 9413, 41, 188, 61, 21, 59, 705, 145, 943, 1093, 11, 59, 72, 145, 1640, 1093, 13, 8, 1358, 13, 24235, 53, 8, 6082, 7, 5, 5, 41, 115, 61, 1908, 83, 8224, 494, 5568, 305, 599, 9, 61, 14296, 13, 8, 1908, 83, 8224, 7554, 24996, 1983, 41, 2658, 412, 5, 134, 5, 254, 5, 314, 5062, 599, 9, 61, 14296, 61, 19, 21012, 318, 5637, 57, 11214, 5637, 37, 7471, 11, 8722, 53, 5637, 

In [9]:
print(ds['train'].features)
print(type(ds['train']))

{'id': Value(dtype='string', id=None), 'congress': Value(dtype='int64', id=None), 'bill_type': Value(dtype='string', id=None), 'bill_number': Value(dtype='int64', id=None), 'bill_version': Value(dtype='string', id=None), 'sections': [{'text': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'header': Value(dtype='string', id=None)}], 'sections_length': Value(dtype='int64', id=None), 'text': Value(dtype='string', id=None), 'text_length': Value(dtype='int64', id=None), 'summary': Value(dtype='string', id=None), 'summary_length': Value(dtype='int64', id=None), 'title': Value(dtype='string', id=None), 'count_text': Value(dtype='int64', id=None), 'count_summary': Value(dtype='int64', id=None), 'prompt': Value(dtype='string', id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'labels': Sequence(feature=Value(dtype='int64', id=None), length=-1, id

In [10]:
ds['train'].features

{'id': Value(dtype='string', id=None),
 'congress': Value(dtype='int64', id=None),
 'bill_type': Value(dtype='string', id=None),
 'bill_number': Value(dtype='int64', id=None),
 'bill_version': Value(dtype='string', id=None),
 'sections': [{'text': Value(dtype='string', id=None),
   'id': Value(dtype='string', id=None),
   'header': Value(dtype='string', id=None)}],
 'sections_length': Value(dtype='int64', id=None),
 'text': Value(dtype='string', id=None),
 'text_length': Value(dtype='int64', id=None),
 'summary': Value(dtype='string', id=None),
 'summary_length': Value(dtype='int64', id=None),
 'title': Value(dtype='string', id=None),
 'count_text': Value(dtype='int64', id=None),
 'count_summary': Value(dtype='int64', id=None),
 'prompt': Value(dtype='string', id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'labels': Sequence(feature=Value(dtype='int64', 

## Load model and set up for QLoRA

In [11]:
# changing the compute dtype
def create_bnb_config() -> BitsAndBytesConfig:
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    return bnb_config

In [12]:
# TODO: later check data types for function signature
def load_model(model_checkpoint: str, bnb_config: BitsAndBytesConfig) -> tuple[AutoModelForSeq2SeqLM, AutoTokenizer]:
    n_gpus = torch.cuda.device_count()
    # specific to GPU model
    max_memory = f'{4096*1}MB'
    
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_checkpoint, 
        quantization_config=bnb_config,
        # dispatch efficiently the model on the available resources
        device_map='auto',
        max_memory={idx: max_memory for idx in range(n_gpus)}        
    )
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    return model, tokenizer

## Create bitsandbytes configuration

In [13]:
# bnb_config = create_bnb_config()
# print(bnb_config)

## LoRA configuration

In [13]:
def create_lora_config(
    modules: torch.nn.Module, 
    rank=8, lora_alpha=16, lora_dropout=0.05, task_type=TaskType.SEQ_2_SEQ_LM
) -> LoraConfig:
    """Apply LoRA to modules"""
    # TODO: parameter, 8, 16, drop_out task_type
    config = LoraConfig(
        # rank of the update matrics, independent of r when use all linear layers in q-lora. 
        # may work on 8 then increase to 16 to see improvement
        r=rank,  # 
        # parameter for scaling
        lora_alpha=lora_alpha,  # 64,
        target_modules=modules,
        # dropout probability for layers. 0.05 for small models (7B, 13B), 0.1 for larger models (33B, 65B) from Q-LORA paper
        lora_dropout=lora_dropout,  
        bias='none',
        # 'SEQ_2_SEQ_LM'
        task_type=task_type,        
    )
    return config


In [14]:
# Find all linear modules to apply lora, qlora all moduls as opposed to self-attention for lora only.
def find_all_linear_names(model: PreTrainedModel) -> list:
    # Q-LORA paper suggests that include all linear modules
    model_modules = str(model.modules)
    pattern = r'\((\w+)\): Linear'
    linear_layer_names = re.findall(pattern, model_modules)
    
    lora_module_names = set()
    # Print the names of the Linear layers
    for name in linear_layer_names:
        lora_module_names.add(name)
    target_modules = list(lora_module_names)
    return target_modules   
 

In [15]:
# print trainable parameters
def print_trainable_parameters(model: torch.nn.Module, use_4bit=False):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        # numel method inherited from torch.Tensor
        num_params = param.numel()
        # if using DS (deepspeed) Zero 3 and the weights are initialized empty
        if num_params == 0 and hasattr(param, 'ds_numel'):
            num_params = param.ds_numel
        all_param += num_params
        if param.requires_grad:
            trainable_params += num_params
    # # why devide by 2?
    # if use_4bit:
    #     trainable_params /= 2
    trainable_perc = 100 * trainable_params / all_param
    print(
        f'all params: {all_param: d}, trainable params: {int(trainable_params): d}, trainable: {trainable_perc: .4f}%.'
    )
# Don't merge LORA adapter into a 4-bit LLM with Q-lora

In [16]:
# torch.cuda.is_available() should be True; otherwise, import error
# model, tokenizer = load_model(model_checkpoint=model_checkpoint, bnb_config=bnb_config)

# for LORA: load raw model
def load_model_for_lora(model_checkpoint: str) -> tuple[AutoModelForSeq2SeqLM, AutoTokenizer]:
    n_gpus = torch.cuda.device_count()
    # specific to GPU model
    max_memory = f'{4096*1}MB'
    
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_checkpoint, 
        # quantization_config=bnb_config,
        # dispatch efficiently the model on the available resources
        device_map='auto',
        max_memory={idx: max_memory for idx in range(n_gpus)}, 
        torch_dtype=torch.bfloat16
    )
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    return model, tokenizer

# load original model and tokenizer
original_model, tokenizer = load_model_for_lora(model_checkpoint=model_checkpoint)
    
# torch.nn.modules.module

In [17]:
# lora_module_names = find_all_linear_names(model)

# for lora
lora_module_names = ['q', 'v']  
print(lora_module_names)

['q', 'v']


In [18]:
# print_trainable_parameters(model, use_4bit=True)
print_trainable_parameters(original_model, use_4bit=False)

all params:  247577856, trainable params:  247577856, trainable:  100.0000%.


In [19]:
print(len(ds_train) //2 //10)
print(len(ds_dev) //2 //10)

1243
155


In [20]:
# torch.nn.modules.module
# transformers.PretrainedModel
def train(model, tokenizer, train_dataset, eval_dataset, checkpoint_dir, output_dir, max_target_length=512):
    # note: model changes dynamically, modify input model
    # Apply preprocessing to the model to prepare it by
    # # 1 - Enabling gradient checkpointing to reduce memory usage during fine-tuning.
    # # added use_reentrant = False to avoid warning
    # model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})

    # 2 - Using the prepare_model_fot_kbit_training method from PEFT
    # does not need this for LORA
    # model = prepare_model_for_kbit_training(model)

    # Get lora module names
    # modules = find_all_linear_names(model)
    modules = ['q', 'v']  # attention layer
    print(modules)

    # Create PEFT config for these modules and wrap the model to PEFT
    # peft_config = create_lora_config(modules)
    # LoRA
    peft_config = create_lora_config(modules, rank=16, lora_alpha=16)  
    # peft_config = create_lora_config(modules, lora_dropout=0)

    # note that this change model: running twice will make an error
    peft_model = get_peft_model(model, peft_config)

    # print information about the percentage of trainable parameters
    print_trainable_parameters(peft_model)

    # 1 - Enabling gradient checkpointing to reduce memory usage during fine-tuning.
    # added use_reentrant = False to avoid warning
    # model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False})

    
    # add trainining part
    # Training parameters
    batch_size = 8  # based on memory constraint, reduce batch size from 8 to 4
    num_train_epochs = 8 * 3 # 8 * 5 # 2 # 8  # 16: 3 3.97 epochs
    # show the training loss with every epoch
    num_steps_for_epoch = len(train_dataset) // batch_size
    num_steps_for_epoch_eval = len(eval_dataset) // batch_size
    print(f'num_steps_for_epoch: {num_steps_for_epoch}')
    print(f'num_steps_for_epoch_eval: {num_steps_for_epoch_eval}')

    peft_training_args = Seq2SeqTrainingArguments(
        output_dir=checkpoint_dir,
        evaluation_strategy='epoch',
        num_train_epochs=num_train_epochs,            
        # evaluation_strategy='steps',
        # save_strategy='steps',            
        max_steps=num_train_epochs * num_steps_for_epoch,
        # per_device_train_batch_size=batch_size,
        # per_device_eval_batch_size=batch_size,
        logging_steps=1,   # num_steps_for_epoch // 10,
        # eval_steps=num_steps_for_epoch_eval // 10,  # _eval
        learning_rate=1e-3, # Higher learning rate than full fine-tuning  # 2e-4,
        auto_find_batch_size=True,
        # weight_decay=0.01,            
        # # pass
        # gradient_accumulation_steps=4,  # effective batch size: batch_size * gradient_accumulation_steps
        # eval_accumulation_steps=4,
        # # gradient checking point
        # gradient_checkpointing=True,
        # gradient_checkpointing_kwargs={'use_reentrant': False},
        # # seq2seq model
        predict_with_generate=True,
        # # have to explicitly specify this. setting by generation_config does not work.
        # generation_max_length=max_target_length,
        # generation_num_beams=5, 
        # generation_config=GenerationConfig(
        #     max_new_tokens=max_target_length, num_beams=5,  
        #     skip_special_tokens=True, early_stopping=False), 
        # # check point
        # load_best_model_at_end=True,
        # # should be multiple of eval_steps
        # save_steps=num_steps_for_epoch_eval // 10 * 2 ,
        # # precision
        # fp16=True,            
        # # need to verify this supports this optim            
        # optim='adamw_bnb_8bit',
        # # parallel processing
        # dataloader_num_workers=4,             
        # # output_dir='outputs',            
        # # optim='paged_adamw_8bit',
        # # warmup_steps=2,
        # # max_steps=20,          
    )
    print(peft_training_args)
    
    trainer = Seq2SeqTrainer(
        model=peft_model,
        train_dataset=train_dataset, 
        eval_dataset=eval_dataset,
        # verify all these arguments
        args=peft_training_args, 
        # compute_metrics=compute_metrics_for_training,
        # data_collator=Databfloat16CollatorForSeq2Seq(tokenizer, model=model),        
    )
    
    # re-enable for inference to speed up predictions for similar inputs
    peft_model.config.use_cache = False

    ### SOURCE https://github.com/artidoro/qlora/blob/main/qlora.py
    # Verifying the datatypes before training
    dtypes = {}
    for _, p in peft_model.named_parameters():
        dtype = p.dtype
        if dtype not in dtypes: 
            dtypes[dtype] = 0
        dtypes[dtype] += p.numel()
    total = 0
    for data_type, num_param in dtypes.items():
        total += num_param
    for data_type, num_param in dtypes.items():
        print(f'data type: {data_type}, num_param: {num_param}, percentage: {num_param/total * 100: .2f}')

    do_train = True

    # Launch training
    print('training...')

    if do_train:
        train_result = trainer.train()
        metrics = train_result.metrics
        trainer.log_metrics('train', metrics)
        trainer.save_metrics('train', metrics)
        trainer.save_state()
        print(metrics)

    ###
    # Saving model
    print('Saving last checkpoint of the model...')
    os.makedirs(output_dir, exist_ok=True)
    trainer.model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

    # Free memory for merging weights
    del peft_model
    del trainer
    torch.cuda.empty_cache()


In [21]:
torch.cuda.empty_cache()

In [22]:
checkpoint_dir = '../models/flan_t5_finetuned_lora/checkpoint'
output_dir = f'../models/flan_t5_finetuned_lora/final_checkpoint-{str(int(time.time()))}'
print(output_dir)
# train(model, tokenizer, train_dataset=None, eval_dataset=None, output_dir=output_dir)
# # only 3.2945% compared to 25%

../models/flan_t5_finetuned_lora/final_checkpoint-1703563820


## Train: train_runtime = 7:39:09.53

In [23]:
# whole sample
# max_steps some number instead of 1
# rank=16 and alpha=16
# train(model, tokenizer, train_dataset=ds_train, eval_dataset=ds_dev, output_dir=output_dir, max_target_length=max_target_length)
train(
    model=original_model, tokenizer=tokenizer, train_dataset=ds_train, eval_dataset=ds_dev, 
    checkpoint_dir=checkpoint_dir, output_dir=output_dir, max_target_length=max_target_length,
)
# only 3.2945% compared to 25%

['q', 'v']
all params:  249347328, trainable params:  1769472, trainable:  0.7096%.
num_steps_for_epoch: 3107
num_steps_for_epoch_eval: 389
Seq2SeqTrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=True,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=True,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=epoch,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
generation_config=Non

Epoch,Training Loss,Validation Loss
1,0.0177,0.094296
2,0.0513,0.084767
3,0.0222,0.081937
4,0.3203,0.081767
5,0.0228,0.07816


***** train metrics *****
  epoch                    =         6.0
  total_flos               = 191720570GF
  train_loss               =        0.12
  train_runtime            =  7:39:09.53
  train_samples_per_second =       5.413
  train_steps_per_second   =       2.707
{'train_runtime': 27549.5311, 'train_samples_per_second': 5.413, 'train_steps_per_second': 2.707, 'total_flos': 2.058583953993892e+17, 'train_loss': 0.11998937406479551, 'epoch': 6.0}
Saving last checkpoint of the model...


## Load tokenizer and trained model

In [24]:
from peft import PeftModel

# is_trainable=False for inference
peft_model = PeftModel.from_pretrained(
    original_model, output_dir, is_trainable=False, 
    # bfloat16 to save memory
    torch_dtype=torch.bfloat16,
    # dispatch efficiently the model on the available resources
    device_map='auto',
)
print_trainable_parameters(peft_model)

all params:  249347328, trainable params:  0, trainable:  0.0000%.


## Inference on entire data: it will take time

In [25]:
def inference(
    example: dict, tokenizer: AutoTokenizer, model: AutoModelForSeq2SeqLM, 
    max_target_length, 
    num_beams: int = 5, max_summary_ratio: float = 1.0, device: str = 'cpu'
) -> dict:
    """Make inference.
    :param example: dict, example.
    :param tokenizer: transformers.AutoTokenizer, tokenizer.
    :param model: transformers.AutoModelForSeq2SeqLM, model.
    :param num_beams: int, the number of beams for beam search. Default: 5.
    :param max_summary_ratio: float, the maximum summary to text ratio. Default: 1.
    :param device: str, device to make an inference.
    :return: dict, prediction with 'predicted_tokens' and 'prediction' columns.
    """
    # convert dataset into tensor, need to put another dimension
    # type torch.LongTensor or torch.int64
    # torch batch does not work because length of the sequence varies with examples
    input_ids = ( 
        torch.tensor(example['input_ids'])
        .reshape((1, -1))
        .to(device)
    )
    
    # generate config
    generation_config = GenerationConfig(
        # does not need max_summary_ratio for count_tokens_input_ids after zero_padding: 
        # max_new_tokens=min(
        #     int(max_summary_ratio * example['count_tokens_input_ids']),
        #     max_target_length
        # ),
        max_new_tokens=max_target_length,
        num_beams=num_beams, 
        skip_special_tokens=True, early_stopping=False
    )
    
    # output type: torch.Tensor
    with torch.no_grad():
        outputs = model.generate(input_ids=input_ids, generation_config=generation_config)
    # back to cpu
    if device != 'cpu':
        outputs = outputs.to('cpu')
        input_ids = input_ids.to('cpu')
    # decode and remove the initial <pad> and the ending </s>
    decoded_outputs = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # assign
    example['predicted_tokens'] = outputs[0].tolist()
    example['prediction'] = decoded_outputs
    # clear cache
    torch.cuda.empty_cache()
    
    return example

In [26]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# peft_model.to(device)
peft_model.device

device(type='cuda', index=0)

* Took 1:45:43 for test dataset with batch size 1024

In [28]:
ds_predict_test = ( 
    # ds['train']
    # ds
    ds['test']    
    # ds['test'].select(range(500))  
    .map(
        inference, 
        fn_kwargs={
            'tokenizer': tokenizer, 'model': peft_model, 'num_beams': num_beams, 
            'max_target_length': max_target_length, 
            'max_summary_ratio': max_summary_ratio, 'device': device,
        }, 
        
        # multiprocessing not working for gpu
        num_proc=1,
        batch_size=1024,
    )
)

Map:   0%|          | 0/3051 [00:00<?, ? examples/s]

In [29]:
# combined with tokenized_train
ds_predict_all = DatasetDict(
    {
        # 'train': ds_predict_train, 
        # 'dev': ds_predict_dev, 
        'test': ds_predict_test
    }
)
print(ds_predict_all)

DatasetDict({
    test: Dataset({
        features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title', 'count_text', 'count_summary', 'prompt', 'input_ids', 'attention_mask', 'labels', 'count_tokens_input_ids', 'count_tokens_labels', 'predicted_tokens', 'prediction'],
        num_rows: 3051
    })
})


## Save predictions

In [30]:
ds_predict_all.save_to_disk(path_predictions)
!ls {path_predictions}

Saving the dataset (0/1 shards):   0%|          | 0/3051 [00:00<?, ? examples/s]

dataset_dict.json  dev	test


In [31]:
# load dataset
ds_predict_all = load_from_disk(dataset_path=path_predictions)
print(ds_predict_all)
ds_predict_all.cleanup_cache_files()

DatasetDict({
    test: Dataset({
        features: ['id', 'congress', 'bill_type', 'bill_number', 'bill_version', 'sections', 'sections_length', 'text', 'text_length', 'summary', 'summary_length', 'title', 'count_text', 'count_summary', 'prompt', 'input_ids', 'attention_mask', 'labels', 'count_tokens_input_ids', 'count_tokens_labels', 'predicted_tokens', 'prediction'],
        num_rows: 3051
    })
})


{'test': 0}

In [27]:
# ds_predict_test = ds_predict_all['test']

In [32]:
# model back to device
if device != 'cpu':
    peft_model.to('cpu')
print(peft_model.device)

cpu


## Verify predictions

In [33]:
# idx = 0
for idx in range(3):
    print(f'idx: {idx}')
    print(f"prompt\n{ds_predict_test['prompt'][idx]}\n")
    print(f"prediction\n{ds_predict_test['prediction'][idx]}\n")
    print(f"summary\n{ds_predict_test['summary'][idx]}\n")

idx: 0
prompt
summarize: That the House of Representatives— (1) reaffirms United States support for Georgia’s sovereignty and territorial integrity within its internationally-recognized borders, and does not recognize the independence of the Abkhazia and South Ossetia regions currently occupied by the Russian Federation; and (2) supports continued cooperation between the United States and Georgia and the efforts of the Government of Georgia to provide for the defense of its people and sovereign territory.

prediction
Reaffirms U.S. support for Georgia's sovereignty and territorial integrity within its internationally-recognized borders, and does not recognize the independence of the Abkhazia and South Ossetia regions currently occupied by the Russian Federation. Supports continued cooperation between the United States and Georgia and the efforts of the government of Georgia to provide for the defense of its people and sovereign territory.

summary
Expresses the sense that the House of 

## Evaulate performance

In [34]:
rouge = evaluate.load('rouge')

In [35]:
result = compute_metrics((ds_predict_all['test']['prediction'], ds_predict_all['test']['summary']))
# result = compute_metrics((ds_predict_test['prediction'], ds_predict_test['summary']))
print(result)

{'rouge1': 0.682, 'rouge2': 0.5868, 'rougeL': 0.645, 'rougeLsum': 0.6513}
