In [87]:
from datasets import load_dataset

dataset = load_dataset("ccdv/arxiv-summarization")
dataset

No config specified, defaulting to: arxiv-summarization/section
Found cached dataset arxiv-summarization (C:/Users/Justin Du/.cache/huggingface/datasets/ccdv___arxiv-summarization/section/1.0.0/fa2c9abf4312afb8660ef8e041d576b8e3943ea96ae771bd3cd091b5798e7cc3)
100%|██████████| 3/3 [00:00<00:00,  4.85it/s]


DatasetDict({
    train: Dataset({
        features: ['article', 'abstract'],
        num_rows: 203037
    })
    validation: Dataset({
        features: ['article', 'abstract'],
        num_rows: 6436
    })
    test: Dataset({
        features: ['article', 'abstract'],
        num_rows: 6440
    })
})

In [88]:
from datasets.dataset_dict import DatasetDict

data_train = dataset['train']
data_test = dataset['test']
data_val = dataset['validation']

In [89]:
data_test

Dataset({
    features: ['article', 'abstract'],
    num_rows: 6440
})

In [90]:
from transformers import AutoTokenizer

checkpoint = 'facebook/bart-large-cnn'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [91]:
max_length = 1024

def tokenize_function(data):

    model_inputs = tokenizer(
        data["article"],
        truncation=True,
        max_length=max_length,
    )


    labels = tokenizer(
        data['abstract'],
        truncation=True,
        max_length=max_length
    )

    model_inputs["decoder_input_ids"] = labels["input_ids"]
    model_inputs["decoder_attention_mask"] = labels["attention_mask"]
    return model_inputs

In [92]:
tok = data_test.map(tokenize_function, batched=True)

Loading cached processed dataset at C:/Users/Justin Du/.cache/huggingface/datasets/ccdv___arxiv-summarization/section/1.0.0/fa2c9abf4312afb8660ef8e041d576b8e3943ea96ae771bd3cd091b5798e7cc3\cache-aa89378e5dcd8e81.arrow


In [93]:
print(tok)

Dataset({
    features: ['article', 'abstract', 'input_ids', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask'],
    num_rows: 6440
})


In [94]:
testx = tok['input_ids'][0]

In [95]:
print(type(testx))

<class 'list'>


In [None]:
print(tokenizer.decode(testx))

In [98]:
import torch

In [99]:
asdf = [list(testx)]
asdf = torch.as_tensor(asdf)

In [100]:
print(asdf)

tensor([[   0, 1990,   59,  ...,   32, 4756,    2]])


In [102]:
outputs = model.generate(asdf, max_length=150, length_penalty=2.0, num_beams=4, early_stopping=True)

In [104]:
print(tokenizer.decode(outputs[0]))

</s><s>The problem of properties of short - term changes of solar activity has been considered extensively. Several periodicities were detected, but the periodicities about 155 days and from the interval of @xmath3 $ ] days are mentioned most often. The power of this periodicity started growing at cycle 19, decreased in cycles 20 and 21 and disappered after cycle 21.</s>


In [None]:
import evaluate
rouge_score = evaluate.load('rouge')

In [97]:
from transformers import AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [None]:
from transformers import Seq2SeqTrainingArguments

batch_size = 8
num_train_epochs = 8
logging_steps = len(data_val)
training_args = Seq2SeqTrainingArguments(
    output_dir='test-trainer',
    evaluation_strategy="epoch",
    learning_rate=5.6e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=num_train_epochs,
    predict_with_generate=True,
    logging_steps=logging_steps
)

In [None]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    # Decode generated summaries into text
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    # Decode reference summaries into text
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # ROUGE expects a newline after each sentence
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    # Compute ROUGE scores
    result = rouge_score.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract the median scores
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    model,
    training_args,
    train_dataset=tok['test'],
    eval_dataset=tok['validation'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
# samples = tok['test'][:2000]
# samples = {k: v for k, v in samples.items() if k not in ['abstract', 'article']}
# batch = data_collator(samples)

In [None]:
# chunk = 500
# tok_abs = []

# for i in (data_test[pos:pos + chunk] for pos in range(0, len(data_test), chunk)):
#     tok_abs.append(tokenizer(i['abstract'], truncation=True, padding=True, return_tensors='tf'))


In [None]:
# tok_art = []
# for i in (data_test[pos:pos + chunk] for pos in range(0, len(data_test), chunk)):
#     tok_art.append(tokenizer(i['article'], truncation=True, ))

In [None]:
# import torch
# tok = torch.cat(tok_abs, dim=1)


In [None]:
# test_tok_abs = tokenizer(data_train['article'], truncation=True, padding='max_length')