In [None]:
import logging
import pandas as pd
from pprint import pprint
from simpletransformers.seq2seq import Seq2SeqModel, Seq2SeqArgs

In [None]:
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.ERROR)

In [None]:
# Load the train and eval datasets

# set the needed dataset
dataset = "All"
columns = ["input_text", "target_text"]

path_train = "4_sscorpus_final_datasets/multi_label_false/"+dataset+"/train/train.csv"
df_train = pd.read_csv(path_train, sep=',', header=None, names=columns)
#df_train.head(5)
print(df_train.shape)

path_eval = "4_sscorpus_final_datasets/multi_label_false/"+dataset+"/eval/eval.csv"
df_eval = pd.read_csv(path_eval, sep=',', header=None, names=columns)
#df_eval.head(5)
print(df_eval.shape)

In [None]:
# The model arguments

model_args = Seq2SeqArgs()
model_args.num_train_epochs = 4
model_args.train_batch_size = 16
model_args.max_length = 256
model_args.max_seq_length = 256
model_args.do_sample = True
model_args.early_stopping = True
model_args.fp16 = False
model_args.do_lower_case = True
model_args.learning_rate = 5e-5
model_args.num_beams = None
model_args.num_return_sequences = 1
model_args.overwrite_output_dir = True
model_args.reprocess_input_data = True
model_args.top_k = 50
model_args.top_p = 0.95
model_args.dataloader_num_workers = 1
model_args.process_count = 1
model_args.use_multiprocessing = False
model_args.wandb_project = "Finetuning "+dataset+" dataset of SSCORPUS with BART model"
#model_args.output_dir = "outputs/multi_label_false/"+dataset+"/"
model_args.output_dir = "outputs/multi_label_false/tmp/"

In [None]:
# Initialize the model
model = Seq2SeqModel(
    encoder_decoder_type="bart",
    encoder_decoder_name="facebook/bart-large-cnn",
    use_cuda=True,
    args=model_args
)

In [None]:
# Train the model
model.train_model(df_train, eval_data=df_eval)

In [None]:
# Evaluate the model
model.eval_model(df_eval)