# First Full LoRA Trial with Transformer

## peft (for LoRA) and FLAN-T5-small for the LLM

I'm following what seems to be a great tutorial from Mehul Gupta,

> https://medium.com/data-science-in-your-pocket/lora-for-fine-tuning-llms-explained-with-codes-and-example-62a7ac5a3578
> 
> https://web.archive.org/web/20240522140323/https://medium.com/data-science-in-your-pocket/lora-for-fine-tuning-llms-explained-with-codes-and-example-62a7ac5a3578

I'm doing this to prepare creating a LoRA for RWKV ( @todo  put links in here ) so as to fine-tune it for Pat's OLECT-LM stuff.

In [None]:
# #  No need to run this again
# !powershell -c (Get-Date -UFormat \"%s_%Y%m%dT%H%M%S%Z00\") -replace '[.][0-9]*_', '_'"

Output was:

`1716367147_20240522T083907-0600`

## Imports

In [None]:
from datasets import load_dataset
from random import randrange
import torch
from transformers import AutoTokenizer, \
                         AutoModelForSeq2SeqLM, \
                         TrainingArguments, \
                         pipeline
from transformers.utils import logging
from peft import LoraConfig, \
                 prepare_model_for_kbit_training, \
                 get_peft_model, \
                 AutoPeftModelForCausalLM
from trl import SFTTrainer
from huggingface_hub import login, notebook_login

from datasets import load_metric
import nltk
import rouge_score

import pickle
import pprint
import timeit
from humanfriendly import format_timespan
import os

## Load the training and test dataset along with the LLM with its tokenizer

The LLM will be fine-tuned. It seems the tokenizer will also be fine-tuned, 
but I'm not sure 

<b>Why aren't we loading the validation set?</b> (I don't know; that's not a teaching question.)

I've tried to make use of it with the `trainer`. We'll see how it goes

In [None]:
#  Need to install  datasets  from pip, not conda. I'll do all from pip. 
#+ I'll get rid of the current conda environment and make it anew.
#+ Actually, I'll make sure  conda  and  pip  are updated, then do what
#+ I discussed above.
#+
#+ cf. 
#+     arch_ref_1 = "https://web.archive.org/web/20240522150357/" + \
#+                  "https://stackoverflow.com/questions/77433096/" + \
#+                  "notimplementederror-loading-a-dataset-" + \
#+                  "cached-in-a-localfilesystem-is-not-suppor"
#+
#+ Also useful might be
#+     arch_ref_2 = "https://web.archive.org/web/20240522150310/" + \
#+                  "https://stackoverflow.com/questions/76340743/" + \
#+                  "huggingface-load-datasets-gives-" + \
#+                  "notimplementederror-cannot-error"
#
data_files = {'train':'samsum-train.json', 
              'evaluation':'samsum-validation.json',
              'test':'samsum-test.json'}
dataset = load_dataset('json', data_files=data_files)

model_name = "google/flan-t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

#  Next line makes training faster but a little less accurate
model.config.pretraining_tp = 1

tokenizer = AutoTokenizer.from_pretrained(model_name, 
                                          trust_remote_code=True)

#  padding instructions for the tokenizer
#+   ??? !!! What about for RWKV !!! ???
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

#### Trying some things I've been learning

In [None]:
print(model)

## Prompt and Trainer

For our SFT (<b>S</b>upervised <b>F</b>ine <b>T</b>uning) model, we use the `class trl.SFTTrainer`.

I want to research this a bit, especially the `formatting_func` that we'll be passing to the `SFTTrainer`.

First, though, some information about SFT. From the Hugging Face Documentation at https://huggingface.co/docs/trl/en/sft_trainer ([archived](https://web.archive.org/web/20240529140717/https://huggingface.co/docs/trl/en/sft_trainer))

> Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.

Though I won't be using the examples unless I get even more stuck, the next paragraph _has_ examples, and I'll put the paragraph here.

> Check out a complete flexible example at [examples/scripts/sft.py](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) \[[archived](https://web.archive.org/web/20240529140740/https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py)\]. Experimental support for Vision Language Models is also included in the example [examples/scripts/vsft_llava.py](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py) \[[archived](https://web.archive.org/web/20240529140738/https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py)\].

RLHF ([archived wikipedia page](https://web.archive.org/web/20240529142205/https://en.wikipedia.org/wiki/Reinforcement_learning_from_human_feedback)) is <b>R</b>einforcement <b>L</b>earning from <b>H</b>uman <b>F</b>eedback. [TRL](https://huggingface.co/docs/trl/en/index#:~:text=TRL%20is%20a%20full%20stack,Policy%20Optimization%20(PPO)%20step.) ([archived]())       <b>T</b>ransfer <b>R</b>einforcement <b>L</b>earning, a library from Hugging Face.

For the parameter, `formatting_func`, I can look ath the documentation site above (specifically [here](https://huggingface.co/docs/trl/en/sft_trainer#:~:text=formatting_func%20(Optional)), at the GitHub repo for [the code](https://github.com/huggingface/trl/blob/main/trl/trainer/sft_trainer.py) (in the docstrings), or from my local `conda` environment, at `C:\Users\bballdave025\.conda\envs\rwkv-lora-pat\Lib\site-packages\trl\trainer\sft_trainer.py`.

Pulling code from the last one, I get

>         formatting_func (`Optional[Callable]`):
>            The formatting function to be used for creating the `ConstantLengthDataset`.

That matches the first very well

> <b>formatting_func</b> (`Optional[Callable]`) — The formatting function to be used for creating the `ConstantLengthDataset`.

(A quick note: In this Jupyter Notebook environment, I could have typed `trainer = SFTTrainer(` and then <kbd>Shift</kbd> + <kbd>Tab</kbd> to find that same documentation.

However, I think that more clarity is found at the [documentation for `ConstantLengthDataset](https://huggingface.co/docs/trl/en/sft_trainer#:~:text=class%20trl.trainer.ConstantLengthDataset)

> <b>formatting_func</b> (`Callable`, <b>optional</b>) — Function that formats the text before tokenization. Usually it is recommended to have follows a certain pattern such as `"### Question: {question} ### Answer: {answer}"`

So, as we'll see the next code from  the tutorial, it basically is a prompt templater/formatter that matches the JSON. For example, we use `sample['dialogue']` to access the `dialogue` key/pair. That's what I got from all this stuff.

Mehul Gupta himself stated

> Next, using the Input and Output, we will create a prompt template which is a requirement by the SFTTrainer we will be using later

### Prompt

In [None]:
def prompt_instruction_format(sample):
    return f""" Instruction:
      Use the Task below and the Input given to write the Response:

      ### Task:
      Summarize the Input

      ### Input:
      {sample['dialogue']}

      ### Response:
      {sample['summary']}
      """
##endof:  prompt_instruction_format(sample)

### Trainer - the LoRA Setup Part

#### Arguments and Configuration

In [None]:
#  Some arguments to pass to the trainer
training_args = TrainingArguments( output_dir='output',
                                   num_train_epochs=1,
                                   per_device_train_batch_size=4,
                                   save_strategy='epoch',
                                   learning_rate=2e-4
)

# the fine-tuning (peft for LoRA) stuff
peft_config = LoraConfig( lora_alpha=16,
                          lora_dropout=0.1,
                          r=64,
                          bias='none',
                          task_type='CAUSAL_LM'
)

`task_type`, cf. https://github.com/huggingface/peft/blob/main/src/peft/config.py#L222

>        Args:
>            peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use.
>            task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform.
>            inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode.

After some searching using Cygwin

```
bballdave025@MYMACHINE /cygdrive/c/Users/bballdave025/.conda/envs/rwkv-lora-pat/Lib/site-packages/peft/utils
$ ls -lah
total 116K
drwx------+ 1 bballdave025 bballdave025    0 May 28 21:09 .
drwx------+ 1 bballdave025 bballdave025    0 May 28 21:09 ..
-rwx------+ 1 bballdave025 bballdave025 2.0K May 28 21:09 __init__.py
drwx------+ 1 bballdave025 bballdave025    0 May 28 21:09 __pycache__
-rwx------+ 1 bballdave025 bballdave025 8.0K May 28 21:09 constants.py
-rwx------+ 1 bballdave025 bballdave025 3.8K May 28 21:09 integrations.py
-rwx------+ 1 bballdave025 bballdave025  17K May 28 21:09 loftq_utils.py
-rwx------+ 1 bballdave025 bballdave025 9.7K May 28 21:09 merge_utils.py
-rwx------+ 1 bballdave025 bballdave025  25K May 28 21:09 other.py
-rwx------+ 1 bballdave025 bballdave025 2.2K May 28 21:09 peft_types.py
-rwx------+ 1 bballdave025 bballdave025  21K May 28 21:09 save_and_load.py

bballdave025@MYMACHINE /cygdrive/c/Users/bballdave025/.conda/envs/rwkv-lora-pat/Lib/site-packages/peft/utils
$ grep -iIRHn "TaskType" .
peft_types.py:60:class TaskType(str, enum.Enum):
__init__.py:20:# from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
__init__.py:22:from .peft_types import PeftType, TaskType

bballdave025@MYMACHINE /cygdrive/c/Users/bballdave025/.conda/envs/rwkv-lora-pat/Lib/site-packages/peft/utils
$
```

So, let's look at the `peft_types.py` file.

The docstring for `class TaskType(str, enum.Enum)` is

```
    Enum class for the different types of tasks supported by PEFT.
    
    Overview of the supported task types:
    - SEQ_CLS: Text classification.
    - SEQ_2_SEQ_LM: Sequence-to-sequence language modeling.
    - CAUSAL_LM: Causal language modeling.
    - TOKEN_CLS: Token classification.
    - QUESTION_ANS: Question answering.
    - FEATURE_EXTRACTION: Feature extraction. Provides the hidden states which can be used as embeddings or features
      for downstream tasks.
```



### We're going to start timing stuff, so here's some system info

`win_system_info_as_script.py` is a script I wrote with the help
of a variety of StackOverflow and documentation sources.
It should be in the working directory.

In [None]:
import win_system_info_as_script as winsysinfo
winsysinfo.run()

### Try for a baseline

#### Just one summarization to begin with, randomly picked

In [None]:
!powershell -c (Get-Date -UFormat \"%s_%Y%m%dT%H%M%S%Z00\") -replace '[.][0-9]*_', '_'"

Output was:

`timestamp`

In [None]:
#  Just one summarization to begin with, randomly picked ... but
#+ now with th possibility of a known seed, to allow visual 
#+ comparison with after-training results.
#+ I'M NOT GOING TO USE THIS REPEATED SEED, I'm just going to
#+ use the datum at the first index to compare.

do_seed_for_repeatable = False

summarizer = pipeline('summarization', model=model, tokenizer=tokenizer)

if do_seed_for_repeatable:
  rand_seed_for_randrange = 137
  random.seed(rand_seed_for_randrange)
##endof:  if do_seed_for_repeatable

sample = dataset['test'][randrange(len(dataset["test"]))]
print(f"dialogue: \n{sample['dialogue']}\n---------------")

res = summarizer(sample["dialogue"])

print(f"flan-t5-small summary:\n{res[0]['summary_text']}")

#### Now, one summarization with comparison to ground truth

In [None]:
# Now one summarization with comparison to ground truth

summarizer = pipeline('summarization', model=model, tokenizer=tokenizer)

pred_test_list = []
ref_test_list = []

sample_num = 0

this_sample = dataset['test'][sample_num]

print(f"dialogue: \n{this_sample['dialogue']}\n---------------")

grnd_summary = this_sample['summary']
res = summarizer(this_sample['dialogue'])
res_summary = res[0]['summary_text']

# humgen is for human-generated

print(f"human-genratd summary:\n{grnd_summary}")
print(f"flan-t5-small summary:\n{res_summary}")

ref_test_list.append(grnd_summary)
pred_test_list.append(res_summary)

print("\n\n---------- ROUGE SCORES ----------")

rouge = load_metric('rouge', trust_remote_code=True)

results = rouge.compute(predictions=pred_test_list,
                        references=ref_test_list,
                        use_aggregator=True)

# >>> print(list(results.keys()))
# ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']

print()
print("ROUGE-1 results")
pprint.pp(results['rouge1'])
print()
print("ROUGE-2 results")
pprint.pp(results['rouge2'])
print()
print("ROUGE-L results")
pprint.pp(results['rougeL'])
print()
print("ROUGE-Lsum results")
pprint.pp(results['rougeLsum'])

#### Verbosity stuff - get rid of the nice advice

In [None]:
!powershell -c (Get-Date -UFormat \"%s_%Y-%m-%dT%H%M%S%Z00\") -replace '[.][0-9]*_', '_'

Output was:

`timestamp`

In [None]:
# bballdave025@MYMACHINE /cygdrive/c/Users/bballdave025/.conda/envs/rwkv-lora-pat/Lib/site-packages/peft/utils
# $ date +'%s_%Y-%m-%dT%H%M%S%z'
# 1717049876_2024-05-30T001756-0600

log_verbosity_is_critical = \
  logging.get_verbosity() == logging.CRITICAL # alias FATAL, 50
log_verbosity_is_error = \
  logging.get_verbosity() == logging.ERROR # 40
log_verbosity_is_warn = \
  logging.get_verbosity() == logging.WARNING # alias WARN, 30
log_verbosity_is_info = \
  logging.get_verbosity() == logging.INFO # 20
log_verbosity_is_debug = \
  logging.get_verbosity() == logging.DEBUG # 10

print( "The statement, 'logging verbosity is CRITICAL' " + \
      f"is {log_verbosity_is_critical}")
print( "The statement, 'logging verbosity is    ERROR' " + \
      f"is {log_verbosity_is_error}")
print( "The statement, 'logging verbosity is  WARNING' " + \
      f"is {log_verbosity_is_warn}")
print( "The statement, 'logging verbosity is     INFO' " + \
      f"is {log_verbosity_is_info}")
print( "The statement, 'logging verbosity is    DEBUG' " + \
      f"is {log_verbosity_is_debug}")

print()

init_log_verbosity = logging.get_verbosity()
print(f"The value of logging.get_verbosity() is: {init_log_verbosity}")

print()

init_t_n_a_w = os.environ.get('TRANSFORMERS_NO_ADVISORY_WARNINGS')
print(f"TRANSFORMERS_NO_ADIVSORY_WARNINGS: {init_t_n_a_w}")

### Actual Baseline

In [None]:
!powershell -c (Get-Date -UFormat \"%s_%Y-%m-%dT%H%M%S%Z00\") -replace '[.][0-9]*_', '_'

Output was:

`timestamp`

In [None]:
#  ref1 = "https://web.archive.org/web/20240530051418/" + \
#+        "https://stackoverflow.com/questions/73221277/" + \
#+        "python-hugging-face-warning"
#  ref2 = "https://web.archive.org/web/20240530051559/" + \
#+        "https://huggingface.co/docs/transformers/en/" + \
#+        "main_classes/logging"

##  Haven't tried this, because the logging seemed easier,
##+ and the logging worked
#os.environ("TRANSFORMERS_NO_ADVISORY_WARNINGS") = 1

logging.set_verbosity_error()

summarizer = pipeline('summarization', model=model, tokenizer=tokenizer)

prediction_list = []
reference_list = []

tic = timeit.default_timer()

for sample_num in range(len(dataset['test'])):
  this_sample = dataset['test'][sample_num]
  
  #print(f"dialogue: \n{this_sample['dialogue']}\n---------------")

  grnd_summary = this_sample['summary']
  res = summarizer(this_sample['dialogue'])
  res_summary = res[0]['summary_text']
  
  #print(f"human-genratd summary:\n{grnd_summary}")
  #print(f"flan-t5-small summary:\n{res_summary}")
  
  reference_list.append(grnd_summary)
  prediction_list.append(res_summary)
##endof:  for sample_num in range(len(dataset['test']))

toc = timeit.default_timer()

baseline_duration = toc - tic

print( "Getting things ready for scoring")
print(f"took {toc - tic:0.4f} seconds.")

print("\n\n---------- ROUGE SCORES ----------")

rouge = load_metric('rouge', trust_remote_code=True)
  #  Set trust_remote_code=False to see the warning,
  #+ deprecation, and what to change to.

results = rouge.compute(predictions=prediction_list,
                        references=reference_list,
                        use_aggregator=True)

# >>> print(list(results.keys()))
# ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']

print()
print("ROUGE-1 results")
pprint.pp(results['rouge1'])
print()
print("ROUGE-2 results")
pprint.pp(results['rouge2'])
print()
print("ROUGE-L results")
pprint.pp(results['rougeL'])
print()
print("ROUGE-Lsum results")
pprint.pp(results['rougeLsum'])

##  Haven't tried this, because the logging seemed easier,
##+ and the logging worked
# os.environ("TRANSFORMERS_NO_ADVISORY_WARNINGS") = init_t_n_a_w

logging.set_verbosity(init_log_verbosity)

In [None]:
#  this time, put it in manually
do_enter_duration_manually = True

if do_enter_duration_manually:
    baseline_duration = 1161.2583 # remember to type in your number, if needed
##endof:  if do_enter_duration_manually

print("Running baseline inference (using the test set)")
print(f"took {format_timespan(baseline_duration)}")

### Trainer - the Actual Trainer Part

In [None]:
trainer = SFTTrainer( model=model,
                      train_dataset=dataset['train'],
                      eval_dataset=dataset['evaluation'],
                      peft_config=peft_config,
                      tokenizer=tokenizer,
                      packing=True,
                      formatting_func=prompt_instruction_format,
                      args=training_args,
                    )
##  Warnings are below output.

##  Ended up not using this.
#                      max_seq_length=675
#          )


First time warnings from the code above (as it still is).

        
>        WARNING:bitsandbytes.cextension:The installed version of bitsandbytes \
>         was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, \
>         and GPU quantization are unavailable.
>        C:\Users\bballdave025\.conda\envs\rwkv-lora-pat\lib\site-packages\trl\\
>         trainer\sft_trainer.py:246: UserWarning: You didn't pass a `max_seq_length` \
>        argument to the SFTTrainer, this will default to 512
>         warnings.warn(
>        
>        [ > Generating train split: 6143/0 [00:04<00:00, 2034.36 examples/s] ]
>        
>        Token indices sequence length is longer than the specified maximum sequence \
>         length for this model (657 > 512). Running this sequence through the model \
>         will result in indexing errors
>        
>        [ > Generating train split: 355/0 [00:00<00:00, 6.10 examples/s] ]

DWB Note

<strike>So, I'm changing the `max_seq_length`.</strike> 
Maybe I should just throw out the offender(s) 
(along with the blank one that's in there somewhere),
but I'll just continue as is.

Actually, it appears I didn't run the updated cell, 
(with `max_seq_length=675`), since the
Warning and Advice are still there.

## Let's Train This LoRA Thing and See How It Does!

In [None]:
# # No need to do this again
# !powershell -c (Get-Date -UFormat \"%s_%Y-%m-%dT%H%M%S%Z00\") -replace '[.][0-9]*_', '_'
###  DWB went in and renamed `profile.ps1` to `NOT-USING_-_pro_file_-_now.ps1.bak`
###+ That should get rid of our errors from `powershell`

Output was:

`timestamp`

At about `1717063394_2024-05-30T100314-0600`, DWB went in and 
renamed `profile.ps1` to `NOT-USING_-_pro_file_-_now.ps1.bak`
That should get rid of our errors from `powershell`

### The long-time-taking training code is just below.

In [None]:
tic = timeit.default_timer()
trainer.train()
toc = timeit.default_timer()
print(f"tic: {tic}")
print(f"toc: {toc}")
training_duration = toc - tic
print(f"Training took {toc - tic:0.4f} seconds.")

In [None]:
# this time, put it in by hand
do_by_hand = True

if do_by_hand:
  training_duration = 11087.7452 ## make sure to enter your value
##endof:  if do_by_hand
print( "Training with LoRA (and with the other info as above)")
print(f"took {format_timespan(training_duration)}.")

#### @todo : consolidate "the other info as above"

I'm talking about the numbers of data points, tokens, whatever.

#### Any Comments / Things to Try (?)

We passed an evaluation set (parameter ``) to the `trainer`.
How can we see information about that?

In [None]:
# space for doing trainer stuff

In [None]:
# more space for doing trainer stuff, if that happens

## Save the Trainer to Hugging Face and Get Our Updated Model

In [None]:
# #Don't need this again
# !powershell -c (Get-Date -UFormat \"%s_%Y-%m-%dT%H%M%S%Z00\") -replace '[.][0-9]*_', '_'

Output was:

`1717078324_2024-05-30T141204-0600`

I'm following the [(archived) tutorial from Mehul Gupta on Medium](https://web.archive.org/web/20240522140323/https://medium.com/data-science-in-your-pocket/lora-for-fine-tuning-llms-explained-with-codes-and-example-62a7ac5a3578); since it's archived, you can follow exactly what I'm doing.

In [None]:
# This will come up with a dialog box with text entry.

# Use the write token, here.
notebook_login()

In [None]:
# Save tokenizer and create a tokenizer model card
tokenizer.save_pretrained('testing')
  #  used 'testing' first - I don't like how vague that is.
  #+ I deleted the repo
  # using 'dwb-first-lora-test-flan-t5-guptal' broke it
  #  actually, I think deleting output from my main model
  #+ page broke it.

# Create the trainer model card
trainer.create_model_card()

# Push the results to the Hugging Face Hub
trainer.push_to_hub()

## Evaluation on the Test Set and Comparison to Baseline

#### Verbosity stuff - get rid of the nice advice

In [None]:
!powershell -c (Get-Date -UFormat \"%s_%Y-%m-%dT%H%M%S%Z00\") -replace '[.][0-9]*_', '_'

Output was:

`timestamp`

In [None]:
# bballdave025@MYMACHINE /cygdrive/c/Users/bballdave025/.conda/envs/rwkv-lora-pat/Lib/site-packages/peft/utils
# $ date +'%s_%Y-%m-%dT%H%M%S%z'
# 1717049876_2024-05-30T001756-0600

log_verbosity_is_critical = \
  logging.get_verbosity() == logging.CRITICAL # alias FATAL, 50
log_verbosity_is_error = \
  logging.get_verbosity() == logging.ERROR # 40
log_verbosity_is_warn = \
  logging.get_verbosity() == logging.WARNING # alias WARN, 30
log_verbosity_is_info = \
  logging.get_verbosity() == logging.INFO # 20
log_verbosity_is_debug = \
  logging.get_verbosity() == logging.DEBUG # 10

print( "The statement, 'logging verbosity is CRITICAL' " + \
      f"is {log_verbosity_is_critical}")
print( "The statement, 'logging verbosity is    ERROR' " + \
      f"is {log_verbosity_is_error}")
print( "The statement, 'logging verbosity is  WARNING' " + \
      f"is {log_verbosity_is_warn}")
print( "The statement, 'logging verbosity is     INFO' " + \
      f"is {log_verbosity_is_info}")
print( "The statement, 'logging verbosity is    DEBUG' " + \
      f"is {log_verbosity_is_debug}")

print()

init_log_verbosity = logging.get_verbosity()
print(f"The value of logging.get_verbosity() is: {init_log_verbosity}")

print()

init_t_n_a_w = os.environ.get('TRANSFORMERS_NO_ADVISORY_WARNINGS')
print(f"TRANSFORMERS_NO_ADIVSORY_WARNINGS: {init_t_n_a_w}")

### Here's the actual evaluation

In [None]:
!powershell -c (Get-Date -UFormat \"%s_%Y-%m-%dT%H%M%S%Z00\") -replace '[.][0-9]*_', '_'

Output was:

`timestamp`

<b>!!! NOTE !!!</b> I'm going to use `tat` (with an underscore
or undescores before, after, or surrounding the variable names)
to indicate 'testing-after-training'.

In [None]:
#  I'm going to use 'tat' for testing-after-training

logging.set_verbosity_error()

summarizer = pipeline('summarization', model=model, tokenizer=tokenizer)

prediction_tat_list = []
reference_tat_list = []

tic = timeit.default_timer()

for sample_num in range(len(dataset['test'])):
  this_sample = dataset['test'][sample_num]
  
  #print(f"dialogue: \n{this_sample['dialogue']}\n---------------")

  grnd_tat_summary = this_sample['summary']
  res_tat = summarizer(this_sample['dialogue'])
  res_tat_summary = res_tat[0]['summary_text']
  
  #print(f"human-genratd summary:\n{grnd_tat_summary}")
  #print(f"flan-t5-small summary:\n{res_tat_summary}")
  
  reference_tat_list.append(grnd_tat_summary)
  prediction_tat_list.append(res_tat_summary)
##endof:  for sample_num in range(len(dataset['test']))

toc = timeit.default_timer()

print( "Getting things ready for scoring (after training)")
print(f"took {toc - tic:0.4f} seconds.")

print("\n\n---------- ROUGE SCORES ----------")

rouge = load_metric('rouge', trust_remote_code=True)
  #  Set trust_remote_code=False to see the warning,
  #+ deprecation, and what to change to.

results_tat = rouge.compute(
                  predictions=prediction_tat_list,
                  references=reference_tat_list,
                  use_aggregator=True
)

# >>> print(list(results_tat.keys()))
# ['rouge1', 'rouge2', 'rougeL', 'rougeLsum']

print()
print("ROUGE-1 results")
pprint.pp(results_tat['rouge1'])
print()
print("ROUGE-2 results")
pprint.pp(results_tat['rouge2'])
print()
print("ROUGE-L results")
pprint.pp(results_tat['rougeL'])
print()
print("ROUGE-Lsum results")
pprint.pp(results_tat['rougeLsum'])

logging.set_verbosity(init_log_verbosity)