In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
project_dir = '/content/drive/My Drive/cnn-dailymail-summarizer'
os.chdir(project_dir)

!pip install -r requirements.txt

In [None]:

import pandas as pd
from cnn_dailymail_news_text_summarizer.dataset import load_datasets
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns
import nltk
from collections import Counter
from collections import defaultdict
import re
from sklearn.feature_extraction.text import CountVectorizer
from transformers import BartForConditionalGeneration, BartTokenizer, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from datasets import Dataset
import torch
import evaluate

In [None]:
if torch.cuda.is_available():
    device_name = torch.device("cuda")
else:
    device_name = torch.device('cpu')
print("Using {}.".format(device_name))

## Loading Data

In [None]:
train_path = os.path.join(project_dir, 'data/raw/cnn_dailymail/train.csv')
test_path = os.path.join(project_dir, 'data/raw/cnn_dailymail/test.csv')
val_path = os.path.join(project_dir, 'data/raw/cnn_dailymail/validation.csv')


In [None]:
train_data, test_data, val_data = load_datasets(train_path, test_path, val_path)

In [None]:
train_data.head()

## Exploratory Data Analysis

In [None]:
sample = train_data.sample()

list(sample['article'])

In [None]:
list(sample['highlights'])

In [None]:
len(train_data)

### Counts and Lengths

In [None]:
eda_data = train_data.sample(frac=0.1)

In [None]:
plt.hist(eda_data['article'].str.len(), bins=50, edgecolor='white')
plt.xlabel("Number of Characters in Article")
plt.ylabel("Number of Articles")
plt.title("Distribution of Characters per Article")
plt.show()

In [None]:
plt.hist(eda_data['article'].str.split().map(lambda x: len(x)), bins=50, edgecolor='white')
plt.xlabel("Number of Words in Article")
plt.ylabel("Number of Articles")
plt.title("Distribution of Words per Article")
plt.show()

In [None]:
nltk.download('punkt')

In [None]:
plt.hist(eda_data['article'].apply(lambda x: len(nltk.sent_tokenize(x))), bins=50, edgecolor='white')
plt.xlabel("Number of Sentences in Article")
plt.ylabel("Number of Articles")
plt.title("Distribution of Sentences per Article")
plt.show()

In [None]:
eda_data['mean_word_length'] = eda_data['article'].map(lambda x : np.mean([len(word) for word in x.split()]))
eda_data.head(10)

In [None]:
plt.figure(figsize=(10, 6))
sns.boxplot(data=eda_data, y='mean_word_length')
plt.ylabel("Mean Word Length")
plt.title("Boxplot of Mean Word Length per Article")
plt.show()

In [None]:
plt.hist(eda_data['highlights'].str.len(), bins=50, edgecolor='white')
plt.xlabel("Number of Characters in Article Summary")
plt.ylabel("Number of Articles")
plt.title("Distribution of Characters per Article Summary")
plt.show()

In [None]:
plt.hist(eda_data['highlights'].str.split().map(lambda x: len(x)), bins=50, edgecolor='white')
plt.xlabel("Number of Words in Article Summary")
plt.ylabel("Number of Articles")
plt.title("Distribution of Words per Article Summary")
plt.show()

In [None]:
plt.hist(eda_data['highlights'].apply(lambda x: len(nltk.sent_tokenize(x))), bins=20, edgecolor='white')
plt.xlabel("Number of Sentences in Article Summary")
plt.ylabel("Number of Articles")
plt.title("Distribution of Sentences per Article Summary")
plt.show()

### Term frequency

In [None]:
eda_data.drop('mean_word_length', axis=1, inplace=True)

In [None]:
nltk.download('stopwords')

In [None]:
stop = set(nltk.corpus.stopwords.words('english'))

In [None]:
def remove_punctuation(text):
    return re.sub(r'[^\w\s]', '', text)

In [None]:
corpus = []
words = eda_data['article'].str.lower().apply(remove_punctuation).str.split()
words = words.values.tolist()
corpus = [word for i in words for word in i]


In [None]:
dic = defaultdict(int)
for word in corpus:
    if word in stop:
        dic[word]+=1

In [None]:
top=sorted(dic.items(), key=lambda x:x[1],reverse=True)
x,y=zip(*top)
plt.figure(figsize=(14, 7))
plt.bar(x[:40], y[:40], color='blue')
plt.xlabel('Words')
plt.ylabel('Frequency')
plt.title('Top 40 Most Frequent Stopwords')
plt.xticks(rotation=45)
plt.show()

In [None]:
counter=Counter(corpus)
most=counter.most_common()

x, y= [], []
for word,count in most:
    if (word not in stop):
        x.append(word)
        y.append(count)
plt.figure(figsize=(14, 7))
plt.bar(x[:40], y[:40], color='blue')
plt.xlabel('Words')
plt.ylabel('Frequency')
plt.title('Top 40 Most Frequent Non-Stopwords')
plt.xticks(rotation=45)
plt.show()

In [None]:
eda_data = eda_data.sample(frac=0.1)
eda_data['article'] = eda_data['article'].str.lower().apply(remove_punctuation)

### N-gram frequency

In [None]:
def get_top_ngram(corpus, stop_words, n=2):
  cv = CountVectorizer(ngram_range=(n, n), stop_words=stop_words)
  ngrams = cv.fit_transform(corpus)
  count_values = ngrams.toarray().sum(axis=0)
  ngram_freq = pd.DataFrame(sorted([(count_values[i], k) for k, i in cv.vocabulary_.items()], reverse = True))
  ngram_freq.columns = ['frequency', 'ngram']
  sns.barplot(x=ngram_freq['frequency'][:20], y=ngram_freq['ngram'][:20])
  if n == 2:
    plt.title('Top 20 Most Frequent Bigrams')
  elif n == 3:
    plt.title('Top 20 Most Frequent Trigrams')
  else:
    plt.title('Top 20 Most Frequent Ngrams')
  plt.show()

In [None]:
get_top_ngram(eda_data['article'], list(stop), 2)

In [None]:
get_top_ngram(eda_data['article'], list(stop), 3)

## Data Preprocessing

In [None]:
checkpoint = "facebook/bart-base"


In [None]:
tokenizer = BartTokenizer.from_pretrained(checkpoint)

In [None]:
tokenizer(train_data['article'][0])


In [None]:
prefix = "summarize: "
def tokenization(examples):
    inputs = [prefix + doc for doc in examples['article']]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)

    labels = tokenizer(text_target=examples["highlights"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
train_data_decr = train_data.sample(frac=0.05)

In [None]:
train_data_decr = Dataset.from_pandas(train_data_decr)

In [None]:
train_data_decr = train_data_decr.map(tokenization, batched=True)

In [None]:
print(train_data_decr)

In [None]:
val_data_decr = val_data.sample(frac=0.05)

In [None]:
val_data_decr = Dataset.from_pandas(val_data_decr)

In [None]:
val_data_decr = val_data_decr.map(tokenization, batched=True)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint, padding=True)

In [None]:
rouge = evaluate.load("rouge")

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
model.gradient_checkpointing_enable()

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir='./model',
    num_train_epochs=3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    gradient_accumulation_steps=4,
    warmup_steps=100,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    evaluation_strategy="epoch",
    save_total_limit=2,
    save_strategy="epoch",
    load_best_model_at_end=True,
    predict_with_generate=True,
    fp16=True,
    remove_unused_columns=True
)


In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data_decr,
    eval_dataset=val_data_decr,
    data_collator=data_collator,
    tokenizer=tokenizer
)

In [None]:
trainer.train()