In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
project_dir = '/content/drive/My Drive/cnn-dailymail-summarizer'
os.chdir(project_dir)

!pip install -r requirements.txt

In [None]:

import pandas as pd
from cnn_dailymail_news_text_summarizer.dataset import load_datasets, remove_punctuation, preprocess_text, tokenize, save_tokenized_datasets, load_tokenized_datasets
from cnn_dailymail_news_text_summarizer.plots import plot_num_characters, plot_num_words, plot_num_sentences, plot_mean_word_length, create_corpus, plot_most_frequent_stopwords, plot_most_frequent_words, get_top_ngram
from cnn_dailymail_news_text_summarizer.training import train_model
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns
import nltk
from collections import Counter
from collections import defaultdict
import re
from sklearn.feature_extraction.text import CountVectorizer
from transformers import BartForConditionalGeneration, BartTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import Dataset, concatenate_datasets
import torch
import evaluate

In [None]:
if torch.cuda.is_available():
    device_name = torch.device("cuda")
else:
    device_name = torch.device('cpu')
print("Using {}.".format(device_name))

## Loading Data

In [None]:

train_path = os.path.join(project_dir, 'data/raw/cnn_dailymail/train.csv')
test_path = os.path.join(project_dir, 'data/raw/cnn_dailymail/test.csv')
val_path = os.path.join(project_dir, 'data/raw/cnn_dailymail/validation.csv')


In [None]:
train_data, test_data, val_data = load_datasets(train_path, test_path, val_path)

In [None]:
train_data.head()

## Exploratory Data Analysis



In [None]:
sample = train_data.sample()
list(sample['article'])

In [None]:
list(sample['highlights'])

In [None]:
len(train_data)

### Counts and Lengths

In [None]:
eda_data = train_data.sample(frac=0.1)

In [None]:
plot_num_characters(eda_data, 'article')

In [None]:
plot_num_words(eda_data, 'article')

In [None]:
nltk.download('punkt')

In [None]:
plot_num_sentences(eda_data, 'article')

In [None]:
plot_mean_word_length(eda_data, 'article')

In [None]:
plot_num_characters(eda_data, 'highlights')

In [None]:
plot_num_words(eda_data, 'highlights')

In [None]:
plot_num_sentences(eda_data, 'highlights')

### Term frequency

In [None]:
eda_data.drop('mean_word_length', axis=1, inplace=True)

In [None]:
nltk.download('stopwords')

In [None]:
stop = set(nltk.corpus.stopwords.words('english'))

In [None]:
plot_most_frequent_stopwords(eda_data, stop)

In [None]:
plot_most_frequent_words(eda_data, stop)

### N-gram frequency

In [None]:
eda_data = eda_data.sample(frac=0.1)

In [None]:
get_top_ngram(eda_data, list(stop), 2)

In [None]:
get_top_ngram(eda_data, list(stop), 3)

In [None]:
get_top_ngram(eda_data, list(stop), 5)

## Data Preprocessing

In [None]:
checkpoint = "facebook/bart-base"

In [None]:
tokenizer = BartTokenizer.from_pretrained(checkpoint)

In [None]:
tokenizer(train_data['article'][0])

In [None]:
train_data = train_data.sample(frac=0.2, random_state=42)
val_data = val_data.sample(frac=0.2, random_state=42)
test_data = test_data.sample(frac=0.2, random_state=42)

In [None]:
train_data = tokenize(train_data, tokenizer)

In [None]:
#print(train_data)

In [None]:
val_data = tokenize(val_data, tokenizer)

In [None]:
test_data = tokenize(test_data, tokenizer)

In [None]:
processed_path = os.path.join(project_dir, 'data/processed/')


In [None]:
save_tokenized_datasets(train_data, test_data, val_data, processed_path)
#train_data, test_data, val_data = load_tokenized_datasets(processed_path)

## Fine-tuning

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint, padding=True)

In [None]:
rouge = evaluate.load("rouge")

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir='../model',
    num_train_epochs=2,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    gradient_accumulation_steps=4,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir='../logs',
    logging_steps=10,
    evaluation_strategy="epoch",
    save_total_limit=2,
    save_strategy="epoch",
    load_best_model_at_end=True,
    predict_with_generate=True,
    fp16=True,
    remove_unused_columns=True,
    save_safetensors=False
)


In [None]:
train_model(train_data, val_data, tokenizer, data_collator, training_args, device_name, rouge, checkpoint=False)

In [None]:
torch.cuda.empty_cache()

## Evaluation

In [None]:
model = BartForConditionalGeneration.from_pretrained('./model/checkpoint-44')
example_text = "The Philadelphia 76ers look forward to watching the team’s All-Star center Joel Embiid compete at the Paris Olympics over the next couple of weeks. Sunday morning would mark the first time Embiid will officially compete in a meaningful game on the international stage since his basketball career started. Leading up to the matchup, Embiid’s playing status went into question, however. On Saturday, Team USA prepared for its battle against Serbia with a practice session. Multiple players were noticeably absent and Embiid was one of them. According to ESPN’s Brian Windhorst, Embiid was dealing with an illness. The big man joined his temporary teammate by missing practice as Los Angeles Lakers big man Anthony Davis was a non-participant as well. According to the report, Davis was dealing with an illness for several days leading up to Saturday’s session. Following Team USA’s practice session, Golden State Warriors head coach Steve Kerr addressed the media and acknowledged the absences. He wasn’t concerned and believed that Team USA would have its full roster ready to go for the debut against Serbia. I'm confident we'll have everybody ready, Kerr told reporters. As USA boarded the bus to head to the arena for their first game, Embiid was spotted with the team, along with Davis. All signs point to the two NBA veterans suiting up and competing."
inputs = tokenizer(example_text, max_length=1024, return_tensors='pt', truncation=True)
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=128, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print("Summary:", summary)
