In [None]:
!pip install -q transformers einops accelerate langchain bitsandbytes xmltodict

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

In [None]:
import sys

from google.colab import drive

drive.mount('/content/drive', force_remount=True)

In [None]:
import os
import xmltodict
import sys

from tqdm import tqdm


def replace_marks(text: str) -> str:
    text = text.replace('&amp;', '&')
    text = text.replace('&quot;', '"')
    text = text.replace('&apost;', "'")
    return text

def read_cnn_corpus(corpus_dir: str) -> dict:
    assert os.path.exists(corpus_dir)
    documents_names = os.listdir(corpus_dir)
    corpus = []
    with tqdm(total=len(documents_names), file=sys.stdout, colour='blue',
              desc='Reading Corpus') as pbar:
      for doc_name in documents_names:
          document_file = os.path.join(corpus_dir, doc_name)
          with open(document_file, encoding='utf-8') as file:
              xml_doc = xmltodict.parse(file.read())
          xml_doc = xml_doc['document']
          title = xml_doc['title']
          title = replace_marks(title)
          highlights_element = xml_doc['summaries']['highlights']
          highlights = ''
          for sentence_highlight in highlights_element['sentence']:
              highlights += sentence_highlight['#text'] + '\n'
          highlights = highlights.strip()
          highlights = replace_marks(highlights)
          gold_standard_element = xml_doc['summaries']['gold_standard']
          gold_standard = ''
          if isinstance(gold_standard_element['sentence'], dict):
              gold_standard += gold_standard_element['sentence']['#text']
          else:
              for sentence_gold_standard in gold_standard_element['sentence']:
                  if isinstance(sentence_gold_standard, dict):
                      gold_standard += sentence_gold_standard['#text'] + '\n'
          gold_standard = gold_standard.strip()
          gold_standard = replace_marks(gold_standard)
          article_element = xml_doc['article']
          text = ''
          for paragraph_element in article_element['paragraph']:
              sentences_element = paragraph_element['sentences']['sentence']
              if isinstance(sentences_element, list):
                  for sentence_element in sentences_element:
                      text += sentence_element['content'] + ' '
              else:
                  text += sentences_element['content'] + ' '
          text = text.strip()
          text = replace_marks(text)
          doc_name = doc_name.replace('.xml', '').replace("'", '').lower()
          name = doc_name.replace(';', '').replace('&', '').replace('%', '').strip()
          document = {
              'name': name,
              'title': title,
              'highlights': highlights,
              'gold_standard': gold_standard,
              'text': text
          }
          corpus.append(document)
          pbar.update(1)
    return corpus


def save_summary(document_name: str, summary: str, summaries_dir: str,
                 summary_name: str):
    document_dir = os.path.join(summaries_dir, document_name.lower())
    os.makedirs(document_dir, exist_ok=True)
    summary_path = os.path.join(document_dir, f'{summary_name}.txt')
    with open(summary_path, 'w', encoding='utf-8') as file:
        file.write(summary)

In [None]:
corpus_dir = f'/content/drive/My Drive/Experimentos/abs_summ_benchmark/corpora/cnn'

summaries_dir = f'/content/drive/My Drive/Experimentos/abs_summ_benchmark/summaries/cnn'

os.makedirs(summaries_dir, exist_ok=True)

corpus_cnn = read_cnn_corpus(corpus_dir)

In [None]:
print(f'Total of Documents: {len(corpus_cnn)}')

print(corpus_cnn[0]['text'])

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import transformers
import torch

from langchain import HuggingFacePipeline
from transformers import AutoTokenizer

model_name = 'gemma_7b'

model_path = 'google/gemma-7b-it'

print(f'Model Name: {model_name} -- {model_path}')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_path)

pipeline = transformers.pipeline(
    'text-generation',
    model=model_path,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map='auto',
    max_length=2048,
    eos_token_id=tokenizer.eos_token_id)

In [None]:
llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': 0.3})

In [None]:
template = """
ARTICLE: ```{text}```. Summarize the article in ```{num_sentences}``` SENTENCES.
SUMMARY:
"""

print(f'Template: {template}')

In [None]:
from langchain_core.prompts import PromptTemplate
from langchain.chains import LLMChain

prompt = PromptTemplate(template=template, input_variables=['text', 'num_sentences'])

print(f'Prompt Template: {prompt}')

llm_chain = LLMChain(prompt=prompt, llm=llm)

In [None]:
import warnings

warnings.simplefilter('ignore')

with tqdm(total=len(corpus_cnn), file=sys.stdout,
          colour='green', desc='Summarizing documents') as pbar:

  for document in corpus_cnn:

    num_sentences = len(document['highlights'].split('\n'))

    input_text = document['text']

    document_dir = os.path.join(summaries_dir, document['name'].lower())
    summary_path = os.path.join(document_dir, f'{model_name}.txt')

    if os.path.exists(summary_path):
      pbar.update(1)
      continue

    summary = llm_chain.run(
        {
            'num_sentences': num_sentences,
            'text': input_text
            }
        )

    summary = summary.split('SUMMARY:')[1].strip()

    save_summary(document['name'], summary, summaries_dir, model_name)

    pbar.update(1)

    break