In [40]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config

In [41]:
# !pip install sentencepiece

# Option 1: Using T5 Model

In [42]:
#initialize the model and tokenizer
model_name = 't5-small'
tokenizer = T5Tokenizer.from_pretrained(model_name)
config = T5Config.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name, config=config)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


T5 models expect the input text to be prefixed with a task-specific prompt, such as "summarize: " for summarization tasks. This helps the model understand what kind of output is expected.

In [43]:
text="""We all know that OpenAI actually started the trend of AI tools after releasing ChatGPT.

After that, we saw everyone shift to AI. Developers and big companies began building AI tools, and even individuals started learning about artificial intelligence.

Thanks to that, we have tons of popular AI tools like RunwayML, Lovable, Claude, Gemini, Perplexity, Cursor, Stitch, NotebookLM, Leonardo AI, Framer AI, and the list goes on.

That's not all. Every day, tons of new AI tools are launched, which makes it difficult for people to find the best ones for their needs.

That's why I spend a lot of time testing some of the best new AI tools and write a couple of posts every month to share the ones that truly stand out."""

In [44]:
import re
import html
# modified with GPT for best practices
def clean_text(text):
    text = html.unescape(text)
    text = text.lower()
    text = re.sub(r'<.*?>', ' ', text)             # remove HTML
    text = re.sub(r'http\S+|www\.\S+', ' ', text)  # remove URLs
    text = re.sub(r'\S+@\S+', ' ', text)           # remove emails
    text = re.sub(r"(.)\1{2,}", r"\1\1", text)     # soooo -> soo

    # keep ! and ? (sentiment)
    text = re.sub(r'[^a-z!? ]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [45]:
cleaned_text=clean_text(text).split()
len(cleaned_text)

127

In [46]:
tokenized_text=tokenizer.encode("summarize: " + clean_text(text), return_tensors="pt", max_length=512, truncation=True)
print(tokenized_text)

tensor([[21603,    10,    62,    66,   214,    24,   539,     9,    23,   700,
           708,     8,  4166,    13,     3,     9,    23,  1339,   227,     3,
         16306,  3582,   122,   102,    17,   227,    24,    62,  1509,   921,
          4108,    12,     3,     9,    23,  5564,    11,   600,   688,  1553,
           740,     3,     9,    23,  1339,    11,   237,  1742,   708,  1036,
            81,  7353,  6123,  2049,    12,    24,    62,    43,  8760,    13,
          1012,     3,     9,    23,  1339,   114, 22750,    51,    40,     3,
          5850,   179,     3,    75, 12513,    15,   873,  7619,   399,  9247,
           485,  8385,   127, 12261, 16638,    40,    51,    90,   106,   986,
            32,     3,     9,    23,  2835,    52,     3,     9,    23,    11,
             8,   570,  1550,    30,    24,     3,     7,    59,    66,   334,
           239,  8760,    13,   126,     3,     9,    23,  1339,    33,  3759,
            84,   656,    34,  1256,    21,   151,  

In [47]:
summary_ids=model.generate(tokenized_text, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
summary=tokenizer.decode(summary_ids[0], skip_special_tokens=True)

print("Summary:")
print(summary)
print(len(summary.split()))

Summary:
openai has tons of popular ai tools like runwayml lovable claude gemini perplexity cursor stitch notebooklm leonardo ai framer ai. the list goes on that s not all every day tons of new ai tools are launched which makes it difficult for people to find the best ones for their needs.
51


# Option 2 : Using Pipeline

In [48]:
from transformers import pipeline

summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
summary = summarizer(text, max_length=130, min_length=30)

print(summary)




Device set to use cpu


[{'summary_text': "Every day, tons of new AI tools are launched, which makes it difficult for people to find the best ones for their needs. That's why I spend a lot of time testing some of the best newAI tools and write a couple of posts every month to share the ones that truly stand out."}]


# Option 3: LSTM/GRU/RNN

In [49]:
import os
from glob import glob
from pathlib import Path

news_files = Path("E:\\70 Days 70 Project\\Text Summarization\\data\\BBC News Summary\\News Articles")
summaries_files = Path("E:\\70 Days 70 Project\\Text Summarization\\data\\BBC News Summary\\Summaries")

news_categories=os.listdir(news_files)
summary_categories=os.listdir(summaries_files)
news_categories,summary_categories

(['business', 'entertainment', 'politics', 'sport', 'tech'],
 ['business', 'entertainment', 'politics', 'sport', 'tech'])

In [50]:
for files in os.listdir(Path(news_files/news_categories[0])):
    print(files)

001.txt
002.txt
003.txt
004.txt
005.txt
006.txt
007.txt
008.txt
009.txt
010.txt
011.txt
012.txt
013.txt
014.txt
015.txt
016.txt
017.txt
018.txt
019.txt
020.txt
021.txt
022.txt
023.txt
024.txt
025.txt
026.txt
027.txt
028.txt
029.txt
030.txt
031.txt
032.txt
033.txt
034.txt
035.txt
036.txt
037.txt
038.txt
039.txt
040.txt
041.txt
042.txt
043.txt
044.txt
045.txt
046.txt
047.txt
048.txt
049.txt
050.txt
051.txt
052.txt
053.txt
054.txt
055.txt
056.txt
057.txt
058.txt
059.txt
060.txt
061.txt
062.txt
063.txt
064.txt
065.txt
066.txt
067.txt
068.txt
069.txt
070.txt
071.txt
072.txt
073.txt
074.txt
075.txt
076.txt
077.txt
078.txt
079.txt
080.txt
081.txt
082.txt
083.txt
084.txt
085.txt
086.txt
087.txt
088.txt
089.txt
090.txt
091.txt
092.txt
093.txt
094.txt
095.txt
096.txt
097.txt
098.txt
099.txt
100.txt
101.txt
102.txt
103.txt
104.txt
105.txt
106.txt
107.txt
108.txt
109.txt
110.txt
111.txt
112.txt
113.txt
114.txt
115.txt
116.txt
117.txt
118.txt
119.txt
120.txt
121.txt
122.txt
123.txt
124.txt
125.txt


Ref: https://www.kaggle.com/code/mallaavinash/text-summarization

In [51]:
contraction_mapping = {"ain't": "is not", "aren't": "are not","can't": "cannot", "'cause": "because", "could've": "could have", "couldn't": "could not",

                           "didn't": "did not", "doesn't": "does not", "don't": "do not", "hadn't": "had not", "hasn't": "has not", "haven't": "have not",

                           "he'd": "he would","he'll": "he will", "he's": "he is", "how'd": "how did", "how'd'y": "how do you", "how'll": "how will", "how's": "how is",

                           "I'd": "I would", "I'd've": "I would have", "I'll": "I will", "I'll've": "I will have","I'm": "I am", "I've": "I have", "i'd": "i would",

                           "i'd've": "i would have", "i'll": "i will",  "i'll've": "i will have","i'm": "i am", "i've": "i have", "isn't": "is not", "it'd": "it would",

                           "it'd've": "it would have", "it'll": "it will", "it'll've": "it will have","it's": "it is", "let's": "let us", "ma'am": "madam",

                           "mayn't": "may not", "might've": "might have","mightn't": "might not","mightn't've": "might not have", "must've": "must have",
                            "mustn't": "must not", "mustn't've": "must not have", "needn't": "need not", "needn't've": "need not have","o'clock": "of the clock",

                           "oughtn't": "ought not", "oughtn't've": "ought not have", "shan't": "shall not", "sha'n't": "shall not", "shan't've": "shall not have",

                           "she'd": "she would", "she'd've": "she would have", "she'll": "she will", "she'll've": "she will have", "she's": "she is",

                           "should've": "should have", "shouldn't": "should not", "shouldn't've": "should not have", "so've": "so have","so's": "so as",

                           "this's": "this is","that'd": "that would", "that'd've": "that would have", "that's": "that is", "there'd": "there would",

                           "there'd've": "there would have", "there's": "there is", "here's": "here is","they'd": "they would", "they'd've": "they would have",
                           "they'll": "they will", "they'll've": "they will have", "they're": "they are", "they've": "they have", "to've": "to have",

                           "wasn't": "was not", "we'd": "we would", "we'd've": "we would have", "we'll": "we will", "we'll've": "we will have", "we're": "we are",

                           "we've": "we have", "weren't": "were not", "what'll": "what will", "what'll've": "what will have", "what're": "what are",

                           "what's": "what is", "what've": "what have", "when's": "when is", "when've": "when have", "where'd": "where did", "where's": "where is",

                           "where've": "where have", "who'll": "who will", "who'll've": "who will have", "who's": "who is", "who've": "who have",

                           "why's": "why is", "why've": "why have", "will've": "will have", "won't": "will not", "won't've": "will not have",

                           "would've": "would have", "wouldn't": "would not", "wouldn't've": "would not have", "y'all": "you all",
                           "y'all'd": "you all would","y'all'd've": "you all would have","y'all're": "you all are","y'all've": "you all have",

                           "you'd": "you would", "you'd've": "you would have", "you'll": "you will", "you'll've": "you will have",

                           "you're": "you are", "you've": "you have"}


In [52]:
import re
import html
# modified with GPT for best practices
def clean_text(text):
    text = html.unescape(text)
    text = text.lower()
    text=' '.join([contraction_mapping[i] if i in contraction_mapping.keys() else i for i in text.split()])
    text = re.sub(r'<.*?>', ' ', text)             # remove HTML
    text = re.sub(r'http\S+|www\.\S+', ' ', text)  # remove URLs
    text = re.sub(r'\S+@\S+', ' ', text)           # remove emails
    text = re.sub(r"(.)\1{2,}", r"\1\1", text)     # soooo -> soo

    # keep ! and ? (sentiment)
    text = re.sub(r'[^a-z!? ]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text


#### Run once

In [53]:
# import pandas as pd
# from tqdm import tqdm

# dataframe={'news':[], 'summary':[]}

# for category in news_categories:
#     for file in tqdm(os.listdir(Path(news_files/category)),desc=f"News Category: {category}"):
#         with open(Path(news_files/category/file), 'r', encoding='utf-8', errors='replace') as news_file:
#             news_content=news_file.read()
#             dataframe['news'].append(news_content)
#     for file in tqdm(os.listdir(Path(summaries_files/category)),desc=f"Summary Category: {category}"):
#         with open(Path(summaries_files/category/file), 'r', encoding='utf-8', errors='replace') as summary_file:
#             summary_content=summary_file.read()
#             dataframe['summary'].append(summary_content)
# df=pd.DataFrame(dataframe)
# df.head()

In [54]:
# df.shape

In [55]:
# df.to_csv("bbc_news_summary_dataset.csv", index=False)

In [56]:
import pandas as pd

df=pd.read_csv("bbc_news_summary_dataset.csv")
df.head()

Unnamed: 0,news,summary
0,Ad sales boost Time Warner profit\n\nQuarterly...,TimeWarner said fourth quarter sales rose 2% t...
1,Dollar gains on Greenspan speech\n\nThe dollar...,The dollar has hit its highest level against t...
2,Yukos unit buyer faces loan claim\n\nThe owner...,Yukos' owner Menatep Group says it will ask Ro...
3,High fuel prices hit BA's profits\n\nBritish A...,"Rod Eddington, BA's chief executive, said the ..."
4,Pernod takeover talk lifts Domecq\n\nShares in...,Pernod has reduced the debt it took on to fund...


In [57]:
df["news"]=df["news"].apply(lambda x: clean_text(x))
df["summary"]=df["summary"].apply(lambda x: clean_text(x))
df.head()

Unnamed: 0,news,summary
0,ad sales boost time warner profit quarterly pr...,timewarner said fourth quarter sales rose to b...
1,dollar gains on greenspan speech the dollar ha...,the dollar has hit its highest level against t...
2,yukos unit buyer faces loan claim the owners o...,yukos owner menatep group says it will ask ros...
3,high fuel prices hit ba s profits british airw...,rod eddington ba s chief executive said the re...
4,pernod takeover talk lifts domecq shares in uk...,pernod has reduced the debt it took on to fund...


In [58]:
df["summary"]='<sos> '+df["summary"]+' <eos>'
df.head()

Unnamed: 0,news,summary
0,ad sales boost time warner profit quarterly pr...,<sos> timewarner said fourth quarter sales ros...
1,dollar gains on greenspan speech the dollar ha...,<sos> the dollar has hit its highest level aga...
2,yukos unit buyer faces loan claim the owners o...,<sos> yukos owner menatep group says it will a...
3,high fuel prices hit ba s profits british airw...,<sos> rod eddington ba s chief executive said ...
4,pernod takeover talk lifts domecq shares in uk...,<sos> pernod has reduced the debt it took on t...


In [59]:
PAD = "<pad>"
SOS = "<sos>"
EOS = "<eos>"
UNK = "<unk>"

Ref: GPT

In [60]:
from collections import Counter

class Vocab:
    def __init__(self, texts, max_size=30000, min_freq=2):
        counter = Counter()
        for text in texts:
            counter.update(text.split())

        self.itos = [PAD, SOS, EOS, UNK]
        for word, freq in counter.most_common():
            if freq >= min_freq and len(self.itos) < max_size:
                self.itos.append(word)

        self.stoi = {word: idx for idx, word in enumerate(self.itos)}

    def encode(self, text):
        return [self.stoi.get(w, self.stoi[UNK]) for w in text.split()]

    def __len__(self):
        return len(self.itos)


In [61]:
article_vocab = Vocab(df["news"], max_size=30000)
summary_vocab = Vocab(df["summary"], max_size=15000)
print(f"Article Vocab Size: {len(article_vocab)}")
print(f"Summary Vocab Size: {len(summary_vocab)}")

Article Vocab Size: 18730
Summary Vocab Size: 12388


In [62]:
from torch.utils.data import Dataset,DataLoader
import torch
class SummaryDataset(Dataset):
    def __init__(self,df,article_vocab,summary_vocab,max_article_len=512,max_summary_len=50):
        super().__init__()
        self.df=df
        self.article_vocab=article_vocab
        self.summary_vocab=summary_vocab
        self.max_article_len=max_article_len
        self.max_summary_len=max_summary_len
        
    def __len__(self):
        return len(self.df)
    
    def pad(self, sequence, max_length, pad_idx):
        return sequence[:max_length] + [pad_idx] * (max_length - len(sequence))
    
    def __getitem__(self, idx):
        article=self.df.iloc[idx]["news"]
        summary=self.df.iloc[idx]["summary"]
        
        enc=self.article_vocab.encode(article)
        dec=self.summary_vocab.encode(summary)
        
        enc=self.pad(enc,self.max_article_len,self.article_vocab.stoi[PAD])
        dec=self.pad(dec,self.max_summary_len,self.summary_vocab.stoi[PAD])
        
        return torch.tensor(enc), torch.tensor(dec)

In [63]:
dataset=SummaryDataset(df,article_vocab,summary_vocab)
loader=DataLoader(dataset,batch_size=2,shuffle=True)

In [64]:
next(iter(dataset))

(tensor([ 4052,   187,   700,    71,  3353,  1026,  3787,   623,    25,    51,
           299,   694,  7877,  2857,     5,    89,    57,    11,     4,   104,
           197,     5,   324,    31,    57,    44,   349,     4,   146,    40,
            10,    76,    53,     6,     4,   326,   846,     9,   923,  7878,
            31,   187,     6,   161,   803,   315,  1846,     7,   586,  5340,
           187,  7877,    16,   627,   394,   187,   636,     5,    89,    31,
            89,    45,   623,    47,  9466,    28,    53,   130,  2726,    40,
          3921,     8,  1026,  3682,    25,  3353,  6344,     7,   353,   226,
            11,  3020,    71,  3353,    16,    15,   543,    13,    14,    76,
          2265,     6,   482,  1881,   923,    29,    45,   202,   315,   239,
          3020,    39,    23,  2858,  4053,    14,   321,  2786,     9,     4,
           627,   394,   623,    47,   917,    60,     9,     4, 10595,   104,
          2787,   179,     4,   125,    16,  3020,  

In [65]:
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)

    def forward(self, x):
        emb = self.embedding(x)
        outputs, (h, c) = self.lstm(emb)
        return h, c

In [66]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, h, c):
        x = x.unsqueeze(1) # Add sequence dimension
        emb = self.embedding(x)
        output, (h, c) = self.lstm(emb, (h, c))
        pred = self.fc(output.squeeze(1))
        return pred, h, c


`x = x.unsqueeze(1)`

Why?

PyTorch LSTM expects:
- (batch, seq_len, features)

But x is:
- (batch,)

So we add a fake time dimension:
- (batch, 1)

In [67]:
class SeqtoSeqModel(nn.Module):
    def __init__(self, encoder_model ,decoder_model , computation_device):
        super(SeqtoSeqModel, self).__init__()
        
        self.encoder_model  = encoder_model
        self.decoder_model = decoder_model
        self.device = computation_device
        
    def forward(self, enc_in_seq, dec_in_seq, teacher_forcing_ratio=0.5):
        batch_size, decoder_sequence_length = dec_in_seq.shape
        decoder_vocabulary_size = self.decoder_model.fc.out_features
        
        all_decoder_outputs = torch.zeros(batch_size, decoder_sequence_length, decoder_vocabulary_size, device=self.device)
        
        encoder_hidden, encoder_cell = self.encoder_model(enc_in_seq)
        current_decoder_input_token = dec_in_seq[:, 0]  # the <sos> token
        
        for time_stamp in range(decoder_sequence_length):
            predicted_token_logits, encoder_hidden, encoder_cell = self.decoder_model(current_decoder_input_token, encoder_hidden, encoder_cell)
            all_decoder_outputs[:, time_stamp, :] = predicted_token_logits
            use_teacher_forcing = (torch.rand(1).item() < teacher_forcing_ratio) and (time_stamp + 1 < decoder_sequence_length)
            if use_teacher_forcing:
                current_decoder_input_token = dec_in_seq[:, time_stamp + 1]
            else:
                current_decoder_input_token = predicted_token_logits.argmax(dim=1)
        return all_decoder_outputs
        

In [68]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

embed_dim = 256
hidden_dim = 512

encoder = Encoder(len(article_vocab), embed_dim, hidden_dim)
decoder = Decoder(len(summary_vocab), embed_dim, hidden_dim)

sequence_to_sequence_model = SeqtoSeqModel(encoder, decoder, device).to(device)

In [69]:
pad_idx = summary_vocab.stoi[PAD]
loss_function = nn.CrossEntropyLoss(ignore_index=pad_idx)
optimizer = torch.optim.Adam(
    sequence_to_sequence_model.parameters(),
    lr=0.001
)

In [None]:
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

num_epochs = 20
teacher_forcing_ratio = 0.6

sequence_to_sequence_model.train()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    for enc_batch, dec_batch in tqdm(loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        enc_batch = enc_batch.to(device)
        dec_batch = dec_batch.to(device)

        dec_input = dec_batch[:, :-1]
        targets = dec_batch[:, 1:]

        optimizer.zero_grad()
        outputs = sequence_to_sequence_model(enc_batch, dec_input, teacher_forcing_ratio=teacher_forcing_ratio)
        outputs = outputs.reshape(-1, outputs.size(-1))
        targets = targets.reshape(-1)

        loss = loss_function(outputs, targets)
        loss.backward()
        clip_grad_norm_(sequence_to_sequence_model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += loss.item()

    avg_loss = epoch_loss / len(loader)
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {avg_loss:.4f}")

In [72]:
def summarize_text(model, article, article_vocab, summary_vocab, device, max_summary_len=50):
    model.eval()
    with torch.no_grad():
        enc_input = torch.tensor(article_vocab.encode(article)).unsqueeze(0).to(device)
        encoder_hidden, encoder_cell = model.encoder_model(enc_input)

        current_decoder_input_token = torch.tensor([summary_vocab.stoi[SOS]]).to(device)
        summary_tokens = []

        for _ in range(max_summary_len):
            predicted_token_logits, encoder_hidden, encoder_cell = model.decoder_model(current_decoder_input_token, encoder_hidden, encoder_cell)
            predicted_token_id = predicted_token_logits.argmax(dim=1).item()
            if predicted_token_id == summary_vocab.stoi[EOS]:
                break
            summary_tokens.append(predicted_token_id)
            current_decoder_input_token = torch.tensor([predicted_token_id]).to(device)

        summary_words = [summary_vocab.itos[token_id] for token_id in summary_tokens]
        return ' '.join(summary_words)

def summarize_text_beam_search(model, article, article_vocab, summary_vocab, device, max_summary_len=50, beam_width=3):
    """Beam search with repetition penalty and length normalization"""
    model.eval()
    with torch.no_grad():
        enc_input = torch.tensor(article_vocab.encode(article)).unsqueeze(0).to(device)
        encoder_hidden, encoder_cell = model.encoder_model(enc_input)
        
        # Initialize beams: (log_prob, tokens, hidden, cell, is_finished)
        beams = [(0.0, [summary_vocab.stoi[SOS]], encoder_hidden, encoder_cell, False)]
        finished_beams = []
        
        for step in range(max_summary_len):
            new_beams = []
            
            for log_prob, tokens, h, c, is_finished in beams:
                if is_finished:
                    finished_beams.append((log_prob, tokens))
                    continue
                    
                current_token = torch.tensor([tokens[-1]]).to(device)
                logits, h_new, c_new = model.decoder_model(current_token, h, c)
                log_probs = torch.log_softmax(logits, dim=1)[0]
                
                # Repetition penalty: penalize tokens that appear frequently in the sequence
                for token_id in set(tokens[1:]):  # Skip SOS token
                    count = tokens[1:].count(token_id)
                    if count > 2:
                        log_probs[token_id] -= 1.0 * (count - 1)  # Penalty increases with frequency
                
                # Block PAD token
                log_probs[summary_vocab.stoi[PAD]] = float('-inf')
                
                # Get top-k candidates
                top_k = torch.topk(log_probs, min(beam_width, log_probs.numel()), largest=True)
                
                for candidate_log_prob, candidate_id in zip(top_k.values, top_k.indices):
                    if torch.isinf(candidate_log_prob):
                        continue
                    new_log_prob = log_prob + candidate_log_prob.item()
                    new_tokens = tokens + [candidate_id.item()]
                    is_eos = (candidate_id.item() == summary_vocab.stoi[EOS])
                    
                    new_beams.append((new_log_prob, new_tokens, h_new, c_new, is_eos))
            
            # Sort by normalized log probability (length penalty)
            new_beams.sort(reverse=True, key=lambda x: x[0] / max(len(x[1]), 1) ** 0.7)
            beams = new_beams[:beam_width]
            
            # Separate finished and unfinished beams
            finished_beams.extend([b for b in beams if b[4]])
            beams = [b for b in beams if not b[4]]
            
            if not beams:
                break
        
        # Combine and sort all beams
        all_beams = finished_beams + beams
        all_beams.sort(reverse=True, key=lambda x: x[0] / max(len(x[1]), 1) ** 0.7)
        
        # Return best beam
        if all_beams:
            best_tokens = all_beams[0][1][1:]  # Remove SOS token
            summary_words = [summary_vocab.itos[token_id] for token_id in best_tokens if token_id != summary_vocab.stoi[EOS]]
            return ' '.join(summary_words)
        return ""

In [None]:
print("Greedy Decoding:")
print(summarize_text(sequence_to_sequence_model, df.iloc[0]["news"], article_vocab, summary_vocab, device))

print("\nBeam Search Decoding (beam_width=3):")
print(summarize_text_beam_search(sequence_to_sequence_model, df.iloc[0]["news"], article_vocab, summary_vocab, device, beam_width=3))