# V 0.3

This is next "version" of [simple finetuning](https://www.kaggle.com/code/yannchikk/bart-large-cnn-dialoguesum-booksum-full-finetuning). There i try train the bart-large-cnn by LoRA finetuning method with PEFT lib. 

There i going to train LoRA with processed text and text without processing and make Model Soup from this two versions. 

In this Notebook: 

- ### Train original '[bart-large-cnn](https://huggingface.co/facebook/bart-large-cnn)'

- ## without text processing + preprocessed dataset 

- ### Using peft lib for add LoRA layers to model

- ### Finetune only 'target_modules'

- ### Using DataParalell more accelerate models

- ### Custom checkpointing

In [None]:
class Config:
    
    max_length = 1024
    target_max_length = 512

    epochs = 10
    
    batch_size = 8

#     model_preset_trained = "doublecringe123/bardt-large-cnn-dialoguesum-booksum"
    
    try: 
        model_preset = model_preset_trained 
    except: 
        model_preset = "facebook/bart-large-cnn"
    
    
    lora_params = {
        'target_modules':['out_proj', 'v_proj', 'q_proj', 'cf1', 'cf2'], 
        'r':8, 
        'lora_alpha': 16, 
    }
    
    save_frecuency = 2

    inp = 'input_content'
    target = 'target'

cfg = Config()

# At First, lets load datasets



In [None]:
! pip install -q --upgrade pip
! pip install -q transformers[torch]
! pip install -q -U transformers==4.38.2 datasets==2.18.0 evaluate rouge_score

try: 
    import wandb
    wandb.init(mode='disabled')
except: 
    ...

This is my code from [github repo](https://github.com/goin2crazy/multy-dataset/blob/main/main.py) 

In [None]:
! wget -O "mds.py" "https://raw.githubusercontent.com/goin2crazy/multy-dataset/main/main.py" 

# Prepare Dataset

In [None]:
from mds import NewDataset
dataset_params = {
    "knkarthick/dialogsum": ("dialogue", "summary"), 
    "doublecringe123/dialoguesum-npc-dialoguesum-stemmed-augmented": ('inp', 'target')
}

dataset = NewDataset(dataset_params, input_col_name = cfg.inp, target_col_name = cfg.target)

# Load model and tokenizer

In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained(cfg.model_preset)
model = AutoModelForSeq2SeqLM.from_pretrained(cfg.model_preset)

# Creating LoRA Model with PEFT

PEFT refers to a group of techniques that enable efficient adaptation of large language models (LLMs) to specific tasks or domains. It involves fine-tuning only a small subset of parameters in the LLM, rather than modifying the entire model. This approach offers several advantages:

I knew about PEFT and LoRA models buildin from [this notebook](https://www.kaggle.com/code/ajinkyabhandare2002/fine-tune-flan-t5-base-for-chat-with-peft-lora#Setup-the-PEFT/LoRA-model-for-Fine-Tuning)

In [None]:
! pip install -q peft

In [None]:
from peft import LoraConfig, get_peft_model, TaskType

lora_conf = LoraConfig(
    **cfg.lora_params, 
    lora_dropout = 0.05,
    bias = 'none', 
    task_type = TaskType.CAUSAL_LM,
    init_lora_weights = 'gaussian', 
)

In [None]:
lora_model = get_peft_model(model=model, peft_config=lora_conf)

lora_model.print_trainable_parameters()

In [None]:
from torch import nn
import torch 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lora_model.model = nn.DataParallel(lora_model.model)

In [None]:
def preprocess_function(examples):
    try: 
        inputs = [doc for doc in examples[cfg.inp]]
        model_inputs = tokenizer(inputs, max_length=cfg.max_length, truncation=True)

        labels = tokenizer(text_target=examples[cfg.target], max_length=cfg.target_max_length, truncation=True)

        model_inputs["labels"] = labels["input_ids"]
        return model_inputs
    except TypeError as e:
        print(e)
        print(examples[cfg.inp])

dataset = dataset.map(preprocess_function, batched = True)
tokenized_train, tokenized_val, tokenized_test = dataset.splits

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=cfg.model_preset)

# Define Metrics

In [None]:
import evaluate

rouge = evaluate.load("rouge")

In [None]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

# Define training arguments

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer

In [None]:
eps = cfg.epochs // cfg.save_frecuency

for i in range(eps): 
    i += 1
    
    print(f"{i}/{eps} Training Initiallization...")
    training_args = Seq2SeqTrainingArguments(
        output_dir="bardt-large-cnn-dialoguesum-booksum-lora",
        evaluation_strategy="epoch",
        save_strategy='no',
    #     save_safetensors = True,
    #     save_steps = 100, 
        learning_rate=2e-5,
        per_device_train_batch_size=cfg.batch_size,
        per_device_eval_batch_size=cfg.batch_size,
        weight_decay=0.01,
        num_train_epochs=cfg.save_frecuency,
        predict_with_generate=True,

        fp16=True,
    )

    trainer = Seq2SeqTrainer(
        model=lora_model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    print(f"{i}/{eps} start Training...")    
    trainer.train()
    
    print(f"{i}/{eps} Saving model...")    
    model.save_pretrained("bardt-large-cnn-dialoguesum-booksum-lora")
    model.push_to_hub("bardt-large-cnn-dialoguesum-booksum-lora", commit_message = f"Original+Augmented+Stemmed Dataset, {i * cfg.frecuency} epochs")