In [None]:
import sys
sys.path.append("../")
import nltk
nltk.download("punkt")
import os
import matplotlib.pyplot as plt
import numpy as np
import datasets
import torch
import evaluate
import time
import urllib.request as libreq
from training_helpers import *
from generate import Generator
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel, \
Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, PegasusModel


In [None]:
WORKING_DIR = os.getcwd()
PATH_TO_MODELS = "../res/models/"
PATH_TO_TOKENIZERS = "../res/tokenizers"
MODEL_NAMES = ["bart-base", "pegasus-xsum", "pegasus-x-base-arxiv", "pegasus-x-large", "longformer-base-4096", "led-large-16384-arxiv", "pegasus-x-base"]
MODEL_PATHS = {m: os.path.join(PATH_TO_MODELS, m) for m in MODEL_NAMES}
TOKENIZER_MAPS = {"bart-base": "bart-base", "pegasus-xsum": "pegasus",  "pegasus-x-base-arxiv": "pegasus", "pegasus-x-base": "pegasus", "pegasus-x-large": "pegasus", "longformer-base-4096": "longformer-base-4096", "led-large-16384-arxiv": "longformer-base-4096"}
TOKENIZER_PATHS = {m: os.path.join(PATH_TO_TOKENIZERS, TOKENIZER_MAPS[m]) for m in MODEL_NAMES}
ARXIV_DATA_PATH = "../res/datasets/arxiv_data/arxiv_data_txt/1001"
ARXIV_METADATA_PATH = "../res/datasets/arxiv_data/arxiv-metadata-oai-snapshot.json"
WIKI_PATH = "../res/datasets/wiki_aligned/aligned_wiki_ds"


In [None]:

    # process_arxiv_src("all_articles.txt", ARXIV_DATA_PATH, article_ids)

In [None]:
#preparing datasets
metadata = datasets.load_dataset("json", data_files=ARXIV_METADATA_PATH, split="train") #load abstracts
text_and_id = get_text_and_id(ARXIV_DATA_PATH) #raw text ds w/ article ids

labels = ["" for _ in range(len(text_and_id))]
abstracts = metadata.select_columns(["abstract", "id"])
abstracts = {i: abstract for i, abstract in zip(abstracts["id"], abstracts["abstract"])}
for idx, article_id in enumerate(text_and_id["id"]):
    labels[idx] = abstracts[article_id]

text_and_id = text_and_id.add_column("labels", labels)

In [None]:
text_and_id

In [None]:
# run training on all models
# TODO: clean up all of the stuff with the splits
MODEL_NAMES = [""]

for name in MODEL_NAMES:
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATHS[name], local_files_only=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATHS[name], local_files_only=True).to(device)
    # pipe = transformers.pipeline("summarization", model=model, tokenizer=tokenizer)

    
    #data
    ds = text_to_inputs(text_and_id, tokenizer, "text", "labels").remove_columns(["id", "text"])
    print(ds.features)
    splits = ds.train_test_split(test_size=0.1)
    train_ds = splits["train"]
    eval_ds = splits["test"]
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")
    
    #prep model
    freeze_base_model_weights(model)
    model.enable_input_require_grads()
    model.train()
    
    #trainer
    train_arg_dict = {"output_dir": "../out/" + str(name) + "/", "save_steps": 100, "evaluation_strategy": "steps",\
                      "eval_steps": 100, "logging_steps": 50,  "max_steps": 600, "per_device_train_batch_size": 1}
    train_args = Seq2SeqTrainingArguments(**train_arg_dict)
    trainer = ArxivWikiTrainer(model, args=train_args, train_dataset=train_ds, eval_dataset = eval_ds, tokenizer=tokenizer, data_collator=data_collator)
    trainer.train()
    
    del tokenizer, model, ds, splits, train_ds, eval_ds, data_collator, trainer #something causing data leak, not sure what (although this line doenst even help...)
    torch.cuda.empty_cache()
    !nvidia-smi

In [None]:
#train#run training on all models
#TODO: clean up all of the stuff with the splits

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATHS["pegasus-x-base"], local_files_only=True).to("cuda:0")

tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATHS["pegasus-x-base"], local_files_only=True)
#data
wiki_ds = datasets.load_from_disk(WIKI_PATH)
ds = text_to_inputs(wiki_ds, tokenizer, "normal_articles", "simple_articles")


In [None]:
#for wiki, basically the same as arxiv
splits = ds.train_test_split(test_size=0.1)
train_ds = splits["train"]
eval_ds = splits["test"].select([i for i in range(20)])
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")

#prep model
freeze_base_model_weights(model)
model.enable_input_require_grads()
model.train()

#trainer
train_arg_dict = {"output_dir": "../out/wiki-pegasus-x-base"  + "/", "save_steps": 100, "evaluation_strategy": "steps",\
                  "eval_steps": 25, "logging_steps": 50,  "max_steps": 500, "per_device_train_batch_size": 1}
train_args = Seq2SeqTrainingArguments(**train_arg_dict)
trainer = ArxivWikiTrainer(model, args=train_args, train_dataset=train_ds, eval_dataset = eval_ds, tokenizer=tokenizer, data_collator=data_collator)
trainer.train()

torch.cuda.empty_cache()
!nvidia-smi