In [1]:
import os
import pandas as pd

from trl import SFTTrainer
from datasets import Dataset
from sklearn.model_selection import train_test_split
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,TrainingArguments

In [2]:
import yaml

with open("config.yaml", 'r') as stream:
    try:
        config = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

In [3]:
from utils import import_data_from_json

airbus_datapath = os.path.join("./data/", "airbus_helicopters_train_set.json")
train_dataset, val_dataset, test_dataset = import_data_from_json(airbus_datapath)

model_name = "google/flan-t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

#setting padding instructions for tokenizer
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [4]:
test_data

Dataset({
    features: ['original_text', 'reference_summary', '__index_level_0__'],
    num_rows: 21
})

In [5]:
test_dataset

Dataset({
    features: ['original_text', 'reference_summary'],
    num_rows: 21
})

In [22]:
# Create the trainer
from utils import prompt_instruction_format

trainingArgs = TrainingArguments(**config['parameters_ft'])

peft_config = LoraConfig(**config['parameters_LoRA'])

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset = val_dataset,
    peft_config=peft_config,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=prompt_instruction_format,
    args=trainingArgs,
    max_seq_length=512
)

trainer.train()

{'learning_rate': 0.0001, 'num_train_epochs': 10, 'per_device_train_batch_size': 4, 'save_strategy': 'steps', 'save_steps': 0.1, 'output_dir': 'output'}


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


  0%|          | 0/240 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [7]:
airbus_dataset

DatasetDict({
    train: Dataset({
        features: ['original_text', 'reference_summary'],
        num_rows: 371
    })
    test: Dataset({
        features: ['original_text', 'reference_summary'],
        num_rows: 42
    })
})