In [None]:
from transformers import T5ForConditionalGeneration, AutoTokenizer
import torch
from tqdm import tqdm

In [None]:
# Load the pretrained model by using the local path or the model name of a huggingface model.

model =  T5ForConditionalGeneration.from_pretrained('ndtran/t5-small_cnn-daily-mails')
tokenizer = AutoTokenizer.from_pretrained('t5-small')
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [None]:
! pip install rouge-score, sentencepiece, spacy

In [None]:
import string, re, spacy

nlp = spacy.load("en_core_web_sm")

In [None]:
def summarize(text):
    global model, tokenizer, device, configs, nlp
    
    input_ids = tokenizer(configs['task_prefix'] + text, return_tensors = 'pt').input_ids
        
    generated_ids = model.generate(
        input_ids.to(device), 
        do_sample = True, 
        max_length = 256,
        top_k = 1, 
        temperature = 0.8
    )
    
    doc = nlp(tokenizer.decode(generated_ids.squeeze(), skip_special_tokens=True))
    sents = [str(sent).lstrip(string.punctuation + ' ').rstrip() for sent in doc.sents]
    
    for i, sent in enumerate(sents):
        if len(sent) > 0:
            sents[i] = sent[0].upper() + sent[1:]
    
    return " ".join(sents)

In [None]:
def multipart_summarize(text):
    global model, tokenizer, device, configs, nlp
    
        
    buffer, tokens_count = '', 0
    nlp_text = nlp(text)
    
    blocks = []
    
    for sent in nlp_text.sents:
        tokens = tokenizer.tokenize(str(sent))
                
        if len(tokens) > 512:
            if tokens_count > 0:
                blocks.append(buffer)
                buffer, tokens_count = '', 0
            
            blocks.append(str(sent))
            
        buffer += str(sent)
        tokens_count += len(tokens)
        
        if tokens_count > 512:
            blocks.append(buffer)
            buffer, tokens_count = '', 0
            
    if tokens_count > 0:
        blocks.append(buffer)
                    
    return " ".join(summarize(e) for e in blocks)

In [None]:
## How to use the above functions

with open('very-long-document.txt', 'r') as fp:
    text = fp.read()
    
print(summarize(text)) 
# or
print(multipart_summarize(text)) # recommended

In [None]:
import json
from torch.utils.data import Dataset, DataLoader

with open('/kaggle/input/t5-base-tokens-cnn-daily/test_ds_encoded.json', 'r') as fp:
    test_list = json.load(fp)

In [None]:
class CNNDaily(Dataset):
    def __init__(self, elements):
        self.elements = elements
    
    def __len__(self):
        return len(self.elements)
    
    def __getitem__(self, index):
        try:
            res = self.elements[index]
            return torch.LongTensor(res['input_ids']), torch.LongTensor(res['attention_mask']), torch.LongTensor(res['labels'])
        except Exception as err:
            print('Exception raised while loading item', index, '\nTrying to load', (index + 1) % len(self.elements))
            print(err)
            return None # self.__getitem__((index + 1) % len(self.elements))

In [None]:
test_ds = CNNDaily(test_list)

test_loader = DataLoader(
    test_ds,
    batch_size = 128,
    shuffle = False
)

In [None]:
# to get raw text from tokens
def unpack(inputs_ids, labels_ids, outputs_ids):
        
    inputs = tokenizer.batch_decode(
        inputs_ids, skip_special_tokens=True)
    
    labels_ids[labels_ids == -100] = 0
    
    labels = tokenizer.batch_decode(
        labels_ids, skip_special_tokens=True)
    
    outputs = tokenizer.batch_decode(
        outputs_ids, skip_special_tokens=True)
    
    return [ {
            'input': inputs[i],
            'label': labels[i],
            'output': outputs[i]
        } for i in range(
            min(inputs_ids.shape[0], labels_ids.shape[0], outputs_ids.shape[0])) ]

In [None]:
results = [] # store the dictionary list

In [None]:
for i, (X, Y, Z) in tqdm(
    enumerate(test_loader), total = len(test_loader), 
    unit = 'Batch', desc = 'Generating'
):
    X, Z = X.to(device), Z.to(device)
    
    generated_ids = model.generate(
        X, 
        do_sample=True, 
        max_length = 128, 
        top_k = 1, 
        temperature = 0.7
    )
    
    results += unpack(X, Z, generated_ids)

In [None]:
with open('predictions.json', 'w') as fp:
    json.dump(results, fp)

#### For the last step

- Save all the results to file to be easily used later for evaluation or analysis
- Here is our sample output for the pre-trained model: [Link](https://www.kaggle.com/code/ndtran/t5-small-inference/notebook?scriptVersionId=135131238)
- And the results for original model: [Link](https://www.kaggle.com/ndtran/t5-small-inference?scriptVersionId=135131482)