#### Summarization with LLMs

In [1]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

data_2 = load_dataset("knkarthick/dialogsum");

In [3]:
model_checkpoint = "google/flan-t5-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device);

In [6]:
from transformers.generation.stopping_criteria import MaxLengthCriteria

def greedy_search(attention_mask, encoder_outputs, max_new_tokens=50, batch_size=1):
    eos_token_id = model.generation_config.eos_token_id
    decoder_start_token_id = 0
    input_ids = torch.ones((batch_size, 1), dtype=torch.long, device=device) * decoder_start_token_id
    max_length = max_new_tokens + input_ids.shape[-1]

    stopping_criteria = MaxLengthCriteria(max_length=max_length)
    eos_token_id_tensor = torch.tensor([eos_token_id]).to(input_ids.device)
    generating = True
    past_key_values = None

    while generating:
        outputs = model(
            decoder_input_ids=input_ids[:, -1:],
            past_key_values=past_key_values,
            encoder_outputs=encoder_outputs,
            attention_mask=attention_mask,
            use_cache=True,
            return_dict=True,
            output_attentions=False,
            output_hidden_states=False,
        )

        next_tokens = torch.argmax(outputs.logits[:, -1, :], dim=-1)
        past_key_values = outputs.past_key_values

        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
        generating = not (next_tokens.eq(eos_token_id_tensor).any() or stopping_criteria(input_ids, None))

    return input_ids


def do_generate(inputs_from_tokenizer):
    #First, we need to use the encoder model
    encoder_outputs = model.get_encoder()(
        input_ids=inputs_from_tokenizer["input_ids"],
        attention_mask=inputs_from_tokenizer["attention_mask"],
        output_attentions=False,
        output_hidden_states=False,
        return_dict=True
    )

    #Use the decoder model to generate the output sequence
    result_tokens = greedy_search(
        attention_mask=inputs["attention_mask"],
        encoder_outputs=encoder_outputs
    )

    return result_tokens

In [10]:
for i in [0,1,2,45,50,60]:
    print("Text:")
    text = "summarize: \n" + data_2["train"][i]["dialogue"]
    inputs = tokenizer(text, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    generation = do_generate(inputs)
    gen_text = tokenizer.batch_decode(generation, skip_special_tokens=True)
    print(text)
    print("Generation:")
    print(gen_text[0])
    print("\n")

Text:
summarize: 
#Person1#: Hi, Mr. Smith. I'm Doctor Hawkins. Why are you here today?
#Person2#: I found it would be a good idea to get a check-up.
#Person1#: Yes, well, you haven't had one for 5 years. You should have one every year.
#Person2#: I know. I figure as long as there is nothing wrong, why go see the doctor?
#Person1#: Well, the best way to avoid serious illnesses is to find out about them early. So try to come at least once a year for your own good.
#Person2#: Ok.
#Person1#: Let me see here. Your eyes and ears look fine. Take a deep breath, please. Do you smoke, Mr. Smith?
#Person2#: Yes.
#Person1#: Smoking is the leading cause of lung cancer and heart disease, you know. You really should quit.
#Person2#: I've tried hundreds of times, but I just can't seem to kick the habit.
#Person1#: Well, we have classes and some medications that might help. I'll give you more information before you leave.
#Person2#: Ok, thanks doctor.
Generation:
Person2 is here for a check-up. He has