## Model 2: Fine Tuning Pre-trained BART

In [None]:
# Training minimum config: 1x RTX A6000

In [None]:
from datasets import load_dataset
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import DataCollatorForSeq2Seq
from transformers import BartForConditionalGeneration, BartTokenizer

In [None]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

In [None]:
base_dir = ""
# If colab
# base_dir = '/content/'
data_path = base_dir + 'data/time_seed_fhir.csv'

In [None]:
fhir_dataset = load_dataset('csv', data_files=data_path, split='train')
fhir_dataset = fhir_dataset.train_test_split(test_size=0.2)
print(f"Training Samples: {fhir_dataset['train'].num_rows}, Test Samples: {fhir_dataset['test'].num_rows}")

In [None]:
prefix = "Generate a FHIR data row using the seed: "
# FHIR,seed
def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["seed"]]
    model_inputs = tokenizer(inputs, truncation=True, padding="longest")
    # Setup the tokenizer for targets
    labels = tokenizer(text_target=examples["FHIR"], truncation=True, padding="longest")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_dataset = fhir_dataset.map(preprocess_function, batched=True).remove_columns(['seed', 'FHIR'])

In [None]:
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large', device_map="auto")

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

training_args = Seq2SeqTrainingArguments(
    output_dir                  = "./results",
    evaluation_strategy         = "epoch",
    learning_rate               = 2e-5,
    per_device_train_batch_size = 8,
    per_device_eval_batch_size  = 32,
    weight_decay                = 0.01,
    save_total_limit            = 3,
    num_train_epochs            = 1,
    fp16                        = True,
)

trainer = Seq2SeqTrainer(
    model         = model,
    args          = training_args,
    train_dataset = tokenized_dataset['train'],
    eval_dataset  = tokenized_dataset['test'],
    tokenizer     = tokenizer,
    data_collator = data_collator
)

In [None]:
trainer.train()