In [1]:
from datasets import load_dataset
from transformers import (AutoModelForSeq2SeqLM, 
                          AutoTokenizer, 
                          GenerationConfig, 
                          TrainingArguments, 
                          Trainer)
import torch
import time
import os
import evaluate
import pandas as pd
import numpy as np
from math import ceil

2023-09-05 19:05:32.259120: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-09-05 19:05:33.735484: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory
2023-09-05 19:05:33.735611: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory


In [2]:
# import transformers
# transformers.__version__ == '4.28.1'

# import peft
# peft.__version__ == '0.3.0'

# import torch
# torch.__version__ == "2.0.0+cu117"

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "2"  
torch.cuda.device_count()

1

In [4]:
from peft import PeftModel


class PeftModelUtils:
    @staticmethod
    def load_base_model(model_path="google/flan-t5-base"):
        model = AutoModelForSeq2SeqLM.from_pretrained(
            model_path, torch_dtype=torch.bfloat16, device_map="auto"
        )
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        return model, tokenizer

    @staticmethod
    def load_from_peft_adapter(
        base_model_path, peft_model_path, train=False, merge_adapter=True
    ):
        model, tokenizer = PeftModelUtils.load_base_model(base_model_path)
        model = PeftModel.from_pretrained(
            model, 
            peft_model_path, 
            torch_dtype=torch.bfloat16, 
            is_trainable=train,
            device_map="auto"
        )

        if merge_adapter:
            model = model.merge_and_unload()

            if train:
                for param in model.parameters():
                    param.requires_grad = True

        # merge the adapter to the main model
        return model, tokenizer

    @staticmethod
    def save_peft_adapter(model, model_path):
        model.save_pretrained(model_path)

    @staticmethod
    def merge_peft_and_save(model, model_path):
        model = model.merge_and_unload()
        model.save_pretrained(model_path)
        
    @staticmethod
    def save_tokenizer(tokenizer):
        tokenizer.save_pretrained(model_path)


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 113
CUDA SETUP: Loading binary /home/qblocks/.local/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda113.so...


  warn(msg)


# load pretrained flan t5 

In [5]:
# # load original model
# model_name='google/flan-t5-base'
# original_model, original_tokenizer = PeftModelUtils.load_base_model(model_path=model_name)

# load summary adapter

In [6]:
model_name='google/flan-t5-base'
adapter_path = "./checkpoint/adapter/"
peft_model, tokenizer = PeftModelUtils.load_from_peft_adapter(model_name, adapter_path, merge_adapter=False)

# load title adapter

In [7]:
model_name='google/flan-t5-base'
adapter_path = "./checkpoint/title_adapter/"
title_model, tokenizer = PeftModelUtils.load_from_peft_adapter(model_name, adapter_path, merge_adapter=False)

# Infer

In [8]:
# bring live news
from newspaper import Article
import re
import nltk
from datetime import datetime


nltk.download('punkt')
phrases_to_remove = ["Sign In", "Want to read more?", "Already have an account?", "To continue reading"]

def remove_phrases(string, phrases):
    pattern = '|'.join(re.escape(phrase) for phrase in phrases)
    result = re.split(pattern, string)
    return result[0]


def curate_article(article):
    # Remove "Advertisement" sections
    curated_article = re.sub(r'Advertisement', '', article)

    # Remove extra spaces and new lines
    curated_article = re.sub(r'\n{3,}', '\n\n', curated_article)
    
    # Remove everything after the stop phrases
    curated_article = remove_phrases(curated_article, phrases_to_remove)
    
    # routine curation
    curated_article = re.sub(r'\s+', ' ', curated_article)
    curated_article = curated_article.strip()

    return curated_article


def get_news(news):
    url = news["link"]
    
    article = Article(url)
    article.download()
    article.parse()
    article.nlp()
    
    news["full_text"] = curate_article(article.text)
    
    if "image_url" not in news:
        news["image_url"] = article.top_image
        
    if "date" not in news:
        news["date"] = datetime.strptime(str(article.meta_data["pdate"]), "%Y%m%d")
    
    if "datetime" not in news:
        news["datetime"] = datetime.strptime(str(article.meta_data["pdate"]), "%Y%m%d")
        
    if "title" not in news:
        news["title"] = article.title
        
    return news

[nltk_data] Downloading package punkt to /home/qblocks/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [9]:
import re

def clean_generated_text(text):
    # Remove extra whitespace and newline characters
    text = ' '.join(text.split())

    # Remove repeated sentences
    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text)
    
    unique_sentences = []
    for sentence in sentences:
        if sentence not in unique_sentences:
            unique_sentences.append(sentence)
            
    cleaned_text = ' '.join(unique_sentences)
    return cleaned_text

In [22]:
class GetTokens:
    def __init__(self, tokenizer):
        # device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # tokenizer
        self.tokenizer = tokenizer
        
        # summary meta data
        summary_word_count = 60

        self.summary_start_prompt = f'Summarize this news article in {summary_word_count} words.\n\n'
        self.summary_end_prompt = '\n\nSummary: '
        
        # title meta data
        title_word_count = 5
        
        self.title_start_prompt = f'Give a title to the given news article in not more than {title_word_count} words.\n\n'
        self.title_mid_prompt = '\n\nSummary: '
        self.title_end_prompt = '\n\nTitle: '

    def get_title_prompt(self, news):
        return self.title_start_prompt + news["full_text"] + self.title_mid_prompt + news["summary"] + self.title_end_prompt
    
    def get_summary_prompt(self, news):
        return self.summary_start_prompt + news["full_text"] + self.summary_end_prompt
    
    def tokenize(self, news: list, mode="summary"):
        prompt_func = getattr(self, f'get_{mode}_prompt')
        prompts = [prompt_func(news=new) for new in news]
        
        return self.tokenizer(prompts, 
                              return_tensors="pt", 
                              max_length=tokenizer.model_max_length,
                              truncation=True,
                              padding=True).input_ids.to(self.device)
    
get_tokens = GetTokens(tokenizer)

In [23]:
news = [{"link": "https://www.ndtv.com/india-news/if-we-name-alliance-bharat-will-they-call-country-bjp-arvind-kejriwal-4361334"}]#,
        # {"link": "https://www.bbc.com/news/world-europe-66712477"},
        # {"link": "https://www.the-independent.com/news/world/americas/us-politics/trump-mugshot-indictment-latest-news-b2404959.html"}]

news = [get_news(new) for new in news]

In [24]:
def infer(model, generation_config, data, mode):
    """
    for summary: data = [{"full_text": "..."}]
    for title: data = [{"full_text": "...", "summary": "..."}]
    """
    input_ids = get_tokens.tokenize(news=data, mode=mode)

    # peft flan T5
    peft_model_outputs = model.generate(input_ids=input_ids, 
                                        generation_config=generation_config)


    for i in range(len(peft_model_outputs)):
        data[i][mode] = clean_generated_text(tokenizer.decode(peft_model_outputs[i], skip_special_tokens=True))
        
    return data

In [31]:
# summary generation config
summary_generation_config = GenerationConfig(max_new_tokens=200, 
                                     num_beams=8,
                                     do_sample=False,
                                     temperature=1.5,
                                     top_k=30,
                                     top_p=0.8)

news = infer(model=peft_model, 
             generation_config=summary_generation_config, 
             mode="summary", 
             data=news)

In [32]:
# title generation config
title_generation_config = GenerationConfig(max_new_tokens=50, 
                                         num_beams=8,
                                         do_sample=False,
                                         temperature=1.5,
                                         top_k=30,
                                         top_p=0.8)

news = infer(model=title_model, 
             generation_config=title_generation_config, 
             mode="title", 
             data=news)

In [33]:
news

[{'link': 'https://www.ndtv.com/india-news/if-we-name-alliance-bharat-will-they-call-country-bjp-arvind-kejriwal-4361334',
  'full_text': 'The Congress has accused the BJP government of "distorting history and dividing India". The use of \'President of Bharat\' in place of the traditional \'President of India\' in an official invite to foreign leaders attending the G20 summit has sparked a flurry of political reactions. While opposition leaders have slammed the move and linked it to their 28-party alliance naming itself INDIA, the BJP has questioned why some parties "object to every issue related to the honour and pride of the country". One of the sharpest reactions to the wording used in the invite came from AAP chief Arvind Kejriwal who asked whether the ruling party would change the country\'s name to \'BJP\' if the opposition alliance decided to call itself \'Bharat\'. Addressing a press conference on Tuesday, the Delhi chief minister said in Hindi, "I have no official information 